1use std::{convert::Infallible, fmt::Display, sync::Arc, time::Duration};
2
3use bytes::Bytes;
4use futures::{StreamExt, future::BoxFuture};
5use http::{Method, Request, Response, header::ALLOW};
6use http_body::Body;
7use http_body_util::{BodyExt, Full, combinators::BoxBody};
8use tokio_stream::wrappers::ReceiverStream;
9use tokio_util::sync::CancellationToken;
10
11use super::session::SessionManager;
12use crate::{
13 RoleServer,
14 model::{ClientJsonRpcMessage, ClientRequest, GetExtensions, ProtocolVersion},
15 serve_server,
16 service::serve_directly,
17 transport::{
18 OneshotTransport, TransportAdapterIdentity,
19 common::{
20 http_header::{
21 EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_MCP_PROTOCOL_VERSION,
22 HEADER_SESSION_ID, JSON_MIME_TYPE,
23 },
24 server_side_http::{
25 BoxResponse, ServerSseMessage, accepted_response, expect_json,
26 internal_error_response, sse_stream_response, unexpected_message_response,
27 },
28 },
29 },
30};
31
32#[derive(Debug, Clone)]
33pub struct StreamableHttpServerConfig {
34 pub sse_keep_alive: Option<Duration>,
36 pub sse_retry: Option<Duration>,
38 pub stateful_mode: bool,
41 pub json_response: bool,
46 pub cancellation_token: CancellationToken,
51}
52
53impl Default for StreamableHttpServerConfig {
54 fn default() -> Self {
55 Self {
56 sse_keep_alive: Some(Duration::from_secs(15)),
57 sse_retry: Some(Duration::from_secs(3)),
58 stateful_mode: true,
59 json_response: false,
60 cancellation_token: CancellationToken::new(),
61 }
62 }
63}
64
65#[expect(
66 clippy::result_large_err,
67 reason = "BoxResponse is intentionally large; matches other handlers in this file"
68)]
69fn validate_protocol_version_header(headers: &http::HeaderMap) -> Result<(), BoxResponse> {
75 if let Some(value) = headers.get(HEADER_MCP_PROTOCOL_VERSION) {
76 let version_str = value.to_str().map_err(|_| {
77 Response::builder()
78 .status(http::StatusCode::BAD_REQUEST)
79 .body(
80 Full::new(Bytes::from(
81 "Bad Request: Invalid MCP-Protocol-Version header encoding",
82 ))
83 .boxed(),
84 )
85 .expect("valid response")
86 })?;
87 let is_known = ProtocolVersion::KNOWN_VERSIONS
88 .iter()
89 .any(|v| v.as_str() == version_str);
90 if !is_known {
91 return Err(Response::builder()
92 .status(http::StatusCode::BAD_REQUEST)
93 .body(
94 Full::new(Bytes::from(format!(
95 "Bad Request: Unsupported MCP-Protocol-Version: {version_str}"
96 )))
97 .boxed(),
98 )
99 .expect("valid response"));
100 }
101 }
102 Ok(())
103}
104
105pub struct StreamableHttpService<S, M = super::session::local::LocalSessionManager> {
189 pub config: StreamableHttpServerConfig,
190 session_manager: Arc<M>,
191 service_factory: Arc<dyn Fn() -> Result<S, std::io::Error> + Send + Sync>,
192}
193
194impl<S, M> Clone for StreamableHttpService<S, M> {
195 fn clone(&self) -> Self {
196 Self {
197 config: self.config.clone(),
198 session_manager: self.session_manager.clone(),
199 service_factory: self.service_factory.clone(),
200 }
201 }
202}
203
204impl<RequestBody, S, M> tower_service::Service<Request<RequestBody>> for StreamableHttpService<S, M>
205where
206 RequestBody: Body + Send + 'static,
207 S: crate::Service<RoleServer>,
208 M: SessionManager,
209 RequestBody::Error: Display,
210 RequestBody::Data: Send + 'static,
211{
212 type Response = BoxResponse;
213 type Error = Infallible;
214 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
215 fn call(&mut self, req: http::Request<RequestBody>) -> Self::Future {
216 let service = self.clone();
217 Box::pin(async move {
218 let response = service.handle(req).await;
219 Ok(response)
220 })
221 }
222 fn poll_ready(
223 &mut self,
224 _cx: &mut std::task::Context<'_>,
225 ) -> std::task::Poll<Result<(), Self::Error>> {
226 std::task::Poll::Ready(Ok(()))
227 }
228}
229
230impl<S, M> StreamableHttpService<S, M>
231where
232 S: crate::Service<RoleServer> + Send + 'static,
233 M: SessionManager,
234{
235 pub fn new(
236 service_factory: impl Fn() -> Result<S, std::io::Error> + Send + Sync + 'static,
237 session_manager: Arc<M>,
238 config: StreamableHttpServerConfig,
239 ) -> Self {
240 Self {
241 config,
242 session_manager,
243 service_factory: Arc::new(service_factory),
244 }
245 }
246 fn get_service(&self) -> Result<S, std::io::Error> {
247 (self.service_factory)()
248 }
249 pub async fn handle<B>(&self, request: Request<B>) -> Response<BoxBody<Bytes, Infallible>>
250 where
251 B: Body + Send + 'static,
252 B::Error: Display,
253 {
254 let method = request.method().clone();
255 let allowed_methods = match self.config.stateful_mode {
256 true => "GET, POST, DELETE",
257 false => "POST",
258 };
259 let result = match (method, self.config.stateful_mode) {
260 (Method::POST, _) => self.handle_post(request).await,
261 (Method::GET, true) => self.handle_get(request).await,
263 (Method::DELETE, true) => self.handle_delete(request).await,
264 _ => {
265 let response = Response::builder()
267 .status(http::StatusCode::METHOD_NOT_ALLOWED)
268 .header(ALLOW, allowed_methods)
269 .body(Full::new(Bytes::from("Method Not Allowed")).boxed())
270 .expect("valid response");
271 return response;
272 }
273 };
274 match result {
275 Ok(response) => response,
276 Err(response) => response,
277 }
278 }
279 async fn handle_get<B>(&self, request: Request<B>) -> Result<BoxResponse, BoxResponse>
280 where
281 B: Body + Send + 'static,
282 B::Error: Display,
283 {
284 if !request
286 .headers()
287 .get(http::header::ACCEPT)
288 .and_then(|header| header.to_str().ok())
289 .is_some_and(|header| header.contains(EVENT_STREAM_MIME_TYPE))
290 {
291 return Ok(Response::builder()
292 .status(http::StatusCode::NOT_ACCEPTABLE)
293 .body(
294 Full::new(Bytes::from(
295 "Not Acceptable: Client must accept text/event-stream",
296 ))
297 .boxed(),
298 )
299 .expect("valid response"));
300 }
301 let session_id = request
303 .headers()
304 .get(HEADER_SESSION_ID)
305 .and_then(|v| v.to_str().ok())
306 .map(|s| s.to_owned().into());
307 let Some(session_id) = session_id else {
308 return Ok(Response::builder()
310 .status(http::StatusCode::BAD_REQUEST)
311 .body(Full::new(Bytes::from("Bad Request: Session ID is required")).boxed())
312 .expect("valid response"));
313 };
314 let has_session = self
316 .session_manager
317 .has_session(&session_id)
318 .await
319 .map_err(internal_error_response("check session"))?;
320 if !has_session {
321 return Ok(Response::builder()
323 .status(http::StatusCode::NOT_FOUND)
324 .body(Full::new(Bytes::from("Not Found: Session not found")).boxed())
325 .expect("valid response"));
326 }
327 validate_protocol_version_header(request.headers())?;
329 let last_event_id = request
331 .headers()
332 .get(HEADER_LAST_EVENT_ID)
333 .and_then(|v| v.to_str().ok())
334 .map(|s| s.to_owned());
335 if let Some(last_event_id) = last_event_id {
336 let stream = self
338 .session_manager
339 .resume(&session_id, last_event_id)
340 .await
341 .map_err(internal_error_response("resume session"))?;
342 Ok(sse_stream_response(
344 stream,
345 self.config.sse_keep_alive,
346 self.config.cancellation_token.child_token(),
347 ))
348 } else {
349 let stream = self
351 .session_manager
352 .create_standalone_stream(&session_id)
353 .await
354 .map_err(internal_error_response("create standalone stream"))?;
355 let stream = if let Some(retry) = self.config.sse_retry {
357 let priming = ServerSseMessage {
358 event_id: Some("0".into()),
359 message: None,
360 retry: Some(retry),
361 };
362 futures::stream::once(async move { priming })
363 .chain(stream)
364 .left_stream()
365 } else {
366 stream.right_stream()
367 };
368 Ok(sse_stream_response(
369 stream,
370 self.config.sse_keep_alive,
371 self.config.cancellation_token.child_token(),
372 ))
373 }
374 }
375
376 async fn handle_post<B>(&self, request: Request<B>) -> Result<BoxResponse, BoxResponse>
377 where
378 B: Body + Send + 'static,
379 B::Error: Display,
380 {
381 if !request
383 .headers()
384 .get(http::header::ACCEPT)
385 .and_then(|header| header.to_str().ok())
386 .is_some_and(|header| {
387 header.contains(JSON_MIME_TYPE) && header.contains(EVENT_STREAM_MIME_TYPE)
388 })
389 {
390 return Ok(Response::builder()
391 .status(http::StatusCode::NOT_ACCEPTABLE)
392 .body(Full::new(Bytes::from("Not Acceptable: Client must accept both application/json and text/event-stream")).boxed())
393 .expect("valid response"));
394 }
395
396 if !request
398 .headers()
399 .get(http::header::CONTENT_TYPE)
400 .and_then(|header| header.to_str().ok())
401 .is_some_and(|header| header.starts_with(JSON_MIME_TYPE))
402 {
403 return Ok(Response::builder()
404 .status(http::StatusCode::UNSUPPORTED_MEDIA_TYPE)
405 .body(
406 Full::new(Bytes::from(
407 "Unsupported Media Type: Content-Type must be application/json",
408 ))
409 .boxed(),
410 )
411 .expect("valid response"));
412 }
413
414 let (part, body) = request.into_parts();
416 let mut message = match expect_json(body).await {
417 Ok(message) => message,
418 Err(response) => return Ok(response),
419 };
420
421 if self.config.stateful_mode {
422 let session_id = part
424 .headers
425 .get(HEADER_SESSION_ID)
426 .and_then(|v| v.to_str().ok());
427 if let Some(session_id) = session_id {
428 let session_id = session_id.to_owned().into();
429 let has_session = self
430 .session_manager
431 .has_session(&session_id)
432 .await
433 .map_err(internal_error_response("check session"))?;
434 if !has_session {
435 return Ok(Response::builder()
437 .status(http::StatusCode::NOT_FOUND)
438 .body(Full::new(Bytes::from("Not Found: Session not found")).boxed())
439 .expect("valid response"));
440 }
441
442 validate_protocol_version_header(&part.headers)?;
444
445 match &mut message {
447 ClientJsonRpcMessage::Request(req) => {
448 req.request.extensions_mut().insert(part);
449 }
450 ClientJsonRpcMessage::Notification(not) => {
451 not.notification.extensions_mut().insert(part);
452 }
453 _ => {
454 }
456 }
457
458 match message {
459 ClientJsonRpcMessage::Request(_) => {
460 let stream = self
461 .session_manager
462 .create_stream(&session_id, message)
463 .await
464 .map_err(internal_error_response("get session"))?;
465 let stream = if let Some(retry) = self.config.sse_retry {
467 let priming = ServerSseMessage {
468 event_id: Some("0".into()),
469 message: None,
470 retry: Some(retry),
471 };
472 futures::stream::once(async move { priming })
473 .chain(stream)
474 .left_stream()
475 } else {
476 stream.right_stream()
477 };
478 Ok(sse_stream_response(
479 stream,
480 self.config.sse_keep_alive,
481 self.config.cancellation_token.child_token(),
482 ))
483 }
484 ClientJsonRpcMessage::Notification(_)
485 | ClientJsonRpcMessage::Response(_)
486 | ClientJsonRpcMessage::Error(_) => {
487 self.session_manager
489 .accept_message(&session_id, message)
490 .await
491 .map_err(internal_error_response("accept message"))?;
492 Ok(accepted_response())
493 }
494 }
495 } else {
496 let (session_id, transport) = self
497 .session_manager
498 .create_session()
499 .await
500 .map_err(internal_error_response("create session"))?;
501 if let ClientJsonRpcMessage::Request(req) = &mut message {
502 if !matches!(req.request, ClientRequest::InitializeRequest(_)) {
503 return Err(unexpected_message_response("initialize request"));
504 }
505 req.request.extensions_mut().insert(part);
507 } else {
508 return Err(unexpected_message_response("initialize request"));
509 }
510 let service = self
511 .get_service()
512 .map_err(internal_error_response("get service"))?;
513 tokio::spawn({
515 let session_manager = self.session_manager.clone();
516 let session_id = session_id.clone();
517 async move {
518 let service = serve_server::<S, M::Transport, _, TransportAdapterIdentity>(
519 service, transport,
520 )
521 .await;
522 match service {
523 Ok(service) => {
524 let _ = service.waiting().await;
526 }
527 Err(e) => {
528 tracing::error!("Failed to create service: {e}");
529 }
530 }
531 let _ = session_manager
532 .close_session(&session_id)
533 .await
534 .inspect_err(|e| {
535 tracing::error!("Failed to close session {session_id}: {e}");
536 });
537 }
538 });
539 let response = self
541 .session_manager
542 .initialize_session(&session_id, message)
543 .await
544 .map_err(internal_error_response("create stream"))?;
545 let stream = futures::stream::once(async move {
546 ServerSseMessage {
547 event_id: None,
548 message: Some(Arc::new(response)),
549 retry: None,
550 }
551 });
552 let stream = if let Some(retry) = self.config.sse_retry {
554 let priming = ServerSseMessage {
555 event_id: Some("0".into()),
556 message: None,
557 retry: Some(retry),
558 };
559 futures::stream::once(async move { priming })
560 .chain(stream)
561 .left_stream()
562 } else {
563 stream.right_stream()
564 };
565 let mut response = sse_stream_response(
566 stream,
567 self.config.sse_keep_alive,
568 self.config.cancellation_token.child_token(),
569 );
570
571 response.headers_mut().insert(
572 HEADER_SESSION_ID,
573 session_id
574 .parse()
575 .map_err(internal_error_response("create session id header"))?,
576 );
577 Ok(response)
578 }
579 } else {
580 let is_init = matches!(
582 &message,
583 ClientJsonRpcMessage::Request(req) if matches!(req.request, ClientRequest::InitializeRequest(_))
584 );
585 if !is_init {
586 validate_protocol_version_header(&part.headers)?;
587 }
588 let service = self
589 .get_service()
590 .map_err(internal_error_response("get service"))?;
591 match message {
592 ClientJsonRpcMessage::Request(mut request) => {
593 request.request.extensions_mut().insert(part);
594 let (transport, mut receiver) =
595 OneshotTransport::<RoleServer>::new(ClientJsonRpcMessage::Request(request));
596 let service = serve_directly(service, transport, None);
597 tokio::spawn(async move {
598 let _ = service.waiting().await;
600 });
601 if self.config.json_response {
602 let cancel = self.config.cancellation_token.child_token();
606 match tokio::select! {
607 res = receiver.recv() => res,
608 _ = cancel.cancelled() => None,
609 } {
610 Some(message) => {
611 tracing::trace!(?message);
612 let body = serde_json::to_vec(&message).map_err(|e| {
613 internal_error_response("serialize json response")(e)
614 })?;
615 Ok(Response::builder()
616 .status(http::StatusCode::OK)
617 .header(http::header::CONTENT_TYPE, JSON_MIME_TYPE)
618 .body(Full::new(Bytes::from(body)).boxed())
619 .expect("valid response"))
620 }
621 None => Err(internal_error_response("empty response")(
622 std::io::Error::new(
623 std::io::ErrorKind::UnexpectedEof,
624 "no response message received from handler",
625 ),
626 )),
627 }
628 } else {
629 let stream = ReceiverStream::new(receiver).map(|message| {
631 tracing::trace!(?message);
632 ServerSseMessage {
633 event_id: None,
634 message: Some(Arc::new(message)),
635 retry: None,
636 }
637 });
638 Ok(sse_stream_response(
639 stream,
640 self.config.sse_keep_alive,
641 self.config.cancellation_token.child_token(),
642 ))
643 }
644 }
645 ClientJsonRpcMessage::Notification(_notification) => {
646 Ok(accepted_response())
648 }
649 ClientJsonRpcMessage::Response(_json_rpc_response) => Ok(accepted_response()),
650 ClientJsonRpcMessage::Error(_json_rpc_error) => Ok(accepted_response()),
651 }
652 }
653 }
654
655 async fn handle_delete<B>(&self, request: Request<B>) -> Result<BoxResponse, BoxResponse>
656 where
657 B: Body + Send + 'static,
658 B::Error: Display,
659 {
660 let session_id = request
662 .headers()
663 .get(HEADER_SESSION_ID)
664 .and_then(|v| v.to_str().ok())
665 .map(|s| s.to_owned().into());
666 let Some(session_id) = session_id else {
667 return Ok(Response::builder()
669 .status(http::StatusCode::BAD_REQUEST)
670 .body(Full::new(Bytes::from("Bad Request: Session ID is required")).boxed())
671 .expect("valid response"));
672 };
673 validate_protocol_version_header(request.headers())?;
675 self.session_manager
677 .close_session(&session_id)
678 .await
679 .map_err(internal_error_response("close session"))?;
680 Ok(accepted_response())
681 }
682}