Skip to main content

mcpkit_rs/transport/streamable_http_server/
tower.rs

1use std::{convert::Infallible, fmt::Display, sync::Arc, time::Duration};
2
3use bytes::Bytes;
4use futures::{StreamExt, future::BoxFuture};
5use http::{Method, Request, Response, header::ALLOW};
6use http_body::Body;
7use http_body_util::{BodyExt, Full, combinators::BoxBody};
8use tokio_stream::wrappers::ReceiverStream;
9use tokio_util::sync::CancellationToken;
10
11use super::session::SessionManager;
12use crate::{
13    RoleServer,
14    model::{ClientJsonRpcMessage, ClientRequest, GetExtensions},
15    serve_server,
16    service::serve_directly,
17    transport::{
18        OneshotTransport, TransportAdapterIdentity,
19        common::{
20            http_header::{
21                EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_SESSION_ID, JSON_MIME_TYPE,
22            },
23            server_side_http::{
24                BoxResponse, ServerSseMessage, accepted_response, expect_json,
25                internal_error_response, sse_stream_response, unexpected_message_response,
26            },
27        },
28    },
29};
30
31#[derive(Debug, Clone)]
32pub struct StreamableHttpServerConfig {
33    /// The ping message duration for SSE connections.
34    pub sse_keep_alive: Option<Duration>,
35    /// The retry interval for SSE priming events.
36    pub sse_retry: Option<Duration>,
37    /// If true, the server will create a session for each request and keep it alive.
38    /// When enabled, SSE priming events are sent to enable client reconnection.
39    pub stateful_mode: bool,
40    /// Cancellation token for the Streamable HTTP server.
41    ///
42    /// When this token is cancelled, all active sessions are terminated and
43    /// the server stops accepting new requests.
44    pub cancellation_token: CancellationToken,
45}
46
47impl Default for StreamableHttpServerConfig {
48    fn default() -> Self {
49        Self {
50            sse_keep_alive: Some(Duration::from_secs(15)),
51            sse_retry: Some(Duration::from_secs(3)),
52            stateful_mode: true,
53            cancellation_token: CancellationToken::new(),
54        }
55    }
56}
57
58/// # Streamable Http Server
59///
60/// ## Extract information from raw http request
61///
62/// The http service will consume the request body, however the rest part will be remain and injected into [`crate::model::Extensions`],
63/// which you can get from [`crate::service::RequestContext`].
64/// ```rust
65/// use mcpkit_rs::handler::server::tool::Extension;
66/// use http::request::Parts;
67/// async fn my_tool(Extension(parts): Extension<Parts>) {
68///     tracing::info!("http parts:{parts:?}")
69/// }
70/// ```
71pub struct StreamableHttpService<S, M = super::session::local::LocalSessionManager> {
72    pub config: StreamableHttpServerConfig,
73    session_manager: Arc<M>,
74    service_factory: Arc<dyn Fn() -> Result<S, std::io::Error> + Send + Sync>,
75}
76
77impl<S, M> Clone for StreamableHttpService<S, M> {
78    fn clone(&self) -> Self {
79        Self {
80            config: self.config.clone(),
81            session_manager: self.session_manager.clone(),
82            service_factory: self.service_factory.clone(),
83        }
84    }
85}
86
87impl<RequestBody, S, M> tower_service::Service<Request<RequestBody>> for StreamableHttpService<S, M>
88where
89    RequestBody: Body + Send + 'static,
90    S: crate::Service<RoleServer>,
91    M: SessionManager,
92    RequestBody::Error: Display,
93    RequestBody::Data: Send + 'static,
94{
95    type Response = BoxResponse;
96    type Error = Infallible;
97    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
98    fn call(&mut self, req: http::Request<RequestBody>) -> Self::Future {
99        let service = self.clone();
100        Box::pin(async move {
101            let response = service.handle(req).await;
102            Ok(response)
103        })
104    }
105    fn poll_ready(
106        &mut self,
107        _cx: &mut std::task::Context<'_>,
108    ) -> std::task::Poll<Result<(), Self::Error>> {
109        std::task::Poll::Ready(Ok(()))
110    }
111}
112
113impl<S, M> StreamableHttpService<S, M>
114where
115    S: crate::Service<RoleServer> + Send + 'static,
116    M: SessionManager,
117{
118    pub fn new(
119        service_factory: impl Fn() -> Result<S, std::io::Error> + Send + Sync + 'static,
120        session_manager: Arc<M>,
121        config: StreamableHttpServerConfig,
122    ) -> Self {
123        Self {
124            config,
125            session_manager,
126            service_factory: Arc::new(service_factory),
127        }
128    }
129    fn get_service(&self) -> Result<S, std::io::Error> {
130        (self.service_factory)()
131    }
132    pub async fn handle<B>(&self, request: Request<B>) -> Response<BoxBody<Bytes, Infallible>>
133    where
134        B: Body + Send + 'static,
135        B::Error: Display,
136    {
137        let method = request.method().clone();
138        let allowed_methods = match self.config.stateful_mode {
139            true => "GET, POST, DELETE, OPTIONS",
140            false => "POST, OPTIONS",
141        };
142
143        // Handle OPTIONS for CORS preflight
144        if method == Method::OPTIONS {
145            let response = Response::builder()
146                .status(http::StatusCode::NO_CONTENT)
147                .header(ALLOW, allowed_methods)
148                .header("Access-Control-Allow-Origin", "*")
149                .header("Access-Control-Allow-Methods", allowed_methods)
150                .header(
151                    "Access-Control-Allow-Headers",
152                    "Content-Type, Accept, X-Session-Id, X-Last-Event-Id",
153                )
154                .header("Access-Control-Max-Age", "3600")
155                .body(Full::new(Bytes::new()).boxed())
156                .expect("valid response");
157            return response;
158        }
159
160        let result = match (method, self.config.stateful_mode) {
161            (Method::POST, _) => self.handle_post(request).await,
162            // if we're not in stateful mode, we don't support GET or DELETE because there is no session
163            (Method::GET, true) => self.handle_get(request).await,
164            (Method::DELETE, true) => self.handle_delete(request).await,
165            _ => {
166                // Handle other methods or return an error
167                let response = Response::builder()
168                    .status(http::StatusCode::METHOD_NOT_ALLOWED)
169                    .header(ALLOW, allowed_methods)
170                    .body(Full::new(Bytes::from("Method Not Allowed")).boxed())
171                    .expect("valid response");
172                return response;
173            }
174        };
175        match result {
176            Ok(response) => response,
177            Err(response) => response,
178        }
179    }
180    async fn handle_get<B>(&self, request: Request<B>) -> Result<BoxResponse, BoxResponse>
181    where
182        B: Body + Send + 'static,
183        B::Error: Display,
184    {
185        // check accept header
186        if !request
187            .headers()
188            .get(http::header::ACCEPT)
189            .and_then(|header| header.to_str().ok())
190            .is_some_and(|header| header.contains(EVENT_STREAM_MIME_TYPE))
191        {
192            return Ok(Response::builder()
193                .status(http::StatusCode::NOT_ACCEPTABLE)
194                .body(
195                    Full::new(Bytes::from(
196                        "Not Acceptable: Client must accept text/event-stream",
197                    ))
198                    .boxed(),
199                )
200                .expect("valid response"));
201        }
202        // check session id
203        let session_id = request
204            .headers()
205            .get(HEADER_SESSION_ID)
206            .and_then(|v| v.to_str().ok())
207            .map(|s| s.to_owned().into());
208        let Some(session_id) = session_id else {
209            // unauthorized
210            return Ok(Response::builder()
211                .status(http::StatusCode::UNAUTHORIZED)
212                .body(Full::new(Bytes::from("Unauthorized: Session ID is required")).boxed())
213                .expect("valid response"));
214        };
215        // check if session exists
216        let has_session = self
217            .session_manager
218            .has_session(&session_id)
219            .await
220            .map_err(internal_error_response("check session"))?;
221        if !has_session {
222            // unauthorized
223            return Ok(Response::builder()
224                .status(http::StatusCode::UNAUTHORIZED)
225                .body(Full::new(Bytes::from("Unauthorized: Session not found")).boxed())
226                .expect("valid response"));
227        }
228        // check if last event id is provided
229        let last_event_id = request
230            .headers()
231            .get(HEADER_LAST_EVENT_ID)
232            .and_then(|v| v.to_str().ok())
233            .map(|s| s.to_owned());
234        if let Some(last_event_id) = last_event_id {
235            // check if session has this event id
236            let stream = self
237                .session_manager
238                .resume(&session_id, last_event_id)
239                .await
240                .map_err(internal_error_response("resume session"))?;
241            // Resume doesn't need priming - client already has the event ID
242            Ok(sse_stream_response(
243                stream,
244                self.config.sse_keep_alive,
245                self.config.cancellation_token.child_token(),
246            ))
247        } else {
248            // create standalone stream
249            let stream = self
250                .session_manager
251                .create_standalone_stream(&session_id)
252                .await
253                .map_err(internal_error_response("create standalone stream"))?;
254            // Prepend priming event if sse_retry configured
255            let stream = if let Some(retry) = self.config.sse_retry {
256                let priming = ServerSseMessage {
257                    event_id: Some("0".into()),
258                    message: None,
259                    retry: Some(retry),
260                };
261                futures::stream::once(async move { priming })
262                    .chain(stream)
263                    .left_stream()
264            } else {
265                stream.right_stream()
266            };
267            Ok(sse_stream_response(
268                stream,
269                self.config.sse_keep_alive,
270                self.config.cancellation_token.child_token(),
271            ))
272        }
273    }
274
275    async fn handle_post<B>(&self, request: Request<B>) -> Result<BoxResponse, BoxResponse>
276    where
277        B: Body + Send + 'static,
278        B::Error: Display,
279    {
280        // check accept header - must accept at least text/event-stream
281        // also accept application/json for compatibility
282        if !request
283            .headers()
284            .get(http::header::ACCEPT)
285            .and_then(|header| header.to_str().ok())
286            .is_some_and(|header| {
287                header.contains(EVENT_STREAM_MIME_TYPE)
288                    || header.contains(JSON_MIME_TYPE)
289                    || header.contains("*/*") // Accept all
290            })
291        {
292            return Ok(Response::builder()
293                .status(http::StatusCode::NOT_ACCEPTABLE)
294                .body(
295                    Full::new(Bytes::from(
296                        "Not Acceptable: Client must accept text/event-stream",
297                    ))
298                    .boxed(),
299                )
300                .expect("valid response"));
301        }
302
303        // check content type
304        if !request
305            .headers()
306            .get(http::header::CONTENT_TYPE)
307            .and_then(|header| header.to_str().ok())
308            .is_some_and(|header| header.starts_with(JSON_MIME_TYPE))
309        {
310            return Ok(Response::builder()
311                .status(http::StatusCode::UNSUPPORTED_MEDIA_TYPE)
312                .body(
313                    Full::new(Bytes::from(
314                        "Unsupported Media Type: Content-Type must be application/json",
315                    ))
316                    .boxed(),
317                )
318                .expect("valid response"));
319        }
320
321        // json deserialize request body
322        let (part, body) = request.into_parts();
323        let mut message = match expect_json(body).await {
324            Ok(message) => message,
325            Err(response) => return Ok(response),
326        };
327
328        if self.config.stateful_mode {
329            // do we have a session id?
330            let session_id = part
331                .headers
332                .get(HEADER_SESSION_ID)
333                .and_then(|v| v.to_str().ok());
334            if let Some(session_id) = session_id {
335                let session_id = session_id.to_owned().into();
336                let has_session = self
337                    .session_manager
338                    .has_session(&session_id)
339                    .await
340                    .map_err(internal_error_response("check session"))?;
341                if !has_session {
342                    // unauthorized
343                    return Ok(Response::builder()
344                        .status(http::StatusCode::UNAUTHORIZED)
345                        .body(Full::new(Bytes::from("Unauthorized: Session not found")).boxed())
346                        .expect("valid response"));
347                }
348
349                // inject request part to extensions
350                match &mut message {
351                    ClientJsonRpcMessage::Request(req) => {
352                        req.request.extensions_mut().insert(part);
353                    }
354                    ClientJsonRpcMessage::Notification(not) => {
355                        not.notification.extensions_mut().insert(part);
356                    }
357                    _ => {
358                        // skip
359                    }
360                }
361
362                match message {
363                    ClientJsonRpcMessage::Request(_) => {
364                        let stream = self
365                            .session_manager
366                            .create_stream(&session_id, message)
367                            .await
368                            .map_err(internal_error_response("get session"))?;
369                        // Prepend priming event if sse_retry configured
370                        let stream = if let Some(retry) = self.config.sse_retry {
371                            let priming = ServerSseMessage {
372                                event_id: Some("0".into()),
373                                message: None,
374                                retry: Some(retry),
375                            };
376                            futures::stream::once(async move { priming })
377                                .chain(stream)
378                                .left_stream()
379                        } else {
380                            stream.right_stream()
381                        };
382                        Ok(sse_stream_response(
383                            stream,
384                            self.config.sse_keep_alive,
385                            self.config.cancellation_token.child_token(),
386                        ))
387                    }
388                    ClientJsonRpcMessage::Notification(_)
389                    | ClientJsonRpcMessage::Response(_)
390                    | ClientJsonRpcMessage::Error(_) => {
391                        // handle notification
392                        self.session_manager
393                            .accept_message(&session_id, message)
394                            .await
395                            .map_err(internal_error_response("accept message"))?;
396                        Ok(accepted_response())
397                    }
398                }
399            } else {
400                let (session_id, transport) = self
401                    .session_manager
402                    .create_session()
403                    .await
404                    .map_err(internal_error_response("create session"))?;
405                if let ClientJsonRpcMessage::Request(req) = &mut message {
406                    if !matches!(req.request, ClientRequest::InitializeRequest(_)) {
407                        return Err(unexpected_message_response("initialize request"));
408                    }
409                    // inject request part to extensions
410                    req.request.extensions_mut().insert(part);
411                } else {
412                    return Err(unexpected_message_response("initialize request"));
413                }
414                let service = self
415                    .get_service()
416                    .map_err(internal_error_response("get service"))?;
417                // spawn a task to serve the session
418                tokio::spawn({
419                    let session_manager = self.session_manager.clone();
420                    let session_id = session_id.clone();
421                    async move {
422                        let service = serve_server::<S, M::Transport, _, TransportAdapterIdentity>(
423                            service, transport,
424                        )
425                        .await;
426                        match service {
427                            Ok(service) => {
428                                // on service created
429                                let _ = service.waiting().await;
430                            }
431                            Err(e) => {
432                                tracing::error!("Failed to create service: {e}");
433                            }
434                        }
435                        let _ = session_manager
436                            .close_session(&session_id)
437                            .await
438                            .inspect_err(|e| {
439                                tracing::error!("Failed to close session {session_id}: {e}");
440                            });
441                    }
442                });
443                // get initialize response
444                let response = self
445                    .session_manager
446                    .initialize_session(&session_id, message)
447                    .await
448                    .map_err(internal_error_response("create stream"))?;
449                let stream = futures::stream::once(async move {
450                    ServerSseMessage {
451                        event_id: None,
452                        message: Some(Arc::new(response)),
453                        retry: None,
454                    }
455                });
456                // Prepend priming event if sse_retry configured
457                let stream = if let Some(retry) = self.config.sse_retry {
458                    let priming = ServerSseMessage {
459                        event_id: Some("0".into()),
460                        message: None,
461                        retry: Some(retry),
462                    };
463                    futures::stream::once(async move { priming })
464                        .chain(stream)
465                        .left_stream()
466                } else {
467                    stream.right_stream()
468                };
469                let mut response = sse_stream_response(
470                    stream,
471                    self.config.sse_keep_alive,
472                    self.config.cancellation_token.child_token(),
473                );
474
475                response.headers_mut().insert(
476                    HEADER_SESSION_ID,
477                    session_id
478                        .parse()
479                        .map_err(internal_error_response("create session id header"))?,
480                );
481                Ok(response)
482            }
483        } else {
484            let service = self
485                .get_service()
486                .map_err(internal_error_response("get service"))?;
487            match message {
488                ClientJsonRpcMessage::Request(mut request) => {
489                    request.request.extensions_mut().insert(part);
490                    let (transport, receiver) =
491                        OneshotTransport::<RoleServer>::new(ClientJsonRpcMessage::Request(request));
492                    let service = serve_directly(service, transport, None);
493                    tokio::spawn(async move {
494                        // on service created
495                        let _ = service.waiting().await;
496                    });
497                    // Stateless mode: no priming (no session to resume)
498                    let stream = ReceiverStream::new(receiver).map(|message| {
499                        tracing::info!(?message);
500                        ServerSseMessage {
501                            event_id: None,
502                            message: Some(Arc::new(message)),
503                            retry: None,
504                        }
505                    });
506                    Ok(sse_stream_response(
507                        stream,
508                        self.config.sse_keep_alive,
509                        self.config.cancellation_token.child_token(),
510                    ))
511                }
512                ClientJsonRpcMessage::Notification(_notification) => {
513                    // ignore
514                    Ok(accepted_response())
515                }
516                ClientJsonRpcMessage::Response(_json_rpc_response) => Ok(accepted_response()),
517                ClientJsonRpcMessage::Error(_json_rpc_error) => Ok(accepted_response()),
518            }
519        }
520    }
521
522    async fn handle_delete<B>(&self, request: Request<B>) -> Result<BoxResponse, BoxResponse>
523    where
524        B: Body + Send + 'static,
525        B::Error: Display,
526    {
527        // check session id
528        let session_id = request
529            .headers()
530            .get(HEADER_SESSION_ID)
531            .and_then(|v| v.to_str().ok())
532            .map(|s| s.to_owned().into());
533        let Some(session_id) = session_id else {
534            // unauthorized
535            return Ok(Response::builder()
536                .status(http::StatusCode::UNAUTHORIZED)
537                .body(Full::new(Bytes::from("Unauthorized: Session ID is required")).boxed())
538                .expect("valid response"));
539        };
540        // close session
541        self.session_manager
542            .close_session(&session_id)
543            .await
544            .map_err(internal_error_response("close session"))?;
545        Ok(accepted_response())
546    }
547}