1use std::{
2 collections::{HashMap, HashSet, VecDeque},
3 num::ParseIntError,
4 time::Duration,
5};
6
7use futures::Stream;
8use thiserror::Error;
9use tokio::sync::{
10 mpsc::{Receiver, Sender},
11 oneshot,
12};
13use tokio_stream::wrappers::ReceiverStream;
14use tracing::instrument;
15
16use crate::{
17 RoleServer,
18 model::{
19 CancelledNotificationParam, ClientJsonRpcMessage, ClientNotification, ClientRequest,
20 JsonRpcNotification, JsonRpcRequest, Notification, ProgressNotificationParam,
21 ProgressToken, RequestId, ServerJsonRpcMessage, ServerNotification,
22 },
23 transport::{
24 WorkerTransport,
25 common::server_side_http::{SessionId, session_id},
26 worker::{Worker, WorkerContext, WorkerQuitReason, WorkerSendRequest},
27 },
28};
29
30#[derive(Debug, Default)]
31#[non_exhaustive]
32pub struct LocalSessionManager {
33 pub sessions: tokio::sync::RwLock<HashMap<SessionId, LocalSessionHandle>>,
34 pub session_config: SessionConfig,
35}
36
37#[derive(Debug, Error)]
38#[non_exhaustive]
39pub enum LocalSessionManagerError {
40 #[error("Session not found: {0}")]
41 SessionNotFound(SessionId),
42 #[error("Session error: {0}")]
43 SessionError(#[from] SessionError),
44 #[error("Invalid event id: {0}")]
45 InvalidEventId(#[from] EventIdParseError),
46}
47impl SessionManager for LocalSessionManager {
48 type Error = LocalSessionManagerError;
49 type Transport = WorkerTransport<LocalSessionWorker>;
50 async fn create_session(&self) -> Result<(SessionId, Self::Transport), Self::Error> {
51 let id = session_id();
52 let (handle, worker) = create_local_session(id.clone(), self.session_config.clone());
53 self.sessions.write().await.insert(id.clone(), handle);
54 Ok((id, WorkerTransport::spawn(worker)))
55 }
56 async fn initialize_session(
57 &self,
58 id: &SessionId,
59 message: ClientJsonRpcMessage,
60 ) -> Result<ServerJsonRpcMessage, Self::Error> {
61 let sessions = self.sessions.read().await;
62 let handle = sessions
63 .get(id)
64 .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
65 let response = handle.initialize(message).await?;
66 Ok(response)
67 }
68 async fn close_session(&self, id: &SessionId) -> Result<(), Self::Error> {
69 let mut sessions = self.sessions.write().await;
70 if let Some(handle) = sessions.remove(id) {
71 handle.close().await?;
72 }
73 Ok(())
74 }
75 async fn has_session(&self, id: &SessionId) -> Result<bool, Self::Error> {
76 let sessions = self.sessions.read().await;
77 Ok(sessions.contains_key(id))
78 }
79 async fn create_stream(
80 &self,
81 id: &SessionId,
82 message: ClientJsonRpcMessage,
83 ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
84 let sessions = self.sessions.read().await;
85 let handle = sessions
86 .get(id)
87 .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
88 let receiver = handle.establish_request_wise_channel().await?;
89 handle
90 .push_message(message, receiver.http_request_id)
91 .await?;
92 Ok(ReceiverStream::new(receiver.inner))
93 }
94
95 async fn create_standalone_stream(
96 &self,
97 id: &SessionId,
98 ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
99 let sessions = self.sessions.read().await;
100 let handle = sessions
101 .get(id)
102 .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
103 let receiver = handle.establish_common_channel().await?;
104 Ok(ReceiverStream::new(receiver.inner))
105 }
106
107 async fn resume(
108 &self,
109 id: &SessionId,
110 last_event_id: String,
111 ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
112 let sessions = self.sessions.read().await;
113 let handle = sessions
114 .get(id)
115 .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
116 let receiver = handle.resume(last_event_id.parse()?).await?;
117 Ok(ReceiverStream::new(receiver.inner))
118 }
119
120 async fn accept_message(
121 &self,
122 id: &SessionId,
123 message: ClientJsonRpcMessage,
124 ) -> Result<(), Self::Error> {
125 let sessions = self.sessions.read().await;
126 let handle = sessions
127 .get(id)
128 .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
129 handle.push_message(message, None).await?;
130 Ok(())
131 }
132}
133
134#[derive(Debug, Clone, PartialEq, Eq, Hash)]
136pub struct EventId {
137 http_request_id: Option<HttpRequestId>,
138 index: usize,
139}
140
141impl std::fmt::Display for EventId {
142 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143 write!(f, "{}", self.index)?;
144 match &self.http_request_id {
145 Some(http_request_id) => write!(f, "/{http_request_id}"),
146 None => write!(f, ""),
147 }
148 }
149}
150
151#[derive(Debug, Clone, Error)]
152#[non_exhaustive]
153pub enum EventIdParseError {
154 #[error("Invalid index: {0}")]
155 InvalidIndex(ParseIntError),
156 #[error("Invalid numeric request id: {0}")]
157 InvalidNumericRequestId(ParseIntError),
158 #[error("Missing request id type")]
159 InvalidRequestIdType,
160 #[error("Missing request id")]
161 MissingRequestId,
162}
163
164impl std::str::FromStr for EventId {
165 type Err = EventIdParseError;
166 fn from_str(s: &str) -> Result<Self, Self::Err> {
167 if let Some((index, request_id)) = s.split_once("/") {
168 let index = usize::from_str(index).map_err(EventIdParseError::InvalidIndex)?;
169 let request_id = u64::from_str(request_id).map_err(EventIdParseError::InvalidIndex)?;
170 Ok(EventId {
171 http_request_id: Some(request_id),
172 index,
173 })
174 } else {
175 let index = usize::from_str(s).map_err(EventIdParseError::InvalidIndex)?;
176 Ok(EventId {
177 http_request_id: None,
178 index,
179 })
180 }
181 }
182}
183
184use super::{ServerSseMessage, SessionManager};
185
186struct CachedTx {
187 tx: Sender<ServerSseMessage>,
188 cache: VecDeque<ServerSseMessage>,
189 http_request_id: Option<HttpRequestId>,
190 capacity: usize,
191}
192
193impl CachedTx {
194 fn new(tx: Sender<ServerSseMessage>, http_request_id: Option<HttpRequestId>) -> Self {
195 Self {
196 cache: VecDeque::with_capacity(tx.capacity()),
197 capacity: tx.capacity(),
198 tx,
199 http_request_id,
200 }
201 }
202 fn new_common(tx: Sender<ServerSseMessage>) -> Self {
203 Self::new(tx, None)
204 }
205
206 fn next_event_id(&self) -> EventId {
207 let index = self.cache.back().map_or(0, |m| {
208 m.event_id
209 .as_deref()
210 .unwrap_or_default()
211 .parse::<EventId>()
212 .expect("valid event id")
213 .index
214 + 1
215 });
216 EventId {
217 http_request_id: self.http_request_id,
218 index,
219 }
220 }
221
222 async fn send(&mut self, message: ServerJsonRpcMessage) {
223 let event_id = self.next_event_id();
224 let message = ServerSseMessage::new(event_id.to_string(), message);
225 self.cache_and_send(message).await;
226 }
227
228 async fn send_priming(&mut self, retry: Duration) {
229 let event_id = self.next_event_id();
230 let message = ServerSseMessage::priming(event_id.to_string(), retry);
231 self.cache_and_send(message).await;
232 }
233
234 async fn cache_and_send(&mut self, message: ServerSseMessage) {
235 if self.cache.len() >= self.capacity {
236 self.cache.pop_front();
237 self.cache.push_back(message.clone());
238 } else {
239 self.cache.push_back(message.clone());
240 }
241 let _ = self.tx.send(message).await.inspect_err(|e| {
242 let event_id = &e.0.event_id;
243 tracing::trace!(?event_id, "trying to send message in a closed session")
244 });
245 }
246
247 async fn sync(&mut self, index: usize) -> Result<(), SessionError> {
248 let Some(front) = self.cache.front() else {
249 return Ok(());
250 };
251 let front_event_id = front
252 .event_id
253 .as_deref()
254 .unwrap_or_default()
255 .parse::<EventId>()?;
256 let sync_index = index.saturating_sub(front_event_id.index);
257 if sync_index > self.cache.len() {
258 return Err(SessionError::InvalidEventId);
260 }
261 for message in self.cache.iter().skip(sync_index) {
262 let send_result = self.tx.send(message.clone()).await;
263 if send_result.is_err() {
264 let event_id: EventId = message.event_id.as_deref().unwrap_or_default().parse()?;
265 return Err(SessionError::ChannelClosed(Some(event_id.index as u64)));
266 }
267 }
268 Ok(())
269 }
270}
271
272struct HttpRequestWise {
273 resources: HashSet<ResourceKey>,
274 tx: CachedTx,
275}
276
277type HttpRequestId = u64;
278#[derive(Debug, Clone, Hash, PartialEq, Eq)]
279enum ResourceKey {
280 McpRequestId(RequestId),
281 ProgressToken(ProgressToken),
282}
283
284pub struct LocalSessionWorker {
285 id: SessionId,
286 next_http_request_id: HttpRequestId,
287 tx_router: HashMap<HttpRequestId, HttpRequestWise>,
288 resource_router: HashMap<ResourceKey, HttpRequestId>,
289 common: CachedTx,
290 shadow_txs: Vec<Sender<ServerSseMessage>>,
296 event_rx: Receiver<SessionEvent>,
297 session_config: SessionConfig,
298}
299
300impl LocalSessionWorker {
301 pub fn id(&self) -> &SessionId {
302 &self.id
303 }
304}
305
306#[derive(Debug, Error)]
307#[non_exhaustive]
308pub enum SessionError {
309 #[error("Invalid request id: {0}")]
310 DuplicatedRequestId(HttpRequestId),
311 #[error("Channel closed: {0:?}")]
312 ChannelClosed(Option<HttpRequestId>),
313 #[error("Cannot parse event id: {0}")]
314 EventIdParseError(#[from] EventIdParseError),
315 #[error("Session service terminated")]
316 SessionServiceTerminated,
317 #[error("Invalid event id")]
318 InvalidEventId,
319 #[error("IO error: {0}")]
320 Io(#[from] std::io::Error),
321}
322
323impl From<SessionError> for std::io::Error {
324 fn from(value: SessionError) -> Self {
325 match value {
326 SessionError::Io(io) => io,
327 _ => std::io::Error::other(format!("Session error: {value}")),
328 }
329 }
330}
331
332enum OutboundChannel {
333 RequestWise { id: HttpRequestId, close: bool },
334 Common,
335}
336#[derive(Debug)]
337#[non_exhaustive]
338pub struct StreamableHttpMessageReceiver {
339 pub http_request_id: Option<HttpRequestId>,
340 pub inner: Receiver<ServerSseMessage>,
341}
342
343impl LocalSessionWorker {
344 fn unregister_resource(&mut self, resource: &ResourceKey) {
345 if let Some(http_request_id) = self.resource_router.remove(resource) {
346 tracing::trace!(?resource, http_request_id, "unregister resource");
347 if let Some(channel) = self.tx_router.get_mut(&http_request_id) {
348 if channel.resources.is_empty() || matches!(resource, ResourceKey::McpRequestId(_))
351 {
352 tracing::debug!(http_request_id, "close http request wise channel");
353 if let Some(channel) = self.tx_router.remove(&http_request_id) {
354 for resource in channel.resources {
355 self.resource_router.remove(&resource);
356 }
357 }
358 }
359 } else {
360 tracing::warn!(http_request_id, "http request wise channel not found");
361 }
362 }
363 }
364 fn register_resource(&mut self, resource: ResourceKey, http_request_id: HttpRequestId) {
365 tracing::trace!(?resource, http_request_id, "register resource");
366 if let Some(channel) = self.tx_router.get_mut(&http_request_id) {
367 channel.resources.insert(resource.clone());
368 self.resource_router.insert(resource, http_request_id);
369 }
370 }
371 fn register_request(
372 &mut self,
373 request: &JsonRpcRequest<ClientRequest>,
374 http_request_id: HttpRequestId,
375 ) {
376 use crate::model::GetMeta;
377 self.register_resource(
378 ResourceKey::McpRequestId(request.id.clone()),
379 http_request_id,
380 );
381 if let Some(progress_token) = request.request.get_meta().get_progress_token() {
382 self.register_resource(
383 ResourceKey::ProgressToken(progress_token.clone()),
384 http_request_id,
385 );
386 }
387 }
388 fn catch_cancellation_notification(
389 &mut self,
390 notification: &JsonRpcNotification<ClientNotification>,
391 ) {
392 if let ClientNotification::CancelledNotification(n) = ¬ification.notification {
393 let request_id = n.params.request_id.clone();
394 let resource = ResourceKey::McpRequestId(request_id);
395 self.unregister_resource(&resource);
396 }
397 }
398 fn next_http_request_id(&mut self) -> HttpRequestId {
399 let id = self.next_http_request_id;
400 self.next_http_request_id = self.next_http_request_id.wrapping_add(1);
401 id
402 }
403 async fn establish_request_wise_channel(
404 &mut self,
405 ) -> Result<StreamableHttpMessageReceiver, SessionError> {
406 let http_request_id = self.next_http_request_id();
407 let (tx, rx) = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
408 self.tx_router.insert(
409 http_request_id,
410 HttpRequestWise {
411 resources: Default::default(),
412 tx: CachedTx::new(tx, Some(http_request_id)),
413 },
414 );
415 tracing::debug!(http_request_id, "establish new request wise channel");
416 Ok(StreamableHttpMessageReceiver {
417 http_request_id: Some(http_request_id),
418 inner: rx,
419 })
420 }
421 fn resolve_outbound_channel(&self, message: &ServerJsonRpcMessage) -> OutboundChannel {
422 match &message {
423 ServerJsonRpcMessage::Request(_) => OutboundChannel::Common,
424 ServerJsonRpcMessage::Notification(JsonRpcNotification {
425 notification:
426 ServerNotification::ProgressNotification(Notification {
427 params: ProgressNotificationParam { progress_token, .. },
428 ..
429 }),
430 ..
431 }) => {
432 let id = self
433 .resource_router
434 .get(&ResourceKey::ProgressToken(progress_token.clone()));
435
436 if let Some(id) = id {
437 OutboundChannel::RequestWise {
438 id: *id,
439 close: false,
440 }
441 } else {
442 OutboundChannel::Common
443 }
444 }
445 ServerJsonRpcMessage::Notification(JsonRpcNotification {
446 notification:
447 ServerNotification::CancelledNotification(Notification {
448 params: CancelledNotificationParam { request_id, .. },
449 ..
450 }),
451 ..
452 }) => {
453 if let Some(id) = self
454 .resource_router
455 .get(&ResourceKey::McpRequestId(request_id.clone()))
456 {
457 OutboundChannel::RequestWise {
458 id: *id,
459 close: false,
460 }
461 } else {
462 OutboundChannel::Common
463 }
464 }
465 ServerJsonRpcMessage::Notification(_) => OutboundChannel::Common,
466 ServerJsonRpcMessage::Response(json_rpc_response) => {
467 if let Some(id) = self
468 .resource_router
469 .get(&ResourceKey::McpRequestId(json_rpc_response.id.clone()))
470 {
471 OutboundChannel::RequestWise {
472 id: *id,
473 close: true,
474 }
475 } else {
476 OutboundChannel::Common
477 }
478 }
479 ServerJsonRpcMessage::Error(json_rpc_error) => {
480 if let Some(id) = self
481 .resource_router
482 .get(&ResourceKey::McpRequestId(json_rpc_error.id.clone()))
483 {
484 OutboundChannel::RequestWise {
485 id: *id,
486 close: true,
487 }
488 } else {
489 OutboundChannel::Common
490 }
491 }
492 }
493 }
494 async fn handle_server_message(
495 &mut self,
496 message: ServerJsonRpcMessage,
497 ) -> Result<(), SessionError> {
498 let outbound_channel = self.resolve_outbound_channel(&message);
499 match outbound_channel {
500 OutboundChannel::RequestWise { id, close } => {
501 if let Some(request_wise) = self.tx_router.get_mut(&id) {
502 request_wise.tx.send(message).await;
503 if close {
504 if let Some(channel) = self.tx_router.remove(&id) {
505 for resource in channel.resources {
506 self.resource_router.remove(&resource);
507 }
508 }
509 }
510 } else {
511 return Err(SessionError::ChannelClosed(Some(id)));
512 }
513 }
514 OutboundChannel::Common => self.common.send(message).await,
515 }
516 Ok(())
517 }
518 async fn resume(
519 &mut self,
520 last_event_id: EventId,
521 ) -> Result<StreamableHttpMessageReceiver, SessionError> {
522 self.shadow_txs.retain(|tx| !tx.is_closed());
524
525 match last_event_id.http_request_id {
526 Some(http_request_id) => {
527 if let Some(request_wise) = self.tx_router.get_mut(&http_request_id) {
528 let channel = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
530 let (tx, rx) = channel;
531 request_wise.tx.tx = tx;
532 let index = last_event_id.index;
533 request_wise.tx.sync(index).await?;
535 Ok(StreamableHttpMessageReceiver {
536 http_request_id: Some(http_request_id),
537 inner: rx,
538 })
539 } else {
540 tracing::debug!(
544 http_request_id,
545 "Request-wise channel completed, falling back to common channel"
546 );
547 self.resume_or_shadow_common(last_event_id.index).await
548 }
549 }
550 None => self.resume_or_shadow_common(last_event_id.index).await,
551 }
552 }
553
554 async fn resume_or_shadow_common(
567 &mut self,
568 last_event_index: usize,
569 ) -> Result<StreamableHttpMessageReceiver, SessionError> {
570 let is_replacing_dead_primary = self.common.tx.is_closed();
571 let capacity = if is_replacing_dead_primary {
572 self.session_config.channel_capacity
573 } else {
574 1 };
576 let (tx, rx) = tokio::sync::mpsc::channel(capacity);
577 if is_replacing_dead_primary {
578 tracing::debug!("Replacing dead common channel with new primary");
580 self.common.tx = tx;
581 self.common.sync(last_event_index).await?;
584 } else {
585 const MAX_SHADOW_STREAMS: usize = 32;
590
591 if self.shadow_txs.len() >= MAX_SHADOW_STREAMS {
592 tracing::warn!(
593 shadow_count = self.shadow_txs.len(),
594 "Shadow stream limit reached, dropping oldest"
595 );
596 self.shadow_txs.remove(0);
597 }
598 tracing::debug!(
599 shadow_count = self.shadow_txs.len(),
600 "Common channel active, creating shadow stream"
601 );
602 self.shadow_txs.push(tx);
603 }
604 Ok(StreamableHttpMessageReceiver {
605 http_request_id: None,
606 inner: rx,
607 })
608 }
609
610 async fn close_sse_stream(
611 &mut self,
612 http_request_id: Option<HttpRequestId>,
613 retry_interval: Option<Duration>,
614 ) -> Result<(), SessionError> {
615 match http_request_id {
616 Some(id) => {
618 let request_wise = self
619 .tx_router
620 .get_mut(&id)
621 .ok_or(SessionError::ChannelClosed(Some(id)))?;
622
623 if let Some(interval) = retry_interval {
625 request_wise.tx.send_priming(interval).await;
626 }
627
628 let (tx, _rx) = tokio::sync::mpsc::channel(1);
630 request_wise.tx.tx = tx;
631
632 tracing::debug!(
633 http_request_id = id,
634 "closed SSE stream for server-initiated disconnection"
635 );
636 Ok(())
637 }
638 None => {
640 if let Some(interval) = retry_interval {
642 self.common.send_priming(interval).await;
643 }
644
645 let (tx, _rx) = tokio::sync::mpsc::channel(1);
647 self.common.tx = tx;
648
649 self.shadow_txs.clear();
651
652 tracing::debug!("closed standalone SSE stream for server-initiated disconnection");
653 Ok(())
654 }
655 }
656 }
657}
658
659#[derive(Debug)]
660#[non_exhaustive]
661pub enum SessionEvent {
662 ClientMessage {
663 message: ClientJsonRpcMessage,
664 http_request_id: Option<HttpRequestId>,
665 },
666 EstablishRequestWiseChannel {
667 responder: oneshot::Sender<Result<StreamableHttpMessageReceiver, SessionError>>,
668 },
669 CloseRequestWiseChannel {
670 id: HttpRequestId,
671 responder: oneshot::Sender<Result<(), SessionError>>,
672 },
673 Resume {
674 last_event_id: EventId,
675 responder: oneshot::Sender<Result<StreamableHttpMessageReceiver, SessionError>>,
676 },
677 InitializeRequest {
678 request: ClientJsonRpcMessage,
679 responder: oneshot::Sender<Result<ServerJsonRpcMessage, SessionError>>,
680 },
681 Close,
682 CloseSseStream {
683 http_request_id: Option<HttpRequestId>,
685 retry_interval: Option<Duration>,
687 responder: oneshot::Sender<Result<(), SessionError>>,
688 },
689}
690
691#[derive(Debug, Clone)]
692#[non_exhaustive]
693pub enum SessionQuitReason {
694 ServiceTerminated,
695 ClientTerminated,
696 ExpectInitializeRequest,
697 ExpectInitializeResponse,
698 Cancelled,
699}
700
701#[derive(Debug, Clone)]
702pub struct LocalSessionHandle {
703 id: SessionId,
704 event_tx: Sender<SessionEvent>,
706}
707
708impl LocalSessionHandle {
709 pub fn id(&self) -> &SessionId {
711 &self.id
712 }
713
714 pub async fn close(&self) -> Result<(), SessionError> {
716 self.event_tx
717 .send(SessionEvent::Close)
718 .await
719 .map_err(|_| SessionError::SessionServiceTerminated)?;
720 Ok(())
721 }
722
723 pub async fn push_message(
725 &self,
726 message: ClientJsonRpcMessage,
727 http_request_id: Option<HttpRequestId>,
728 ) -> Result<(), SessionError> {
729 self.event_tx
730 .send(SessionEvent::ClientMessage {
731 message,
732 http_request_id,
733 })
734 .await
735 .map_err(|_| SessionError::SessionServiceTerminated)?;
736 Ok(())
737 }
738
739 pub async fn establish_request_wise_channel(
743 &self,
744 ) -> Result<StreamableHttpMessageReceiver, SessionError> {
745 let (tx, rx) = tokio::sync::oneshot::channel();
746 self.event_tx
747 .send(SessionEvent::EstablishRequestWiseChannel { responder: tx })
748 .await
749 .map_err(|_| SessionError::SessionServiceTerminated)?;
750 rx.await
751 .map_err(|_| SessionError::SessionServiceTerminated)?
752 }
753
754 pub async fn close_request_wise_channel(
756 &self,
757 request_id: HttpRequestId,
758 ) -> Result<(), SessionError> {
759 let (tx, rx) = tokio::sync::oneshot::channel();
760 self.event_tx
761 .send(SessionEvent::CloseRequestWiseChannel {
762 id: request_id,
763 responder: tx,
764 })
765 .await
766 .map_err(|_| SessionError::SessionServiceTerminated)?;
767 rx.await
768 .map_err(|_| SessionError::SessionServiceTerminated)?
769 }
770
771 pub async fn establish_common_channel(
773 &self,
774 ) -> Result<StreamableHttpMessageReceiver, SessionError> {
775 let (tx, rx) = tokio::sync::oneshot::channel();
776 self.event_tx
777 .send(SessionEvent::Resume {
778 last_event_id: EventId {
779 http_request_id: None,
780 index: 0,
781 },
782 responder: tx,
783 })
784 .await
785 .map_err(|_| SessionError::SessionServiceTerminated)?;
786 rx.await
787 .map_err(|_| SessionError::SessionServiceTerminated)?
788 }
789
790 pub async fn resume(
792 &self,
793 last_event_id: EventId,
794 ) -> Result<StreamableHttpMessageReceiver, SessionError> {
795 let (tx, rx) = tokio::sync::oneshot::channel();
796 self.event_tx
797 .send(SessionEvent::Resume {
798 last_event_id,
799 responder: tx,
800 })
801 .await
802 .map_err(|_| SessionError::SessionServiceTerminated)?;
803 rx.await
804 .map_err(|_| SessionError::SessionServiceTerminated)?
805 }
806
807 pub async fn initialize(
811 &self,
812 request: ClientJsonRpcMessage,
813 ) -> Result<ServerJsonRpcMessage, SessionError> {
814 let (tx, rx) = tokio::sync::oneshot::channel();
815 self.event_tx
816 .send(SessionEvent::InitializeRequest {
817 request,
818 responder: tx,
819 })
820 .await
821 .map_err(|_| SessionError::SessionServiceTerminated)?;
822 rx.await
823 .map_err(|_| SessionError::SessionServiceTerminated)?
824 }
825
826 pub async fn close_sse_stream(
837 &self,
838 http_request_id: HttpRequestId,
839 retry_interval: Option<Duration>,
840 ) -> Result<(), SessionError> {
841 let (tx, rx) = tokio::sync::oneshot::channel();
842 self.event_tx
843 .send(SessionEvent::CloseSseStream {
844 http_request_id: Some(http_request_id),
845 retry_interval,
846 responder: tx,
847 })
848 .await
849 .map_err(|_| SessionError::SessionServiceTerminated)?;
850 rx.await
851 .map_err(|_| SessionError::SessionServiceTerminated)?
852 }
853
854 pub async fn close_standalone_sse_stream(
864 &self,
865 retry_interval: Option<Duration>,
866 ) -> Result<(), SessionError> {
867 let (tx, rx) = tokio::sync::oneshot::channel();
868 self.event_tx
869 .send(SessionEvent::CloseSseStream {
870 http_request_id: None,
871 retry_interval,
872 responder: tx,
873 })
874 .await
875 .map_err(|_| SessionError::SessionServiceTerminated)?;
876 rx.await
877 .map_err(|_| SessionError::SessionServiceTerminated)?
878 }
879}
880
881pub type SessionTransport = WorkerTransport<LocalSessionWorker>;
882
883#[allow(clippy::large_enum_variant)]
884#[derive(Debug, Error)]
885#[non_exhaustive]
886pub enum LocalSessionWorkerError {
887 #[error("transport terminated")]
888 TransportTerminated,
889 #[error("unexpected message: {0:?}")]
890 UnexpectedEvent(SessionEvent),
891 #[error("fail to send initialize request {0}")]
892 FailToSendInitializeRequest(SessionError),
893 #[error("fail to handle message: {0}")]
894 FailToHandleMessage(SessionError),
895 #[error("keep alive timeout after {}ms", _0.as_millis())]
896 KeepAliveTimeout(Duration),
897 #[error("Transport closed")]
898 TransportClosed,
899 #[error("Tokio join error {0}")]
900 TokioJoinError(#[from] tokio::task::JoinError),
901}
902impl Worker for LocalSessionWorker {
903 type Error = LocalSessionWorkerError;
904 type Role = RoleServer;
905 fn err_closed() -> Self::Error {
906 LocalSessionWorkerError::TransportClosed
907 }
908 fn err_join(e: tokio::task::JoinError) -> Self::Error {
909 LocalSessionWorkerError::TokioJoinError(e)
910 }
911 fn config(&self) -> crate::transport::worker::WorkerConfig {
912 crate::transport::worker::WorkerConfig {
913 name: Some(format!("streamable-http-session-{}", self.id)),
914 channel_buffer_capacity: self.session_config.channel_capacity,
915 }
916 }
917 #[instrument(name = "streamable_http_session", skip_all, fields(id = self.id.as_ref()))]
918 async fn run(
919 mut self,
920 mut context: WorkerContext<Self>,
921 ) -> Result<(), WorkerQuitReason<Self::Error>> {
922 enum InnerEvent {
923 FromHttpService(SessionEvent),
924 FromHandler(WorkerSendRequest<LocalSessionWorker>),
925 }
926 let evt = self.event_rx.recv().await.ok_or_else(|| {
928 WorkerQuitReason::fatal(
929 LocalSessionWorkerError::TransportTerminated,
930 "get initialize request",
931 )
932 })?;
933 let SessionEvent::InitializeRequest { request, responder } = evt else {
934 return Err(WorkerQuitReason::fatal(
935 LocalSessionWorkerError::UnexpectedEvent(evt),
936 "get initialize request",
937 ));
938 };
939 context.send_to_handler(request).await?;
940 let send_initialize_response = context.recv_from_handler().await?;
941 responder
942 .send(Ok(send_initialize_response.message))
943 .map_err(|_| {
944 WorkerQuitReason::fatal(
945 LocalSessionWorkerError::FailToSendInitializeRequest(
946 SessionError::SessionServiceTerminated,
947 ),
948 "send initialize response",
949 )
950 })?;
951 send_initialize_response
952 .responder
953 .send(Ok(()))
954 .map_err(|_| WorkerQuitReason::HandlerTerminated)?;
955 let ct = context.cancellation_token.clone();
956 let keep_alive = self.session_config.keep_alive.unwrap_or(Duration::MAX);
957 loop {
958 let keep_alive_timeout = tokio::time::sleep(keep_alive);
959 let event = tokio::select! {
960 event = self.event_rx.recv() => {
961 if let Some(event) = event {
962 InnerEvent::FromHttpService(event)
963 } else {
964 return Err(WorkerQuitReason::fatal(LocalSessionWorkerError::TransportTerminated, "waiting next session event"))
965 }
966 },
967 from_handler = context.recv_from_handler() => {
968 InnerEvent::FromHandler(from_handler?)
969 }
970 _ = ct.cancelled() => {
971 return Err(WorkerQuitReason::Cancelled)
972 }
973 _ = keep_alive_timeout => {
974 return Err(WorkerQuitReason::fatal(LocalSessionWorkerError::KeepAliveTimeout(keep_alive), "poll next session event"))
975 }
976 };
977 match event {
978 InnerEvent::FromHandler(WorkerSendRequest { message, responder }) => {
979 let to_unregister = match &message {
981 crate::model::JsonRpcMessage::Response(json_rpc_response) => {
982 let request_id = json_rpc_response.id.clone();
983 Some(ResourceKey::McpRequestId(request_id))
984 }
985 crate::model::JsonRpcMessage::Error(json_rpc_error) => {
986 let request_id = json_rpc_error.id.clone();
987 Some(ResourceKey::McpRequestId(request_id))
988 }
989 _ => {
990 None
991 }
993 };
994 let handle_result = self
995 .handle_server_message(message)
996 .await
997 .map_err(LocalSessionWorkerError::FailToHandleMessage);
998 let _ = responder.send(handle_result).inspect_err(|error| {
999 tracing::warn!(?error, "failed to send message to http service handler");
1000 });
1001 if let Some(to_unregister) = to_unregister {
1002 self.unregister_resource(&to_unregister);
1003 }
1004 }
1005 InnerEvent::FromHttpService(SessionEvent::ClientMessage {
1006 message: json_rpc_message,
1007 http_request_id,
1008 }) => {
1009 match &json_rpc_message {
1010 crate::model::JsonRpcMessage::Request(request) => {
1011 if let Some(http_request_id) = http_request_id {
1012 self.register_request(request, http_request_id)
1013 }
1014 }
1015 crate::model::JsonRpcMessage::Notification(notification) => {
1016 self.catch_cancellation_notification(notification)
1017 }
1018 _ => {}
1019 }
1020 context.send_to_handler(json_rpc_message).await?;
1021 }
1022 InnerEvent::FromHttpService(SessionEvent::EstablishRequestWiseChannel {
1023 responder,
1024 }) => {
1025 let handle_result = self.establish_request_wise_channel().await;
1026 let _ = responder.send(handle_result);
1027 }
1028 InnerEvent::FromHttpService(SessionEvent::CloseRequestWiseChannel {
1029 id,
1030 responder,
1031 }) => {
1032 let _handle_result = self.tx_router.remove(&id);
1033 let _ = responder.send(Ok(()));
1034 }
1035 InnerEvent::FromHttpService(SessionEvent::Resume {
1036 last_event_id,
1037 responder,
1038 }) => {
1039 let handle_result = self.resume(last_event_id).await;
1040 let _ = responder.send(handle_result);
1041 }
1042 InnerEvent::FromHttpService(SessionEvent::Close) => {
1043 return Err(WorkerQuitReason::TransportClosed);
1044 }
1045 InnerEvent::FromHttpService(SessionEvent::CloseSseStream {
1046 http_request_id,
1047 retry_interval,
1048 responder,
1049 }) => {
1050 let handle_result =
1051 self.close_sse_stream(http_request_id, retry_interval).await;
1052 let _ = responder.send(handle_result);
1053 }
1054 _ => {
1055 }
1057 }
1058 }
1059 }
1060}
1061
1062#[derive(Debug, Clone)]
1063#[non_exhaustive]
1064pub struct SessionConfig {
1065 pub channel_capacity: usize,
1067 pub keep_alive: Option<Duration>,
1079}
1080
1081impl SessionConfig {
1082 pub const DEFAULT_CHANNEL_CAPACITY: usize = 16;
1083 pub const DEFAULT_KEEP_ALIVE: Duration = Duration::from_secs(300);
1084}
1085
1086impl Default for SessionConfig {
1087 fn default() -> Self {
1088 Self {
1089 channel_capacity: Self::DEFAULT_CHANNEL_CAPACITY,
1090 keep_alive: Some(Self::DEFAULT_KEEP_ALIVE),
1091 }
1092 }
1093}
1094
1095pub fn create_local_session(
1101 id: impl Into<SessionId>,
1102 config: SessionConfig,
1103) -> (LocalSessionHandle, LocalSessionWorker) {
1104 let id = id.into();
1105 let (event_tx, event_rx) = tokio::sync::mpsc::channel(config.channel_capacity);
1106 let (common_tx, _) = tokio::sync::mpsc::channel(config.channel_capacity);
1107 let common = CachedTx::new_common(common_tx);
1108 tracing::info!(session_id = ?id, "create new session");
1109 let handle = LocalSessionHandle {
1110 event_tx,
1111 id: id.clone(),
1112 };
1113 let session_worker = LocalSessionWorker {
1114 next_http_request_id: 0,
1115 id,
1116 tx_router: HashMap::new(),
1117 resource_router: HashMap::new(),
1118 common,
1119 shadow_txs: Vec::new(),
1120 event_rx,
1121 session_config: config.clone(),
1122 };
1123 (handle, session_worker)
1124}