1use std::{
2 collections::{HashMap, HashSet, VecDeque},
3 num::ParseIntError,
4 time::{Duration, Instant},
5};
6
7use futures::{Stream, StreamExt};
8use thiserror::Error;
9use tokio::sync::{
10 mpsc::{Receiver, Sender},
11 oneshot,
12};
13use tokio_stream::wrappers::ReceiverStream;
14use tracing::instrument;
15
16use crate::{
17 RoleServer,
18 model::{
19 CancelledNotificationParam, ClientJsonRpcMessage, ClientNotification, ClientRequest,
20 JsonRpcNotification, JsonRpcRequest, Notification, ProgressNotificationParam,
21 ProgressToken, RequestId, ServerJsonRpcMessage, ServerNotification,
22 },
23 transport::{
24 WorkerTransport,
25 common::server_side_http::{SessionId, session_id},
26 worker::{Worker, WorkerContext, WorkerQuitReason, WorkerSendRequest},
27 },
28};
29
30#[derive(Debug, Default)]
31#[non_exhaustive]
32pub struct LocalSessionManager {
33 pub sessions: tokio::sync::RwLock<HashMap<SessionId, LocalSessionHandle>>,
34 pub session_config: SessionConfig,
35}
36
37#[derive(Debug, Error)]
38#[non_exhaustive]
39pub enum LocalSessionManagerError {
40 #[error("Session not found: {0}")]
41 SessionNotFound(SessionId),
42 #[error("Session error: {0}")]
43 SessionError(#[from] SessionError),
44 #[error("Invalid event id: {0}")]
45 InvalidEventId(#[from] EventIdParseError),
46}
47impl SessionManager for LocalSessionManager {
48 type Error = LocalSessionManagerError;
49 type Transport = WorkerTransport<LocalSessionWorker>;
50 async fn create_session(&self) -> Result<(SessionId, Self::Transport), Self::Error> {
51 let id = session_id();
52 let (handle, worker) = create_local_session(id.clone(), self.session_config.clone());
53 self.sessions.write().await.insert(id.clone(), handle);
54 Ok((id, WorkerTransport::spawn(worker)))
55 }
56 async fn initialize_session(
57 &self,
58 id: &SessionId,
59 message: ClientJsonRpcMessage,
60 ) -> Result<ServerJsonRpcMessage, Self::Error> {
61 let sessions = self.sessions.read().await;
62 let handle = sessions
63 .get(id)
64 .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
65 let response = handle.initialize(message).await?;
66 Ok(response)
67 }
68 async fn close_session(&self, id: &SessionId) -> Result<(), Self::Error> {
69 let mut sessions = self.sessions.write().await;
70 if let Some(handle) = sessions.remove(id) {
71 handle.close().await?;
72 }
73 Ok(())
74 }
75 async fn has_session(&self, id: &SessionId) -> Result<bool, Self::Error> {
76 let sessions = self.sessions.read().await;
77 Ok(sessions.contains_key(id))
78 }
79 async fn create_stream(
80 &self,
81 id: &SessionId,
82 message: ClientJsonRpcMessage,
83 ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
84 let sessions = self.sessions.read().await;
85 let handle = sessions
86 .get(id)
87 .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
88 let receiver = handle.establish_request_wise_channel().await?;
89 let http_request_id = receiver.http_request_id;
90 handle.push_message(message, http_request_id).await?;
91
92 let priming = self.session_config.sse_retry.map(|retry| {
93 let event_id = match http_request_id {
94 Some(id) => format!("0/{id}"),
95 None => "0".into(),
96 };
97 ServerSseMessage::priming(event_id, retry)
98 });
99 Ok(futures::stream::iter(priming).chain(ReceiverStream::new(receiver.inner)))
100 }
101
102 async fn create_standalone_stream(
103 &self,
104 id: &SessionId,
105 ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
106 let sessions = self.sessions.read().await;
107 let handle = sessions
108 .get(id)
109 .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
110 let receiver = handle.establish_common_channel().await?;
111 Ok(ReceiverStream::new(receiver.inner))
112 }
113
114 async fn resume(
115 &self,
116 id: &SessionId,
117 last_event_id: String,
118 ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
119 let sessions = self.sessions.read().await;
120 let handle = sessions
121 .get(id)
122 .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
123 let receiver = handle.resume(last_event_id.parse()?).await?;
124 Ok(ReceiverStream::new(receiver.inner))
125 }
126
127 async fn accept_message(
128 &self,
129 id: &SessionId,
130 message: ClientJsonRpcMessage,
131 ) -> Result<(), Self::Error> {
132 let sessions = self.sessions.read().await;
133 let handle = sessions
134 .get(id)
135 .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
136 handle.push_message(message, None).await?;
137 Ok(())
138 }
139
140 async fn restore_session(
141 &self,
142 id: SessionId,
143 ) -> Result<RestoreOutcome<Self::Transport>, Self::Error> {
144 let mut sessions = self.sessions.write().await;
145 if sessions.contains_key(&id) {
146 return Ok(RestoreOutcome::AlreadyPresent);
148 }
149 let (handle, worker) = create_local_session(id.clone(), self.session_config.clone());
150 sessions.insert(id, handle);
151 Ok(RestoreOutcome::Restored(WorkerTransport::spawn(worker)))
152 }
153}
154
155#[derive(Debug, Clone, PartialEq, Eq, Hash)]
157pub struct EventId {
158 http_request_id: Option<HttpRequestId>,
159 index: usize,
160}
161
162impl std::fmt::Display for EventId {
163 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164 write!(f, "{}", self.index)?;
165 match &self.http_request_id {
166 Some(http_request_id) => write!(f, "/{http_request_id}"),
167 None => write!(f, ""),
168 }
169 }
170}
171
172#[derive(Debug, Clone, Error)]
173#[non_exhaustive]
174pub enum EventIdParseError {
175 #[error("Invalid index: {0}")]
176 InvalidIndex(ParseIntError),
177 #[error("Invalid numeric request id: {0}")]
178 InvalidNumericRequestId(ParseIntError),
179 #[error("Missing request id type")]
180 InvalidRequestIdType,
181 #[error("Missing request id")]
182 MissingRequestId,
183}
184
185impl std::str::FromStr for EventId {
186 type Err = EventIdParseError;
187 fn from_str(s: &str) -> Result<Self, Self::Err> {
188 if let Some((index, request_id)) = s.split_once("/") {
189 let index = usize::from_str(index).map_err(EventIdParseError::InvalidIndex)?;
190 let request_id = u64::from_str(request_id).map_err(EventIdParseError::InvalidIndex)?;
191 Ok(EventId {
192 http_request_id: Some(request_id),
193 index,
194 })
195 } else {
196 let index = usize::from_str(s).map_err(EventIdParseError::InvalidIndex)?;
197 Ok(EventId {
198 http_request_id: None,
199 index,
200 })
201 }
202 }
203}
204
205use super::{RestoreOutcome, ServerSseMessage, SessionManager};
206
207struct CachedTx {
208 tx: Sender<ServerSseMessage>,
209 cache: VecDeque<ServerSseMessage>,
210 http_request_id: Option<HttpRequestId>,
211 capacity: usize,
212 starting_index: usize,
213}
214
215impl CachedTx {
216 fn new(
217 tx: Sender<ServerSseMessage>,
218 http_request_id: Option<HttpRequestId>,
219 starting_index: usize,
220 ) -> Self {
221 Self {
222 cache: VecDeque::with_capacity(tx.capacity()),
223 capacity: tx.capacity(),
224 tx,
225 http_request_id,
226 starting_index,
227 }
228 }
229 fn new_common(tx: Sender<ServerSseMessage>) -> Self {
230 Self::new(tx, None, 0)
231 }
232
233 fn next_event_id(&self) -> EventId {
234 let index = self.cache.back().map_or(self.starting_index, |m| {
235 m.event_id
236 .as_deref()
237 .unwrap_or_default()
238 .parse::<EventId>()
239 .expect("valid event id")
240 .index
241 + 1
242 });
243 EventId {
244 http_request_id: self.http_request_id,
245 index,
246 }
247 }
248
249 async fn send(&mut self, message: ServerJsonRpcMessage) {
250 let event_id = self.next_event_id();
251 let message = ServerSseMessage::new(event_id.to_string(), message);
252 self.cache_and_send(message).await;
253 }
254
255 async fn send_priming(&mut self, retry: Duration) {
256 let event_id = self.next_event_id();
257 let message = ServerSseMessage::priming(event_id.to_string(), retry);
258 self.cache_and_send(message).await;
259 }
260
261 async fn cache_and_send(&mut self, message: ServerSseMessage) {
262 if self.cache.len() >= self.capacity {
263 self.cache.pop_front();
264 self.cache.push_back(message.clone());
265 } else {
266 self.cache.push_back(message.clone());
267 }
268 let _ = self.tx.send(message).await.inspect_err(|e| {
269 let event_id = &e.0.event_id;
270 tracing::trace!(?event_id, "trying to send message in a closed session")
271 });
272 }
273
274 async fn sync(&mut self, index: usize) -> Result<(), SessionError> {
275 let Some(front) = self.cache.front() else {
276 return Ok(());
277 };
278 let front_event_id = front
279 .event_id
280 .as_deref()
281 .unwrap_or_default()
282 .parse::<EventId>()?;
283 let sync_index = index.saturating_sub(front_event_id.index);
284 if sync_index > self.cache.len() {
285 return Err(SessionError::InvalidEventId);
287 }
288 for message in self.cache.iter().skip(sync_index) {
289 let send_result = self.tx.send(message.clone()).await;
290 if send_result.is_err() {
291 let event_id: EventId = message.event_id.as_deref().unwrap_or_default().parse()?;
292 return Err(SessionError::ChannelClosed(Some(event_id.index as u64)));
293 }
294 }
295 Ok(())
296 }
297}
298
299struct HttpRequestWise {
300 resources: HashSet<ResourceKey>,
301 tx: CachedTx,
302 completed_at: Option<Instant>,
303}
304
305type HttpRequestId = u64;
306#[derive(Debug, Clone, Hash, PartialEq, Eq)]
307enum ResourceKey {
308 McpRequestId(RequestId),
309 ProgressToken(ProgressToken),
310}
311
312pub struct LocalSessionWorker {
313 id: SessionId,
314 next_http_request_id: HttpRequestId,
315 tx_router: HashMap<HttpRequestId, HttpRequestWise>,
316 resource_router: HashMap<ResourceKey, HttpRequestId>,
317 common: CachedTx,
318 shadow_txs: Vec<Sender<ServerSseMessage>>,
324 event_rx: Receiver<SessionEvent>,
325 session_config: SessionConfig,
326}
327
328impl LocalSessionWorker {
329 pub fn id(&self) -> &SessionId {
330 &self.id
331 }
332}
333
334#[derive(Debug, Error)]
335#[non_exhaustive]
336pub enum SessionError {
337 #[error("Invalid request id: {0}")]
338 DuplicatedRequestId(HttpRequestId),
339 #[error("Channel closed: {0:?}")]
340 ChannelClosed(Option<HttpRequestId>),
341 #[error("Cannot parse event id: {0}")]
342 EventIdParseError(#[from] EventIdParseError),
343 #[error("Session service terminated")]
344 SessionServiceTerminated,
345 #[error("Invalid event id")]
346 InvalidEventId,
347 #[error("IO error: {0}")]
348 Io(#[from] std::io::Error),
349}
350
351impl From<SessionError> for std::io::Error {
352 fn from(value: SessionError) -> Self {
353 match value {
354 SessionError::Io(io) => io,
355 _ => std::io::Error::other(format!("Session error: {value}")),
356 }
357 }
358}
359
360enum OutboundChannel {
361 RequestWise { id: HttpRequestId, close: bool },
362 Common,
363}
364#[derive(Debug)]
365#[non_exhaustive]
366pub struct StreamableHttpMessageReceiver {
367 pub http_request_id: Option<HttpRequestId>,
368 pub inner: Receiver<ServerSseMessage>,
369}
370
371impl LocalSessionWorker {
372 fn unregister_resource(&mut self, resource: &ResourceKey) {
373 let Some(http_request_id) = self.resource_router.remove(resource) else {
374 return;
375 };
376 tracing::trace!(?resource, http_request_id, "unregister resource");
377 let Some(channel) = self.tx_router.get_mut(&http_request_id) else {
378 tracing::warn!(http_request_id, "http request wise channel not found");
379 return;
380 };
381 if !channel.resources.is_empty() && !matches!(resource, ResourceKey::McpRequestId(_)) {
382 return;
383 }
384 tracing::debug!(http_request_id, "close http request wise channel");
385 let resources: Vec<_> = channel.resources.drain().collect();
386 channel.completed_at = Some(Instant::now());
387 let (closed_tx, _) = tokio::sync::mpsc::channel(1);
391 channel.tx.tx = closed_tx;
392 for resource in resources {
393 self.resource_router.remove(&resource);
394 }
395 }
396 fn register_resource(&mut self, resource: ResourceKey, http_request_id: HttpRequestId) {
397 tracing::trace!(?resource, http_request_id, "register resource");
398 if let Some(channel) = self.tx_router.get_mut(&http_request_id) {
399 channel.resources.insert(resource.clone());
400 self.resource_router.insert(resource, http_request_id);
401 }
402 }
403 fn register_request(
404 &mut self,
405 request: &JsonRpcRequest<ClientRequest>,
406 http_request_id: HttpRequestId,
407 ) {
408 use crate::model::GetMeta;
409 self.register_resource(
410 ResourceKey::McpRequestId(request.id.clone()),
411 http_request_id,
412 );
413 if let Some(progress_token) = request.request.get_meta().get_progress_token() {
414 self.register_resource(
415 ResourceKey::ProgressToken(progress_token.clone()),
416 http_request_id,
417 );
418 }
419 }
420 fn catch_cancellation_notification(
421 &mut self,
422 notification: &JsonRpcNotification<ClientNotification>,
423 ) {
424 if let ClientNotification::CancelledNotification(n) = ¬ification.notification {
425 let request_id = n.params.request_id.clone();
426 let resource = ResourceKey::McpRequestId(request_id);
427 self.unregister_resource(&resource);
428 }
429 }
430 fn evict_expired_channels(&mut self) {
431 let ttl = self.session_config.completed_cache_ttl;
432 self.tx_router
433 .retain(|_, rw| rw.completed_at.is_none_or(|at| at.elapsed() < ttl));
434 }
435 fn next_http_request_id(&mut self) -> HttpRequestId {
436 let id = self.next_http_request_id;
437 self.next_http_request_id = self.next_http_request_id.wrapping_add(1);
438 id
439 }
440 async fn establish_request_wise_channel(
441 &mut self,
442 ) -> Result<StreamableHttpMessageReceiver, SessionError> {
443 let http_request_id = self.next_http_request_id();
444 let (tx, rx) = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
445 let starting_index = usize::from(self.session_config.sse_retry.is_some());
446 self.tx_router.insert(
447 http_request_id,
448 HttpRequestWise {
449 resources: Default::default(),
450 tx: CachedTx::new(tx, Some(http_request_id), starting_index),
451 completed_at: None,
452 },
453 );
454 tracing::debug!(http_request_id, "establish new request wise channel");
455 Ok(StreamableHttpMessageReceiver {
456 http_request_id: Some(http_request_id),
457 inner: rx,
458 })
459 }
460 fn resolve_outbound_channel(&self, message: &ServerJsonRpcMessage) -> OutboundChannel {
461 match &message {
462 ServerJsonRpcMessage::Request(_) => OutboundChannel::Common,
463 ServerJsonRpcMessage::Notification(JsonRpcNotification {
464 notification:
465 ServerNotification::ProgressNotification(Notification {
466 params: ProgressNotificationParam { progress_token, .. },
467 ..
468 }),
469 ..
470 }) => {
471 let id = self
472 .resource_router
473 .get(&ResourceKey::ProgressToken(progress_token.clone()));
474
475 if let Some(id) = id {
476 OutboundChannel::RequestWise {
477 id: *id,
478 close: false,
479 }
480 } else {
481 OutboundChannel::Common
482 }
483 }
484 ServerJsonRpcMessage::Notification(JsonRpcNotification {
485 notification:
486 ServerNotification::CancelledNotification(Notification {
487 params: CancelledNotificationParam { request_id, .. },
488 ..
489 }),
490 ..
491 }) => {
492 if let Some(id) = self
493 .resource_router
494 .get(&ResourceKey::McpRequestId(request_id.clone()))
495 {
496 OutboundChannel::RequestWise {
497 id: *id,
498 close: false,
499 }
500 } else {
501 OutboundChannel::Common
502 }
503 }
504 ServerJsonRpcMessage::Notification(_) => OutboundChannel::Common,
505 ServerJsonRpcMessage::Response(json_rpc_response) => {
506 if let Some(id) = self
507 .resource_router
508 .get(&ResourceKey::McpRequestId(json_rpc_response.id.clone()))
509 {
510 OutboundChannel::RequestWise {
511 id: *id,
512 close: true,
513 }
514 } else {
515 OutboundChannel::Common
516 }
517 }
518 ServerJsonRpcMessage::Error(json_rpc_error) => {
519 if let Some(id) = self
520 .resource_router
521 .get(&ResourceKey::McpRequestId(json_rpc_error.id.clone()))
522 {
523 OutboundChannel::RequestWise {
524 id: *id,
525 close: true,
526 }
527 } else {
528 OutboundChannel::Common
529 }
530 }
531 }
532 }
533 async fn handle_server_message(
534 &mut self,
535 message: ServerJsonRpcMessage,
536 ) -> Result<(), SessionError> {
537 let outbound_channel = self.resolve_outbound_channel(&message);
538 match outbound_channel {
539 OutboundChannel::RequestWise { id, close } => {
540 if let Some(request_wise) = self.tx_router.get_mut(&id) {
541 request_wise.tx.send(message).await;
542 if close {
543 if let Some(channel) = self.tx_router.remove(&id) {
544 for resource in channel.resources {
545 self.resource_router.remove(&resource);
546 }
547 }
548 }
549 } else {
550 return Err(SessionError::ChannelClosed(Some(id)));
551 }
552 }
553 OutboundChannel::Common => self.common.send(message).await,
554 }
555 Ok(())
556 }
557 async fn resume(
558 &mut self,
559 last_event_id: EventId,
560 ) -> Result<StreamableHttpMessageReceiver, SessionError> {
561 self.shadow_txs.retain(|tx| !tx.is_closed());
563
564 match last_event_id.http_request_id {
565 Some(http_request_id) => {
566 let request_wise = self
567 .tx_router
568 .get_mut(&http_request_id)
569 .ok_or(SessionError::ChannelClosed(Some(http_request_id)))?;
570 let is_completed = request_wise.completed_at.is_some();
571 let (tx, rx) = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
572 request_wise.tx.tx = tx;
573 let index = last_event_id.index;
574 request_wise.tx.sync(index).await?;
575 if is_completed {
576 let (closed_tx, _) = tokio::sync::mpsc::channel(1);
579 request_wise.tx.tx = closed_tx;
580 }
581 Ok(StreamableHttpMessageReceiver {
582 http_request_id: Some(http_request_id),
583 inner: rx,
584 })
585 }
586 None => self.resume_or_shadow_common(last_event_id.index).await,
587 }
588 }
589
590 async fn resume_or_shadow_common(
603 &mut self,
604 last_event_index: usize,
605 ) -> Result<StreamableHttpMessageReceiver, SessionError> {
606 let is_replacing_dead_primary = self.common.tx.is_closed();
607 let capacity = if is_replacing_dead_primary {
608 self.session_config.channel_capacity
609 } else {
610 1 };
612 let (tx, rx) = tokio::sync::mpsc::channel(capacity);
613 if is_replacing_dead_primary {
614 tracing::debug!("Replacing dead common channel with new primary");
616 self.common.tx = tx;
617 self.common.sync(last_event_index).await?;
620 } else {
621 const MAX_SHADOW_STREAMS: usize = 32;
626
627 if self.shadow_txs.len() >= MAX_SHADOW_STREAMS {
628 tracing::warn!(
629 shadow_count = self.shadow_txs.len(),
630 "Shadow stream limit reached, dropping oldest"
631 );
632 self.shadow_txs.remove(0);
633 }
634 tracing::debug!(
635 shadow_count = self.shadow_txs.len(),
636 "Common channel active, creating shadow stream"
637 );
638 self.shadow_txs.push(tx);
639 }
640 Ok(StreamableHttpMessageReceiver {
641 http_request_id: None,
642 inner: rx,
643 })
644 }
645
646 async fn close_sse_stream(
647 &mut self,
648 http_request_id: Option<HttpRequestId>,
649 retry_interval: Option<Duration>,
650 ) -> Result<(), SessionError> {
651 match http_request_id {
652 Some(id) => {
654 let request_wise = self
655 .tx_router
656 .get_mut(&id)
657 .ok_or(SessionError::ChannelClosed(Some(id)))?;
658
659 if let Some(interval) = retry_interval {
661 request_wise.tx.send_priming(interval).await;
662 }
663
664 let (tx, _rx) = tokio::sync::mpsc::channel(1);
666 request_wise.tx.tx = tx;
667
668 tracing::debug!(
669 http_request_id = id,
670 "closed SSE stream for server-initiated disconnection"
671 );
672 Ok(())
673 }
674 None => {
676 if let Some(interval) = retry_interval {
678 self.common.send_priming(interval).await;
679 }
680
681 let (tx, _rx) = tokio::sync::mpsc::channel(1);
683 self.common.tx = tx;
684
685 self.shadow_txs.clear();
687
688 tracing::debug!("closed standalone SSE stream for server-initiated disconnection");
689 Ok(())
690 }
691 }
692 }
693}
694
695#[derive(Debug)]
696#[non_exhaustive]
697pub enum SessionEvent {
698 ClientMessage {
699 message: ClientJsonRpcMessage,
700 http_request_id: Option<HttpRequestId>,
701 },
702 EstablishRequestWiseChannel {
703 responder: oneshot::Sender<Result<StreamableHttpMessageReceiver, SessionError>>,
704 },
705 CloseRequestWiseChannel {
706 id: HttpRequestId,
707 responder: oneshot::Sender<Result<(), SessionError>>,
708 },
709 Resume {
710 last_event_id: EventId,
711 responder: oneshot::Sender<Result<StreamableHttpMessageReceiver, SessionError>>,
712 },
713 InitializeRequest {
714 request: ClientJsonRpcMessage,
715 responder: oneshot::Sender<Result<ServerJsonRpcMessage, SessionError>>,
716 },
717 Close,
718 CloseSseStream {
719 http_request_id: Option<HttpRequestId>,
721 retry_interval: Option<Duration>,
723 responder: oneshot::Sender<Result<(), SessionError>>,
724 },
725}
726
727#[derive(Debug, Clone)]
728#[non_exhaustive]
729pub enum SessionQuitReason {
730 ServiceTerminated,
731 ClientTerminated,
732 ExpectInitializeRequest,
733 ExpectInitializeResponse,
734 Cancelled,
735}
736
737#[derive(Debug, Clone)]
738pub struct LocalSessionHandle {
739 id: SessionId,
740 event_tx: Sender<SessionEvent>,
742}
743
744impl LocalSessionHandle {
745 pub fn id(&self) -> &SessionId {
747 &self.id
748 }
749
750 pub async fn close(&self) -> Result<(), SessionError> {
752 self.event_tx
753 .send(SessionEvent::Close)
754 .await
755 .map_err(|_| SessionError::SessionServiceTerminated)?;
756 Ok(())
757 }
758
759 pub async fn push_message(
761 &self,
762 message: ClientJsonRpcMessage,
763 http_request_id: Option<HttpRequestId>,
764 ) -> Result<(), SessionError> {
765 self.event_tx
766 .send(SessionEvent::ClientMessage {
767 message,
768 http_request_id,
769 })
770 .await
771 .map_err(|_| SessionError::SessionServiceTerminated)?;
772 Ok(())
773 }
774
775 pub async fn establish_request_wise_channel(
779 &self,
780 ) -> Result<StreamableHttpMessageReceiver, SessionError> {
781 let (tx, rx) = tokio::sync::oneshot::channel();
782 self.event_tx
783 .send(SessionEvent::EstablishRequestWiseChannel { responder: tx })
784 .await
785 .map_err(|_| SessionError::SessionServiceTerminated)?;
786 rx.await
787 .map_err(|_| SessionError::SessionServiceTerminated)?
788 }
789
790 pub async fn close_request_wise_channel(
792 &self,
793 request_id: HttpRequestId,
794 ) -> Result<(), SessionError> {
795 let (tx, rx) = tokio::sync::oneshot::channel();
796 self.event_tx
797 .send(SessionEvent::CloseRequestWiseChannel {
798 id: request_id,
799 responder: tx,
800 })
801 .await
802 .map_err(|_| SessionError::SessionServiceTerminated)?;
803 rx.await
804 .map_err(|_| SessionError::SessionServiceTerminated)?
805 }
806
807 pub async fn establish_common_channel(
809 &self,
810 ) -> Result<StreamableHttpMessageReceiver, SessionError> {
811 let (tx, rx) = tokio::sync::oneshot::channel();
812 self.event_tx
813 .send(SessionEvent::Resume {
814 last_event_id: EventId {
815 http_request_id: None,
816 index: 0,
817 },
818 responder: tx,
819 })
820 .await
821 .map_err(|_| SessionError::SessionServiceTerminated)?;
822 rx.await
823 .map_err(|_| SessionError::SessionServiceTerminated)?
824 }
825
826 pub async fn resume(
828 &self,
829 last_event_id: EventId,
830 ) -> Result<StreamableHttpMessageReceiver, SessionError> {
831 let (tx, rx) = tokio::sync::oneshot::channel();
832 self.event_tx
833 .send(SessionEvent::Resume {
834 last_event_id,
835 responder: tx,
836 })
837 .await
838 .map_err(|_| SessionError::SessionServiceTerminated)?;
839 rx.await
840 .map_err(|_| SessionError::SessionServiceTerminated)?
841 }
842
843 pub async fn initialize(
847 &self,
848 request: ClientJsonRpcMessage,
849 ) -> Result<ServerJsonRpcMessage, SessionError> {
850 let (tx, rx) = tokio::sync::oneshot::channel();
851 self.event_tx
852 .send(SessionEvent::InitializeRequest {
853 request,
854 responder: tx,
855 })
856 .await
857 .map_err(|_| SessionError::SessionServiceTerminated)?;
858 rx.await
859 .map_err(|_| SessionError::SessionServiceTerminated)?
860 }
861
862 pub async fn close_sse_stream(
873 &self,
874 http_request_id: HttpRequestId,
875 retry_interval: Option<Duration>,
876 ) -> Result<(), SessionError> {
877 let (tx, rx) = tokio::sync::oneshot::channel();
878 self.event_tx
879 .send(SessionEvent::CloseSseStream {
880 http_request_id: Some(http_request_id),
881 retry_interval,
882 responder: tx,
883 })
884 .await
885 .map_err(|_| SessionError::SessionServiceTerminated)?;
886 rx.await
887 .map_err(|_| SessionError::SessionServiceTerminated)?
888 }
889
890 pub async fn close_standalone_sse_stream(
900 &self,
901 retry_interval: Option<Duration>,
902 ) -> Result<(), SessionError> {
903 let (tx, rx) = tokio::sync::oneshot::channel();
904 self.event_tx
905 .send(SessionEvent::CloseSseStream {
906 http_request_id: None,
907 retry_interval,
908 responder: tx,
909 })
910 .await
911 .map_err(|_| SessionError::SessionServiceTerminated)?;
912 rx.await
913 .map_err(|_| SessionError::SessionServiceTerminated)?
914 }
915}
916
917pub type SessionTransport = WorkerTransport<LocalSessionWorker>;
918
919#[allow(clippy::large_enum_variant)]
920#[derive(Debug, Error)]
921#[non_exhaustive]
922pub enum LocalSessionWorkerError {
923 #[error("transport terminated")]
924 TransportTerminated,
925 #[error("unexpected message: {0:?}")]
926 UnexpectedEvent(SessionEvent),
927 #[error("fail to send initialize request {0}")]
928 FailToSendInitializeRequest(SessionError),
929 #[error("fail to handle message: {0}")]
930 FailToHandleMessage(SessionError),
931 #[error("keep alive timeout after {}ms", _0.as_millis())]
932 KeepAliveTimeout(Duration),
933 #[error("Transport closed")]
934 TransportClosed,
935 #[error("Tokio join error {0}")]
936 TokioJoinError(#[from] tokio::task::JoinError),
937}
938impl Worker for LocalSessionWorker {
939 type Error = LocalSessionWorkerError;
940 type Role = RoleServer;
941 fn err_closed() -> Self::Error {
942 LocalSessionWorkerError::TransportClosed
943 }
944 fn err_join(e: tokio::task::JoinError) -> Self::Error {
945 LocalSessionWorkerError::TokioJoinError(e)
946 }
947 fn config(&self) -> crate::transport::worker::WorkerConfig {
948 crate::transport::worker::WorkerConfig {
949 name: Some(format!("streamable-http-session-{}", self.id)),
950 channel_buffer_capacity: self.session_config.channel_capacity,
951 }
952 }
953 #[instrument(name = "streamable_http_session", skip_all, fields(id = self.id.as_ref()))]
954 async fn run(
955 mut self,
956 mut context: WorkerContext<Self>,
957 ) -> Result<(), WorkerQuitReason<Self::Error>> {
958 enum InnerEvent {
959 FromHttpService(SessionEvent),
960 FromHandler(WorkerSendRequest<LocalSessionWorker>),
961 }
962 let evt = self.event_rx.recv().await.ok_or_else(|| {
964 WorkerQuitReason::fatal(
965 LocalSessionWorkerError::TransportTerminated,
966 "get initialize request",
967 )
968 })?;
969 let SessionEvent::InitializeRequest { request, responder } = evt else {
970 return Err(WorkerQuitReason::fatal(
971 LocalSessionWorkerError::UnexpectedEvent(evt),
972 "get initialize request",
973 ));
974 };
975 context.send_to_handler(request).await?;
976 let send_initialize_response = context.recv_from_handler().await?;
977 responder
978 .send(Ok(send_initialize_response.message))
979 .map_err(|_| {
980 WorkerQuitReason::fatal(
981 LocalSessionWorkerError::FailToSendInitializeRequest(
982 SessionError::SessionServiceTerminated,
983 ),
984 "send initialize response",
985 )
986 })?;
987 send_initialize_response
988 .responder
989 .send(Ok(()))
990 .map_err(|_| WorkerQuitReason::HandlerTerminated)?;
991 let ct = context.cancellation_token.clone();
992 let keep_alive = self.session_config.keep_alive.unwrap_or(Duration::MAX);
993 loop {
994 self.evict_expired_channels();
995 let keep_alive_timeout = tokio::time::sleep(keep_alive);
996 let event = tokio::select! {
997 event = self.event_rx.recv() => {
998 if let Some(event) = event {
999 InnerEvent::FromHttpService(event)
1000 } else {
1001 return Err(WorkerQuitReason::fatal(LocalSessionWorkerError::TransportTerminated, "waiting next session event"))
1002 }
1003 },
1004 from_handler = context.recv_from_handler() => {
1005 InnerEvent::FromHandler(from_handler?)
1006 }
1007 _ = ct.cancelled() => {
1008 return Err(WorkerQuitReason::Cancelled)
1009 }
1010 _ = keep_alive_timeout => {
1011 return Err(WorkerQuitReason::fatal(LocalSessionWorkerError::KeepAliveTimeout(keep_alive), "poll next session event"))
1012 }
1013 };
1014 match event {
1015 InnerEvent::FromHandler(WorkerSendRequest { message, responder }) => {
1016 let to_unregister = match &message {
1018 crate::model::JsonRpcMessage::Response(json_rpc_response) => {
1019 let request_id = json_rpc_response.id.clone();
1020 Some(ResourceKey::McpRequestId(request_id))
1021 }
1022 crate::model::JsonRpcMessage::Error(json_rpc_error) => {
1023 let request_id = json_rpc_error.id.clone();
1024 Some(ResourceKey::McpRequestId(request_id))
1025 }
1026 _ => {
1027 None
1028 }
1030 };
1031 let handle_result = self
1032 .handle_server_message(message)
1033 .await
1034 .map_err(LocalSessionWorkerError::FailToHandleMessage);
1035 let _ = responder.send(handle_result).inspect_err(|error| {
1036 tracing::warn!(?error, "failed to send message to http service handler");
1037 });
1038 if let Some(to_unregister) = to_unregister {
1039 self.unregister_resource(&to_unregister);
1040 }
1041 }
1042 InnerEvent::FromHttpService(SessionEvent::ClientMessage {
1043 message: json_rpc_message,
1044 http_request_id,
1045 }) => {
1046 match &json_rpc_message {
1047 crate::model::JsonRpcMessage::Request(request) => {
1048 if let Some(http_request_id) = http_request_id {
1049 self.register_request(request, http_request_id)
1050 }
1051 }
1052 crate::model::JsonRpcMessage::Notification(notification) => {
1053 self.catch_cancellation_notification(notification)
1054 }
1055 _ => {}
1056 }
1057 context.send_to_handler(json_rpc_message).await?;
1058 }
1059 InnerEvent::FromHttpService(SessionEvent::EstablishRequestWiseChannel {
1060 responder,
1061 }) => {
1062 let handle_result = self.establish_request_wise_channel().await;
1063 let _ = responder.send(handle_result);
1064 }
1065 InnerEvent::FromHttpService(SessionEvent::CloseRequestWiseChannel {
1066 id,
1067 responder,
1068 }) => {
1069 let _handle_result = self.tx_router.remove(&id);
1070 let _ = responder.send(Ok(()));
1071 }
1072 InnerEvent::FromHttpService(SessionEvent::Resume {
1073 last_event_id,
1074 responder,
1075 }) => {
1076 let handle_result = self.resume(last_event_id).await;
1077 let _ = responder.send(handle_result);
1078 }
1079 InnerEvent::FromHttpService(SessionEvent::Close) => {
1080 return Err(WorkerQuitReason::TransportClosed);
1081 }
1082 InnerEvent::FromHttpService(SessionEvent::CloseSseStream {
1083 http_request_id,
1084 retry_interval,
1085 responder,
1086 }) => {
1087 let handle_result =
1088 self.close_sse_stream(http_request_id, retry_interval).await;
1089 let _ = responder.send(handle_result);
1090 }
1091 _ => {
1092 }
1094 }
1095 }
1096 }
1097}
1098
1099#[derive(Debug, Clone)]
1100#[non_exhaustive]
1101pub struct SessionConfig {
1102 pub channel_capacity: usize,
1104 pub keep_alive: Option<Duration>,
1116 pub sse_retry: Option<Duration>,
1121 pub completed_cache_ttl: Duration,
1125}
1126
1127impl SessionConfig {
1128 pub const DEFAULT_CHANNEL_CAPACITY: usize = 16;
1129 pub const DEFAULT_KEEP_ALIVE: Duration = Duration::from_secs(300);
1130 pub const DEFAULT_SSE_RETRY: Duration = Duration::from_secs(3);
1131 pub const DEFAULT_COMPLETED_CACHE_TTL: Duration = Duration::from_secs(60);
1132}
1133
1134impl Default for SessionConfig {
1135 fn default() -> Self {
1136 Self {
1137 channel_capacity: Self::DEFAULT_CHANNEL_CAPACITY,
1138 keep_alive: Some(Self::DEFAULT_KEEP_ALIVE),
1139 sse_retry: Some(Self::DEFAULT_SSE_RETRY),
1140 completed_cache_ttl: Self::DEFAULT_COMPLETED_CACHE_TTL,
1141 }
1142 }
1143}
1144
1145pub fn create_local_session(
1151 id: impl Into<SessionId>,
1152 config: SessionConfig,
1153) -> (LocalSessionHandle, LocalSessionWorker) {
1154 let id = id.into();
1155 let (event_tx, event_rx) = tokio::sync::mpsc::channel(config.channel_capacity);
1156 let (common_tx, _) = tokio::sync::mpsc::channel(config.channel_capacity);
1157 let common = CachedTx::new_common(common_tx);
1158 tracing::info!(session_id = ?id, "create new session");
1159 let handle = LocalSessionHandle {
1160 event_tx,
1161 id: id.clone(),
1162 };
1163 let session_worker = LocalSessionWorker {
1164 next_http_request_id: 0,
1165 id,
1166 tx_router: HashMap::new(),
1167 resource_router: HashMap::new(),
1168 common,
1169 shadow_txs: Vec::new(),
1170 event_rx,
1171 session_config: config.clone(),
1172 };
1173 (handle, session_worker)
1174}