Skip to main content

rmcp_soddygo/transport/streamable_http_server/
tower.rs

1use std::{convert::Infallible, fmt::Display, sync::Arc, time::Duration};
2
3use bytes::Bytes;
4use futures::{StreamExt, future::BoxFuture};
5use http::{Method, Request, Response, header::ALLOW};
6use http_body::Body;
7use http_body_util::{BodyExt, Full, combinators::BoxBody};
8use tokio_stream::wrappers::ReceiverStream;
9use tokio_util::sync::CancellationToken;
10
11use super::session::SessionManager;
12use crate::{
13    RoleServer,
14    model::{ClientJsonRpcMessage, ClientRequest, GetExtensions, ProtocolVersion},
15    serve_server,
16    service::serve_directly,
17    transport::{
18        OneshotTransport, TransportAdapterIdentity,
19        common::{
20            http_header::{
21                EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_MCP_PROTOCOL_VERSION,
22                HEADER_SESSION_ID, JSON_MIME_TYPE,
23            },
24            server_side_http::{
25                BoxResponse, ServerSseMessage, accepted_response, expect_json,
26                internal_error_response, sse_stream_response, unexpected_message_response,
27            },
28        },
29    },
30};
31
32#[derive(Debug, Clone)]
33pub struct StreamableHttpServerConfig {
34    /// The ping message duration for SSE connections.
35    pub sse_keep_alive: Option<Duration>,
36    /// The retry interval for SSE priming events.
37    pub sse_retry: Option<Duration>,
38    /// If true, the server will create a session for each request and keep it alive.
39    /// When enabled, SSE priming events are sent to enable client reconnection.
40    pub stateful_mode: bool,
41    /// When true and `stateful_mode` is false, the server returns
42    /// `Content-Type: application/json` directly instead of `text/event-stream`.
43    /// This eliminates SSE framing overhead for simple request-response tools,
44    /// allowed by the MCP Streamable HTTP spec (2025-06-18).
45    pub json_response: bool,
46    /// Cancellation token for the Streamable HTTP server.
47    ///
48    /// When this token is cancelled, all active sessions are terminated and
49    /// the server stops accepting new requests.
50    pub cancellation_token: CancellationToken,
51}
52
53impl Default for StreamableHttpServerConfig {
54    fn default() -> Self {
55        Self {
56            sse_keep_alive: Some(Duration::from_secs(15)),
57            sse_retry: Some(Duration::from_secs(3)),
58            stateful_mode: true,
59            json_response: false,
60            cancellation_token: CancellationToken::new(),
61        }
62    }
63}
64
65#[expect(
66    clippy::result_large_err,
67    reason = "BoxResponse is intentionally large; matches other handlers in this file"
68)]
69/// Validates the `MCP-Protocol-Version` header on incoming HTTP requests.
70///
71/// Per the MCP 2025-06-18 spec:
72/// - If the header is present but contains an unsupported version, return 400 Bad Request.
73/// - If the header is absent, assume `2025-03-26` for backwards compatibility (no error).
74fn validate_protocol_version_header(headers: &http::HeaderMap) -> Result<(), BoxResponse> {
75    if let Some(value) = headers.get(HEADER_MCP_PROTOCOL_VERSION) {
76        let version_str = value.to_str().map_err(|_| {
77            Response::builder()
78                .status(http::StatusCode::BAD_REQUEST)
79                .body(
80                    Full::new(Bytes::from(
81                        "Bad Request: Invalid MCP-Protocol-Version header encoding",
82                    ))
83                    .boxed(),
84                )
85                .expect("valid response")
86        })?;
87        let is_known = ProtocolVersion::KNOWN_VERSIONS
88            .iter()
89            .any(|v| v.as_str() == version_str);
90        if !is_known {
91            return Err(Response::builder()
92                .status(http::StatusCode::BAD_REQUEST)
93                .body(
94                    Full::new(Bytes::from(format!(
95                        "Bad Request: Unsupported MCP-Protocol-Version: {version_str}"
96                    )))
97                    .boxed(),
98                )
99                .expect("valid response"));
100        }
101    }
102    Ok(())
103}
104
105/// # Streamable HTTP server
106///
107/// An HTTP service that implements the
108/// [Streamable HTTP transport](https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#streamable-http)
109/// for MCP servers.
110///
111/// ## Session management
112///
113/// When [`StreamableHttpServerConfig::stateful_mode`] is `true` (the default),
114/// the server creates a session for each client that sends an `initialize`
115/// request. The session ID is returned in the `Mcp-Session-Id` response header
116/// and the client must include it on all subsequent requests.
117///
118/// Two tool calls carrying the same `Mcp-Session-Id` come from the same logical
119/// session (typically one conversation in an LLM client). Different session IDs
120/// mean different sessions.
121///
122/// The [`SessionManager`] trait controls how sessions are stored and routed:
123///
124/// * [`LocalSessionManager`](super::session::local::LocalSessionManager) —
125///   in-memory session store (default).
126/// * [`NeverSessionManager`](super::session::never::NeverSessionManager) —
127///   disables sessions entirely (stateless mode).
128///
129/// ## Accessing HTTP request data from tool handlers
130///
131/// The service consumes the request body but injects the remaining
132/// [`http::request::Parts`] into [`crate::model::Extensions`], which is
133/// accessible through [`crate::service::RequestContext`].
134///
135/// ### Reading the raw HTTP parts
136///
137/// ```rust
138/// use rmcp::handler::server::tool::Extension;
139/// use http::request::Parts;
140/// async fn my_tool(Extension(parts): Extension<Parts>) {
141///     tracing::info!("http parts:{parts:?}")
142/// }
143/// ```
144///
145/// ### Reading the session ID inside a tool handler
146///
147/// ```rust,ignore
148/// use rmcp::handler::server::tool::Extension;
149/// use rmcp::service::RequestContext;
150/// use rmcp::model::RoleServer;
151///
152/// #[tool(description = "session-aware tool")]
153/// async fn my_tool(
154///     &self,
155///     Extension(parts): Extension<http::request::Parts>,
156/// ) -> Result<CallToolResult, rmcp::ErrorData> {
157///     if let Some(session_id) = parts.headers.get("mcp-session-id") {
158///         tracing::info!(?session_id, "called from session");
159///     }
160///     // ...
161///     # todo!()
162/// }
163/// ```
164///
165/// ### Accessing custom axum/tower extension state
166///
167/// State added via axum's `Extension` layer is available inside
168/// `Parts.extensions`:
169///
170/// ```rust,ignore
171/// use rmcp::service::RequestContext;
172/// use rmcp::model::RoleServer;
173///
174/// #[derive(Clone)]
175/// struct AppState { /* ... */ }
176///
177/// #[tool(description = "example")]
178/// async fn my_tool(
179///     &self,
180///     ctx: RequestContext<RoleServer>,
181/// ) -> Result<CallToolResult, rmcp::ErrorData> {
182///     let parts = ctx.extensions.get::<http::request::Parts>().unwrap();
183///     let state = parts.extensions.get::<AppState>().unwrap();
184///     // use state...
185///     # todo!()
186/// }
187/// ```
188pub struct StreamableHttpService<S, M = super::session::local::LocalSessionManager> {
189    pub config: StreamableHttpServerConfig,
190    session_manager: Arc<M>,
191    service_factory: Arc<dyn Fn() -> Result<S, std::io::Error> + Send + Sync>,
192}
193
194impl<S, M> Clone for StreamableHttpService<S, M> {
195    fn clone(&self) -> Self {
196        Self {
197            config: self.config.clone(),
198            session_manager: self.session_manager.clone(),
199            service_factory: self.service_factory.clone(),
200        }
201    }
202}
203
204impl<RequestBody, S, M> tower_service::Service<Request<RequestBody>> for StreamableHttpService<S, M>
205where
206    RequestBody: Body + Send + 'static,
207    S: crate::Service<RoleServer>,
208    M: SessionManager,
209    RequestBody::Error: Display,
210    RequestBody::Data: Send + 'static,
211{
212    type Response = BoxResponse;
213    type Error = Infallible;
214    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
215    fn call(&mut self, req: http::Request<RequestBody>) -> Self::Future {
216        let service = self.clone();
217        Box::pin(async move {
218            let response = service.handle(req).await;
219            Ok(response)
220        })
221    }
222    fn poll_ready(
223        &mut self,
224        _cx: &mut std::task::Context<'_>,
225    ) -> std::task::Poll<Result<(), Self::Error>> {
226        std::task::Poll::Ready(Ok(()))
227    }
228}
229
230impl<S, M> StreamableHttpService<S, M>
231where
232    S: crate::Service<RoleServer> + Send + 'static,
233    M: SessionManager,
234{
235    pub fn new(
236        service_factory: impl Fn() -> Result<S, std::io::Error> + Send + Sync + 'static,
237        session_manager: Arc<M>,
238        config: StreamableHttpServerConfig,
239    ) -> Self {
240        Self {
241            config,
242            session_manager,
243            service_factory: Arc::new(service_factory),
244        }
245    }
246    fn get_service(&self) -> Result<S, std::io::Error> {
247        (self.service_factory)()
248    }
249    pub async fn handle<B>(&self, request: Request<B>) -> Response<BoxBody<Bytes, Infallible>>
250    where
251        B: Body + Send + 'static,
252        B::Error: Display,
253    {
254        let method = request.method().clone();
255        let allowed_methods = match self.config.stateful_mode {
256            true => "GET, POST, DELETE",
257            false => "POST",
258        };
259        let result = match (method, self.config.stateful_mode) {
260            (Method::POST, _) => self.handle_post(request).await,
261            // if we're not in stateful mode, we don't support GET or DELETE because there is no session
262            (Method::GET, true) => self.handle_get(request).await,
263            (Method::DELETE, true) => self.handle_delete(request).await,
264            _ => {
265                // Handle other methods or return an error
266                let response = Response::builder()
267                    .status(http::StatusCode::METHOD_NOT_ALLOWED)
268                    .header(ALLOW, allowed_methods)
269                    .body(Full::new(Bytes::from("Method Not Allowed")).boxed())
270                    .expect("valid response");
271                return response;
272            }
273        };
274        match result {
275            Ok(response) => response,
276            Err(response) => response,
277        }
278    }
279    async fn handle_get<B>(&self, request: Request<B>) -> Result<BoxResponse, BoxResponse>
280    where
281        B: Body + Send + 'static,
282        B::Error: Display,
283    {
284        // check accept header
285        if !request
286            .headers()
287            .get(http::header::ACCEPT)
288            .and_then(|header| header.to_str().ok())
289            .is_some_and(|header| header.contains(EVENT_STREAM_MIME_TYPE))
290        {
291            return Ok(Response::builder()
292                .status(http::StatusCode::NOT_ACCEPTABLE)
293                .body(
294                    Full::new(Bytes::from(
295                        "Not Acceptable: Client must accept text/event-stream",
296                    ))
297                    .boxed(),
298                )
299                .expect("valid response"));
300        }
301        // check session id
302        let session_id = request
303            .headers()
304            .get(HEADER_SESSION_ID)
305            .and_then(|v| v.to_str().ok())
306            .map(|s| s.to_owned().into());
307        let Some(session_id) = session_id else {
308            // MCP spec: servers that require a session ID SHOULD respond with 400 Bad Request
309            return Ok(Response::builder()
310                .status(http::StatusCode::BAD_REQUEST)
311                .body(Full::new(Bytes::from("Bad Request: Session ID is required")).boxed())
312                .expect("valid response"));
313        };
314        // check if session exists
315        let has_session = self
316            .session_manager
317            .has_session(&session_id)
318            .await
319            .map_err(internal_error_response("check session"))?;
320        if !has_session {
321            // MCP spec: server MUST respond with 404 Not Found for terminated/unknown sessions
322            return Ok(Response::builder()
323                .status(http::StatusCode::NOT_FOUND)
324                .body(Full::new(Bytes::from("Not Found: Session not found")).boxed())
325                .expect("valid response"));
326        }
327        // Validate MCP-Protocol-Version header (per 2025-06-18 spec)
328        validate_protocol_version_header(request.headers())?;
329        // check if last event id is provided
330        let last_event_id = request
331            .headers()
332            .get(HEADER_LAST_EVENT_ID)
333            .and_then(|v| v.to_str().ok())
334            .map(|s| s.to_owned());
335        if let Some(last_event_id) = last_event_id {
336            // check if session has this event id
337            let stream = self
338                .session_manager
339                .resume(&session_id, last_event_id)
340                .await
341                .map_err(internal_error_response("resume session"))?;
342            // Resume doesn't need priming - client already has the event ID
343            Ok(sse_stream_response(
344                stream,
345                self.config.sse_keep_alive,
346                self.config.cancellation_token.child_token(),
347            ))
348        } else {
349            // create standalone stream
350            let stream = self
351                .session_manager
352                .create_standalone_stream(&session_id)
353                .await
354                .map_err(internal_error_response("create standalone stream"))?;
355            // Prepend priming event if sse_retry configured
356            let stream = if let Some(retry) = self.config.sse_retry {
357                let priming = ServerSseMessage {
358                    event_id: Some("0".into()),
359                    message: None,
360                    retry: Some(retry),
361                };
362                futures::stream::once(async move { priming })
363                    .chain(stream)
364                    .left_stream()
365            } else {
366                stream.right_stream()
367            };
368            Ok(sse_stream_response(
369                stream,
370                self.config.sse_keep_alive,
371                self.config.cancellation_token.child_token(),
372            ))
373        }
374    }
375
376    async fn handle_post<B>(&self, request: Request<B>) -> Result<BoxResponse, BoxResponse>
377    where
378        B: Body + Send + 'static,
379        B::Error: Display,
380    {
381        // check accept header
382        if !request
383            .headers()
384            .get(http::header::ACCEPT)
385            .and_then(|header| header.to_str().ok())
386            .is_some_and(|header| {
387                header.contains(JSON_MIME_TYPE) && header.contains(EVENT_STREAM_MIME_TYPE)
388            })
389        {
390            return Ok(Response::builder()
391                .status(http::StatusCode::NOT_ACCEPTABLE)
392                .body(Full::new(Bytes::from("Not Acceptable: Client must accept both application/json and text/event-stream")).boxed())
393                .expect("valid response"));
394        }
395
396        // check content type
397        if !request
398            .headers()
399            .get(http::header::CONTENT_TYPE)
400            .and_then(|header| header.to_str().ok())
401            .is_some_and(|header| header.starts_with(JSON_MIME_TYPE))
402        {
403            return Ok(Response::builder()
404                .status(http::StatusCode::UNSUPPORTED_MEDIA_TYPE)
405                .body(
406                    Full::new(Bytes::from(
407                        "Unsupported Media Type: Content-Type must be application/json",
408                    ))
409                    .boxed(),
410                )
411                .expect("valid response"));
412        }
413
414        // json deserialize request body
415        let (part, body) = request.into_parts();
416        let mut message = match expect_json(body).await {
417            Ok(message) => message,
418            Err(response) => return Ok(response),
419        };
420
421        if self.config.stateful_mode {
422            // do we have a session id?
423            let session_id = part
424                .headers
425                .get(HEADER_SESSION_ID)
426                .and_then(|v| v.to_str().ok());
427            if let Some(session_id) = session_id {
428                let session_id = session_id.to_owned().into();
429                let has_session = self
430                    .session_manager
431                    .has_session(&session_id)
432                    .await
433                    .map_err(internal_error_response("check session"))?;
434                if !has_session {
435                    // MCP spec: server MUST respond with 404 Not Found for terminated/unknown sessions
436                    return Ok(Response::builder()
437                        .status(http::StatusCode::NOT_FOUND)
438                        .body(Full::new(Bytes::from("Not Found: Session not found")).boxed())
439                        .expect("valid response"));
440                }
441
442                // Validate MCP-Protocol-Version header (per 2025-06-18 spec)
443                validate_protocol_version_header(&part.headers)?;
444
445                // inject request part to extensions
446                match &mut message {
447                    ClientJsonRpcMessage::Request(req) => {
448                        req.request.extensions_mut().insert(part);
449                    }
450                    ClientJsonRpcMessage::Notification(not) => {
451                        not.notification.extensions_mut().insert(part);
452                    }
453                    _ => {
454                        // skip
455                    }
456                }
457
458                match message {
459                    ClientJsonRpcMessage::Request(_) => {
460                        let stream = self
461                            .session_manager
462                            .create_stream(&session_id, message)
463                            .await
464                            .map_err(internal_error_response("get session"))?;
465                        // Prepend priming event if sse_retry configured
466                        let stream = if let Some(retry) = self.config.sse_retry {
467                            let priming = ServerSseMessage {
468                                event_id: Some("0".into()),
469                                message: None,
470                                retry: Some(retry),
471                            };
472                            futures::stream::once(async move { priming })
473                                .chain(stream)
474                                .left_stream()
475                        } else {
476                            stream.right_stream()
477                        };
478                        Ok(sse_stream_response(
479                            stream,
480                            self.config.sse_keep_alive,
481                            self.config.cancellation_token.child_token(),
482                        ))
483                    }
484                    ClientJsonRpcMessage::Notification(_)
485                    | ClientJsonRpcMessage::Response(_)
486                    | ClientJsonRpcMessage::Error(_) => {
487                        // handle notification
488                        self.session_manager
489                            .accept_message(&session_id, message)
490                            .await
491                            .map_err(internal_error_response("accept message"))?;
492                        Ok(accepted_response())
493                    }
494                }
495            } else {
496                let (session_id, transport) = self
497                    .session_manager
498                    .create_session()
499                    .await
500                    .map_err(internal_error_response("create session"))?;
501                if let ClientJsonRpcMessage::Request(req) = &mut message {
502                    if !matches!(req.request, ClientRequest::InitializeRequest(_)) {
503                        return Err(unexpected_message_response("initialize request"));
504                    }
505                    // inject request part to extensions
506                    req.request.extensions_mut().insert(part);
507                } else {
508                    return Err(unexpected_message_response("initialize request"));
509                }
510                let service = self
511                    .get_service()
512                    .map_err(internal_error_response("get service"))?;
513                // spawn a task to serve the session
514                tokio::spawn({
515                    let session_manager = self.session_manager.clone();
516                    let session_id = session_id.clone();
517                    async move {
518                        let service = serve_server::<S, M::Transport, _, TransportAdapterIdentity>(
519                            service, transport,
520                        )
521                        .await;
522                        match service {
523                            Ok(service) => {
524                                // on service created
525                                let _ = service.waiting().await;
526                            }
527                            Err(e) => {
528                                tracing::error!("Failed to create service: {e}");
529                            }
530                        }
531                        let _ = session_manager
532                            .close_session(&session_id)
533                            .await
534                            .inspect_err(|e| {
535                                tracing::error!("Failed to close session {session_id}: {e}");
536                            });
537                    }
538                });
539                // get initialize response
540                let response = self
541                    .session_manager
542                    .initialize_session(&session_id, message)
543                    .await
544                    .map_err(internal_error_response("create stream"))?;
545                let stream = futures::stream::once(async move {
546                    ServerSseMessage {
547                        event_id: None,
548                        message: Some(Arc::new(response)),
549                        retry: None,
550                    }
551                });
552                // Prepend priming event if sse_retry configured
553                let stream = if let Some(retry) = self.config.sse_retry {
554                    let priming = ServerSseMessage {
555                        event_id: Some("0".into()),
556                        message: None,
557                        retry: Some(retry),
558                    };
559                    futures::stream::once(async move { priming })
560                        .chain(stream)
561                        .left_stream()
562                } else {
563                    stream.right_stream()
564                };
565                let mut response = sse_stream_response(
566                    stream,
567                    self.config.sse_keep_alive,
568                    self.config.cancellation_token.child_token(),
569                );
570
571                response.headers_mut().insert(
572                    HEADER_SESSION_ID,
573                    session_id
574                        .parse()
575                        .map_err(internal_error_response("create session id header"))?,
576                );
577                Ok(response)
578            }
579        } else {
580            // Stateless mode: validate MCP-Protocol-Version on non-init requests
581            let is_init = matches!(
582                &message,
583                ClientJsonRpcMessage::Request(req) if matches!(req.request, ClientRequest::InitializeRequest(_))
584            );
585            if !is_init {
586                validate_protocol_version_header(&part.headers)?;
587            }
588            let service = self
589                .get_service()
590                .map_err(internal_error_response("get service"))?;
591            match message {
592                ClientJsonRpcMessage::Request(mut request) => {
593                    request.request.extensions_mut().insert(part);
594                    let (transport, mut receiver) =
595                        OneshotTransport::<RoleServer>::new(ClientJsonRpcMessage::Request(request));
596                    let service = serve_directly(service, transport, None);
597                    tokio::spawn(async move {
598                        // on service created
599                        let _ = service.waiting().await;
600                    });
601                    if self.config.json_response {
602                        // JSON-direct mode: await the single response and return as
603                        // application/json, eliminating SSE framing overhead.
604                        // Allowed by MCP Streamable HTTP spec (2025-06-18).
605                        let cancel = self.config.cancellation_token.child_token();
606                        match tokio::select! {
607                            res = receiver.recv() => res,
608                            _ = cancel.cancelled() => None,
609                        } {
610                            Some(message) => {
611                                tracing::trace!(?message);
612                                let body = serde_json::to_vec(&message).map_err(|e| {
613                                    internal_error_response("serialize json response")(e)
614                                })?;
615                                Ok(Response::builder()
616                                    .status(http::StatusCode::OK)
617                                    .header(http::header::CONTENT_TYPE, JSON_MIME_TYPE)
618                                    .body(Full::new(Bytes::from(body)).boxed())
619                                    .expect("valid response"))
620                            }
621                            None => Err(internal_error_response("empty response")(
622                                std::io::Error::new(
623                                    std::io::ErrorKind::UnexpectedEof,
624                                    "no response message received from handler",
625                                ),
626                            )),
627                        }
628                    } else {
629                        // SSE mode (default): original behaviour preserved unchanged
630                        let stream = ReceiverStream::new(receiver).map(|message| {
631                            tracing::trace!(?message);
632                            ServerSseMessage {
633                                event_id: None,
634                                message: Some(Arc::new(message)),
635                                retry: None,
636                            }
637                        });
638                        Ok(sse_stream_response(
639                            stream,
640                            self.config.sse_keep_alive,
641                            self.config.cancellation_token.child_token(),
642                        ))
643                    }
644                }
645                ClientJsonRpcMessage::Notification(_notification) => {
646                    // ignore
647                    Ok(accepted_response())
648                }
649                ClientJsonRpcMessage::Response(_json_rpc_response) => Ok(accepted_response()),
650                ClientJsonRpcMessage::Error(_json_rpc_error) => Ok(accepted_response()),
651            }
652        }
653    }
654
655    async fn handle_delete<B>(&self, request: Request<B>) -> Result<BoxResponse, BoxResponse>
656    where
657        B: Body + Send + 'static,
658        B::Error: Display,
659    {
660        // check session id
661        let session_id = request
662            .headers()
663            .get(HEADER_SESSION_ID)
664            .and_then(|v| v.to_str().ok())
665            .map(|s| s.to_owned().into());
666        let Some(session_id) = session_id else {
667            // MCP spec: servers that require a session ID SHOULD respond with 400 Bad Request
668            return Ok(Response::builder()
669                .status(http::StatusCode::BAD_REQUEST)
670                .body(Full::new(Bytes::from("Bad Request: Session ID is required")).boxed())
671                .expect("valid response"));
672        };
673        // Validate MCP-Protocol-Version header (per 2025-06-18 spec)
674        validate_protocol_version_header(request.headers())?;
675        // close session
676        self.session_manager
677            .close_session(&session_id)
678            .await
679            .map_err(internal_error_response("close session"))?;
680        Ok(accepted_response())
681    }
682}