1use std::{collections::HashMap, convert::Infallible, fmt::Display, sync::Arc, time::Duration};
2
3use bytes::Bytes;
4use futures::{StreamExt, future::BoxFuture};
5use http::{HeaderMap, Method, Request, Response, header::ALLOW};
6use http_body::Body;
7use http_body_util::{BodyExt, Full, combinators::BoxBody};
8use tokio_stream::wrappers::ReceiverStream;
9use tokio_util::sync::CancellationToken;
10
11use super::session::{
12 RestoreOutcome, SessionId, SessionManager, SessionRestoreMarker, SessionState, SessionStore,
13};
14use crate::{
15 RoleServer,
16 model::{
17 ClientJsonRpcMessage, ClientNotification, ClientRequest, GetExtensions, InitializeRequest,
18 InitializedNotification, ProtocolVersion,
19 },
20 serve_server,
21 service::serve_directly,
22 transport::{
23 OneshotTransport, TransportAdapterIdentity,
24 common::{
25 http_header::{
26 EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_MCP_PROTOCOL_VERSION,
27 HEADER_SESSION_ID, JSON_MIME_TYPE,
28 },
29 server_side_http::{
30 BoxResponse, ServerSseMessage, accepted_response, expect_json,
31 internal_error_response, sse_stream_response, unexpected_message_response,
32 },
33 },
34 },
35};
36
37#[non_exhaustive]
38#[derive(Debug, Clone)]
39pub struct StreamableHttpServerConfig {
40 pub sse_keep_alive: Option<Duration>,
42 pub sse_retry: Option<Duration>,
44 pub stateful_mode: bool,
47 pub json_response: bool,
52 pub cancellation_token: CancellationToken,
57 pub allowed_hosts: Vec<String>,
67 pub session_store: Option<Arc<dyn SessionStore>>,
89}
90
91impl std::fmt::Debug for dyn SessionStore {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 f.write_str("<SessionStore>")
94 }
95}
96
97impl Default for StreamableHttpServerConfig {
98 fn default() -> Self {
99 Self {
100 sse_keep_alive: Some(Duration::from_secs(15)),
101 sse_retry: Some(Duration::from_secs(3)),
102 stateful_mode: true,
103 json_response: false,
104 cancellation_token: CancellationToken::new(),
105 allowed_hosts: vec!["localhost".into(), "127.0.0.1".into(), "::1".into()],
106 session_store: None,
107 }
108 }
109}
110
111impl StreamableHttpServerConfig {
112 pub fn with_allowed_hosts(
113 mut self,
114 allowed_hosts: impl IntoIterator<Item = impl Into<String>>,
115 ) -> Self {
116 self.allowed_hosts = allowed_hosts.into_iter().map(Into::into).collect();
117 self
118 }
119 pub fn disable_allowed_hosts(mut self) -> Self {
121 self.allowed_hosts.clear();
122 self
123 }
124 pub fn with_sse_keep_alive(mut self, duration: Option<Duration>) -> Self {
125 self.sse_keep_alive = duration;
126 self
127 }
128
129 pub fn with_sse_retry(mut self, duration: Option<Duration>) -> Self {
130 self.sse_retry = duration;
131 self
132 }
133
134 pub fn with_stateful_mode(mut self, stateful: bool) -> Self {
135 self.stateful_mode = stateful;
136 self
137 }
138
139 pub fn with_json_response(mut self, json_response: bool) -> Self {
140 self.json_response = json_response;
141 self
142 }
143
144 pub fn with_cancellation_token(mut self, token: CancellationToken) -> Self {
145 self.cancellation_token = token;
146 self
147 }
148}
149
150#[expect(
151 clippy::result_large_err,
152 reason = "BoxResponse is intentionally large; matches other handlers in this file"
153)]
154fn validate_protocol_version_header(headers: &http::HeaderMap) -> Result<(), BoxResponse> {
160 if let Some(value) = headers.get(HEADER_MCP_PROTOCOL_VERSION) {
161 let version_str = value.to_str().map_err(|_| {
162 Response::builder()
163 .status(http::StatusCode::BAD_REQUEST)
164 .body(
165 Full::new(Bytes::from(
166 "Bad Request: Invalid MCP-Protocol-Version header encoding",
167 ))
168 .boxed(),
169 )
170 .expect("valid response")
171 })?;
172 let is_known = ProtocolVersion::KNOWN_VERSIONS
173 .iter()
174 .any(|v| v.as_str() == version_str);
175 if !is_known {
176 return Err(Response::builder()
177 .status(http::StatusCode::BAD_REQUEST)
178 .body(
179 Full::new(Bytes::from(format!(
180 "Bad Request: Unsupported MCP-Protocol-Version: {version_str}"
181 )))
182 .boxed(),
183 )
184 .expect("valid response"));
185 }
186 }
187 Ok(())
188}
189
190fn forbidden_response(message: impl Into<String>) -> BoxResponse {
191 Response::builder()
192 .status(http::StatusCode::FORBIDDEN)
193 .body(Full::new(Bytes::from(message.into())).boxed())
194 .expect("valid response")
195}
196
197fn normalize_host(host: &str) -> String {
198 host.trim_matches('[')
199 .trim_matches(']')
200 .to_ascii_lowercase()
201}
202
203#[derive(Debug, Clone, PartialEq, Eq)]
204struct NormalizedAuthority {
205 host: String,
206 port: Option<u16>,
207}
208
209fn normalize_authority(host: &str, port: Option<u16>) -> NormalizedAuthority {
210 NormalizedAuthority {
211 host: normalize_host(host),
212 port,
213 }
214}
215
216fn parse_allowed_authority(allowed: &str) -> Option<NormalizedAuthority> {
217 let allowed = allowed.trim();
218 if allowed.is_empty() {
219 return None;
220 }
221
222 if let Ok(authority) = http::uri::Authority::try_from(allowed) {
223 return Some(normalize_authority(authority.host(), authority.port_u16()));
224 }
225
226 Some(normalize_authority(allowed, None))
227}
228
229fn host_is_allowed(host: &NormalizedAuthority, allowed_hosts: &[String]) -> bool {
230 if allowed_hosts.is_empty() {
231 return true;
233 }
234 allowed_hosts
235 .iter()
236 .filter_map(|allowed| parse_allowed_authority(allowed))
237 .any(|allowed| {
238 allowed.host == host.host
239 && match allowed.port {
240 Some(port) => host.port == Some(port),
241 None => true,
242 }
243 })
244}
245
246fn bad_request_response(message: &str) -> BoxResponse {
247 let body = Full::from(message.to_string()).boxed();
248
249 http::Response::builder()
250 .status(http::StatusCode::BAD_REQUEST)
251 .header(http::header::CONTENT_TYPE, "text/plain; charset=utf-8")
252 .body(body)
253 .expect("failed to build bad request response")
254}
255
256fn parse_host_header(headers: &HeaderMap) -> Result<NormalizedAuthority, BoxResponse> {
257 let Some(host) = headers.get(http::header::HOST) else {
258 return Err(bad_request_response("Bad Request: missing Host header"));
259 };
260
261 let host = host
262 .to_str()
263 .map_err(|_| bad_request_response("Bad Request: Invalid Host header encoding"))?;
264 let authority = http::uri::Authority::try_from(host)
265 .map_err(|_| bad_request_response("Bad Request: Invalid Host header"))?;
266 Ok(normalize_authority(authority.host(), authority.port_u16()))
267}
268
269fn validate_dns_rebinding_headers(
270 headers: &HeaderMap,
271 config: &StreamableHttpServerConfig,
272) -> Result<(), BoxResponse> {
273 let host = parse_host_header(headers)?;
274 if !host_is_allowed(&host, &config.allowed_hosts) {
275 return Err(forbidden_response("Forbidden: Host header is not allowed"));
276 }
277
278 Ok(())
279}
280
281pub struct StreamableHttpService<S, M> {
365 pub config: StreamableHttpServerConfig,
366 session_manager: Arc<M>,
367 service_factory: Arc<dyn Fn() -> Result<S, std::io::Error> + Send + Sync>,
368 pending_restores: Option<
373 Arc<tokio::sync::RwLock<HashMap<SessionId, tokio::sync::watch::Sender<Option<bool>>>>>,
374 >,
375}
376
377impl<S, M> Clone for StreamableHttpService<S, M> {
378 fn clone(&self) -> Self {
379 Self {
380 config: self.config.clone(),
381 session_manager: self.session_manager.clone(),
382 service_factory: self.service_factory.clone(),
383 pending_restores: self.pending_restores.clone(),
384 }
385 }
386}
387
388impl<RequestBody, S, M> tower_service::Service<Request<RequestBody>> for StreamableHttpService<S, M>
389where
390 RequestBody: Body + Send + 'static,
391 S: crate::Service<RoleServer> + Send + 'static,
392 M: SessionManager,
393 RequestBody::Error: Display,
394 RequestBody::Data: Send + 'static,
395{
396 type Response = BoxResponse;
397 type Error = Infallible;
398 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
399 fn call(&mut self, req: http::Request<RequestBody>) -> Self::Future {
400 let service = self.clone();
401 Box::pin(async move {
402 let response = service.handle(req).await;
403 Ok(response)
404 })
405 }
406 fn poll_ready(
407 &mut self,
408 _cx: &mut std::task::Context<'_>,
409 ) -> std::task::Poll<Result<(), Self::Error>> {
410 std::task::Poll::Ready(Ok(()))
411 }
412}
413
414struct PendingRestoreGuard {
422 pending_restores:
423 Arc<tokio::sync::RwLock<HashMap<SessionId, tokio::sync::watch::Sender<Option<bool>>>>>,
424 session_id: SessionId,
425 watch_tx: tokio::sync::watch::Sender<Option<bool>>,
426 result: bool,
428}
429
430impl Drop for PendingRestoreGuard {
431 fn drop(&mut self) {
432 let _ = self.watch_tx.send(Some(self.result));
434 let pending_restores = self.pending_restores.clone();
436 let session_id = self.session_id.clone();
437 tokio::spawn(async move {
438 pending_restores.write().await.remove(&session_id);
439 });
440 }
441}
442
443impl<S, M> StreamableHttpService<S, M>
444where
445 S: crate::Service<RoleServer> + Send + 'static,
446 M: SessionManager,
447{
448 pub fn new(
449 service_factory: impl Fn() -> Result<S, std::io::Error> + Send + Sync + 'static,
450 session_manager: Arc<M>,
451 config: StreamableHttpServerConfig,
452 ) -> Self {
453 let pending_restores = config.session_store.is_some().then(|| {
454 Arc::new(tokio::sync::RwLock::new(HashMap::<
455 SessionId,
456 tokio::sync::watch::Sender<Option<bool>>,
457 >::new()))
458 });
459 Self {
460 config,
461 session_manager,
462 service_factory: Arc::new(service_factory),
463 pending_restores,
464 }
465 }
466 fn get_service(&self) -> Result<S, std::io::Error> {
467 (self.service_factory)()
468 }
469
470 fn spawn_session_worker(
478 session_manager: Arc<M>,
479 session_id: SessionId,
480 service: S,
481 transport: M::Transport,
482 init_done_tx: Option<tokio::sync::oneshot::Sender<()>>,
483 ) where
484 S: crate::Service<RoleServer> + Send + 'static,
485 M: SessionManager,
486 {
487 tokio::spawn(async move {
488 let svc =
489 serve_server::<S, M::Transport, _, TransportAdapterIdentity>(service, transport)
490 .await;
491 match svc {
492 Ok(svc) => {
493 if let Some(tx) = init_done_tx {
494 let _ = tx.send(());
495 }
496 let _ = svc.waiting().await;
497 }
498 Err(e) => {
499 tracing::error!("Failed to serve session: {e}");
500 }
502 }
503 let _ = session_manager
504 .close_session(&session_id)
505 .await
506 .inspect_err(|e| {
507 tracing::error!("Failed to close session {session_id}: {e}");
508 });
509 });
510 }
511
512 async fn try_restore_from_store(
522 &self,
523 session_id: &SessionId,
524 parts: &http::request::Parts,
525 ) -> Result<bool, std::io::Error>
526 where
527 S: crate::Service<RoleServer> + Send + 'static,
528 M: SessionManager,
529 {
530 let (Some(pending_restores), Some(store)) =
532 (&self.pending_restores, &self.config.session_store)
533 else {
534 return Ok(false);
535 };
536
537 let (watch_tx, _watch_rx) = tokio::sync::watch::channel(None::<bool>);
542 {
543 let mut pending = pending_restores.write().await;
544 if let Some(tx) = pending.get(session_id) {
545 let mut rx = tx.subscribe();
546 drop(pending);
547 let result = rx
549 .wait_for(|r| r.is_some())
550 .await
551 .map(|r| r.unwrap_or(false))
552 .unwrap_or(false);
553 return Ok(result);
554 }
555 pending.insert(session_id.clone(), watch_tx.clone());
556 }
557
558 let mut guard = PendingRestoreGuard {
560 pending_restores: pending_restores.clone(),
561 session_id: session_id.clone(),
562 watch_tx: watch_tx.clone(),
563 result: false,
564 };
565
566 let state = match store.load(session_id.as_ref()).await {
568 Ok(Some(s)) => s,
569 Ok(None) => {
570 return Ok(false);
571 }
572 Err(e) => {
573 tracing::error!(
574 session_id = session_id.as_ref(),
575 error = %e,
576 "session store load failed during restore"
577 );
578 return Err(std::io::Error::other(e));
579 }
580 };
581
582 let transport = match self
584 .session_manager
585 .restore_session(session_id.clone())
586 .await
587 .map_err(|e| std::io::Error::other(e.to_string()))
588 {
589 Ok(RestoreOutcome::Restored(t)) => t,
590 Ok(RestoreOutcome::AlreadyPresent) => {
591 return Err(std::io::Error::other(
594 "restore_session returned AlreadyPresent unexpectedly; session manager might have modified the session store outside of the restore_session API",
595 ));
596 }
597 Ok(RestoreOutcome::NotSupported) => {
598 return Ok(false);
599 }
600 Err(e) => {
601 return Err(e);
602 }
603 };
604
605 let service = match self.get_service() {
607 Ok(s) => s,
608 Err(e) => {
609 return Err(e);
610 }
611 };
612
613 let mut restore_init = ClientJsonRpcMessage::request(
617 ClientRequest::InitializeRequest(InitializeRequest {
618 params: state.initialize_params,
619 ..Default::default()
620 }),
621 crate::model::NumberOrString::Number(0),
622 );
623 restore_init.insert_extension(parts.clone());
624 restore_init.insert_extension(SessionRestoreMarker {
625 id: session_id.clone(),
626 });
627 let mut restore_initialized = ClientJsonRpcMessage::notification(
628 ClientNotification::InitializedNotification(InitializedNotification {
629 ..Default::default()
630 }),
631 );
632 restore_initialized.insert_extension(parts.clone());
633 restore_initialized.insert_extension(SessionRestoreMarker {
634 id: session_id.clone(),
635 });
636 let (init_done_tx, init_done_rx) = tokio::sync::oneshot::channel::<()>();
638
639 Self::spawn_session_worker(
640 self.session_manager.clone(),
641 session_id.clone(),
642 service,
643 transport,
644 Some(init_done_tx),
645 );
646
647 if let Err(e) = self
648 .session_manager
649 .initialize_session(session_id, restore_init)
650 .await
651 .map_err(|e| std::io::Error::other(e.to_string()))
652 {
653 return Err(e);
654 }
655
656 if let Err(e) = self
657 .session_manager
658 .accept_message(session_id, restore_initialized)
659 .await
660 .map_err(|e| std::io::Error::other(e.to_string()))
661 {
662 return Err(e);
663 }
664
665 if init_done_rx.await.is_err() {
666 return Err(std::io::Error::other(
667 "serve_server initialization failed during restore",
668 ));
669 }
670
671 guard.result = true;
673
674 tracing::debug!(
675 session_id = session_id.as_ref(),
676 "session restored from external store"
677 );
678 Ok(true)
679 }
680 pub async fn handle<B>(&self, request: Request<B>) -> Response<BoxBody<Bytes, Infallible>>
681 where
682 B: Body + Send + 'static,
683 B::Error: Display,
684 {
685 if let Err(response) = validate_dns_rebinding_headers(request.headers(), &self.config) {
686 return response;
687 }
688 let method = request.method().clone();
689 let allowed_methods = match self.config.stateful_mode {
690 true => "GET, POST, DELETE",
691 false => "POST",
692 };
693 let result = match (method, self.config.stateful_mode) {
694 (Method::POST, _) => self.handle_post(request).await,
695 (Method::GET, true) => self.handle_get(request).await,
697 (Method::DELETE, true) => self.handle_delete(request).await,
698 _ => {
699 let response = Response::builder()
701 .status(http::StatusCode::METHOD_NOT_ALLOWED)
702 .header(ALLOW, allowed_methods)
703 .body(Full::new(Bytes::from("Method Not Allowed")).boxed())
704 .expect("valid response");
705 return response;
706 }
707 };
708 match result {
709 Ok(response) => response,
710 Err(response) => response,
711 }
712 }
713 async fn handle_get<B>(&self, request: Request<B>) -> Result<BoxResponse, BoxResponse>
714 where
715 B: Body + Send + 'static,
716 B::Error: Display,
717 {
718 if !request
720 .headers()
721 .get(http::header::ACCEPT)
722 .and_then(|header| header.to_str().ok())
723 .is_some_and(|header| header.contains(EVENT_STREAM_MIME_TYPE))
724 {
725 return Ok(Response::builder()
726 .status(http::StatusCode::NOT_ACCEPTABLE)
727 .body(
728 Full::new(Bytes::from(
729 "Not Acceptable: Client must accept text/event-stream",
730 ))
731 .boxed(),
732 )
733 .expect("valid response"));
734 }
735 let session_id = request
737 .headers()
738 .get(HEADER_SESSION_ID)
739 .and_then(|v| v.to_str().ok())
740 .map(|s| s.to_owned().into());
741 let Some(session_id) = session_id else {
742 return Ok(Response::builder()
744 .status(http::StatusCode::BAD_REQUEST)
745 .body(Full::new(Bytes::from("Bad Request: Session ID is required")).boxed())
746 .expect("valid response"));
747 };
748 let has_session = self
750 .session_manager
751 .has_session(&session_id)
752 .await
753 .map_err(internal_error_response("check session"))?;
754 let (parts, _) = request.into_parts();
755 if !has_session {
756 let restored = self
758 .try_restore_from_store(&session_id, &parts)
759 .await
760 .map_err(internal_error_response("restore session"))?;
761 if !restored {
762 return Ok(Response::builder()
764 .status(http::StatusCode::NOT_FOUND)
765 .body(Full::new(Bytes::from("Not Found: Session not found")).boxed())
766 .expect("valid response"));
767 }
768 }
769 validate_protocol_version_header(&parts.headers)?;
771 let last_event_id = parts
773 .headers
774 .get(HEADER_LAST_EVENT_ID)
775 .and_then(|v| v.to_str().ok())
776 .map(|s| s.to_owned());
777 if let Some(last_event_id) = last_event_id {
778 match self
779 .session_manager
780 .resume(&session_id, last_event_id)
781 .await
782 {
783 Ok(stream) => {
784 return Ok(sse_stream_response(
785 stream,
786 self.config.sse_keep_alive,
787 self.config.cancellation_token.child_token(),
788 ));
789 }
790 Err(e) => {
791 tracing::warn!("Resume failed ({e}), returning empty stream");
797 return Ok(sse_stream_response(
798 futures::stream::empty(),
799 None,
800 self.config.cancellation_token.child_token(),
801 ));
802 }
803 }
804 }
805 let stream = self
807 .session_manager
808 .create_standalone_stream(&session_id)
809 .await
810 .map_err(internal_error_response("create standalone stream"))?;
811 let stream = if let Some(retry) = self.config.sse_retry {
812 let priming = ServerSseMessage::priming("0", retry);
813 futures::stream::once(async move { priming })
814 .chain(stream)
815 .left_stream()
816 } else {
817 stream.right_stream()
818 };
819 Ok(sse_stream_response(
820 stream,
821 self.config.sse_keep_alive,
822 self.config.cancellation_token.child_token(),
823 ))
824 }
825
826 async fn handle_post<B>(&self, request: Request<B>) -> Result<BoxResponse, BoxResponse>
827 where
828 B: Body + Send + 'static,
829 B::Error: Display,
830 {
831 if !request
833 .headers()
834 .get(http::header::ACCEPT)
835 .and_then(|header| header.to_str().ok())
836 .is_some_and(|header| {
837 header.contains(JSON_MIME_TYPE) && header.contains(EVENT_STREAM_MIME_TYPE)
838 })
839 {
840 return Ok(Response::builder()
841 .status(http::StatusCode::NOT_ACCEPTABLE)
842 .body(Full::new(Bytes::from("Not Acceptable: Client must accept both application/json and text/event-stream")).boxed())
843 .expect("valid response"));
844 }
845
846 if !request
848 .headers()
849 .get(http::header::CONTENT_TYPE)
850 .and_then(|header| header.to_str().ok())
851 .is_some_and(|header| header.starts_with(JSON_MIME_TYPE))
852 {
853 return Ok(Response::builder()
854 .status(http::StatusCode::UNSUPPORTED_MEDIA_TYPE)
855 .body(
856 Full::new(Bytes::from(
857 "Unsupported Media Type: Content-Type must be application/json",
858 ))
859 .boxed(),
860 )
861 .expect("valid response"));
862 }
863
864 let (part, body) = request.into_parts();
866 let mut message = match expect_json(body).await {
867 Ok(message) => message,
868 Err(response) => return Ok(response),
869 };
870
871 if self.config.stateful_mode {
872 let session_id = part
874 .headers
875 .get(HEADER_SESSION_ID)
876 .and_then(|v| v.to_str().ok());
877 if let Some(session_id) = session_id {
878 let session_id = session_id.to_owned().into();
879 let has_session = self
880 .session_manager
881 .has_session(&session_id)
882 .await
883 .map_err(internal_error_response("check session"))?;
884 if !has_session {
885 let restored = self
887 .try_restore_from_store(&session_id, &part)
888 .await
889 .map_err(internal_error_response("restore session"))?;
890 if !restored {
891 return Ok(Response::builder()
893 .status(http::StatusCode::NOT_FOUND)
894 .body(Full::new(Bytes::from("Not Found: Session not found")).boxed())
895 .expect("valid response"));
896 }
897 }
898
899 validate_protocol_version_header(&part.headers)?;
901
902 match &mut message {
904 ClientJsonRpcMessage::Request(req) => {
905 req.request.extensions_mut().insert(part);
906 }
907 ClientJsonRpcMessage::Notification(not) => {
908 not.notification.extensions_mut().insert(part);
909 }
910 _ => {
911 }
913 }
914
915 match message {
916 ClientJsonRpcMessage::Request(_) => {
917 let stream = self
921 .session_manager
922 .create_stream(&session_id, message)
923 .await
924 .map_err(internal_error_response("get session"))?;
925 Ok(sse_stream_response(
926 stream,
927 self.config.sse_keep_alive,
928 self.config.cancellation_token.child_token(),
929 ))
930 }
931 ClientJsonRpcMessage::Notification(_)
932 | ClientJsonRpcMessage::Response(_)
933 | ClientJsonRpcMessage::Error(_) => {
934 self.session_manager
936 .accept_message(&session_id, message)
937 .await
938 .map_err(internal_error_response("accept message"))?;
939 Ok(accepted_response())
940 }
941 }
942 } else {
943 let (session_id, transport) = self
944 .session_manager
945 .create_session()
946 .await
947 .map_err(internal_error_response("create session"))?;
948 let stored_init_params = if self.config.session_store.is_some() {
951 if let ClientJsonRpcMessage::Request(req) = &message {
952 if let ClientRequest::InitializeRequest(init_req) = &req.request {
953 Some(init_req.params.clone())
954 } else {
955 None
956 }
957 } else {
958 None
959 }
960 } else {
961 None
962 };
963 if let ClientJsonRpcMessage::Request(req) = &mut message {
964 if !matches!(req.request, ClientRequest::InitializeRequest(_)) {
965 return Err(unexpected_message_response("initialize request"));
966 }
967 req.request.extensions_mut().insert(part);
969 } else {
970 return Err(unexpected_message_response("initialize request"));
971 }
972 let service = self
973 .get_service()
974 .map_err(internal_error_response("get service"))?;
975 Self::spawn_session_worker(
977 self.session_manager.clone(),
978 session_id.clone(),
979 service,
980 transport,
981 None,
982 );
983 let response = self
985 .session_manager
986 .initialize_session(&session_id, message)
987 .await
988 .map_err(internal_error_response("create stream"))?;
989 if let (Some(store), Some(params)) =
991 (&self.config.session_store, stored_init_params)
992 {
993 let state = SessionState {
994 initialize_params: params,
995 };
996 let _ = store
997 .store(session_id.as_ref(), &state)
998 .await
999 .inspect_err(|e| {
1000 tracing::warn!(
1001 "Failed to persist session {} to store: {e}",
1002 session_id
1003 );
1004 });
1005 }
1006 let stream =
1007 futures::stream::once(async move { ServerSseMessage::from_message(response) });
1008 let stream = if let Some(retry) = self.config.sse_retry {
1010 let priming = ServerSseMessage::priming("0", retry);
1011 futures::stream::once(async move { priming })
1012 .chain(stream)
1013 .left_stream()
1014 } else {
1015 stream.right_stream()
1016 };
1017 let mut response = sse_stream_response(
1018 stream,
1019 self.config.sse_keep_alive,
1020 self.config.cancellation_token.child_token(),
1021 );
1022
1023 response.headers_mut().insert(
1024 HEADER_SESSION_ID,
1025 session_id
1026 .parse()
1027 .map_err(internal_error_response("create session id header"))?,
1028 );
1029 Ok(response)
1030 }
1031 } else {
1032 let is_init = matches!(
1034 &message,
1035 ClientJsonRpcMessage::Request(req) if matches!(req.request, ClientRequest::InitializeRequest(_))
1036 );
1037 if !is_init {
1038 validate_protocol_version_header(&part.headers)?;
1039 }
1040 let service = self
1041 .get_service()
1042 .map_err(internal_error_response("get service"))?;
1043 match message {
1044 ClientJsonRpcMessage::Request(mut request) => {
1045 request.request.extensions_mut().insert(part);
1046 let (transport, mut receiver) =
1047 OneshotTransport::<RoleServer>::new(ClientJsonRpcMessage::Request(request));
1048 let service = serve_directly(service, transport, None);
1049 tokio::spawn(async move {
1050 let _ = service.waiting().await;
1052 });
1053 if self.config.json_response {
1054 let cancel = self.config.cancellation_token.child_token();
1058 match tokio::select! {
1059 res = receiver.recv() => res,
1060 _ = cancel.cancelled() => None,
1061 } {
1062 Some(message) => {
1063 tracing::trace!(?message);
1064 let body = serde_json::to_vec(&message).map_err(|e| {
1065 internal_error_response("serialize json response")(e)
1066 })?;
1067 Ok(Response::builder()
1068 .status(http::StatusCode::OK)
1069 .header(http::header::CONTENT_TYPE, JSON_MIME_TYPE)
1070 .body(Full::new(Bytes::from(body)).boxed())
1071 .expect("valid response"))
1072 }
1073 None => Err(internal_error_response("empty response")(
1074 std::io::Error::new(
1075 std::io::ErrorKind::UnexpectedEof,
1076 "no response message received from handler",
1077 ),
1078 )),
1079 }
1080 } else {
1081 let stream = ReceiverStream::new(receiver).map(|message| {
1083 tracing::trace!(?message);
1084 ServerSseMessage::from_message(message)
1085 });
1086 Ok(sse_stream_response(
1087 stream,
1088 self.config.sse_keep_alive,
1089 self.config.cancellation_token.child_token(),
1090 ))
1091 }
1092 }
1093 ClientJsonRpcMessage::Notification(_notification) => {
1094 Ok(accepted_response())
1096 }
1097 ClientJsonRpcMessage::Response(_json_rpc_response) => Ok(accepted_response()),
1098 ClientJsonRpcMessage::Error(_json_rpc_error) => Ok(accepted_response()),
1099 }
1100 }
1101 }
1102
1103 async fn handle_delete<B>(&self, request: Request<B>) -> Result<BoxResponse, BoxResponse>
1104 where
1105 B: Body + Send + 'static,
1106 B::Error: Display,
1107 {
1108 let session_id = request
1110 .headers()
1111 .get(HEADER_SESSION_ID)
1112 .and_then(|v| v.to_str().ok())
1113 .map(|s| s.to_owned().into());
1114 let Some(session_id) = session_id else {
1115 return Ok(Response::builder()
1117 .status(http::StatusCode::BAD_REQUEST)
1118 .body(Full::new(Bytes::from("Bad Request: Session ID is required")).boxed())
1119 .expect("valid response"));
1120 };
1121 validate_protocol_version_header(request.headers())?;
1123 self.session_manager
1125 .close_session(&session_id)
1126 .await
1127 .map_err(internal_error_response("close session"))?;
1128 if let Some(store) = &self.config.session_store {
1131 let _ = store.delete(session_id.as_ref()).await.inspect_err(|e| {
1132 tracing::warn!("Failed to delete session {} from store: {e}", session_id);
1133 });
1134 }
1135 Ok(accepted_response())
1136 }
1137}