Skip to main content

rmcp_soddygo/transport/streamable_http_server/
tower.rs

1use std::{collections::HashMap, convert::Infallible, fmt::Display, sync::Arc, time::Duration};
2
3use bytes::Bytes;
4use futures::{StreamExt, future::BoxFuture};
5use http::{HeaderMap, Method, Request, Response, header::ALLOW};
6use http_body::Body;
7use http_body_util::{BodyExt, Full, combinators::BoxBody};
8use tokio_stream::wrappers::ReceiverStream;
9use tokio_util::sync::CancellationToken;
10
11use super::session::{
12    RestoreOutcome, SessionId, SessionManager, SessionRestoreMarker, SessionState, SessionStore,
13};
14use crate::{
15    RoleServer,
16    model::{
17        ClientJsonRpcMessage, ClientNotification, ClientRequest, GetExtensions, InitializeRequest,
18        InitializedNotification, ProtocolVersion,
19    },
20    serve_server,
21    service::serve_directly,
22    transport::{
23        OneshotTransport, TransportAdapterIdentity,
24        common::{
25            http_header::{
26                EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_MCP_PROTOCOL_VERSION,
27                HEADER_SESSION_ID, JSON_MIME_TYPE,
28            },
29            server_side_http::{
30                BoxResponse, ServerSseMessage, accepted_response, expect_json,
31                internal_error_response, sse_stream_response, unexpected_message_response,
32            },
33        },
34    },
35};
36
37#[non_exhaustive]
38#[derive(Debug, Clone)]
39pub struct StreamableHttpServerConfig {
40    /// The ping message duration for SSE connections.
41    pub sse_keep_alive: Option<Duration>,
42    /// The retry interval for SSE priming events.
43    pub sse_retry: Option<Duration>,
44    /// If true, the server will create a session for each request and keep it alive.
45    /// When enabled, SSE priming events are sent to enable client reconnection.
46    pub stateful_mode: bool,
47    /// When true and `stateful_mode` is false, the server returns
48    /// `Content-Type: application/json` directly instead of `text/event-stream`.
49    /// This eliminates SSE framing overhead for simple request-response tools,
50    /// allowed by the MCP Streamable HTTP spec (2025-06-18).
51    pub json_response: bool,
52    /// Cancellation token for the Streamable HTTP server.
53    ///
54    /// When this token is cancelled, all active sessions are terminated and
55    /// the server stops accepting new requests.
56    pub cancellation_token: CancellationToken,
57    /// Allowed hostnames or `host:port` authorities for inbound `Host` validation.
58    ///
59    /// By default, Streamable HTTP servers only accept loopback hosts to
60    /// prevent DNS rebinding attacks against locally running servers. Public
61    /// deployments should override this list with their own hostnames.
62    /// examples:
63    ///     allowed_hosts = ["localhost", "127.0.0.1", "0.0.0.0"]
64    /// or with ports:
65    ///     allowed_hosts = ["example.com", "example.com:8080"]
66    pub allowed_hosts: Vec<String>,
67    /// Optional external session store for cross-instance recovery.
68    ///
69    /// When set, [`SessionState`] (the client's `initialize` parameters) is
70    /// persisted after a successful handshake and deleted when the session
71    /// closes. On any subsequent request that arrives at an instance with no
72    /// in-memory session, the store is consulted: if an entry is found the
73    /// session is transparently restored so the client does not need to
74    /// re-initialize.
75    ///
76    /// # Example
77    /// ```rust,ignore
78    /// use std::sync::Arc;
79    /// use rmcp::transport::streamable_http_server::{
80    ///     StreamableHttpServerConfig, session::SessionStore,
81    /// };
82    ///
83    /// let config = StreamableHttpServerConfig {
84    ///     session_store: Some(Arc::new(MyRedisStore::new())),
85    ///     ..Default::default()
86    /// };
87    /// ```
88    pub session_store: Option<Arc<dyn SessionStore>>,
89}
90
91impl std::fmt::Debug for dyn SessionStore {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        f.write_str("<SessionStore>")
94    }
95}
96
97impl Default for StreamableHttpServerConfig {
98    fn default() -> Self {
99        Self {
100            sse_keep_alive: Some(Duration::from_secs(15)),
101            sse_retry: Some(Duration::from_secs(3)),
102            stateful_mode: true,
103            json_response: false,
104            cancellation_token: CancellationToken::new(),
105            allowed_hosts: vec!["localhost".into(), "127.0.0.1".into(), "::1".into()],
106            session_store: None,
107        }
108    }
109}
110
111impl StreamableHttpServerConfig {
112    pub fn with_allowed_hosts(
113        mut self,
114        allowed_hosts: impl IntoIterator<Item = impl Into<String>>,
115    ) -> Self {
116        self.allowed_hosts = allowed_hosts.into_iter().map(Into::into).collect();
117        self
118    }
119    /// Disable allowed hosts. This will allow requests with any `Host` header, which is NOT recommended for public deployments.
120    pub fn disable_allowed_hosts(mut self) -> Self {
121        self.allowed_hosts.clear();
122        self
123    }
124    pub fn with_sse_keep_alive(mut self, duration: Option<Duration>) -> Self {
125        self.sse_keep_alive = duration;
126        self
127    }
128
129    pub fn with_sse_retry(mut self, duration: Option<Duration>) -> Self {
130        self.sse_retry = duration;
131        self
132    }
133
134    pub fn with_stateful_mode(mut self, stateful: bool) -> Self {
135        self.stateful_mode = stateful;
136        self
137    }
138
139    pub fn with_json_response(mut self, json_response: bool) -> Self {
140        self.json_response = json_response;
141        self
142    }
143
144    pub fn with_cancellation_token(mut self, token: CancellationToken) -> Self {
145        self.cancellation_token = token;
146        self
147    }
148}
149
150#[expect(
151    clippy::result_large_err,
152    reason = "BoxResponse is intentionally large; matches other handlers in this file"
153)]
154/// Validates the `MCP-Protocol-Version` header on incoming HTTP requests.
155///
156/// Per the MCP 2025-06-18 spec:
157/// - If the header is present but contains an unsupported version, return 400 Bad Request.
158/// - If the header is absent, assume `2025-03-26` for backwards compatibility (no error).
159fn validate_protocol_version_header(headers: &http::HeaderMap) -> Result<(), BoxResponse> {
160    if let Some(value) = headers.get(HEADER_MCP_PROTOCOL_VERSION) {
161        let version_str = value.to_str().map_err(|_| {
162            Response::builder()
163                .status(http::StatusCode::BAD_REQUEST)
164                .body(
165                    Full::new(Bytes::from(
166                        "Bad Request: Invalid MCP-Protocol-Version header encoding",
167                    ))
168                    .boxed(),
169                )
170                .expect("valid response")
171        })?;
172        let is_known = ProtocolVersion::KNOWN_VERSIONS
173            .iter()
174            .any(|v| v.as_str() == version_str);
175        if !is_known {
176            return Err(Response::builder()
177                .status(http::StatusCode::BAD_REQUEST)
178                .body(
179                    Full::new(Bytes::from(format!(
180                        "Bad Request: Unsupported MCP-Protocol-Version: {version_str}"
181                    )))
182                    .boxed(),
183                )
184                .expect("valid response"));
185        }
186    }
187    Ok(())
188}
189
190fn forbidden_response(message: impl Into<String>) -> BoxResponse {
191    Response::builder()
192        .status(http::StatusCode::FORBIDDEN)
193        .body(Full::new(Bytes::from(message.into())).boxed())
194        .expect("valid response")
195}
196
197fn normalize_host(host: &str) -> String {
198    host.trim_matches('[')
199        .trim_matches(']')
200        .to_ascii_lowercase()
201}
202
203#[derive(Debug, Clone, PartialEq, Eq)]
204struct NormalizedAuthority {
205    host: String,
206    port: Option<u16>,
207}
208
209fn normalize_authority(host: &str, port: Option<u16>) -> NormalizedAuthority {
210    NormalizedAuthority {
211        host: normalize_host(host),
212        port,
213    }
214}
215
216fn parse_allowed_authority(allowed: &str) -> Option<NormalizedAuthority> {
217    let allowed = allowed.trim();
218    if allowed.is_empty() {
219        return None;
220    }
221
222    if let Ok(authority) = http::uri::Authority::try_from(allowed) {
223        return Some(normalize_authority(authority.host(), authority.port_u16()));
224    }
225
226    Some(normalize_authority(allowed, None))
227}
228
229fn host_is_allowed(host: &NormalizedAuthority, allowed_hosts: &[String]) -> bool {
230    if allowed_hosts.is_empty() {
231        // If the allowed hosts list is empty, allow all hosts (not recommended).
232        return true;
233    }
234    allowed_hosts
235        .iter()
236        .filter_map(|allowed| parse_allowed_authority(allowed))
237        .any(|allowed| {
238            allowed.host == host.host
239                && match allowed.port {
240                    Some(port) => host.port == Some(port),
241                    None => true,
242                }
243        })
244}
245
246fn bad_request_response(message: &str) -> BoxResponse {
247    let body = Full::from(message.to_string()).boxed();
248
249    http::Response::builder()
250        .status(http::StatusCode::BAD_REQUEST)
251        .header(http::header::CONTENT_TYPE, "text/plain; charset=utf-8")
252        .body(body)
253        .expect("failed to build bad request response")
254}
255
256fn parse_host_header(headers: &HeaderMap) -> Result<NormalizedAuthority, BoxResponse> {
257    let Some(host) = headers.get(http::header::HOST) else {
258        return Err(bad_request_response("Bad Request: missing Host header"));
259    };
260
261    let host = host
262        .to_str()
263        .map_err(|_| bad_request_response("Bad Request: Invalid Host header encoding"))?;
264    let authority = http::uri::Authority::try_from(host)
265        .map_err(|_| bad_request_response("Bad Request: Invalid Host header"))?;
266    Ok(normalize_authority(authority.host(), authority.port_u16()))
267}
268
269fn validate_dns_rebinding_headers(
270    headers: &HeaderMap,
271    config: &StreamableHttpServerConfig,
272) -> Result<(), BoxResponse> {
273    let host = parse_host_header(headers)?;
274    if !host_is_allowed(&host, &config.allowed_hosts) {
275        return Err(forbidden_response("Forbidden: Host header is not allowed"));
276    }
277
278    Ok(())
279}
280
281/// # Streamable HTTP server
282///
283/// An HTTP service that implements the
284/// [Streamable HTTP transport](https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#streamable-http)
285/// for MCP servers.
286///
287/// ## Session management
288///
289/// When [`StreamableHttpServerConfig::stateful_mode`] is `true` (the default),
290/// the server creates a session for each client that sends an `initialize`
291/// request. The session ID is returned in the `Mcp-Session-Id` response header
292/// and the client must include it on all subsequent requests.
293///
294/// Two tool calls carrying the same `Mcp-Session-Id` come from the same logical
295/// session (typically one conversation in an LLM client). Different session IDs
296/// mean different sessions.
297///
298/// The [`SessionManager`] trait controls how sessions are stored and routed:
299///
300/// * [`LocalSessionManager`](super::session::local::LocalSessionManager) —
301///   in-memory session store (default).
302/// * [`NeverSessionManager`](super::session::never::NeverSessionManager) —
303///   disables sessions entirely (stateless mode).
304///
305/// ## Accessing HTTP request data from tool handlers
306///
307/// The service consumes the request body but injects the remaining
308/// [`http::request::Parts`] into [`crate::model::Extensions`], which is
309/// accessible through [`crate::service::RequestContext`].
310///
311/// ### Reading the raw HTTP parts
312///
313/// ```rust
314/// use rmcp::handler::server::tool::Extension;
315/// use http::request::Parts;
316/// async fn my_tool(Extension(parts): Extension<Parts>) {
317///     tracing::info!("http parts:{parts:?}")
318/// }
319/// ```
320///
321/// ### Reading the session ID inside a tool handler
322///
323/// ```rust,ignore
324/// use rmcp::handler::server::tool::Extension;
325/// use rmcp::service::RequestContext;
326/// use rmcp::model::RoleServer;
327///
328/// #[tool(description = "session-aware tool")]
329/// async fn my_tool(
330///     &self,
331///     Extension(parts): Extension<http::request::Parts>,
332/// ) -> Result<CallToolResult, rmcp::ErrorData> {
333///     if let Some(session_id) = parts.headers.get("mcp-session-id") {
334///         tracing::info!(?session_id, "called from session");
335///     }
336///     // ...
337///     # todo!()
338/// }
339/// ```
340///
341/// ### Accessing custom axum/tower extension state
342///
343/// State added via axum's `Extension` layer is available inside
344/// `Parts.extensions`:
345///
346/// ```rust,ignore
347/// use rmcp::service::RequestContext;
348/// use rmcp::model::RoleServer;
349///
350/// #[derive(Clone)]
351/// struct AppState { /* ... */ }
352///
353/// #[tool(description = "example")]
354/// async fn my_tool(
355///     &self,
356///     ctx: RequestContext<RoleServer>,
357/// ) -> Result<CallToolResult, rmcp::ErrorData> {
358///     let parts = ctx.extensions.get::<http::request::Parts>().unwrap();
359///     let state = parts.extensions.get::<AppState>().unwrap();
360///     // use state...
361///     # todo!()
362/// }
363/// ```
364pub struct StreamableHttpService<S, M> {
365    pub config: StreamableHttpServerConfig,
366    session_manager: Arc<M>,
367    service_factory: Arc<dyn Fn() -> Result<S, std::io::Error> + Send + Sync>,
368    /// Tracks in-progress session restores so that concurrent requests for the
369    /// same unknown session ID wait for the first restore to complete rather
370    /// than racing to replay the initialize handshake. `None` when no external
371    /// session store is configured (avoids allocating the map).
372    pending_restores: Option<
373        Arc<tokio::sync::RwLock<HashMap<SessionId, tokio::sync::watch::Sender<Option<bool>>>>>,
374    >,
375}
376
377impl<S, M> Clone for StreamableHttpService<S, M> {
378    fn clone(&self) -> Self {
379        Self {
380            config: self.config.clone(),
381            session_manager: self.session_manager.clone(),
382            service_factory: self.service_factory.clone(),
383            pending_restores: self.pending_restores.clone(),
384        }
385    }
386}
387
388impl<RequestBody, S, M> tower_service::Service<Request<RequestBody>> for StreamableHttpService<S, M>
389where
390    RequestBody: Body + Send + 'static,
391    S: crate::Service<RoleServer> + Send + 'static,
392    M: SessionManager,
393    RequestBody::Error: Display,
394    RequestBody::Data: Send + 'static,
395{
396    type Response = BoxResponse;
397    type Error = Infallible;
398    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
399    fn call(&mut self, req: http::Request<RequestBody>) -> Self::Future {
400        let service = self.clone();
401        Box::pin(async move {
402            let response = service.handle(req).await;
403            Ok(response)
404        })
405    }
406    fn poll_ready(
407        &mut self,
408        _cx: &mut std::task::Context<'_>,
409    ) -> std::task::Poll<Result<(), Self::Error>> {
410        std::task::Poll::Ready(Ok(()))
411    }
412}
413
414/// Guard used inside [`StreamableHttpService::try_restore_from_store`].
415///
416/// Ensures the `pending_restores` map entry is always cleaned up — even when
417/// the future is cancelled mid-await.
418///
419/// `result` defaults to `false` (failure / cancellation). Only the success path
420/// needs to set it to `true` before returning.
421struct PendingRestoreGuard {
422    pending_restores:
423        Arc<tokio::sync::RwLock<HashMap<SessionId, tokio::sync::watch::Sender<Option<bool>>>>>,
424    session_id: SessionId,
425    watch_tx: tokio::sync::watch::Sender<Option<bool>>,
426    /// The value that will be broadcast to waiting tasks on drop.
427    result: bool,
428}
429
430impl Drop for PendingRestoreGuard {
431    fn drop(&mut self) {
432        // `send` is synchronous — unblocks waiters immediately, no lock needed.
433        let _ = self.watch_tx.send(Some(self.result));
434        // Remove the map entry asynchronously (requires the async write lock).
435        let pending_restores = self.pending_restores.clone();
436        let session_id = self.session_id.clone();
437        tokio::spawn(async move {
438            pending_restores.write().await.remove(&session_id);
439        });
440    }
441}
442
443impl<S, M> StreamableHttpService<S, M>
444where
445    S: crate::Service<RoleServer> + Send + 'static,
446    M: SessionManager,
447{
448    pub fn new(
449        service_factory: impl Fn() -> Result<S, std::io::Error> + Send + Sync + 'static,
450        session_manager: Arc<M>,
451        config: StreamableHttpServerConfig,
452    ) -> Self {
453        let pending_restores = config.session_store.is_some().then(|| {
454            Arc::new(tokio::sync::RwLock::new(HashMap::<
455                SessionId,
456                tokio::sync::watch::Sender<Option<bool>>,
457            >::new()))
458        });
459        Self {
460            config,
461            session_manager,
462            service_factory: Arc::new(service_factory),
463            pending_restores,
464        }
465    }
466    fn get_service(&self) -> Result<S, std::io::Error> {
467        (self.service_factory)()
468    }
469
470    /// Spawn a task that runs `serve_server` for the given session, waits for
471    /// it to finish, and then calls `close_session`.
472    ///
473    /// `init_done_tx`: when `Some`, the sender is fired after `serve_server`
474    /// returns successfully, signalling to the caller that the MCP handshake
475    /// is complete. Used by `try_restore_from_store` to synchronise with the
476    /// restore `initialize` replay; `handle_post` passes `None`.
477    fn spawn_session_worker(
478        session_manager: Arc<M>,
479        session_id: SessionId,
480        service: S,
481        transport: M::Transport,
482        init_done_tx: Option<tokio::sync::oneshot::Sender<()>>,
483    ) where
484        S: crate::Service<RoleServer> + Send + 'static,
485        M: SessionManager,
486    {
487        tokio::spawn(async move {
488            let svc =
489                serve_server::<S, M::Transport, _, TransportAdapterIdentity>(service, transport)
490                    .await;
491            match svc {
492                Ok(svc) => {
493                    if let Some(tx) = init_done_tx {
494                        let _ = tx.send(());
495                    }
496                    let _ = svc.waiting().await;
497                }
498                Err(e) => {
499                    tracing::error!("Failed to serve session: {e}");
500                    // Dropping init_done_tx (if Some) signals failure to the caller.
501                }
502            }
503            let _ = session_manager
504                .close_session(&session_id)
505                .await
506                .inspect_err(|e| {
507                    tracing::error!("Failed to close session {session_id}: {e}");
508                });
509        });
510    }
511
512    /// Attempt to restore a session from the external store.
513    ///
514    /// Returns `true` when the session is available and ready to serve the
515    /// current request (either just restored or already in memory). Returns
516    /// `false` when no store is configured or the session ID is unknown.
517    ///
518    /// Concurrent requests for the same unknown session ID are serialized: the
519    /// first caller performs the full restore and handshake replay while others
520    /// subscribe to a `watch` channel and wait, avoiding duplicate handshakes.
521    async fn try_restore_from_store(
522        &self,
523        session_id: &SessionId,
524        parts: &http::request::Parts,
525    ) -> Result<bool, std::io::Error>
526    where
527        S: crate::Service<RoleServer> + Send + 'static,
528        M: SessionManager,
529    {
530        // Both fields are Some iff a session store is configured.
531        let (Some(pending_restores), Some(store)) =
532            (&self.pending_restores, &self.config.session_store)
533        else {
534            return Ok(false);
535        };
536
537        // Serialize concurrent restores for the same session ID.
538        // Write-lock once: if another task is already restoring, subscribe and wait;
539        // otherwise, register ourselves as the restoring task.
540        // Channel value: None = in progress, Some(true) = restored, Some(false) = not found/failed.
541        let (watch_tx, _watch_rx) = tokio::sync::watch::channel(None::<bool>);
542        {
543            let mut pending = pending_restores.write().await;
544            if let Some(tx) = pending.get(session_id) {
545                let mut rx = tx.subscribe();
546                drop(pending);
547                // Wait for the restore to finish, then propagate the outcome.
548                let result = rx
549                    .wait_for(|r| r.is_some())
550                    .await
551                    .map(|r| r.unwrap_or(false))
552                    .unwrap_or(false);
553                return Ok(result);
554            }
555            pending.insert(session_id.clone(), watch_tx.clone());
556        }
557
558        // Guard: signals waiters and cleans up the map entry on drop
559        let mut guard = PendingRestoreGuard {
560            pending_restores: pending_restores.clone(),
561            session_id: session_id.clone(),
562            watch_tx: watch_tx.clone(),
563            result: false,
564        };
565
566        // --- Step 3: load from external store ---
567        let state = match store.load(session_id.as_ref()).await {
568            Ok(Some(s)) => s,
569            Ok(None) => {
570                return Ok(false);
571            }
572            Err(e) => {
573                tracing::error!(
574                    session_id = session_id.as_ref(),
575                    error = %e,
576                    "session store load failed during restore"
577                );
578                return Err(std::io::Error::other(e));
579            }
580        };
581
582        // --- Step 4: ask the session manager to allocate an in-memory worker ---
583        let transport = match self
584            .session_manager
585            .restore_session(session_id.clone())
586            .await
587            .map_err(|e| std::io::Error::other(e.to_string()))
588        {
589            Ok(RestoreOutcome::Restored(t)) => t,
590            Ok(RestoreOutcome::AlreadyPresent) => {
591                // Invariant violation: pending_restores ensures only one task can call
592                // restore_session per session ID, so AlreadyPresent is impossible here.
593                return Err(std::io::Error::other(
594                    "restore_session returned AlreadyPresent unexpectedly; session manager might have modified the session store outside of the restore_session API",
595                ));
596            }
597            Ok(RestoreOutcome::NotSupported) => {
598                return Ok(false);
599            }
600            Err(e) => {
601                return Err(e);
602            }
603        };
604
605        // --- Step 5: replay the MCP initialize handshake ---
606        let service = match self.get_service() {
607            Ok(s) => s,
608            Err(e) => {
609                return Err(e);
610            }
611        };
612
613        // `serve_server` requires both the `initialize` request and the
614        // `notifications/initialized` notification before transitioning to
615        // the running state — we must send both before returning.
616        let mut restore_init = ClientJsonRpcMessage::request(
617            ClientRequest::InitializeRequest(InitializeRequest {
618                params: state.initialize_params,
619                ..Default::default()
620            }),
621            crate::model::NumberOrString::Number(0),
622        );
623        restore_init.insert_extension(parts.clone());
624        restore_init.insert_extension(SessionRestoreMarker {
625            id: session_id.clone(),
626        });
627        let mut restore_initialized = ClientJsonRpcMessage::notification(
628            ClientNotification::InitializedNotification(InitializedNotification {
629                ..Default::default()
630            }),
631        );
632        restore_initialized.insert_extension(parts.clone());
633        restore_initialized.insert_extension(SessionRestoreMarker {
634            id: session_id.clone(),
635        });
636        // Signal from the spawned task once serve_server finishes initialising.
637        let (init_done_tx, init_done_rx) = tokio::sync::oneshot::channel::<()>();
638
639        Self::spawn_session_worker(
640            self.session_manager.clone(),
641            session_id.clone(),
642            service,
643            transport,
644            Some(init_done_tx),
645        );
646
647        if let Err(e) = self
648            .session_manager
649            .initialize_session(session_id, restore_init)
650            .await
651            .map_err(|e| std::io::Error::other(e.to_string()))
652        {
653            return Err(e);
654        }
655
656        if let Err(e) = self
657            .session_manager
658            .accept_message(session_id, restore_initialized)
659            .await
660            .map_err(|e| std::io::Error::other(e.to_string()))
661        {
662            return Err(e);
663        }
664
665        if init_done_rx.await.is_err() {
666            return Err(std::io::Error::other(
667                "serve_server initialization failed during restore",
668            ));
669        }
670
671        // Restore complete — wake any waiting concurrent requests.
672        guard.result = true;
673
674        tracing::debug!(
675            session_id = session_id.as_ref(),
676            "session restored from external store"
677        );
678        Ok(true)
679    }
680    pub async fn handle<B>(&self, request: Request<B>) -> Response<BoxBody<Bytes, Infallible>>
681    where
682        B: Body + Send + 'static,
683        B::Error: Display,
684    {
685        if let Err(response) = validate_dns_rebinding_headers(request.headers(), &self.config) {
686            return response;
687        }
688        let method = request.method().clone();
689        let allowed_methods = match self.config.stateful_mode {
690            true => "GET, POST, DELETE",
691            false => "POST",
692        };
693        let result = match (method, self.config.stateful_mode) {
694            (Method::POST, _) => self.handle_post(request).await,
695            // if we're not in stateful mode, we don't support GET or DELETE because there is no session
696            (Method::GET, true) => self.handle_get(request).await,
697            (Method::DELETE, true) => self.handle_delete(request).await,
698            _ => {
699                // Handle other methods or return an error
700                let response = Response::builder()
701                    .status(http::StatusCode::METHOD_NOT_ALLOWED)
702                    .header(ALLOW, allowed_methods)
703                    .body(Full::new(Bytes::from("Method Not Allowed")).boxed())
704                    .expect("valid response");
705                return response;
706            }
707        };
708        match result {
709            Ok(response) => response,
710            Err(response) => response,
711        }
712    }
713    async fn handle_get<B>(&self, request: Request<B>) -> Result<BoxResponse, BoxResponse>
714    where
715        B: Body + Send + 'static,
716        B::Error: Display,
717    {
718        // check accept header
719        if !request
720            .headers()
721            .get(http::header::ACCEPT)
722            .and_then(|header| header.to_str().ok())
723            .is_some_and(|header| header.contains(EVENT_STREAM_MIME_TYPE))
724        {
725            return Ok(Response::builder()
726                .status(http::StatusCode::NOT_ACCEPTABLE)
727                .body(
728                    Full::new(Bytes::from(
729                        "Not Acceptable: Client must accept text/event-stream",
730                    ))
731                    .boxed(),
732                )
733                .expect("valid response"));
734        }
735        // check session id
736        let session_id = request
737            .headers()
738            .get(HEADER_SESSION_ID)
739            .and_then(|v| v.to_str().ok())
740            .map(|s| s.to_owned().into());
741        let Some(session_id) = session_id else {
742            // MCP spec: servers that require a session ID SHOULD respond with 400 Bad Request
743            return Ok(Response::builder()
744                .status(http::StatusCode::BAD_REQUEST)
745                .body(Full::new(Bytes::from("Bad Request: Session ID is required")).boxed())
746                .expect("valid response"));
747        };
748        // check if session exists
749        let has_session = self
750            .session_manager
751            .has_session(&session_id)
752            .await
753            .map_err(internal_error_response("check session"))?;
754        let (parts, _) = request.into_parts();
755        if !has_session {
756            // Attempt transparent cross-instance restore from external store.
757            let restored = self
758                .try_restore_from_store(&session_id, &parts)
759                .await
760                .map_err(internal_error_response("restore session"))?;
761            if !restored {
762                // MCP spec: server MUST respond with 404 Not Found for terminated/unknown sessions
763                return Ok(Response::builder()
764                    .status(http::StatusCode::NOT_FOUND)
765                    .body(Full::new(Bytes::from("Not Found: Session not found")).boxed())
766                    .expect("valid response"));
767            }
768        }
769        // Validate MCP-Protocol-Version header (per 2025-06-18 spec)
770        validate_protocol_version_header(&parts.headers)?;
771        // check if last event id is provided
772        let last_event_id = parts
773            .headers
774            .get(HEADER_LAST_EVENT_ID)
775            .and_then(|v| v.to_str().ok())
776            .map(|s| s.to_owned());
777        if let Some(last_event_id) = last_event_id {
778            match self
779                .session_manager
780                .resume(&session_id, last_event_id)
781                .await
782            {
783                Ok(stream) => {
784                    return Ok(sse_stream_response(
785                        stream,
786                        self.config.sse_keep_alive,
787                        self.config.cancellation_token.child_token(),
788                    ));
789                }
790                Err(e) => {
791                    // Return 200 with an immediately-closed empty stream.
792                    // Returning an HTTP error would cause EventSource to retry
793                    // with the same Last-Event-ID in an infinite loop. An empty
794                    // 200 cleanly terminates the EventSource without delivering
795                    // events from a different stream.
796                    tracing::warn!("Resume failed ({e}), returning empty stream");
797                    return Ok(sse_stream_response(
798                        futures::stream::empty(),
799                        None,
800                        self.config.cancellation_token.child_token(),
801                    ));
802                }
803            }
804        }
805        // No Last-Event-ID — create standalone stream
806        let stream = self
807            .session_manager
808            .create_standalone_stream(&session_id)
809            .await
810            .map_err(internal_error_response("create standalone stream"))?;
811        let stream = if let Some(retry) = self.config.sse_retry {
812            let priming = ServerSseMessage::priming("0", retry);
813            futures::stream::once(async move { priming })
814                .chain(stream)
815                .left_stream()
816        } else {
817            stream.right_stream()
818        };
819        Ok(sse_stream_response(
820            stream,
821            self.config.sse_keep_alive,
822            self.config.cancellation_token.child_token(),
823        ))
824    }
825
826    async fn handle_post<B>(&self, request: Request<B>) -> Result<BoxResponse, BoxResponse>
827    where
828        B: Body + Send + 'static,
829        B::Error: Display,
830    {
831        // check accept header
832        if !request
833            .headers()
834            .get(http::header::ACCEPT)
835            .and_then(|header| header.to_str().ok())
836            .is_some_and(|header| {
837                header.contains(JSON_MIME_TYPE) && header.contains(EVENT_STREAM_MIME_TYPE)
838            })
839        {
840            return Ok(Response::builder()
841                .status(http::StatusCode::NOT_ACCEPTABLE)
842                .body(Full::new(Bytes::from("Not Acceptable: Client must accept both application/json and text/event-stream")).boxed())
843                .expect("valid response"));
844        }
845
846        // check content type
847        if !request
848            .headers()
849            .get(http::header::CONTENT_TYPE)
850            .and_then(|header| header.to_str().ok())
851            .is_some_and(|header| header.starts_with(JSON_MIME_TYPE))
852        {
853            return Ok(Response::builder()
854                .status(http::StatusCode::UNSUPPORTED_MEDIA_TYPE)
855                .body(
856                    Full::new(Bytes::from(
857                        "Unsupported Media Type: Content-Type must be application/json",
858                    ))
859                    .boxed(),
860                )
861                .expect("valid response"));
862        }
863
864        // json deserialize request body
865        let (part, body) = request.into_parts();
866        let mut message = match expect_json(body).await {
867            Ok(message) => message,
868            Err(response) => return Ok(response),
869        };
870
871        if self.config.stateful_mode {
872            // do we have a session id?
873            let session_id = part
874                .headers
875                .get(HEADER_SESSION_ID)
876                .and_then(|v| v.to_str().ok());
877            if let Some(session_id) = session_id {
878                let session_id = session_id.to_owned().into();
879                let has_session = self
880                    .session_manager
881                    .has_session(&session_id)
882                    .await
883                    .map_err(internal_error_response("check session"))?;
884                if !has_session {
885                    // Attempt transparent cross-instance restore from external store.
886                    let restored = self
887                        .try_restore_from_store(&session_id, &part)
888                        .await
889                        .map_err(internal_error_response("restore session"))?;
890                    if !restored {
891                        // MCP spec: server MUST respond with 404 Not Found for terminated/unknown sessions
892                        return Ok(Response::builder()
893                            .status(http::StatusCode::NOT_FOUND)
894                            .body(Full::new(Bytes::from("Not Found: Session not found")).boxed())
895                            .expect("valid response"));
896                    }
897                }
898
899                // Validate MCP-Protocol-Version header (per 2025-06-18 spec)
900                validate_protocol_version_header(&part.headers)?;
901
902                // inject request part to extensions
903                match &mut message {
904                    ClientJsonRpcMessage::Request(req) => {
905                        req.request.extensions_mut().insert(part);
906                    }
907                    ClientJsonRpcMessage::Notification(not) => {
908                        not.notification.extensions_mut().insert(part);
909                    }
910                    _ => {
911                        // skip
912                    }
913                }
914
915                match message {
916                    ClientJsonRpcMessage::Request(_) => {
917                        // Priming for request-wise streams is handled by the
918                        // session layer (SessionManager::create_stream) which
919                        // has access to the http_request_id for correct event IDs.
920                        let stream = self
921                            .session_manager
922                            .create_stream(&session_id, message)
923                            .await
924                            .map_err(internal_error_response("get session"))?;
925                        Ok(sse_stream_response(
926                            stream,
927                            self.config.sse_keep_alive,
928                            self.config.cancellation_token.child_token(),
929                        ))
930                    }
931                    ClientJsonRpcMessage::Notification(_)
932                    | ClientJsonRpcMessage::Response(_)
933                    | ClientJsonRpcMessage::Error(_) => {
934                        // handle notification
935                        self.session_manager
936                            .accept_message(&session_id, message)
937                            .await
938                            .map_err(internal_error_response("accept message"))?;
939                        Ok(accepted_response())
940                    }
941                }
942            } else {
943                let (session_id, transport) = self
944                    .session_manager
945                    .create_session()
946                    .await
947                    .map_err(internal_error_response("create session"))?;
948                // Capture init params for external store persistence before
949                // extensions are injected (which would require Clone).
950                let stored_init_params = if self.config.session_store.is_some() {
951                    if let ClientJsonRpcMessage::Request(req) = &message {
952                        if let ClientRequest::InitializeRequest(init_req) = &req.request {
953                            Some(init_req.params.clone())
954                        } else {
955                            None
956                        }
957                    } else {
958                        None
959                    }
960                } else {
961                    None
962                };
963                if let ClientJsonRpcMessage::Request(req) = &mut message {
964                    if !matches!(req.request, ClientRequest::InitializeRequest(_)) {
965                        return Err(unexpected_message_response("initialize request"));
966                    }
967                    // inject request part to extensions
968                    req.request.extensions_mut().insert(part);
969                } else {
970                    return Err(unexpected_message_response("initialize request"));
971                }
972                let service = self
973                    .get_service()
974                    .map_err(internal_error_response("get service"))?;
975                // spawn a task to serve the session
976                Self::spawn_session_worker(
977                    self.session_manager.clone(),
978                    session_id.clone(),
979                    service,
980                    transport,
981                    None,
982                );
983                // get initialize response
984                let response = self
985                    .session_manager
986                    .initialize_session(&session_id, message)
987                    .await
988                    .map_err(internal_error_response("create stream"))?;
989                // Persist session state to external store after a successful handshake.
990                if let (Some(store), Some(params)) =
991                    (&self.config.session_store, stored_init_params)
992                {
993                    let state = SessionState {
994                        initialize_params: params,
995                    };
996                    let _ = store
997                        .store(session_id.as_ref(), &state)
998                        .await
999                        .inspect_err(|e| {
1000                            tracing::warn!(
1001                                "Failed to persist session {} to store: {e}",
1002                                session_id
1003                            );
1004                        });
1005                }
1006                let stream =
1007                    futures::stream::once(async move { ServerSseMessage::from_message(response) });
1008                // Prepend priming event if sse_retry configured
1009                let stream = if let Some(retry) = self.config.sse_retry {
1010                    let priming = ServerSseMessage::priming("0", retry);
1011                    futures::stream::once(async move { priming })
1012                        .chain(stream)
1013                        .left_stream()
1014                } else {
1015                    stream.right_stream()
1016                };
1017                let mut response = sse_stream_response(
1018                    stream,
1019                    self.config.sse_keep_alive,
1020                    self.config.cancellation_token.child_token(),
1021                );
1022
1023                response.headers_mut().insert(
1024                    HEADER_SESSION_ID,
1025                    session_id
1026                        .parse()
1027                        .map_err(internal_error_response("create session id header"))?,
1028                );
1029                Ok(response)
1030            }
1031        } else {
1032            // Stateless mode: validate MCP-Protocol-Version on non-init requests
1033            let is_init = matches!(
1034                &message,
1035                ClientJsonRpcMessage::Request(req) if matches!(req.request, ClientRequest::InitializeRequest(_))
1036            );
1037            if !is_init {
1038                validate_protocol_version_header(&part.headers)?;
1039            }
1040            let service = self
1041                .get_service()
1042                .map_err(internal_error_response("get service"))?;
1043            match message {
1044                ClientJsonRpcMessage::Request(mut request) => {
1045                    request.request.extensions_mut().insert(part);
1046                    let (transport, mut receiver) =
1047                        OneshotTransport::<RoleServer>::new(ClientJsonRpcMessage::Request(request));
1048                    let service = serve_directly(service, transport, None);
1049                    tokio::spawn(async move {
1050                        // on service created
1051                        let _ = service.waiting().await;
1052                    });
1053                    if self.config.json_response {
1054                        // JSON-direct mode: await the single response and return as
1055                        // application/json, eliminating SSE framing overhead.
1056                        // Allowed by MCP Streamable HTTP spec (2025-06-18).
1057                        let cancel = self.config.cancellation_token.child_token();
1058                        match tokio::select! {
1059                            res = receiver.recv() => res,
1060                            _ = cancel.cancelled() => None,
1061                        } {
1062                            Some(message) => {
1063                                tracing::trace!(?message);
1064                                let body = serde_json::to_vec(&message).map_err(|e| {
1065                                    internal_error_response("serialize json response")(e)
1066                                })?;
1067                                Ok(Response::builder()
1068                                    .status(http::StatusCode::OK)
1069                                    .header(http::header::CONTENT_TYPE, JSON_MIME_TYPE)
1070                                    .body(Full::new(Bytes::from(body)).boxed())
1071                                    .expect("valid response"))
1072                            }
1073                            None => Err(internal_error_response("empty response")(
1074                                std::io::Error::new(
1075                                    std::io::ErrorKind::UnexpectedEof,
1076                                    "no response message received from handler",
1077                                ),
1078                            )),
1079                        }
1080                    } else {
1081                        // SSE mode (default): original behaviour preserved unchanged
1082                        let stream = ReceiverStream::new(receiver).map(|message| {
1083                            tracing::trace!(?message);
1084                            ServerSseMessage::from_message(message)
1085                        });
1086                        Ok(sse_stream_response(
1087                            stream,
1088                            self.config.sse_keep_alive,
1089                            self.config.cancellation_token.child_token(),
1090                        ))
1091                    }
1092                }
1093                ClientJsonRpcMessage::Notification(_notification) => {
1094                    // ignore
1095                    Ok(accepted_response())
1096                }
1097                ClientJsonRpcMessage::Response(_json_rpc_response) => Ok(accepted_response()),
1098                ClientJsonRpcMessage::Error(_json_rpc_error) => Ok(accepted_response()),
1099            }
1100        }
1101    }
1102
1103    async fn handle_delete<B>(&self, request: Request<B>) -> Result<BoxResponse, BoxResponse>
1104    where
1105        B: Body + Send + 'static,
1106        B::Error: Display,
1107    {
1108        // check session id
1109        let session_id = request
1110            .headers()
1111            .get(HEADER_SESSION_ID)
1112            .and_then(|v| v.to_str().ok())
1113            .map(|s| s.to_owned().into());
1114        let Some(session_id) = session_id else {
1115            // MCP spec: servers that require a session ID SHOULD respond with 400 Bad Request
1116            return Ok(Response::builder()
1117                .status(http::StatusCode::BAD_REQUEST)
1118                .body(Full::new(Bytes::from("Bad Request: Session ID is required")).boxed())
1119                .expect("valid response"));
1120        };
1121        // Validate MCP-Protocol-Version header (per 2025-06-18 spec)
1122        validate_protocol_version_header(request.headers())?;
1123        // close session
1124        self.session_manager
1125            .close_session(&session_id)
1126            .await
1127            .map_err(internal_error_response("close session"))?;
1128        // Remove from external store: a DELETE means the client intentionally
1129        // ends the session, so the store entry is no longer needed.
1130        if let Some(store) = &self.config.session_store {
1131            let _ = store.delete(session_id.as_ref()).await.inspect_err(|e| {
1132                tracing::warn!("Failed to delete session {} from store: {e}", session_id);
1133            });
1134        }
1135        Ok(accepted_response())
1136    }
1137}