bamboo_engine/mcp/transports/
streamable_http.rs1use async_trait::async_trait;
8use eventsource_stream::Eventsource;
9use futures::StreamExt;
10use reqwest::header::{HeaderMap, HeaderValue};
11use reqwest::Client;
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::sync::Arc;
14use tokio::sync::{mpsc, Mutex};
15use tracing::{debug, trace, warn};
16
17use crate::mcp::config::{HeaderConfig, StreamableHttpConfig};
18use crate::mcp::error::{McpError, Result};
19use crate::mcp::protocol::client::McpTransport;
20
21const MCP_SESSION_ID_HEADER: &str = "mcp-session-id";
22const ACCEPT_HEADER: &str = "application/json, text/event-stream";
23
24pub struct StreamableHttpTransport {
25 config: StreamableHttpConfig,
26 client: Client,
27 session_id: Arc<Mutex<Option<String>>>,
28 connected: Arc<AtomicBool>,
29 message_tx: mpsc::Sender<String>,
30 message_rx: Mutex<mpsc::Receiver<String>>,
31 get_sse_handle: Mutex<Option<tokio::task::JoinHandle<()>>>,
32}
33
34impl StreamableHttpTransport {
35 pub fn new(config: StreamableHttpConfig) -> Self {
36 Self::new_with_client(config, Client::new())
37 }
38
39 pub fn new_with_client(config: StreamableHttpConfig, client: Client) -> Self {
40 let (message_tx, message_rx) = mpsc::channel(256);
41 Self {
42 config,
43 client,
44 session_id: Arc::new(Mutex::new(None)),
45 connected: Arc::new(AtomicBool::new(false)),
46 message_tx,
47 message_rx: Mutex::new(message_rx),
48 get_sse_handle: Mutex::new(None),
49 }
50 }
51
52 fn build_headers(&self, include_session_id: bool) -> Result<HeaderMap> {
53 let mut headers = HeaderMap::new();
54 headers.insert(
55 reqwest::header::ACCEPT,
56 HeaderValue::from_static(ACCEPT_HEADER),
57 );
58 headers.insert(
59 reqwest::header::CONTENT_TYPE,
60 HeaderValue::from_static("application/json"),
61 );
62
63 for HeaderConfig { name, value, .. } in &self.config.headers {
64 let header_name = reqwest::header::HeaderName::from_bytes(name.as_bytes())
65 .map_err(|e| McpError::InvalidConfig(format!("Invalid header name: {}", e)))?;
66 let header_value = value
67 .parse()
68 .map_err(|e| McpError::InvalidConfig(format!("Invalid header value: {}", e)))?;
69 headers.insert(header_name, header_value);
70 }
71
72 if include_session_id {
73 }
75
76 Ok(headers)
77 }
78
79 fn redact_url_for_log(url: &str) -> String {
80 match reqwest::Url::parse(url) {
81 Ok(mut parsed) => {
82 parsed.set_query(None);
83 parsed.set_fragment(None);
84 parsed.to_string()
85 }
86 Err(_) => url.to_string(),
87 }
88 }
89
90 async fn post_and_route_response(
94 &self,
95 message: String,
96 session_id: Option<String>,
97 ) -> Result<()> {
98 let mut headers = self.build_headers(true)?;
99
100 if let Some(sid) = session_id {
101 let value = HeaderValue::from_str(&sid)
102 .map_err(|e| McpError::Transport(format!("Invalid session id: {}", e)))?;
103 headers.insert(MCP_SESSION_ID_HEADER, value);
104 }
105
106 trace!(
107 "MCP StreamableHTTP POST (url={}, bytes={})",
108 Self::redact_url_for_log(&self.config.url),
109 message.len()
110 );
111
112 let response = tokio::time::timeout(
113 tokio::time::Duration::from_secs(60),
114 self.client
115 .post(&self.config.url)
116 .headers(headers)
117 .body(message)
118 .send(),
119 )
120 .await
121 .map_err(|_| McpError::Timeout("POST request timed out".to_string()))??;
122
123 let status = response.status();
124
125 if let Some(sid) = response.headers().get(MCP_SESSION_ID_HEADER) {
127 let sid_str = sid
128 .to_str()
129 .map_err(|e| McpError::Transport(format!("Invalid session id header: {}", e)))?;
130 let mut guard = self.session_id.lock().await;
131 guard.get_or_insert_with(|| sid_str.to_string());
132 }
133
134 if status == reqwest::StatusCode::ACCEPTED {
135 trace!("MCP StreamableHTTP POST accepted (202)");
137 return Ok(());
138 }
139
140 if !status.is_success() {
141 let body = response.text().await.unwrap_or_default();
142 return Err(McpError::Transport(format!(
143 "POST failed: {} - {}",
144 status, body
145 )));
146 }
147
148 let content_type = response
149 .headers()
150 .get(reqwest::header::CONTENT_TYPE)
151 .and_then(|v| v.to_str().ok())
152 .unwrap_or("");
153
154 if content_type.contains("text/event-stream") {
155 trace!("MCP StreamableHTTP POST response is SSE stream");
157 let tx = self.message_tx.clone();
158 let url = self.config.url.clone();
159 let connected = self.connected.clone();
160
161 tokio::spawn(async move {
165 let mut stream = response.bytes_stream().eventsource();
166 while let Some(event) = stream.next().await {
167 match event {
168 Ok(evt) => {
169 if !evt.data.trim().is_empty() {
170 trace!(
171 "MCP StreamableHTTP POST SSE event (event='{}', data_len={})",
172 evt.event,
173 evt.data.len()
174 );
175 if tx.send(evt.data).await.is_err() {
176 break;
177 }
178 }
179 }
180 Err(e) => {
181 warn!("MCP StreamableHTTP POST SSE error: {}", e);
182 break;
183 }
184 }
185 }
186 let _ = (url, connected); });
188 } else {
189 let body = response.text().await?;
191 if !body.trim().is_empty() {
192 trace!(
193 "MCP StreamableHTTP POST response is JSON (bytes={})",
194 body.len()
195 );
196 if self.message_tx.send(body).await.is_err() {
197 warn!("MCP StreamableHTTP: message channel closed");
198 }
199 }
200 }
201
202 Ok(())
203 }
204
205 async fn start_get_sse_stream(&self) {
208 let mut headers = self.build_headers(true).unwrap_or_default();
209 headers.insert(
210 reqwest::header::ACCEPT,
211 HeaderValue::from_static("text/event-stream"),
212 );
213
214 {
216 let sid = self.session_id.lock().await;
217 if let Some(sid) = sid.as_ref() {
218 if let Ok(value) = HeaderValue::from_str(sid) {
219 headers.insert(MCP_SESSION_ID_HEADER, value);
220 }
221 }
222 }
223
224 trace!(
225 "MCP StreamableHTTP GET SSE stream (url={})",
226 Self::redact_url_for_log(&self.config.url)
227 );
228
229 let response = match self
230 .client
231 .get(&self.config.url)
232 .headers(headers)
233 .send()
234 .await
235 {
236 Ok(r) => r,
237 Err(e) => {
238 debug!("MCP StreamableHTTP GET SSE stream failed: {}", e);
239 return;
240 }
241 };
242
243 if response.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED {
244 debug!("MCP StreamableHTTP server does not support GET SSE stream (405)");
245 return;
246 }
247
248 if !response.status().is_success() {
249 debug!(
250 "MCP StreamableHTTP GET SSE stream returned: {}",
251 response.status()
252 );
253 return;
254 }
255
256 if let Some(sid) = response.headers().get(MCP_SESSION_ID_HEADER) {
258 if let Ok(sid_str) = sid.to_str() {
259 let mut guard = self.session_id.lock().await;
260 guard.get_or_insert_with(|| sid_str.to_string());
261 }
262 }
263
264 debug!("MCP StreamableHTTP GET SSE stream opened");
265
266 let tx = self.message_tx.clone();
267 let connected = self.connected.clone();
268
269 let handle = tokio::spawn(async move {
270 let mut stream = response.bytes_stream().eventsource();
271 while let Some(event) = stream.next().await {
272 match event {
273 Ok(evt) => {
274 if !evt.data.trim().is_empty() {
275 trace!(
276 "MCP StreamableHTTP GET SSE event (event='{}', data_len={})",
277 evt.event,
278 evt.data.len()
279 );
280 if tx.send(evt.data).await.is_err() {
281 break;
282 }
283 }
284 }
285 Err(e) => {
286 warn!("MCP StreamableHTTP GET SSE error: {}", e);
287 break;
288 }
289 }
290 }
291 connected.store(false, Ordering::SeqCst);
292 });
293
294 let mut guard = self.get_sse_handle.lock().await;
295 *guard = Some(handle);
296 }
297}
298
299#[async_trait]
300impl McpTransport for StreamableHttpTransport {
301 async fn connect(&mut self) -> Result<()> {
302 debug!(
303 "Connecting to MCP StreamableHTTP endpoint: {} (connect_timeout_ms={})",
304 Self::redact_url_for_log(&self.config.url),
305 self.config.connect_timeout_ms
306 );
307
308 self.connected.store(true, Ordering::SeqCst);
312
313 debug!("MCP StreamableHTTP transport ready");
314 Ok(())
315 }
316
317 async fn disconnect(&mut self) -> Result<()> {
318 debug!("Disconnecting MCP StreamableHTTP transport");
319
320 self.connected.store(false, Ordering::SeqCst);
321
322 {
324 let mut guard = self.get_sse_handle.lock().await;
325 if let Some(handle) = guard.take() {
326 handle.abort();
327 }
328 }
329
330 {
332 let sid = self.session_id.lock().await;
333 if let Some(session_id) = sid.as_ref() {
334 let mut headers = self.build_headers(false)?;
335 if let Ok(value) = HeaderValue::from_str(session_id) {
336 headers.insert(MCP_SESSION_ID_HEADER, value);
337 }
338
339 trace!(
340 "MCP StreamableHTTP DELETE session (url={})",
341 Self::redact_url_for_log(&self.config.url)
342 );
343 let _ = self
344 .client
345 .delete(&self.config.url)
346 .headers(headers)
347 .send()
348 .await;
349 }
350 }
351
352 {
354 let mut guard = self.session_id.lock().await;
355 *guard = None;
356 }
357
358 debug!("MCP StreamableHTTP transport disconnected");
359 Ok(())
360 }
361
362 async fn send(&self, message: String) -> Result<()> {
363 if !self.is_connected() {
364 return Err(McpError::Disconnected);
365 }
366
367 let session_id = self.session_id.lock().await.clone();
368
369 self.post_and_route_response(message, session_id).await?;
370
371 {
374 let guard = self.get_sse_handle.lock().await;
375 if guard.is_none() {
376 drop(guard);
378 self.start_get_sse_stream().await;
379 }
380 }
381
382 Ok(())
383 }
384
385 async fn receive(&self) -> Result<Option<String>> {
386 if !self.is_connected() {
387 return Err(McpError::Disconnected);
388 }
389
390 let mut rx = self.message_rx.lock().await;
391 match tokio::time::timeout(tokio::time::Duration::from_millis(100), rx.recv()).await {
392 Ok(Some(message)) => {
393 trace!(
394 "MCP StreamableHTTP received message (bytes={})",
395 message.len()
396 );
397 Ok(Some(message))
398 }
399 Ok(None) => {
400 warn!("MCP StreamableHTTP message channel closed");
401 Err(McpError::Disconnected)
402 }
403 Err(_) => Ok(None), }
405 }
406
407 fn is_connected(&self) -> bool {
408 self.connected.load(Ordering::SeqCst)
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415
416 fn create_test_config() -> StreamableHttpConfig {
417 StreamableHttpConfig {
418 url: "http://localhost:3000/mcp".to_string(),
419 headers: vec![],
420 connect_timeout_ms: 5000,
421 }
422 }
423
424 #[test]
425 fn test_transport_new() {
426 let config = create_test_config();
427 let transport = StreamableHttpTransport::new(config);
428 assert!(!transport.is_connected());
429 }
430
431 #[test]
432 fn test_build_headers_basic() {
433 let config = create_test_config();
434 let transport = StreamableHttpTransport::new(config);
435 let headers = transport.build_headers(false).unwrap();
436
437 assert_eq!(headers.get(reqwest::header::ACCEPT).unwrap(), ACCEPT_HEADER);
438 assert_eq!(
439 headers.get(reqwest::header::CONTENT_TYPE).unwrap(),
440 "application/json"
441 );
442 }
443
444 #[test]
445 fn test_build_headers_with_custom() {
446 let config = StreamableHttpConfig {
447 url: "http://localhost:3000/mcp".to_string(),
448 headers: vec![HeaderConfig {
449 name: "Authorization".to_string(),
450 value: "Bearer token123".to_string(),
451 value_encrypted: None,
452 }],
453 connect_timeout_ms: 5000,
454 };
455 let transport = StreamableHttpTransport::new(config);
456 let headers = transport.build_headers(false).unwrap();
457
458 assert!(headers.contains_key("authorization"));
459 }
460
461 #[test]
462 fn test_build_headers_invalid_name() {
463 let config = StreamableHttpConfig {
464 url: "http://localhost:3000/mcp".to_string(),
465 headers: vec![HeaderConfig {
466 name: "Invalid\nName".to_string(),
467 value: "test".to_string(),
468 value_encrypted: None,
469 }],
470 connect_timeout_ms: 5000,
471 };
472 let transport = StreamableHttpTransport::new(config);
473 assert!(transport.build_headers(false).is_err());
474 }
475
476 #[test]
477 fn test_redact_url() {
478 assert_eq!(
479 StreamableHttpTransport::redact_url_for_log("http://example.com/mcp?token=secret"),
480 "http://example.com/mcp"
481 );
482 }
483
484 #[tokio::test]
485 async fn test_send_disconnected() {
486 let config = create_test_config();
487 let transport = StreamableHttpTransport::new(config);
488
489 let result = transport.send("{}".to_string()).await;
490 assert!(result.is_err());
491 match result.unwrap_err() {
492 McpError::Disconnected => {}
493 e => panic!("Expected Disconnected, got: {:?}", e),
494 }
495 }
496
497 #[tokio::test]
498 async fn test_receive_disconnected() {
499 let config = create_test_config();
500 let transport = StreamableHttpTransport::new(config);
501
502 let result = transport.receive().await;
503 assert!(result.is_err());
504 match result.unwrap_err() {
505 McpError::Disconnected => {}
506 e => panic!("Expected Disconnected, got: {:?}", e),
507 }
508 }
509
510 #[tokio::test]
511 async fn test_connect_disconnect() {
512 let config = create_test_config();
513 let mut transport = StreamableHttpTransport::new(config);
514
515 transport.connect().await.unwrap();
516 assert!(transport.is_connected());
517
518 transport.disconnect().await.unwrap();
519 assert!(!transport.is_connected());
520 }
521
522 #[tokio::test]
523 async fn test_receive_timeout() {
524 let config = create_test_config();
525 let transport = StreamableHttpTransport::new(config);
526 transport.connected.store(true, Ordering::SeqCst);
527
528 let result = transport.receive().await;
529 assert!(result.is_ok());
530 assert!(result.unwrap().is_none());
531 }
532
533 #[tokio::test]
534 async fn test_session_id_stored_on_response() {
535 let config = create_test_config();
536 let transport = StreamableHttpTransport::new(config);
537 transport.connected.store(true, Ordering::SeqCst);
538
539 {
541 let mut guard = transport.session_id.lock().await;
542 *guard = Some("test-session-123".to_string());
543 }
544
545 let sid = transport.session_id.lock().await;
546 assert_eq!(sid.as_deref(), Some("test-session-123"));
547 }
548
549 #[tokio::test]
550 async fn test_disconnect_clears_session() {
551 let config = create_test_config();
552 let mut transport = StreamableHttpTransport::new(config);
553 transport.connect().await.unwrap();
554
555 {
556 let mut guard = transport.session_id.lock().await;
557 *guard = Some("test-session".to_string());
558 }
559
560 transport.disconnect().await.unwrap();
561
562 let sid = transport.session_id.lock().await;
563 assert!(sid.is_none());
564 }
565}