Skip to main content

mcp_kit/transport/
streamable.rs

1//! Streamable HTTP transport for MCP (2025-03-26 spec).
2//!
3//! This transport uses a single HTTP endpoint that can return either:
4//! - JSON response for simple requests
5//! - SSE stream for long-running operations
6//!
7//! # Example
8//!
9//! ```rust,ignore
10//! use mcp_kit::prelude::*;
11//!
12//! let server = McpServer::builder()
13//!     .name("my-server")
14//!     .version("1.0.0")
15//!     .build();
16//!
17//! // Serve on a single endpoint
18//! server.serve_streamable(([127, 0, 0, 1], 3000)).await?;
19//! ```
20//!
21//! # Protocol
22//!
23//! ```text
24//! POST /mcp
25//! Content-Type: application/json
26//! Mcp-Session-Id: <optional session id>
27//!
28//! {"jsonrpc":"2.0","method":"...","id":1}
29//!
30//! Response (JSON for simple requests):
31//! 200 OK
32//! Content-Type: application/json
33//! Mcp-Session-Id: <session id>
34//!
35//! {"jsonrpc":"2.0","result":{...},"id":1}
36//!
37//! Response (SSE for streaming):
38//! 200 OK
39//! Content-Type: text/event-stream
40//! Mcp-Session-Id: <session id>
41//!
42//! data: {"jsonrpc":"2.0","method":"notifications/progress",...}
43//! data: {"jsonrpc":"2.0","result":{...},"id":1}
44//! ```
45
46use std::{convert::Infallible, sync::Arc, time::Duration};
47
48use crate::server::{core::McpServer, session::Session};
49use crate::{error::McpResult, protocol::JsonRpcMessage};
50use axum::{
51    extract::State as AxumState,
52    http::{HeaderMap, HeaderValue, StatusCode},
53    response::{sse::Event, IntoResponse, Response, Sse},
54    routing::{delete, get, post},
55    Json as AxumJson, Router as AxumRouter,
56};
57use dashmap::DashMap;
58use futures_util::stream;
59use std::future::Future;
60use tokio::sync::mpsc;
61use tracing::info;
62use uuid::Uuid;
63
64#[cfg(feature = "auth")]
65use crate::auth::{AuthenticatedIdentity, DynAuthProvider};
66#[cfg(feature = "auth")]
67use crate::transport::auth_layer::{auth_middleware, AuthMiddlewareState};
68
69// ─── Constants ────────────────────────────────────────────────────────────────
70
71const MCP_SESSION_ID_HEADER: &str = "Mcp-Session-Id";
72const SESSION_TIMEOUT_SECS: u64 = 3600; // 1 hour
73
74// ─── Shared State ─────────────────────────────────────────────────────────────
75
76type NotificationTx = mpsc::Sender<JsonRpcMessage>;
77
78pub(crate) struct SessionEntry {
79    session: Arc<tokio::sync::Mutex<Session>>,
80    notification_tx: Option<NotificationTx>,
81    last_active: std::time::Instant,
82}
83
84#[derive(Clone)]
85pub struct StreamableState {
86    pub(crate) server: Arc<McpServer>,
87    pub(crate) sessions: Arc<DashMap<String, SessionEntry>>,
88    #[cfg(feature = "auth")]
89    #[allow(dead_code)]
90    pub(crate) auth: Option<AuthMiddlewareState>,
91}
92
93impl StreamableState {
94    fn get_or_create_session(
95        &self,
96        session_id: Option<&str>,
97    ) -> (String, Arc<tokio::sync::Mutex<Session>>) {
98        if let Some(sid) = session_id {
99            if let Some(mut entry) = self.sessions.get_mut(sid) {
100                entry.last_active = std::time::Instant::now();
101                return (sid.to_string(), entry.session.clone());
102            }
103        }
104
105        // Create new session
106        let session_id = Uuid::new_v4().to_string();
107        let session = Arc::new(tokio::sync::Mutex::new(Session::new()));
108        self.sessions.insert(
109            session_id.clone(),
110            SessionEntry {
111                session: session.clone(),
112                notification_tx: None,
113                last_active: std::time::Instant::now(),
114            },
115        );
116        info!(session_id = %session_id, "Created new session");
117        (session_id, session)
118    }
119
120    fn set_notification_channel(&self, session_id: &str, tx: NotificationTx) {
121        if let Some(mut entry) = self.sessions.get_mut(session_id) {
122            entry.notification_tx = Some(tx);
123        }
124    }
125
126    fn remove_notification_channel(&self, session_id: &str) {
127        if let Some(mut entry) = self.sessions.get_mut(session_id) {
128            entry.notification_tx = None;
129        }
130    }
131
132    fn remove_session(&self, session_id: &str) {
133        self.sessions.remove(session_id);
134        info!(session_id = %session_id, "Session terminated");
135    }
136
137    fn cleanup_expired_sessions(&self) {
138        let timeout = Duration::from_secs(SESSION_TIMEOUT_SECS);
139        let now = std::time::Instant::now();
140
141        self.sessions.retain(|sid, entry| {
142            let expired = now.duration_since(entry.last_active) > timeout;
143            if expired {
144                info!(session_id = %sid, "Session expired");
145            }
146            !expired
147        });
148    }
149}
150
151// ─── StreamableTransport ──────────────────────────────────────────────────────
152
153/// Streamable HTTP transport (MCP 2025-03-26 spec).
154pub struct StreamableTransport {
155    server: McpServer,
156    addr: std::net::SocketAddr,
157    endpoint: String,
158    #[cfg(feature = "auth")]
159    auth: Option<AuthMiddlewareState>,
160}
161
162impl StreamableTransport {
163    /// Create a new Streamable HTTP transport.
164    pub fn new(server: McpServer, addr: impl Into<std::net::SocketAddr>) -> Self {
165        Self {
166            server,
167            addr: addr.into(),
168            endpoint: "/mcp".to_string(),
169            #[cfg(feature = "auth")]
170            auth: None,
171        }
172    }
173
174    /// Set custom endpoint path (default: "/mcp").
175    pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
176        self.endpoint = endpoint.into();
177        self
178    }
179
180    /// Require authentication on all requests.
181    #[cfg(feature = "auth")]
182    pub fn with_auth(mut self, provider: DynAuthProvider) -> Self {
183        self.auth = Some(AuthMiddlewareState {
184            provider,
185            require_auth: true,
186        });
187        self
188    }
189
190    /// Accept optional authentication.
191    #[cfg(feature = "auth")]
192    pub fn with_optional_auth(mut self, provider: DynAuthProvider) -> Self {
193        self.auth = Some(AuthMiddlewareState {
194            provider,
195            require_auth: false,
196        });
197        self
198    }
199
200    /// Build the Axum router for this transport.
201    pub fn build_router(self) -> (AxumRouter, StreamableState) {
202        let state = StreamableState {
203            server: Arc::new(self.server),
204            sessions: Arc::new(DashMap::new()),
205            #[cfg(feature = "auth")]
206            auth: self.auth.clone(),
207        };
208
209        let routes = AxumRouter::new()
210            .route(&self.endpoint, post(handle_post))
211            .route(&self.endpoint, get(handle_get_sse))
212            .route(&self.endpoint, delete(handle_delete));
213
214        #[cfg(feature = "auth")]
215        let routes = if let Some(auth_state) = self.auth {
216            routes.route_layer(axum::middleware::from_fn_with_state(
217                auth_state,
218                auth_middleware,
219            ))
220        } else {
221            routes
222        };
223
224        (routes.with_state(state.clone()), state)
225    }
226
227    /// Start serving on the configured address.
228    pub async fn serve(self) -> McpResult<()> {
229        let addr = self.addr;
230        let (router, state) = self.build_router();
231
232        // Spawn session cleanup task
233        let cleanup_state = state.clone();
234        tokio::spawn(async move {
235            let mut interval = tokio::time::interval(Duration::from_secs(300)); // Every 5 minutes
236            loop {
237                interval.tick().await;
238                cleanup_state.cleanup_expired_sessions();
239            }
240        });
241
242        info!(addr = %addr, "Streamable HTTP transport listening");
243
244        let listener = tokio::net::TcpListener::bind(addr)
245            .await
246            .map_err(crate::error::McpError::Io)?;
247
248        axum::serve(listener, router)
249            .await
250            .map_err(|e| crate::error::McpError::Transport(e.to_string()))?;
251
252        Ok(())
253    }
254}
255
256// ─── POST Handler ─────────────────────────────────────────────────────────────
257
258async fn handle_post(
259    AxumState(state): AxumState<StreamableState>,
260    headers: HeaderMap,
261    #[cfg(feature = "auth")] identity: Option<axum::Extension<Arc<AuthenticatedIdentity>>>,
262    AxumJson(msg): AxumJson<JsonRpcMessage>,
263) -> Response {
264    // Get or create session
265    let session_id_header = headers
266        .get(MCP_SESSION_ID_HEADER)
267        .and_then(|v| v.to_str().ok());
268
269    let (session_id, session_arc) = state.get_or_create_session(session_id_header);
270
271    let mut session = session_arc.lock().await;
272
273    // Set identity if authenticated
274    #[cfg(feature = "auth")]
275    if let Some(axum::Extension(id)) = identity {
276        session.identity = Some((*id).clone());
277    }
278
279    // Check if this is a method that might need streaming
280    let needs_streaming = matches!(&msg, JsonRpcMessage::Request(req) if
281        req.method == "tools/call" ||
282        req.method == "sampling/createMessage"
283    );
284
285    let server = state.server.clone();
286
287    // Handle the message
288    match server.handle_message(msg, &mut session).await {
289        Some(response) => {
290            if needs_streaming {
291                // For potentially long-running operations, use SSE
292                stream_response(session_id, response)
293            } else {
294                // Simple JSON response
295                json_response(session_id, response)
296            }
297        }
298        None => {
299            // Notification - no response needed
300            (StatusCode::ACCEPTED, [(MCP_SESSION_ID_HEADER, session_id)]).into_response()
301        }
302    }
303}
304
305fn json_response(session_id: String, response: JsonRpcMessage) -> Response {
306    let mut resp = AxumJson(response).into_response();
307    resp.headers_mut().insert(
308        MCP_SESSION_ID_HEADER,
309        HeaderValue::from_str(&session_id).unwrap_or_else(|_| HeaderValue::from_static("")),
310    );
311    resp
312}
313
314fn stream_response(session_id: String, response: JsonRpcMessage) -> Response {
315    let event = Event::default()
316        .event("message")
317        .data(serde_json::to_string(&response).unwrap_or_default());
318
319    let stream = stream::once(async move { Ok::<_, Infallible>(event) });
320
321    let mut resp = Sse::new(stream)
322        .keep_alive(axum::response::sse::KeepAlive::default())
323        .into_response();
324
325    resp.headers_mut().insert(
326        MCP_SESSION_ID_HEADER,
327        HeaderValue::from_str(&session_id).unwrap_or_else(|_| HeaderValue::from_static("")),
328    );
329
330    resp
331}
332
333// ─── GET Handler (SSE stream for server-initiated messages) ───────────────────
334
335async fn handle_get_sse(
336    AxumState(state): AxumState<StreamableState>,
337    headers: HeaderMap,
338    #[cfg(feature = "auth")] _identity: Option<axum::Extension<Arc<AuthenticatedIdentity>>>,
339) -> Response {
340    // Require existing session for SSE
341    let Some(session_id) = headers
342        .get(MCP_SESSION_ID_HEADER)
343        .and_then(|v| v.to_str().ok())
344    else {
345        return (
346            StatusCode::BAD_REQUEST,
347            "Mcp-Session-Id header required for SSE",
348        )
349            .into_response();
350    };
351
352    // Verify session exists
353    if !state.sessions.contains_key(session_id) {
354        return (StatusCode::NOT_FOUND, "Session not found").into_response();
355    }
356
357    let session_id = session_id.to_string();
358
359    // Create notification channel
360    let (tx, rx) = mpsc::channel::<JsonRpcMessage>(64);
361    state.set_notification_channel(&session_id, tx);
362
363    info!(session_id = %session_id, "SSE stream opened");
364
365    let cleanup_state = state.clone();
366    let session_id_for_cleanup = session_id.clone();
367
368    let stream = stream::unfold(rx, move |mut rx| async move {
369        match rx.recv().await {
370            Some(msg) => {
371                let data = serde_json::to_string(&msg).unwrap_or_default();
372                let event = Event::default().event("message").data(data);
373                Some((Ok::<_, Infallible>(event), rx))
374            }
375            None => None,
376        }
377    });
378
379    // Spawn cleanup task for when client disconnects
380    tokio::spawn(async move {
381        // This task monitors connection - cleanup happens when stream ends
382        tokio::time::sleep(Duration::from_secs(SESSION_TIMEOUT_SECS)).await;
383        cleanup_state.remove_notification_channel(&session_id_for_cleanup);
384        info!(session_id = %session_id_for_cleanup, "SSE stream timeout cleanup");
385    });
386
387    let mut resp = Sse::new(stream)
388        .keep_alive(axum::response::sse::KeepAlive::default())
389        .into_response();
390
391    resp.headers_mut().insert(
392        MCP_SESSION_ID_HEADER,
393        HeaderValue::from_str(&session_id).unwrap_or_else(|_| HeaderValue::from_static("")),
394    );
395
396    resp
397}
398
399// ─── DELETE Handler (terminate session) ───────────────────────────────────────
400
401async fn handle_delete(
402    AxumState(state): AxumState<StreamableState>,
403    headers: HeaderMap,
404) -> Response {
405    let Some(session_id) = headers
406        .get(MCP_SESSION_ID_HEADER)
407        .and_then(|v| v.to_str().ok())
408    else {
409        return (StatusCode::BAD_REQUEST, "Mcp-Session-Id header required").into_response();
410    };
411
412    if state.sessions.contains_key(session_id) {
413        state.remove_session(session_id);
414        StatusCode::NO_CONTENT.into_response()
415    } else {
416        (StatusCode::NOT_FOUND, "Session not found").into_response()
417    }
418}
419
420// ─── Extension Trait ──────────────────────────────────────────────────────────
421
422/// Extension trait that adds `.serve_streamable()` to `McpServer`.
423pub trait ServeStreamableExt {
424    /// Serve using the Streamable HTTP transport.
425    fn serve_streamable(
426        self,
427        addr: impl Into<std::net::SocketAddr>,
428    ) -> impl Future<Output = McpResult<()>> + Send;
429}
430
431impl ServeStreamableExt for McpServer {
432    fn serve_streamable(
433        self,
434        addr: impl Into<std::net::SocketAddr>,
435    ) -> impl Future<Output = McpResult<()>> + Send {
436        #[cfg(feature = "auth")]
437        {
438            let transport = StreamableTransport::new(self.clone(), addr);
439            let transport = match (self.auth_provider, self.require_auth) {
440                (Some(provider), true) => transport.with_auth(provider),
441                (Some(provider), false) => transport.with_optional_auth(provider),
442                (None, _) => transport,
443            };
444            transport.serve()
445        }
446        #[cfg(not(feature = "auth"))]
447        StreamableTransport::new(self, addr).serve()
448    }
449}
450
451#[cfg(test)]
452mod tests {
453    use super::*;
454
455    #[test]
456    fn test_session_creation() {
457        let state = StreamableState {
458            server: Arc::new(McpServer::builder().name("test").version("1.0").build()),
459            sessions: Arc::new(DashMap::new()),
460            #[cfg(feature = "auth")]
461            auth: None,
462        };
463
464        // Create new session
465        let (sid1, _) = state.get_or_create_session(None);
466        assert!(!sid1.is_empty());
467        assert!(state.sessions.contains_key(&sid1));
468
469        // Get existing session
470        let (sid2, _) = state.get_or_create_session(Some(&sid1));
471        assert_eq!(sid1, sid2);
472    }
473}