Skip to main content

bamboo_engine/mcp/transports/
streamable_http.rs

1//! MCP Streamable HTTP transport (MCP protocol version 2025-03-26).
2//!
3//! Implements the "Streamable HTTP" transport where a single HTTP endpoint
4//! handles POST (send), GET (server-initiated SSE), and DELETE (session termination).
5//! See <https://modelcontextprotocol.io/specification/2025-03-26/basic/transports>
6
7use async_trait::async_trait;
8use eventsource_stream::Eventsource;
9use futures::StreamExt;
10use reqwest::header::{HeaderMap, HeaderValue};
11use reqwest::Client;
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::sync::Arc;
14use tokio::sync::{mpsc, Mutex};
15use tracing::{debug, trace, warn};
16
17use crate::mcp::config::{HeaderConfig, StreamableHttpConfig};
18use crate::mcp::error::{McpError, Result};
19use crate::mcp::protocol::client::McpTransport;
20
21const MCP_SESSION_ID_HEADER: &str = "mcp-session-id";
22const ACCEPT_HEADER: &str = "application/json, text/event-stream";
23
24pub struct StreamableHttpTransport {
25    config: StreamableHttpConfig,
26    client: Client,
27    session_id: Arc<Mutex<Option<String>>>,
28    connected: Arc<AtomicBool>,
29    message_tx: mpsc::Sender<String>,
30    message_rx: Mutex<mpsc::Receiver<String>>,
31    get_sse_handle: Mutex<Option<tokio::task::JoinHandle<()>>>,
32}
33
34impl StreamableHttpTransport {
35    pub fn new(config: StreamableHttpConfig) -> Self {
36        Self::new_with_client(config, Client::new())
37    }
38
39    pub fn new_with_client(config: StreamableHttpConfig, client: Client) -> Self {
40        let (message_tx, message_rx) = mpsc::channel(256);
41        Self {
42            config,
43            client,
44            session_id: Arc::new(Mutex::new(None)),
45            connected: Arc::new(AtomicBool::new(false)),
46            message_tx,
47            message_rx: Mutex::new(message_rx),
48            get_sse_handle: Mutex::new(None),
49        }
50    }
51
52    fn build_headers(&self, include_session_id: bool) -> Result<HeaderMap> {
53        let mut headers = HeaderMap::new();
54        headers.insert(
55            reqwest::header::ACCEPT,
56            HeaderValue::from_static(ACCEPT_HEADER),
57        );
58        headers.insert(
59            reqwest::header::CONTENT_TYPE,
60            HeaderValue::from_static("application/json"),
61        );
62
63        for HeaderConfig { name, value, .. } in &self.config.headers {
64            let header_name = reqwest::header::HeaderName::from_bytes(name.as_bytes())
65                .map_err(|e| McpError::InvalidConfig(format!("Invalid header name: {}", e)))?;
66            let header_value = value
67                .parse()
68                .map_err(|e| McpError::InvalidConfig(format!("Invalid header value: {}", e)))?;
69            headers.insert(header_name, header_value);
70        }
71
72        if include_session_id {
73            // Session id is added per-request in send() after reading the lock.
74        }
75
76        Ok(headers)
77    }
78
79    fn redact_url_for_log(url: &str) -> String {
80        match reqwest::Url::parse(url) {
81            Ok(mut parsed) => {
82                parsed.set_query(None);
83                parsed.set_fragment(None);
84                parsed.to_string()
85            }
86            Err(_) => url.to_string(),
87        }
88    }
89
90    /// POST a message to the MCP endpoint and route any response(s) to the
91    /// message channel. Returns Ok(()) if the POST was accepted (202) or a
92    /// response was successfully forwarded.
93    async fn post_and_route_response(
94        &self,
95        message: String,
96        session_id: Option<String>,
97    ) -> Result<()> {
98        let mut headers = self.build_headers(true)?;
99
100        if let Some(sid) = session_id {
101            let value = HeaderValue::from_str(&sid)
102                .map_err(|e| McpError::Transport(format!("Invalid session id: {}", e)))?;
103            headers.insert(MCP_SESSION_ID_HEADER, value);
104        }
105
106        trace!(
107            "MCP StreamableHTTP POST (url={}, bytes={})",
108            Self::redact_url_for_log(&self.config.url),
109            message.len()
110        );
111
112        let response = tokio::time::timeout(
113            tokio::time::Duration::from_secs(60),
114            self.client
115                .post(&self.config.url)
116                .headers(headers)
117                .body(message)
118                .send(),
119        )
120        .await
121        .map_err(|_| McpError::Timeout("POST request timed out".to_string()))??;
122
123        let status = response.status();
124
125        // Extract session id from response if present.
126        if let Some(sid) = response.headers().get(MCP_SESSION_ID_HEADER) {
127            let sid_str = sid
128                .to_str()
129                .map_err(|e| McpError::Transport(format!("Invalid session id header: {}", e)))?;
130            let mut guard = self.session_id.lock().await;
131            guard.get_or_insert_with(|| sid_str.to_string());
132        }
133
134        if status == reqwest::StatusCode::ACCEPTED {
135            // Server accepted notification/response, no body.
136            trace!("MCP StreamableHTTP POST accepted (202)");
137            return Ok(());
138        }
139
140        if !status.is_success() {
141            let body = response.text().await.unwrap_or_default();
142            return Err(McpError::Transport(format!(
143                "POST failed: {} - {}",
144                status, body
145            )));
146        }
147
148        let content_type = response
149            .headers()
150            .get(reqwest::header::CONTENT_TYPE)
151            .and_then(|v| v.to_str().ok())
152            .unwrap_or("");
153
154        if content_type.contains("text/event-stream") {
155            // SSE response — parse events and forward each to channel.
156            trace!("MCP StreamableHTTP POST response is SSE stream");
157            let tx = self.message_tx.clone();
158            let url = self.config.url.clone();
159            let connected = self.connected.clone();
160
161            // We need to consume the response body in a spawned task to avoid
162            // blocking the caller. Events from this POST's SSE response are
163            // forwarded to the channel so receive() can pick them up.
164            tokio::spawn(async move {
165                let mut stream = response.bytes_stream().eventsource();
166                while let Some(event) = stream.next().await {
167                    match event {
168                        Ok(evt) => {
169                            if !evt.data.trim().is_empty() {
170                                trace!(
171                                    "MCP StreamableHTTP POST SSE event (event='{}', data_len={})",
172                                    evt.event,
173                                    evt.data.len()
174                                );
175                                if tx.send(evt.data).await.is_err() {
176                                    break;
177                                }
178                            }
179                        }
180                        Err(e) => {
181                            warn!("MCP StreamableHTTP POST SSE error: {}", e);
182                            break;
183                        }
184                    }
185                }
186                let _ = (url, connected); // suppress unused warnings
187            });
188        } else {
189            // JSON response — forward the body directly.
190            let body = response.text().await?;
191            if !body.trim().is_empty() {
192                trace!(
193                    "MCP StreamableHTTP POST response is JSON (bytes={})",
194                    body.len()
195                );
196                if self.message_tx.send(body).await.is_err() {
197                    warn!("MCP StreamableHTTP: message channel closed");
198                }
199            }
200        }
201
202        Ok(())
203    }
204
205    /// Attempt to open a GET SSE stream for server-initiated messages.
206    /// Per spec, the server MAY return 405 if it doesn't support this.
207    async fn start_get_sse_stream(&self) {
208        let mut headers = self.build_headers(true).unwrap_or_default();
209        headers.insert(
210            reqwest::header::ACCEPT,
211            HeaderValue::from_static("text/event-stream"),
212        );
213
214        // Add session id if available.
215        {
216            let sid = self.session_id.lock().await;
217            if let Some(sid) = sid.as_ref() {
218                if let Ok(value) = HeaderValue::from_str(sid) {
219                    headers.insert(MCP_SESSION_ID_HEADER, value);
220                }
221            }
222        }
223
224        trace!(
225            "MCP StreamableHTTP GET SSE stream (url={})",
226            Self::redact_url_for_log(&self.config.url)
227        );
228
229        let response = match self
230            .client
231            .get(&self.config.url)
232            .headers(headers)
233            .send()
234            .await
235        {
236            Ok(r) => r,
237            Err(e) => {
238                debug!("MCP StreamableHTTP GET SSE stream failed: {}", e);
239                return;
240            }
241        };
242
243        if response.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED {
244            debug!("MCP StreamableHTTP server does not support GET SSE stream (405)");
245            return;
246        }
247
248        if !response.status().is_success() {
249            debug!(
250                "MCP StreamableHTTP GET SSE stream returned: {}",
251                response.status()
252            );
253            return;
254        }
255
256        // Extract session id from GET response if present.
257        if let Some(sid) = response.headers().get(MCP_SESSION_ID_HEADER) {
258            if let Ok(sid_str) = sid.to_str() {
259                let mut guard = self.session_id.lock().await;
260                guard.get_or_insert_with(|| sid_str.to_string());
261            }
262        }
263
264        debug!("MCP StreamableHTTP GET SSE stream opened");
265
266        let tx = self.message_tx.clone();
267        let connected = self.connected.clone();
268
269        let handle = tokio::spawn(async move {
270            let mut stream = response.bytes_stream().eventsource();
271            while let Some(event) = stream.next().await {
272                match event {
273                    Ok(evt) => {
274                        if !evt.data.trim().is_empty() {
275                            trace!(
276                                "MCP StreamableHTTP GET SSE event (event='{}', data_len={})",
277                                evt.event,
278                                evt.data.len()
279                            );
280                            if tx.send(evt.data).await.is_err() {
281                                break;
282                            }
283                        }
284                    }
285                    Err(e) => {
286                        warn!("MCP StreamableHTTP GET SSE error: {}", e);
287                        break;
288                    }
289                }
290            }
291            connected.store(false, Ordering::SeqCst);
292        });
293
294        let mut guard = self.get_sse_handle.lock().await;
295        *guard = Some(handle);
296    }
297}
298
299#[async_trait]
300impl McpTransport for StreamableHttpTransport {
301    async fn connect(&mut self) -> Result<()> {
302        debug!(
303            "Connecting to MCP StreamableHTTP endpoint: {} (connect_timeout_ms={})",
304            Self::redact_url_for_log(&self.config.url),
305            self.config.connect_timeout_ms
306        );
307
308        // Streamable HTTP doesn't have a separate "connect" step in the traditional
309        // sense. The first request (initialize) will be sent via send(). Here we
310        // just mark the transport as ready and optionally open a GET SSE stream.
311        self.connected.store(true, Ordering::SeqCst);
312
313        debug!("MCP StreamableHTTP transport ready");
314        Ok(())
315    }
316
317    async fn disconnect(&mut self) -> Result<()> {
318        debug!("Disconnecting MCP StreamableHTTP transport");
319
320        self.connected.store(false, Ordering::SeqCst);
321
322        // Cancel the GET SSE stream background task.
323        {
324            let mut guard = self.get_sse_handle.lock().await;
325            if let Some(handle) = guard.take() {
326                handle.abort();
327            }
328        }
329
330        // Send DELETE to terminate session (best-effort).
331        {
332            let sid = self.session_id.lock().await;
333            if let Some(session_id) = sid.as_ref() {
334                let mut headers = self.build_headers(false)?;
335                if let Ok(value) = HeaderValue::from_str(session_id) {
336                    headers.insert(MCP_SESSION_ID_HEADER, value);
337                }
338
339                trace!(
340                    "MCP StreamableHTTP DELETE session (url={})",
341                    Self::redact_url_for_log(&self.config.url)
342                );
343                let _ = self
344                    .client
345                    .delete(&self.config.url)
346                    .headers(headers)
347                    .send()
348                    .await;
349            }
350        }
351
352        // Clear session id.
353        {
354            let mut guard = self.session_id.lock().await;
355            *guard = None;
356        }
357
358        debug!("MCP StreamableHTTP transport disconnected");
359        Ok(())
360    }
361
362    async fn send(&self, message: String) -> Result<()> {
363        if !self.is_connected() {
364            return Err(McpError::Disconnected);
365        }
366
367        let session_id = self.session_id.lock().await.clone();
368
369        self.post_and_route_response(message, session_id).await?;
370
371        // After the first successful exchange, try to open the GET SSE stream
372        // for server-initiated messages (if not already opened).
373        {
374            let guard = self.get_sse_handle.lock().await;
375            if guard.is_none() {
376                // Don't hold the lock while starting the stream.
377                drop(guard);
378                self.start_get_sse_stream().await;
379            }
380        }
381
382        Ok(())
383    }
384
385    async fn receive(&self) -> Result<Option<String>> {
386        if !self.is_connected() {
387            return Err(McpError::Disconnected);
388        }
389
390        let mut rx = self.message_rx.lock().await;
391        match tokio::time::timeout(tokio::time::Duration::from_millis(100), rx.recv()).await {
392            Ok(Some(message)) => {
393                trace!(
394                    "MCP StreamableHTTP received message (bytes={})",
395                    message.len()
396                );
397                Ok(Some(message))
398            }
399            Ok(None) => {
400                warn!("MCP StreamableHTTP message channel closed");
401                Err(McpError::Disconnected)
402            }
403            Err(_) => Ok(None), // Timeout, no message available
404        }
405    }
406
407    fn is_connected(&self) -> bool {
408        self.connected.load(Ordering::SeqCst)
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415
416    fn create_test_config() -> StreamableHttpConfig {
417        StreamableHttpConfig {
418            url: "http://localhost:3000/mcp".to_string(),
419            headers: vec![],
420            connect_timeout_ms: 5000,
421        }
422    }
423
424    #[test]
425    fn test_transport_new() {
426        let config = create_test_config();
427        let transport = StreamableHttpTransport::new(config);
428        assert!(!transport.is_connected());
429    }
430
431    #[test]
432    fn test_build_headers_basic() {
433        let config = create_test_config();
434        let transport = StreamableHttpTransport::new(config);
435        let headers = transport.build_headers(false).unwrap();
436
437        assert_eq!(headers.get(reqwest::header::ACCEPT).unwrap(), ACCEPT_HEADER);
438        assert_eq!(
439            headers.get(reqwest::header::CONTENT_TYPE).unwrap(),
440            "application/json"
441        );
442    }
443
444    #[test]
445    fn test_build_headers_with_custom() {
446        let config = StreamableHttpConfig {
447            url: "http://localhost:3000/mcp".to_string(),
448            headers: vec![HeaderConfig {
449                name: "Authorization".to_string(),
450                value: "Bearer token123".to_string(),
451                value_encrypted: None,
452            }],
453            connect_timeout_ms: 5000,
454        };
455        let transport = StreamableHttpTransport::new(config);
456        let headers = transport.build_headers(false).unwrap();
457
458        assert!(headers.contains_key("authorization"));
459    }
460
461    #[test]
462    fn test_build_headers_invalid_name() {
463        let config = StreamableHttpConfig {
464            url: "http://localhost:3000/mcp".to_string(),
465            headers: vec![HeaderConfig {
466                name: "Invalid\nName".to_string(),
467                value: "test".to_string(),
468                value_encrypted: None,
469            }],
470            connect_timeout_ms: 5000,
471        };
472        let transport = StreamableHttpTransport::new(config);
473        assert!(transport.build_headers(false).is_err());
474    }
475
476    #[test]
477    fn test_redact_url() {
478        assert_eq!(
479            StreamableHttpTransport::redact_url_for_log("http://example.com/mcp?token=secret"),
480            "http://example.com/mcp"
481        );
482    }
483
484    #[tokio::test]
485    async fn test_send_disconnected() {
486        let config = create_test_config();
487        let transport = StreamableHttpTransport::new(config);
488
489        let result = transport.send("{}".to_string()).await;
490        assert!(result.is_err());
491        match result.unwrap_err() {
492            McpError::Disconnected => {}
493            e => panic!("Expected Disconnected, got: {:?}", e),
494        }
495    }
496
497    #[tokio::test]
498    async fn test_receive_disconnected() {
499        let config = create_test_config();
500        let transport = StreamableHttpTransport::new(config);
501
502        let result = transport.receive().await;
503        assert!(result.is_err());
504        match result.unwrap_err() {
505            McpError::Disconnected => {}
506            e => panic!("Expected Disconnected, got: {:?}", e),
507        }
508    }
509
510    #[tokio::test]
511    async fn test_connect_disconnect() {
512        let config = create_test_config();
513        let mut transport = StreamableHttpTransport::new(config);
514
515        transport.connect().await.unwrap();
516        assert!(transport.is_connected());
517
518        transport.disconnect().await.unwrap();
519        assert!(!transport.is_connected());
520    }
521
522    #[tokio::test]
523    async fn test_receive_timeout() {
524        let config = create_test_config();
525        let transport = StreamableHttpTransport::new(config);
526        transport.connected.store(true, Ordering::SeqCst);
527
528        let result = transport.receive().await;
529        assert!(result.is_ok());
530        assert!(result.unwrap().is_none());
531    }
532
533    #[tokio::test]
534    async fn test_session_id_stored_on_response() {
535        let config = create_test_config();
536        let transport = StreamableHttpTransport::new(config);
537        transport.connected.store(true, Ordering::SeqCst);
538
539        // Simulate a session id being stored.
540        {
541            let mut guard = transport.session_id.lock().await;
542            *guard = Some("test-session-123".to_string());
543        }
544
545        let sid = transport.session_id.lock().await;
546        assert_eq!(sid.as_deref(), Some("test-session-123"));
547    }
548
549    #[tokio::test]
550    async fn test_disconnect_clears_session() {
551        let config = create_test_config();
552        let mut transport = StreamableHttpTransport::new(config);
553        transport.connect().await.unwrap();
554
555        {
556            let mut guard = transport.session_id.lock().await;
557            *guard = Some("test-session".to_string());
558        }
559
560        transport.disconnect().await.unwrap();
561
562        let sid = transport.session_id.lock().await;
563        assert!(sid.is_none());
564    }
565}