1use std::{convert::Infallible, fmt::Display, sync::Arc, time::Duration};
2
3use bytes::Bytes;
4use futures::{StreamExt, future::BoxFuture};
5use http::{HeaderMap, Method, Request, Response, header::ALLOW};
6use http_body::Body;
7use http_body_util::{BodyExt, Full, combinators::BoxBody};
8use tokio_stream::wrappers::ReceiverStream;
9use tokio_util::sync::CancellationToken;
10
11use super::session::SessionManager;
12use crate::{
13 RoleServer,
14 model::{ClientJsonRpcMessage, ClientRequest, GetExtensions, ProtocolVersion},
15 serve_server,
16 service::serve_directly,
17 transport::{
18 OneshotTransport, TransportAdapterIdentity,
19 common::{
20 http_header::{
21 EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_MCP_PROTOCOL_VERSION,
22 HEADER_SESSION_ID, JSON_MIME_TYPE,
23 },
24 server_side_http::{
25 BoxResponse, ServerSseMessage, accepted_response, expect_json,
26 internal_error_response, sse_stream_response, unexpected_message_response,
27 },
28 },
29 },
30};
31
32#[non_exhaustive]
33#[derive(Debug, Clone)]
34pub struct StreamableHttpServerConfig {
35 pub sse_keep_alive: Option<Duration>,
37 pub sse_retry: Option<Duration>,
39 pub stateful_mode: bool,
42 pub json_response: bool,
47 pub cancellation_token: CancellationToken,
52 pub allowed_hosts: Vec<String>,
62}
63
64impl Default for StreamableHttpServerConfig {
65 fn default() -> Self {
66 Self {
67 sse_keep_alive: Some(Duration::from_secs(15)),
68 sse_retry: Some(Duration::from_secs(3)),
69 stateful_mode: true,
70 json_response: false,
71 cancellation_token: CancellationToken::new(),
72 allowed_hosts: vec!["localhost".into(), "127.0.0.1".into(), "::1".into()],
73 }
74 }
75}
76
77impl StreamableHttpServerConfig {
78 pub fn with_allowed_hosts(
79 mut self,
80 allowed_hosts: impl IntoIterator<Item = impl Into<String>>,
81 ) -> Self {
82 self.allowed_hosts = allowed_hosts.into_iter().map(Into::into).collect();
83 self
84 }
85 pub fn disable_allowed_hosts(mut self) -> Self {
87 self.allowed_hosts.clear();
88 self
89 }
90 pub fn with_sse_keep_alive(mut self, duration: Option<Duration>) -> Self {
91 self.sse_keep_alive = duration;
92 self
93 }
94
95 pub fn with_sse_retry(mut self, duration: Option<Duration>) -> Self {
96 self.sse_retry = duration;
97 self
98 }
99
100 pub fn with_stateful_mode(mut self, stateful: bool) -> Self {
101 self.stateful_mode = stateful;
102 self
103 }
104
105 pub fn with_json_response(mut self, json_response: bool) -> Self {
106 self.json_response = json_response;
107 self
108 }
109
110 pub fn with_cancellation_token(mut self, token: CancellationToken) -> Self {
111 self.cancellation_token = token;
112 self
113 }
114}
115
116#[expect(
117 clippy::result_large_err,
118 reason = "BoxResponse is intentionally large; matches other handlers in this file"
119)]
120fn validate_protocol_version_header(headers: &http::HeaderMap) -> Result<(), BoxResponse> {
126 if let Some(value) = headers.get(HEADER_MCP_PROTOCOL_VERSION) {
127 let version_str = value.to_str().map_err(|_| {
128 Response::builder()
129 .status(http::StatusCode::BAD_REQUEST)
130 .body(
131 Full::new(Bytes::from(
132 "Bad Request: Invalid MCP-Protocol-Version header encoding",
133 ))
134 .boxed(),
135 )
136 .expect("valid response")
137 })?;
138 let is_known = ProtocolVersion::KNOWN_VERSIONS
139 .iter()
140 .any(|v| v.as_str() == version_str);
141 if !is_known {
142 return Err(Response::builder()
143 .status(http::StatusCode::BAD_REQUEST)
144 .body(
145 Full::new(Bytes::from(format!(
146 "Bad Request: Unsupported MCP-Protocol-Version: {version_str}"
147 )))
148 .boxed(),
149 )
150 .expect("valid response"));
151 }
152 }
153 Ok(())
154}
155
156fn forbidden_response(message: impl Into<String>) -> BoxResponse {
157 Response::builder()
158 .status(http::StatusCode::FORBIDDEN)
159 .body(Full::new(Bytes::from(message.into())).boxed())
160 .expect("valid response")
161}
162
163fn normalize_host(host: &str) -> String {
164 host.trim_matches('[')
165 .trim_matches(']')
166 .to_ascii_lowercase()
167}
168
169#[derive(Debug, Clone, PartialEq, Eq)]
170struct NormalizedAuthority {
171 host: String,
172 port: Option<u16>,
173}
174
175fn normalize_authority(host: &str, port: Option<u16>) -> NormalizedAuthority {
176 NormalizedAuthority {
177 host: normalize_host(host),
178 port,
179 }
180}
181
182fn parse_allowed_authority(allowed: &str) -> Option<NormalizedAuthority> {
183 let allowed = allowed.trim();
184 if allowed.is_empty() {
185 return None;
186 }
187
188 if let Ok(authority) = http::uri::Authority::try_from(allowed) {
189 return Some(normalize_authority(authority.host(), authority.port_u16()));
190 }
191
192 Some(normalize_authority(allowed, None))
193}
194
195fn host_is_allowed(host: &NormalizedAuthority, allowed_hosts: &[String]) -> bool {
196 if allowed_hosts.is_empty() {
197 return true;
199 }
200 allowed_hosts
201 .iter()
202 .filter_map(|allowed| parse_allowed_authority(allowed))
203 .any(|allowed| {
204 allowed.host == host.host
205 && match allowed.port {
206 Some(port) => host.port == Some(port),
207 None => true,
208 }
209 })
210}
211
212fn bad_request_response(message: &str) -> BoxResponse {
213 let body = Full::from(message.to_string()).boxed();
214
215 http::Response::builder()
216 .status(http::StatusCode::BAD_REQUEST)
217 .header(http::header::CONTENT_TYPE, "text/plain; charset=utf-8")
218 .body(body)
219 .expect("failed to build bad request response")
220}
221
222fn parse_host_header(headers: &HeaderMap) -> Result<NormalizedAuthority, BoxResponse> {
223 let Some(host) = headers.get(http::header::HOST) else {
224 return Err(bad_request_response("Bad Request: missing Host header"));
225 };
226
227 let host = host
228 .to_str()
229 .map_err(|_| bad_request_response("Bad Request: Invalid Host header encoding"))?;
230 let authority = http::uri::Authority::try_from(host)
231 .map_err(|_| bad_request_response("Bad Request: Invalid Host header"))?;
232 Ok(normalize_authority(authority.host(), authority.port_u16()))
233}
234
235fn validate_dns_rebinding_headers(
236 headers: &HeaderMap,
237 config: &StreamableHttpServerConfig,
238) -> Result<(), BoxResponse> {
239 let host = parse_host_header(headers)?;
240 if !host_is_allowed(&host, &config.allowed_hosts) {
241 return Err(forbidden_response("Forbidden: Host header is not allowed"));
242 }
243
244 Ok(())
245}
246
247pub struct StreamableHttpService<S, M> {
331 pub config: StreamableHttpServerConfig,
332 session_manager: Arc<M>,
333 service_factory: Arc<dyn Fn() -> Result<S, std::io::Error> + Send + Sync>,
334}
335
336impl<S, M> Clone for StreamableHttpService<S, M> {
337 fn clone(&self) -> Self {
338 Self {
339 config: self.config.clone(),
340 session_manager: self.session_manager.clone(),
341 service_factory: self.service_factory.clone(),
342 }
343 }
344}
345
346impl<RequestBody, S, M> tower_service::Service<Request<RequestBody>> for StreamableHttpService<S, M>
347where
348 RequestBody: Body + Send + 'static,
349 S: crate::Service<RoleServer> + Send + 'static,
350 M: SessionManager,
351 RequestBody::Error: Display,
352 RequestBody::Data: Send + 'static,
353{
354 type Response = BoxResponse;
355 type Error = Infallible;
356 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
357 fn call(&mut self, req: http::Request<RequestBody>) -> Self::Future {
358 let service = self.clone();
359 Box::pin(async move {
360 let response = service.handle(req).await;
361 Ok(response)
362 })
363 }
364 fn poll_ready(
365 &mut self,
366 _cx: &mut std::task::Context<'_>,
367 ) -> std::task::Poll<Result<(), Self::Error>> {
368 std::task::Poll::Ready(Ok(()))
369 }
370}
371
372impl<S, M> StreamableHttpService<S, M>
373where
374 S: crate::Service<RoleServer> + Send + 'static,
375 M: SessionManager,
376{
377 pub fn new(
378 service_factory: impl Fn() -> Result<S, std::io::Error> + Send + Sync + 'static,
379 session_manager: Arc<M>,
380 config: StreamableHttpServerConfig,
381 ) -> Self {
382 Self {
383 config,
384 session_manager,
385 service_factory: Arc::new(service_factory),
386 }
387 }
388 fn get_service(&self) -> Result<S, std::io::Error> {
389 (self.service_factory)()
390 }
391 pub async fn handle<B>(&self, request: Request<B>) -> Response<BoxBody<Bytes, Infallible>>
392 where
393 B: Body + Send + 'static,
394 B::Error: Display,
395 {
396 if let Err(response) = validate_dns_rebinding_headers(request.headers(), &self.config) {
397 return response;
398 }
399 let method = request.method().clone();
400 let allowed_methods = match self.config.stateful_mode {
401 true => "GET, POST, DELETE",
402 false => "POST",
403 };
404 let result = match (method, self.config.stateful_mode) {
405 (Method::POST, _) => self.handle_post(request).await,
406 (Method::GET, true) => self.handle_get(request).await,
408 (Method::DELETE, true) => self.handle_delete(request).await,
409 _ => {
410 let response = Response::builder()
412 .status(http::StatusCode::METHOD_NOT_ALLOWED)
413 .header(ALLOW, allowed_methods)
414 .body(Full::new(Bytes::from("Method Not Allowed")).boxed())
415 .expect("valid response");
416 return response;
417 }
418 };
419 match result {
420 Ok(response) => response,
421 Err(response) => response,
422 }
423 }
424 async fn handle_get<B>(&self, request: Request<B>) -> Result<BoxResponse, BoxResponse>
425 where
426 B: Body + Send + 'static,
427 B::Error: Display,
428 {
429 if !request
431 .headers()
432 .get(http::header::ACCEPT)
433 .and_then(|header| header.to_str().ok())
434 .is_some_and(|header| header.contains(EVENT_STREAM_MIME_TYPE))
435 {
436 return Ok(Response::builder()
437 .status(http::StatusCode::NOT_ACCEPTABLE)
438 .body(
439 Full::new(Bytes::from(
440 "Not Acceptable: Client must accept text/event-stream",
441 ))
442 .boxed(),
443 )
444 .expect("valid response"));
445 }
446 let session_id = request
448 .headers()
449 .get(HEADER_SESSION_ID)
450 .and_then(|v| v.to_str().ok())
451 .map(|s| s.to_owned().into());
452 let Some(session_id) = session_id else {
453 return Ok(Response::builder()
455 .status(http::StatusCode::BAD_REQUEST)
456 .body(Full::new(Bytes::from("Bad Request: Session ID is required")).boxed())
457 .expect("valid response"));
458 };
459 let has_session = self
461 .session_manager
462 .has_session(&session_id)
463 .await
464 .map_err(internal_error_response("check session"))?;
465 if !has_session {
466 return Ok(Response::builder()
468 .status(http::StatusCode::NOT_FOUND)
469 .body(Full::new(Bytes::from("Not Found: Session not found")).boxed())
470 .expect("valid response"));
471 }
472 validate_protocol_version_header(request.headers())?;
474 let last_event_id = request
476 .headers()
477 .get(HEADER_LAST_EVENT_ID)
478 .and_then(|v| v.to_str().ok())
479 .map(|s| s.to_owned());
480 if let Some(last_event_id) = last_event_id {
481 let stream = self
483 .session_manager
484 .resume(&session_id, last_event_id)
485 .await
486 .map_err(internal_error_response("resume session"))?;
487 Ok(sse_stream_response(
489 stream,
490 self.config.sse_keep_alive,
491 self.config.cancellation_token.child_token(),
492 ))
493 } else {
494 let stream = self
496 .session_manager
497 .create_standalone_stream(&session_id)
498 .await
499 .map_err(internal_error_response("create standalone stream"))?;
500 let stream = if let Some(retry) = self.config.sse_retry {
502 let priming = ServerSseMessage::priming("0", retry);
503 futures::stream::once(async move { priming })
504 .chain(stream)
505 .left_stream()
506 } else {
507 stream.right_stream()
508 };
509 Ok(sse_stream_response(
510 stream,
511 self.config.sse_keep_alive,
512 self.config.cancellation_token.child_token(),
513 ))
514 }
515 }
516
517 async fn handle_post<B>(&self, request: Request<B>) -> Result<BoxResponse, BoxResponse>
518 where
519 B: Body + Send + 'static,
520 B::Error: Display,
521 {
522 if !request
524 .headers()
525 .get(http::header::ACCEPT)
526 .and_then(|header| header.to_str().ok())
527 .is_some_and(|header| {
528 header.contains(JSON_MIME_TYPE) && header.contains(EVENT_STREAM_MIME_TYPE)
529 })
530 {
531 return Ok(Response::builder()
532 .status(http::StatusCode::NOT_ACCEPTABLE)
533 .body(Full::new(Bytes::from("Not Acceptable: Client must accept both application/json and text/event-stream")).boxed())
534 .expect("valid response"));
535 }
536
537 if !request
539 .headers()
540 .get(http::header::CONTENT_TYPE)
541 .and_then(|header| header.to_str().ok())
542 .is_some_and(|header| header.starts_with(JSON_MIME_TYPE))
543 {
544 return Ok(Response::builder()
545 .status(http::StatusCode::UNSUPPORTED_MEDIA_TYPE)
546 .body(
547 Full::new(Bytes::from(
548 "Unsupported Media Type: Content-Type must be application/json",
549 ))
550 .boxed(),
551 )
552 .expect("valid response"));
553 }
554
555 let (part, body) = request.into_parts();
557 let mut message = match expect_json(body).await {
558 Ok(message) => message,
559 Err(response) => return Ok(response),
560 };
561
562 if self.config.stateful_mode {
563 let session_id = part
565 .headers
566 .get(HEADER_SESSION_ID)
567 .and_then(|v| v.to_str().ok());
568 if let Some(session_id) = session_id {
569 let session_id = session_id.to_owned().into();
570 let has_session = self
571 .session_manager
572 .has_session(&session_id)
573 .await
574 .map_err(internal_error_response("check session"))?;
575 if !has_session {
576 return Ok(Response::builder()
578 .status(http::StatusCode::NOT_FOUND)
579 .body(Full::new(Bytes::from("Not Found: Session not found")).boxed())
580 .expect("valid response"));
581 }
582
583 validate_protocol_version_header(&part.headers)?;
585
586 match &mut message {
588 ClientJsonRpcMessage::Request(req) => {
589 req.request.extensions_mut().insert(part);
590 }
591 ClientJsonRpcMessage::Notification(not) => {
592 not.notification.extensions_mut().insert(part);
593 }
594 _ => {
595 }
597 }
598
599 match message {
600 ClientJsonRpcMessage::Request(_) => {
601 let stream = self
602 .session_manager
603 .create_stream(&session_id, message)
604 .await
605 .map_err(internal_error_response("get session"))?;
606 let stream = if let Some(retry) = self.config.sse_retry {
608 let priming = ServerSseMessage::priming("0", retry);
609 futures::stream::once(async move { priming })
610 .chain(stream)
611 .left_stream()
612 } else {
613 stream.right_stream()
614 };
615 Ok(sse_stream_response(
616 stream,
617 self.config.sse_keep_alive,
618 self.config.cancellation_token.child_token(),
619 ))
620 }
621 ClientJsonRpcMessage::Notification(_)
622 | ClientJsonRpcMessage::Response(_)
623 | ClientJsonRpcMessage::Error(_) => {
624 self.session_manager
626 .accept_message(&session_id, message)
627 .await
628 .map_err(internal_error_response("accept message"))?;
629 Ok(accepted_response())
630 }
631 }
632 } else {
633 let (session_id, transport) = self
634 .session_manager
635 .create_session()
636 .await
637 .map_err(internal_error_response("create session"))?;
638 if let ClientJsonRpcMessage::Request(req) = &mut message {
639 if !matches!(req.request, ClientRequest::InitializeRequest(_)) {
640 return Err(unexpected_message_response("initialize request"));
641 }
642 req.request.extensions_mut().insert(part);
644 } else {
645 return Err(unexpected_message_response("initialize request"));
646 }
647 let service = self
648 .get_service()
649 .map_err(internal_error_response("get service"))?;
650 tokio::spawn({
652 let session_manager = self.session_manager.clone();
653 let session_id = session_id.clone();
654 async move {
655 let service = serve_server::<S, M::Transport, _, TransportAdapterIdentity>(
656 service, transport,
657 )
658 .await;
659 match service {
660 Ok(service) => {
661 let _ = service.waiting().await;
663 }
664 Err(e) => {
665 tracing::error!("Failed to create service: {e}");
666 }
667 }
668 let _ = session_manager
669 .close_session(&session_id)
670 .await
671 .inspect_err(|e| {
672 tracing::error!("Failed to close session {session_id}: {e}");
673 });
674 }
675 });
676 let response = self
678 .session_manager
679 .initialize_session(&session_id, message)
680 .await
681 .map_err(internal_error_response("create stream"))?;
682 let stream =
683 futures::stream::once(async move { ServerSseMessage::from_message(response) });
684 let stream = if let Some(retry) = self.config.sse_retry {
686 let priming = ServerSseMessage::priming("0", retry);
687 futures::stream::once(async move { priming })
688 .chain(stream)
689 .left_stream()
690 } else {
691 stream.right_stream()
692 };
693 let mut response = sse_stream_response(
694 stream,
695 self.config.sse_keep_alive,
696 self.config.cancellation_token.child_token(),
697 );
698
699 response.headers_mut().insert(
700 HEADER_SESSION_ID,
701 session_id
702 .parse()
703 .map_err(internal_error_response("create session id header"))?,
704 );
705 Ok(response)
706 }
707 } else {
708 let is_init = matches!(
710 &message,
711 ClientJsonRpcMessage::Request(req) if matches!(req.request, ClientRequest::InitializeRequest(_))
712 );
713 if !is_init {
714 validate_protocol_version_header(&part.headers)?;
715 }
716 let service = self
717 .get_service()
718 .map_err(internal_error_response("get service"))?;
719 match message {
720 ClientJsonRpcMessage::Request(mut request) => {
721 request.request.extensions_mut().insert(part);
722 let (transport, mut receiver) =
723 OneshotTransport::<RoleServer>::new(ClientJsonRpcMessage::Request(request));
724 let service = serve_directly(service, transport, None);
725 tokio::spawn(async move {
726 let _ = service.waiting().await;
728 });
729 if self.config.json_response {
730 let cancel = self.config.cancellation_token.child_token();
734 match tokio::select! {
735 res = receiver.recv() => res,
736 _ = cancel.cancelled() => None,
737 } {
738 Some(message) => {
739 tracing::trace!(?message);
740 let body = serde_json::to_vec(&message).map_err(|e| {
741 internal_error_response("serialize json response")(e)
742 })?;
743 Ok(Response::builder()
744 .status(http::StatusCode::OK)
745 .header(http::header::CONTENT_TYPE, JSON_MIME_TYPE)
746 .body(Full::new(Bytes::from(body)).boxed())
747 .expect("valid response"))
748 }
749 None => Err(internal_error_response("empty response")(
750 std::io::Error::new(
751 std::io::ErrorKind::UnexpectedEof,
752 "no response message received from handler",
753 ),
754 )),
755 }
756 } else {
757 let stream = ReceiverStream::new(receiver).map(|message| {
759 tracing::trace!(?message);
760 ServerSseMessage::from_message(message)
761 });
762 Ok(sse_stream_response(
763 stream,
764 self.config.sse_keep_alive,
765 self.config.cancellation_token.child_token(),
766 ))
767 }
768 }
769 ClientJsonRpcMessage::Notification(_notification) => {
770 Ok(accepted_response())
772 }
773 ClientJsonRpcMessage::Response(_json_rpc_response) => Ok(accepted_response()),
774 ClientJsonRpcMessage::Error(_json_rpc_error) => Ok(accepted_response()),
775 }
776 }
777 }
778
779 async fn handle_delete<B>(&self, request: Request<B>) -> Result<BoxResponse, BoxResponse>
780 where
781 B: Body + Send + 'static,
782 B::Error: Display,
783 {
784 let session_id = request
786 .headers()
787 .get(HEADER_SESSION_ID)
788 .and_then(|v| v.to_str().ok())
789 .map(|s| s.to_owned().into());
790 let Some(session_id) = session_id else {
791 return Ok(Response::builder()
793 .status(http::StatusCode::BAD_REQUEST)
794 .body(Full::new(Bytes::from("Bad Request: Session ID is required")).boxed())
795 .expect("valid response"));
796 };
797 validate_protocol_version_header(request.headers())?;
799 self.session_manager
801 .close_session(&session_id)
802 .await
803 .map_err(internal_error_response("close session"))?;
804 Ok(accepted_response())
805 }
806}