1use anyhow::{Context, Result, bail};
26use async_trait::async_trait;
27use std::sync::Arc;
28use std::sync::atomic::{AtomicU64, Ordering};
29use std::time::Duration;
30use tokio::sync::RwLock;
31
32use super::protocol::{JsonRpcRequest, JsonRpcResponse, RequestId};
33use super::transport::{McpTransport, notification_body};
34
35const SESSION_ID_HEADER: &str = "Mcp-Session-Id";
37const PROTOCOL_VERSION_HEADER: &str = "MCP-Protocol-Version";
39
40const DEFAULT_HTTP_TIMEOUT: Duration = Duration::from_mins(1);
46
47const SEND_DEADLINE: Duration = Duration::from_mins(1);
51
52const MAX_RESPONSE_BODY_BYTES: usize = 16 * 1024 * 1024;
55
56#[derive(Clone, Debug)]
59pub struct HttpReply {
60 pub content_type: String,
62 pub body: String,
64 pub session_id: Option<String>,
66}
67
68impl HttpReply {
69 #[must_use]
71 pub fn json(body: impl Into<String>) -> Self {
72 Self {
73 content_type: "application/json".to_string(),
74 body: body.into(),
75 session_id: None,
76 }
77 }
78
79 #[must_use]
81 pub fn event_stream(body: impl Into<String>) -> Self {
82 Self {
83 content_type: "text/event-stream".to_string(),
84 body: body.into(),
85 session_id: None,
86 }
87 }
88
89 #[must_use]
91 pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
92 self.session_id = Some(session_id.into());
93 self
94 }
95}
96
97#[derive(Clone, Debug)]
99pub struct HttpRequest {
100 pub body: String,
102 pub authorization: Option<String>,
104 pub session_id: Option<String>,
106 pub protocol_version: Option<String>,
108 pub extra_headers: Vec<(String, String)>,
110}
111
112#[async_trait]
117pub trait HttpPoster: Send + Sync {
118 async fn post(&self, request: HttpRequest) -> Result<HttpReply>;
125}
126
127#[derive(Clone, Debug, Default)]
129pub enum McpAuth {
130 #[default]
132 None,
133 Bearer(String),
135}
136
137impl McpAuth {
138 #[must_use]
140 fn header_value(&self) -> Option<String> {
141 match self {
142 Self::None => None,
143 Self::Bearer(token) => Some(format!("Bearer {token}")),
144 }
145 }
146}
147
148pub struct StreamableHttpTransport {
154 poster: Arc<dyn HttpPoster>,
155 auth: McpAuth,
156 extra_headers: Vec<(String, String)>,
157 next_id: AtomicU64,
158 session_id: RwLock<Option<String>>,
160 protocol_version: RwLock<Option<String>>,
162}
163
164impl StreamableHttpTransport {
165 pub fn new(endpoint: impl Into<String>, auth: McpAuth) -> Result<Arc<Self>> {
171 Ok(Arc::new(Self::builder(endpoint, auth)?))
172 }
173
174 #[must_use]
179 pub fn with_poster(poster: Arc<dyn HttpPoster>, auth: McpAuth) -> Arc<Self> {
180 Arc::new(Self::with_poster_owned(poster, auth))
181 }
182
183 pub fn builder(endpoint: impl Into<String>, auth: McpAuth) -> Result<Self> {
206 let poster = ReqwestPoster::new(endpoint)?;
207 Ok(Self::with_poster_owned(Arc::new(poster), auth))
208 }
209
210 #[must_use]
213 pub fn with_poster_owned(poster: Arc<dyn HttpPoster>, auth: McpAuth) -> Self {
214 Self {
215 poster,
216 auth,
217 extra_headers: Vec::new(),
218 next_id: AtomicU64::new(1),
219 session_id: RwLock::new(None),
220 protocol_version: RwLock::new(None),
221 }
222 }
223
224 #[must_use]
230 pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
231 self.extra_headers.push((name.into(), value.into()));
232 self
233 }
234
235 fn next_request_id(&self) -> u64 {
236 self.next_id.fetch_add(1, Ordering::SeqCst)
237 }
238
239 async fn build_http_request(&self, body: String) -> HttpRequest {
240 HttpRequest {
241 body,
242 authorization: self.auth.header_value(),
243 session_id: self.session_id.read().await.clone(),
244 protocol_version: self.protocol_version.read().await.clone(),
245 extra_headers: self.extra_headers.clone(),
246 }
247 }
248
249 async fn capture_session_id(&self, reply: &HttpReply) {
251 if let Some(ref sid) = reply.session_id {
252 let mut guard = self.session_id.write().await;
253 if guard.as_deref() != Some(sid.as_str()) {
254 *guard = Some(sid.clone());
255 }
256 }
257 }
258}
259
260fn parse_reply(reply: &HttpReply, id: &RequestId) -> Result<JsonRpcResponse> {
267 if reply.content_type.contains("text/event-stream") {
268 parse_sse_response(&reply.body, id)
269 } else {
270 serde_json::from_str::<JsonRpcResponse>(reply.body.trim())
271 .context("failed to parse JSON MCP response body")
272 }
273}
274
275fn ids_match(a: &RequestId, b: &RequestId) -> bool {
278 match (a, b) {
279 (RequestId::Number(x), RequestId::Number(y)) => x == y,
280 (RequestId::String(x), RequestId::String(y)) => x == y,
281 (RequestId::Number(n), RequestId::String(s))
282 | (RequestId::String(s), RequestId::Number(n)) => s.parse::<u64>().ok() == Some(*n),
283 }
284}
285
286fn parse_sse_response(body: &str, id: &RequestId) -> Result<JsonRpcResponse> {
295 let mut data_buf = String::new();
296
297 let try_match = |data: &mut String| -> Option<JsonRpcResponse> {
298 if data.is_empty() {
299 return None;
300 }
301 let raw = std::mem::take(data);
302 let trimmed = raw.trim();
303 if let Ok(value) = serde_json::from_str::<serde_json::Value>(trimmed)
307 && value.get("method").is_some()
308 {
309 return None;
310 }
311 if let Ok(resp) = serde_json::from_str::<JsonRpcResponse>(trimmed)
312 && ids_match(&resp.id, id)
313 {
314 return Some(resp);
315 }
316 None
317 };
318
319 for line in body.lines() {
320 let line = line.trim_end_matches('\r');
321 if line.is_empty() {
322 if let Some(resp) = try_match(&mut data_buf) {
324 return Ok(resp);
325 }
326 continue;
327 }
328 if let Some(rest) = line.strip_prefix("data:") {
330 let rest = rest.strip_prefix(' ').unwrap_or(rest);
331 if !data_buf.is_empty() {
332 data_buf.push('\n');
333 }
334 data_buf.push_str(rest);
335 }
336 }
338 if let Some(resp) = try_match(&mut data_buf) {
340 return Ok(resp);
341 }
342
343 bail!("SSE stream contained no JSON-RPC response matching the request id")
344}
345
346#[async_trait]
347impl McpTransport for StreamableHttpTransport {
348 async fn send(&self, mut request: JsonRpcRequest) -> Result<JsonRpcResponse> {
349 let id = self.next_request_id();
350 request.id = RequestId::Number(id);
351 let request_id = request.id.clone();
352
353 let body = serde_json::to_string(&request).context("failed to serialize MCP request")?;
354 let http_request = self.build_http_request(body).await;
355 let reply = tokio::time::timeout(SEND_DEADLINE, self.poster.post(http_request))
357 .await
358 .context("MCP HTTP request timed out")??;
359 self.capture_session_id(&reply).await;
360
361 let response = parse_reply(&reply, &request_id)?;
362
363 if let Some(ref error) = response.error {
364 bail!("JSON-RPC error {}: {}", error.code, error.message);
365 }
366 Ok(response)
367 }
368
369 async fn send_notification(&self, mut request: JsonRpcRequest) -> Result<()> {
370 let id = self.next_request_id();
374 request.id = RequestId::Number(id);
375 let body = notification_body(&request)?;
376 let http_request = self.build_http_request(body).await;
377 let reply = tokio::time::timeout(SEND_DEADLINE, self.poster.post(http_request))
378 .await
379 .context("MCP HTTP request timed out")??;
380 self.capture_session_id(&reply).await;
381 Ok(())
382 }
383
384 async fn set_protocol_version(&self, version: &str) {
385 let mut guard = self.protocol_version.write().await;
386 *guard = Some(version.to_string());
387 }
388
389 async fn close(&self) -> Result<()> {
390 Ok(())
391 }
392}
393
394pub struct ReqwestPoster {
396 client: reqwest::Client,
397 endpoint: String,
398}
399
400impl ReqwestPoster {
401 pub fn new(endpoint: impl Into<String>) -> Result<Self> {
412 let client = reqwest::Client::builder()
413 .timeout(DEFAULT_HTTP_TIMEOUT)
414 .build()
415 .context("failed to build MCP HTTP client")?;
416 Ok(Self {
417 client,
418 endpoint: endpoint.into(),
419 })
420 }
421
422 #[must_use]
424 pub fn with_client(client: reqwest::Client, endpoint: impl Into<String>) -> Self {
425 Self {
426 client,
427 endpoint: endpoint.into(),
428 }
429 }
430}
431
432#[async_trait]
433impl HttpPoster for ReqwestPoster {
434 async fn post(&self, request: HttpRequest) -> Result<HttpReply> {
435 let mut builder = self
436 .client
437 .post(&self.endpoint)
438 .header(
440 reqwest::header::ACCEPT,
441 "application/json, text/event-stream",
442 )
443 .header(reqwest::header::CONTENT_TYPE, "application/json")
444 .body(request.body);
445
446 if let Some(auth) = request.authorization {
447 builder = builder.header(reqwest::header::AUTHORIZATION, auth);
448 }
449 if let Some(sid) = request.session_id {
450 builder = builder.header(SESSION_ID_HEADER, sid);
451 }
452 if let Some(version) = request.protocol_version {
453 builder = builder.header(PROTOCOL_VERSION_HEADER, version);
454 }
455 for (name, value) in request.extra_headers {
456 builder = builder.header(name, value);
457 }
458
459 let mut response = builder
460 .send()
461 .await
462 .context("MCP HTTP request failed to send")?;
463
464 let status = response.status();
465 let session_id = response
466 .headers()
467 .get(SESSION_ID_HEADER)
468 .and_then(|v| v.to_str().ok())
469 .map(ToString::to_string);
470 let content_type = response
471 .headers()
472 .get(reqwest::header::CONTENT_TYPE)
473 .and_then(|v| v.to_str().ok())
474 .map_or_else(
475 || "application/json".to_string(),
476 |s| s.split(';').next().unwrap_or(s).trim().to_lowercase(),
477 );
478
479 let mut body_bytes: Vec<u8> = Vec::new();
482 while let Some(chunk) = response
483 .chunk()
484 .await
485 .context("failed to read MCP HTTP response body")?
486 {
487 if body_bytes.len() + chunk.len() > MAX_RESPONSE_BODY_BYTES {
488 bail!("MCP HTTP response body exceeds {MAX_RESPONSE_BODY_BYTES} bytes");
489 }
490 body_bytes.extend_from_slice(&chunk);
491 }
492 let body = String::from_utf8_lossy(&body_bytes).into_owned();
493
494 if !status.is_success() {
495 bail!("MCP HTTP request returned status {status}: {body}");
496 }
497
498 Ok(HttpReply {
499 content_type,
500 body,
501 session_id,
502 })
503 }
504}
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509
510 fn ok_response(id: u64, result: &serde_json::Value) -> String {
511 serde_json::json!({
512 "jsonrpc": "2.0",
513 "id": id,
514 "result": result,
515 })
516 .to_string()
517 }
518
519 #[test]
520 fn parse_json_body() {
521 let reply = HttpReply::json(ok_response(1, &serde_json::json!({"ok": true})));
522 let resp = parse_reply(&reply, &RequestId::Number(1)).expect("parse");
523 assert!(!resp.is_error());
524 assert!(resp.result().is_some());
525 }
526
527 #[test]
528 fn parse_sse_single_event() {
529 let body = format!(
530 "event: message\ndata: {}\n\n",
531 ok_response(2, &serde_json::json!({}))
532 );
533 let reply = HttpReply::event_stream(body);
534 let resp = parse_reply(&reply, &RequestId::Number(2)).expect("parse");
535 assert_eq!(resp.id, RequestId::Number(2));
536 }
537
538 #[test]
539 fn parse_sse_skips_non_matching_then_matches() {
540 let body = format!(
543 "data: {}\n\ndata: {}\n\n",
544 ok_response(99, &serde_json::json!({"unrelated": true})),
545 ok_response(3, &serde_json::json!({"answer": 42})),
546 );
547 let reply = HttpReply::event_stream(body);
548 let resp = parse_reply(&reply, &RequestId::Number(3)).expect("parse");
549 assert_eq!(resp.id, RequestId::Number(3));
550 }
551
552 #[test]
553 fn parse_sse_multiline_data() {
554 let body = "data: {\"jsonrpc\":\"2.0\",\ndata: \"id\":4,\ndata: \"result\":{}}\n\n";
557 let reply = HttpReply::event_stream(body.to_string());
558 let resp = parse_reply(&reply, &RequestId::Number(4)).expect("parse");
559 assert_eq!(resp.id, RequestId::Number(4));
560 }
561
562 #[test]
563 fn bearer_auth_header_value() {
564 assert_eq!(McpAuth::None.header_value(), None);
565 assert_eq!(
566 McpAuth::Bearer("tok".to_string()).header_value().as_deref(),
567 Some("Bearer tok"),
568 );
569 }
570
571 #[test]
576 fn parse_sse_no_matching_id_is_error() {
577 let body = format!(
578 "data: {}\n\n",
579 ok_response(99, &serde_json::json!({"x": 1}))
580 );
581 let reply = HttpReply::event_stream(body);
582 let result = parse_reply(&reply, &RequestId::Number(3));
583 assert!(
584 result.is_err(),
585 "a stream with no matching id must error rather than return a fallback"
586 );
587 }
588
589 #[test]
593 fn parse_sse_skips_server_request_with_method() -> Result<()> {
594 let server_request = serde_json::json!({
595 "jsonrpc": "2.0",
596 "id": 3,
597 "method": "sampling/createMessage",
598 "params": {},
599 })
600 .to_string();
601 let body = format!(
602 "data: {server_request}\n\ndata: {}\n\n",
603 ok_response(3, &serde_json::json!({"answer": 42})),
604 );
605 let reply = HttpReply::event_stream(body);
606 let resp = parse_reply(&reply, &RequestId::Number(3))?;
607 assert_eq!(resp.id, RequestId::Number(3));
608 assert!(
609 resp.result().is_some(),
610 "must return the real reply, not the server request"
611 );
612 Ok(())
613 }
614
615 #[test]
616 fn ids_match_coerces_numeric_string() {
617 assert!(ids_match(&RequestId::Number(5), &RequestId::Number(5)));
618 assert!(ids_match(
619 &RequestId::Number(5),
620 &RequestId::String("5".to_string())
621 ));
622 assert!(ids_match(
623 &RequestId::String("5".to_string()),
624 &RequestId::Number(5)
625 ));
626 assert!(!ids_match(
627 &RequestId::Number(5),
628 &RequestId::String("six".to_string())
629 ));
630 assert!(!ids_match(&RequestId::Number(5), &RequestId::Number(6)));
631 }
632
633 struct CapturingPoster {
635 last_body: std::sync::Mutex<Option<String>>,
636 }
637
638 #[async_trait]
639 impl HttpPoster for CapturingPoster {
640 async fn post(&self, request: HttpRequest) -> Result<HttpReply> {
641 *self
642 .last_body
643 .lock()
644 .unwrap_or_else(std::sync::PoisonError::into_inner) = Some(request.body);
645 Ok(HttpReply::json(ok_response(1, &serde_json::json!({}))))
646 }
647 }
648
649 #[tokio::test]
652 async fn send_notification_omits_id() -> Result<()> {
653 let poster = Arc::new(CapturingPoster {
654 last_body: std::sync::Mutex::new(None),
655 });
656 let transport = StreamableHttpTransport::with_poster(poster.clone(), McpAuth::None);
657
658 transport
659 .send_notification(JsonRpcRequest::new("notifications/initialized", None, 0))
660 .await?;
661
662 let body = poster
663 .last_body
664 .lock()
665 .unwrap_or_else(std::sync::PoisonError::into_inner)
666 .clone()
667 .context("no body captured")?;
668 let value: serde_json::Value = serde_json::from_str(&body)?;
669 assert!(
670 value.get("id").is_none(),
671 "notification must not carry an id, got: {body}"
672 );
673 assert_eq!(
674 value.get("method").and_then(serde_json::Value::as_str),
675 Some("notifications/initialized")
676 );
677 Ok(())
678 }
679
680 #[tokio::test]
683 async fn builder_with_header_is_forwarded() -> Result<()> {
684 struct HeaderCapturingPoster {
685 headers: std::sync::Mutex<Vec<(String, String)>>,
686 }
687
688 #[async_trait]
689 impl HttpPoster for HeaderCapturingPoster {
690 async fn post(&self, request: HttpRequest) -> Result<HttpReply> {
691 *self
692 .headers
693 .lock()
694 .unwrap_or_else(std::sync::PoisonError::into_inner) = request.extra_headers;
695 Ok(HttpReply::json(ok_response(1, &serde_json::json!({}))))
696 }
697 }
698
699 let poster = Arc::new(HeaderCapturingPoster {
700 headers: std::sync::Mutex::new(Vec::new()),
701 });
702 let transport = Arc::new(
703 StreamableHttpTransport::with_poster_owned(poster.clone(), McpAuth::None)
704 .with_header("X-Tenant-Id", "acme"),
705 );
706
707 transport.send(JsonRpcRequest::new("ping", None, 0)).await?;
708
709 let headers = poster
710 .headers
711 .lock()
712 .unwrap_or_else(std::sync::PoisonError::into_inner)
713 .clone();
714 assert!(
715 headers
716 .iter()
717 .any(|(k, v)| k == "X-Tenant-Id" && v == "acme"),
718 "custom header set via builder must be forwarded, got: {headers:?}"
719 );
720 Ok(())
721 }
722}