1use anyhow::{Context, Result, bail};
26use async_trait::async_trait;
27use std::sync::Arc;
28use std::sync::atomic::{AtomicU64, Ordering};
29use tokio::sync::RwLock;
30
31use super::protocol::{JsonRpcRequest, JsonRpcResponse, RequestId};
32use super::transport::McpTransport;
33
34const SESSION_ID_HEADER: &str = "Mcp-Session-Id";
36const PROTOCOL_VERSION_HEADER: &str = "MCP-Protocol-Version";
38
39#[derive(Clone, Debug)]
42pub struct HttpReply {
43 pub content_type: String,
45 pub body: String,
47 pub session_id: Option<String>,
49}
50
51impl HttpReply {
52 #[must_use]
54 pub fn json(body: impl Into<String>) -> Self {
55 Self {
56 content_type: "application/json".to_string(),
57 body: body.into(),
58 session_id: None,
59 }
60 }
61
62 #[must_use]
64 pub fn event_stream(body: impl Into<String>) -> Self {
65 Self {
66 content_type: "text/event-stream".to_string(),
67 body: body.into(),
68 session_id: None,
69 }
70 }
71
72 #[must_use]
74 pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
75 self.session_id = Some(session_id.into());
76 self
77 }
78}
79
80#[derive(Clone, Debug)]
82pub struct HttpRequest {
83 pub body: String,
85 pub authorization: Option<String>,
87 pub session_id: Option<String>,
89 pub protocol_version: Option<String>,
91 pub extra_headers: Vec<(String, String)>,
93}
94
95#[async_trait]
100pub trait HttpPoster: Send + Sync {
101 async fn post(&self, request: HttpRequest) -> Result<HttpReply>;
108}
109
110#[derive(Clone, Debug, Default)]
112pub enum McpAuth {
113 #[default]
115 None,
116 Bearer(String),
118}
119
120impl McpAuth {
121 #[must_use]
123 fn header_value(&self) -> Option<String> {
124 match self {
125 Self::None => None,
126 Self::Bearer(token) => Some(format!("Bearer {token}")),
127 }
128 }
129}
130
131pub struct StreamableHttpTransport {
137 poster: Arc<dyn HttpPoster>,
138 auth: McpAuth,
139 extra_headers: Vec<(String, String)>,
140 next_id: AtomicU64,
141 session_id: RwLock<Option<String>>,
143 protocol_version: RwLock<Option<String>>,
145}
146
147impl StreamableHttpTransport {
148 pub fn new(endpoint: impl Into<String>, auth: McpAuth) -> Result<Arc<Self>> {
154 let poster = ReqwestPoster::new(endpoint)?;
155 Ok(Self::with_poster(Arc::new(poster), auth))
156 }
157
158 #[must_use]
163 pub fn with_poster(poster: Arc<dyn HttpPoster>, auth: McpAuth) -> Arc<Self> {
164 Arc::new(Self {
165 poster,
166 auth,
167 extra_headers: Vec::new(),
168 next_id: AtomicU64::new(1),
169 session_id: RwLock::new(None),
170 protocol_version: RwLock::new(None),
171 })
172 }
173
174 #[must_use]
176 pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
177 self.extra_headers.push((name.into(), value.into()));
178 self
179 }
180
181 fn next_request_id(&self) -> u64 {
182 self.next_id.fetch_add(1, Ordering::SeqCst)
183 }
184
185 async fn build_http_request(&self, body: String) -> HttpRequest {
186 HttpRequest {
187 body,
188 authorization: self.auth.header_value(),
189 session_id: self.session_id.read().await.clone(),
190 protocol_version: self.protocol_version.read().await.clone(),
191 extra_headers: self.extra_headers.clone(),
192 }
193 }
194
195 async fn capture_session_id(&self, reply: &HttpReply) {
197 if let Some(ref sid) = reply.session_id {
198 let mut guard = self.session_id.write().await;
199 if guard.as_deref() != Some(sid.as_str()) {
200 *guard = Some(sid.clone());
201 }
202 }
203 }
204}
205
206fn parse_reply(reply: &HttpReply, id: &RequestId) -> Result<JsonRpcResponse> {
213 if reply.content_type.contains("text/event-stream") {
214 parse_sse_response(&reply.body, id)
215 } else {
216 serde_json::from_str::<JsonRpcResponse>(reply.body.trim())
217 .context("failed to parse JSON MCP response body")
218 }
219}
220
221fn parse_sse_response(body: &str, id: &RequestId) -> Result<JsonRpcResponse> {
223 let mut data_buf = String::new();
224 let mut last_parsed: Option<JsonRpcResponse> = None;
225
226 let flush =
227 |data: &mut String, last: &mut Option<JsonRpcResponse>| -> Option<JsonRpcResponse> {
228 if data.is_empty() {
229 return None;
230 }
231 let raw = std::mem::take(data);
232 if let Ok(resp) = serde_json::from_str::<JsonRpcResponse>(raw.trim()) {
233 if &resp.id == id {
234 return Some(resp);
235 }
236 *last = Some(resp);
237 }
238 None
239 };
240
241 for line in body.lines() {
242 let line = line.trim_end_matches('\r');
243 if line.is_empty() {
244 if let Some(resp) = flush(&mut data_buf, &mut last_parsed) {
246 return Ok(resp);
247 }
248 continue;
249 }
250 if let Some(rest) = line.strip_prefix("data:") {
252 let rest = rest.strip_prefix(' ').unwrap_or(rest);
253 if !data_buf.is_empty() {
254 data_buf.push('\n');
255 }
256 data_buf.push_str(rest);
257 }
258 }
260 if let Some(resp) = flush(&mut data_buf, &mut last_parsed) {
262 return Ok(resp);
263 }
264
265 last_parsed.context("SSE stream contained no JSON-RPC response matching the request id")
266}
267
268#[async_trait]
269impl McpTransport for StreamableHttpTransport {
270 async fn send(&self, mut request: JsonRpcRequest) -> Result<JsonRpcResponse> {
271 let id = self.next_request_id();
272 request.id = RequestId::Number(id);
273 let request_id = request.id.clone();
274
275 let body = serde_json::to_string(&request).context("failed to serialize MCP request")?;
276 let http_request = self.build_http_request(body).await;
277 let reply = self.poster.post(http_request).await?;
278 self.capture_session_id(&reply).await;
279
280 let response = parse_reply(&reply, &request_id)?;
281
282 if let Some(ref error) = response.error {
283 bail!("JSON-RPC error {}: {}", error.code, error.message);
284 }
285 Ok(response)
286 }
287
288 async fn send_notification(&self, mut request: JsonRpcRequest) -> Result<()> {
289 let id = self.next_request_id();
292 request.id = RequestId::Number(id);
293 let body = serde_json::to_string(&request).context("failed to serialize MCP request")?;
294 let http_request = self.build_http_request(body).await;
295 let reply = self.poster.post(http_request).await?;
296 self.capture_session_id(&reply).await;
297 Ok(())
298 }
299
300 async fn set_protocol_version(&self, version: &str) {
301 let mut guard = self.protocol_version.write().await;
302 *guard = Some(version.to_string());
303 }
304
305 async fn close(&self) -> Result<()> {
306 Ok(())
307 }
308}
309
310pub struct ReqwestPoster {
312 client: reqwest::Client,
313 endpoint: String,
314}
315
316impl ReqwestPoster {
317 pub fn new(endpoint: impl Into<String>) -> Result<Self> {
323 let client = reqwest::Client::builder()
324 .build()
325 .context("failed to build MCP HTTP client")?;
326 Ok(Self {
327 client,
328 endpoint: endpoint.into(),
329 })
330 }
331
332 #[must_use]
334 pub fn with_client(client: reqwest::Client, endpoint: impl Into<String>) -> Self {
335 Self {
336 client,
337 endpoint: endpoint.into(),
338 }
339 }
340}
341
342#[async_trait]
343impl HttpPoster for ReqwestPoster {
344 async fn post(&self, request: HttpRequest) -> Result<HttpReply> {
345 let mut builder = self
346 .client
347 .post(&self.endpoint)
348 .header(
350 reqwest::header::ACCEPT,
351 "application/json, text/event-stream",
352 )
353 .header(reqwest::header::CONTENT_TYPE, "application/json")
354 .body(request.body);
355
356 if let Some(auth) = request.authorization {
357 builder = builder.header(reqwest::header::AUTHORIZATION, auth);
358 }
359 if let Some(sid) = request.session_id {
360 builder = builder.header(SESSION_ID_HEADER, sid);
361 }
362 if let Some(version) = request.protocol_version {
363 builder = builder.header(PROTOCOL_VERSION_HEADER, version);
364 }
365 for (name, value) in request.extra_headers {
366 builder = builder.header(name, value);
367 }
368
369 let response = builder
370 .send()
371 .await
372 .context("MCP HTTP request failed to send")?;
373
374 let status = response.status();
375 let session_id = response
376 .headers()
377 .get(SESSION_ID_HEADER)
378 .and_then(|v| v.to_str().ok())
379 .map(ToString::to_string);
380 let content_type = response
381 .headers()
382 .get(reqwest::header::CONTENT_TYPE)
383 .and_then(|v| v.to_str().ok())
384 .map_or_else(
385 || "application/json".to_string(),
386 |s| s.split(';').next().unwrap_or(s).trim().to_lowercase(),
387 );
388
389 let body = response
390 .text()
391 .await
392 .context("failed to read MCP HTTP response body")?;
393
394 if !status.is_success() {
395 bail!("MCP HTTP request returned status {status}: {body}");
396 }
397
398 Ok(HttpReply {
399 content_type,
400 body,
401 session_id,
402 })
403 }
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409
410 fn ok_response(id: u64, result: &serde_json::Value) -> String {
411 serde_json::json!({
412 "jsonrpc": "2.0",
413 "id": id,
414 "result": result,
415 })
416 .to_string()
417 }
418
419 #[test]
420 fn parse_json_body() {
421 let reply = HttpReply::json(ok_response(1, &serde_json::json!({"ok": true})));
422 let resp = parse_reply(&reply, &RequestId::Number(1)).expect("parse");
423 assert!(!resp.is_error());
424 assert!(resp.result().is_some());
425 }
426
427 #[test]
428 fn parse_sse_single_event() {
429 let body = format!(
430 "event: message\ndata: {}\n\n",
431 ok_response(2, &serde_json::json!({}))
432 );
433 let reply = HttpReply::event_stream(body);
434 let resp = parse_reply(&reply, &RequestId::Number(2)).expect("parse");
435 assert_eq!(resp.id, RequestId::Number(2));
436 }
437
438 #[test]
439 fn parse_sse_skips_non_matching_then_matches() {
440 let body = format!(
443 "data: {}\n\ndata: {}\n\n",
444 ok_response(99, &serde_json::json!({"unrelated": true})),
445 ok_response(3, &serde_json::json!({"answer": 42})),
446 );
447 let reply = HttpReply::event_stream(body);
448 let resp = parse_reply(&reply, &RequestId::Number(3)).expect("parse");
449 assert_eq!(resp.id, RequestId::Number(3));
450 }
451
452 #[test]
453 fn parse_sse_multiline_data() {
454 let body = "data: {\"jsonrpc\":\"2.0\",\ndata: \"id\":4,\ndata: \"result\":{}}\n\n";
457 let reply = HttpReply::event_stream(body.to_string());
458 let resp = parse_reply(&reply, &RequestId::Number(4)).expect("parse");
459 assert_eq!(resp.id, RequestId::Number(4));
460 }
461
462 #[test]
463 fn bearer_auth_header_value() {
464 assert_eq!(McpAuth::None.header_value(), None);
465 assert_eq!(
466 McpAuth::Bearer("tok".to_string()).header_value().as_deref(),
467 Some("Bearer tok"),
468 );
469 }
470}