Skip to main content

rmcp_soddygo/transport/streamable_http_server/session/
local.rs

1use std::{
2    collections::{HashMap, HashSet, VecDeque},
3    num::ParseIntError,
4    time::{Duration, Instant},
5};
6
7use futures::{Stream, StreamExt};
8use thiserror::Error;
9use tokio::sync::{
10    mpsc::{Receiver, Sender},
11    oneshot,
12};
13use tokio_stream::wrappers::ReceiverStream;
14use tracing::instrument;
15
16use crate::{
17    RoleServer,
18    model::{
19        CancelledNotificationParam, ClientJsonRpcMessage, ClientNotification, ClientRequest,
20        JsonRpcNotification, JsonRpcRequest, Notification, ProgressNotificationParam,
21        ProgressToken, RequestId, ServerJsonRpcMessage, ServerNotification,
22    },
23    transport::{
24        WorkerTransport,
25        common::server_side_http::{SessionId, session_id},
26        worker::{Worker, WorkerContext, WorkerQuitReason, WorkerSendRequest},
27    },
28};
29
30#[derive(Debug, Default)]
31#[non_exhaustive]
32pub struct LocalSessionManager {
33    pub sessions: tokio::sync::RwLock<HashMap<SessionId, LocalSessionHandle>>,
34    pub session_config: SessionConfig,
35}
36
37#[derive(Debug, Error)]
38#[non_exhaustive]
39pub enum LocalSessionManagerError {
40    #[error("Session not found: {0}")]
41    SessionNotFound(SessionId),
42    #[error("Session error: {0}")]
43    SessionError(#[from] SessionError),
44    #[error("Invalid event id: {0}")]
45    InvalidEventId(#[from] EventIdParseError),
46}
47impl SessionManager for LocalSessionManager {
48    type Error = LocalSessionManagerError;
49    type Transport = WorkerTransport<LocalSessionWorker>;
50    async fn create_session(&self) -> Result<(SessionId, Self::Transport), Self::Error> {
51        let id = session_id();
52        let (handle, worker) = create_local_session(id.clone(), self.session_config.clone());
53        self.sessions.write().await.insert(id.clone(), handle);
54        Ok((id, WorkerTransport::spawn(worker)))
55    }
56    async fn initialize_session(
57        &self,
58        id: &SessionId,
59        message: ClientJsonRpcMessage,
60    ) -> Result<ServerJsonRpcMessage, Self::Error> {
61        let sessions = self.sessions.read().await;
62        let handle = sessions
63            .get(id)
64            .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
65        let response = handle.initialize(message).await?;
66        Ok(response)
67    }
68    async fn close_session(&self, id: &SessionId) -> Result<(), Self::Error> {
69        let mut sessions = self.sessions.write().await;
70        if let Some(handle) = sessions.remove(id) {
71            handle.close().await?;
72        }
73        Ok(())
74    }
75    async fn has_session(&self, id: &SessionId) -> Result<bool, Self::Error> {
76        let sessions = self.sessions.read().await;
77        Ok(sessions.contains_key(id))
78    }
79    async fn create_stream(
80        &self,
81        id: &SessionId,
82        message: ClientJsonRpcMessage,
83    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
84        let sessions = self.sessions.read().await;
85        let handle = sessions
86            .get(id)
87            .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
88        let receiver = handle.establish_request_wise_channel().await?;
89        let http_request_id = receiver.http_request_id;
90        handle.push_message(message, http_request_id).await?;
91
92        let priming = self.session_config.sse_retry.map(|retry| {
93            let event_id = match http_request_id {
94                Some(id) => format!("0/{id}"),
95                None => "0".into(),
96            };
97            ServerSseMessage::priming(event_id, retry)
98        });
99        Ok(futures::stream::iter(priming).chain(ReceiverStream::new(receiver.inner)))
100    }
101
102    async fn create_standalone_stream(
103        &self,
104        id: &SessionId,
105    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
106        let sessions = self.sessions.read().await;
107        let handle = sessions
108            .get(id)
109            .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
110        let receiver = handle.establish_common_channel().await?;
111        Ok(ReceiverStream::new(receiver.inner))
112    }
113
114    async fn resume(
115        &self,
116        id: &SessionId,
117        last_event_id: String,
118    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
119        let sessions = self.sessions.read().await;
120        let handle = sessions
121            .get(id)
122            .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
123        let receiver = handle.resume(last_event_id.parse()?).await?;
124        Ok(ReceiverStream::new(receiver.inner))
125    }
126
127    async fn accept_message(
128        &self,
129        id: &SessionId,
130        message: ClientJsonRpcMessage,
131    ) -> Result<(), Self::Error> {
132        let sessions = self.sessions.read().await;
133        let handle = sessions
134            .get(id)
135            .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
136        handle.push_message(message, None).await?;
137        Ok(())
138    }
139
140    async fn restore_session(
141        &self,
142        id: SessionId,
143    ) -> Result<RestoreOutcome<Self::Transport>, Self::Error> {
144        let mut sessions = self.sessions.write().await;
145        if sessions.contains_key(&id) {
146            // A concurrent request already restored this session.
147            return Ok(RestoreOutcome::AlreadyPresent);
148        }
149        let (handle, worker) = create_local_session(id.clone(), self.session_config.clone());
150        sessions.insert(id, handle);
151        Ok(RestoreOutcome::Restored(WorkerTransport::spawn(worker)))
152    }
153}
154
155/// `<index>/request_id>`
156#[derive(Debug, Clone, PartialEq, Eq, Hash)]
157pub struct EventId {
158    http_request_id: Option<HttpRequestId>,
159    index: usize,
160}
161
162impl std::fmt::Display for EventId {
163    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164        write!(f, "{}", self.index)?;
165        match &self.http_request_id {
166            Some(http_request_id) => write!(f, "/{http_request_id}"),
167            None => write!(f, ""),
168        }
169    }
170}
171
172#[derive(Debug, Clone, Error)]
173#[non_exhaustive]
174pub enum EventIdParseError {
175    #[error("Invalid index: {0}")]
176    InvalidIndex(ParseIntError),
177    #[error("Invalid numeric request id: {0}")]
178    InvalidNumericRequestId(ParseIntError),
179    #[error("Missing request id type")]
180    InvalidRequestIdType,
181    #[error("Missing request id")]
182    MissingRequestId,
183}
184
185impl std::str::FromStr for EventId {
186    type Err = EventIdParseError;
187    fn from_str(s: &str) -> Result<Self, Self::Err> {
188        if let Some((index, request_id)) = s.split_once("/") {
189            let index = usize::from_str(index).map_err(EventIdParseError::InvalidIndex)?;
190            let request_id = u64::from_str(request_id).map_err(EventIdParseError::InvalidIndex)?;
191            Ok(EventId {
192                http_request_id: Some(request_id),
193                index,
194            })
195        } else {
196            let index = usize::from_str(s).map_err(EventIdParseError::InvalidIndex)?;
197            Ok(EventId {
198                http_request_id: None,
199                index,
200            })
201        }
202    }
203}
204
205use super::{RestoreOutcome, ServerSseMessage, SessionManager};
206
207struct CachedTx {
208    tx: Sender<ServerSseMessage>,
209    cache: VecDeque<ServerSseMessage>,
210    http_request_id: Option<HttpRequestId>,
211    capacity: usize,
212    starting_index: usize,
213}
214
215impl CachedTx {
216    fn new(
217        tx: Sender<ServerSseMessage>,
218        http_request_id: Option<HttpRequestId>,
219        starting_index: usize,
220    ) -> Self {
221        Self {
222            cache: VecDeque::with_capacity(tx.capacity()),
223            capacity: tx.capacity(),
224            tx,
225            http_request_id,
226            starting_index,
227        }
228    }
229    fn new_common(tx: Sender<ServerSseMessage>) -> Self {
230        Self::new(tx, None, 0)
231    }
232
233    fn next_event_id(&self) -> EventId {
234        let index = self.cache.back().map_or(self.starting_index, |m| {
235            m.event_id
236                .as_deref()
237                .unwrap_or_default()
238                .parse::<EventId>()
239                .expect("valid event id")
240                .index
241                + 1
242        });
243        EventId {
244            http_request_id: self.http_request_id,
245            index,
246        }
247    }
248
249    async fn send(&mut self, message: ServerJsonRpcMessage) {
250        let event_id = self.next_event_id();
251        let message = ServerSseMessage::new(event_id.to_string(), message);
252        self.cache_and_send(message).await;
253    }
254
255    async fn send_priming(&mut self, retry: Duration) {
256        let event_id = self.next_event_id();
257        let message = ServerSseMessage::priming(event_id.to_string(), retry);
258        self.cache_and_send(message).await;
259    }
260
261    async fn cache_and_send(&mut self, message: ServerSseMessage) {
262        if self.cache.len() >= self.capacity {
263            self.cache.pop_front();
264            self.cache.push_back(message.clone());
265        } else {
266            self.cache.push_back(message.clone());
267        }
268        let _ = self.tx.send(message).await.inspect_err(|e| {
269            let event_id = &e.0.event_id;
270            tracing::trace!(?event_id, "trying to send message in a closed session")
271        });
272    }
273
274    async fn sync(&mut self, index: usize) -> Result<(), SessionError> {
275        let Some(front) = self.cache.front() else {
276            return Ok(());
277        };
278        let front_event_id = front
279            .event_id
280            .as_deref()
281            .unwrap_or_default()
282            .parse::<EventId>()?;
283        let sync_index = index.saturating_sub(front_event_id.index);
284        if sync_index > self.cache.len() {
285            // invalid index
286            return Err(SessionError::InvalidEventId);
287        }
288        for message in self.cache.iter().skip(sync_index) {
289            let send_result = self.tx.send(message.clone()).await;
290            if send_result.is_err() {
291                let event_id: EventId = message.event_id.as_deref().unwrap_or_default().parse()?;
292                return Err(SessionError::ChannelClosed(Some(event_id.index as u64)));
293            }
294        }
295        Ok(())
296    }
297}
298
299struct HttpRequestWise {
300    resources: HashSet<ResourceKey>,
301    tx: CachedTx,
302    completed_at: Option<Instant>,
303}
304
305type HttpRequestId = u64;
306#[derive(Debug, Clone, Hash, PartialEq, Eq)]
307enum ResourceKey {
308    McpRequestId(RequestId),
309    ProgressToken(ProgressToken),
310}
311
312pub struct LocalSessionWorker {
313    id: SessionId,
314    next_http_request_id: HttpRequestId,
315    tx_router: HashMap<HttpRequestId, HttpRequestWise>,
316    resource_router: HashMap<ResourceKey, HttpRequestId>,
317    common: CachedTx,
318    /// Shadow senders for secondary SSE streams (e.g. from POST EventSource
319    /// reconnections). These keep the HTTP connections alive via SSE keep-alive
320    /// without receiving notifications, preventing MCP clients from entering
321    /// infinite reconnect loops when multiple EventSource connections compete
322    /// to replace the common channel.
323    shadow_txs: Vec<Sender<ServerSseMessage>>,
324    event_rx: Receiver<SessionEvent>,
325    session_config: SessionConfig,
326}
327
328impl LocalSessionWorker {
329    pub fn id(&self) -> &SessionId {
330        &self.id
331    }
332}
333
334#[derive(Debug, Error)]
335#[non_exhaustive]
336pub enum SessionError {
337    #[error("Invalid request id: {0}")]
338    DuplicatedRequestId(HttpRequestId),
339    #[error("Channel closed: {0:?}")]
340    ChannelClosed(Option<HttpRequestId>),
341    #[error("Cannot parse event id: {0}")]
342    EventIdParseError(#[from] EventIdParseError),
343    #[error("Session service terminated")]
344    SessionServiceTerminated,
345    #[error("Invalid event id")]
346    InvalidEventId,
347    #[error("IO error: {0}")]
348    Io(#[from] std::io::Error),
349}
350
351impl From<SessionError> for std::io::Error {
352    fn from(value: SessionError) -> Self {
353        match value {
354            SessionError::Io(io) => io,
355            _ => std::io::Error::other(format!("Session error: {value}")),
356        }
357    }
358}
359
360enum OutboundChannel {
361    RequestWise { id: HttpRequestId, close: bool },
362    Common,
363}
364#[derive(Debug)]
365#[non_exhaustive]
366pub struct StreamableHttpMessageReceiver {
367    pub http_request_id: Option<HttpRequestId>,
368    pub inner: Receiver<ServerSseMessage>,
369}
370
371impl LocalSessionWorker {
372    fn unregister_resource(&mut self, resource: &ResourceKey) {
373        let Some(http_request_id) = self.resource_router.remove(resource) else {
374            return;
375        };
376        tracing::trace!(?resource, http_request_id, "unregister resource");
377        let Some(channel) = self.tx_router.get_mut(&http_request_id) else {
378            tracing::warn!(http_request_id, "http request wise channel not found");
379            return;
380        };
381        if !channel.resources.is_empty() && !matches!(resource, ResourceKey::McpRequestId(_)) {
382            return;
383        }
384        tracing::debug!(http_request_id, "close http request wise channel");
385        let resources: Vec<_> = channel.resources.drain().collect();
386        channel.completed_at = Some(Instant::now());
387        // Close the sender so the client's SSE stream ends,
388        // but keep the entry so the cache is available for
389        // late resume requests.
390        let (closed_tx, _) = tokio::sync::mpsc::channel(1);
391        channel.tx.tx = closed_tx;
392        for resource in resources {
393            self.resource_router.remove(&resource);
394        }
395    }
396    fn register_resource(&mut self, resource: ResourceKey, http_request_id: HttpRequestId) {
397        tracing::trace!(?resource, http_request_id, "register resource");
398        if let Some(channel) = self.tx_router.get_mut(&http_request_id) {
399            channel.resources.insert(resource.clone());
400            self.resource_router.insert(resource, http_request_id);
401        }
402    }
403    fn register_request(
404        &mut self,
405        request: &JsonRpcRequest<ClientRequest>,
406        http_request_id: HttpRequestId,
407    ) {
408        use crate::model::GetMeta;
409        self.register_resource(
410            ResourceKey::McpRequestId(request.id.clone()),
411            http_request_id,
412        );
413        if let Some(progress_token) = request.request.get_meta().get_progress_token() {
414            self.register_resource(
415                ResourceKey::ProgressToken(progress_token.clone()),
416                http_request_id,
417            );
418        }
419    }
420    fn catch_cancellation_notification(
421        &mut self,
422        notification: &JsonRpcNotification<ClientNotification>,
423    ) {
424        if let ClientNotification::CancelledNotification(n) = &notification.notification {
425            let request_id = n.params.request_id.clone();
426            let resource = ResourceKey::McpRequestId(request_id);
427            self.unregister_resource(&resource);
428        }
429    }
430    fn evict_expired_channels(&mut self) {
431        let ttl = self.session_config.completed_cache_ttl;
432        self.tx_router
433            .retain(|_, rw| rw.completed_at.is_none_or(|at| at.elapsed() < ttl));
434    }
435    fn next_http_request_id(&mut self) -> HttpRequestId {
436        let id = self.next_http_request_id;
437        self.next_http_request_id = self.next_http_request_id.wrapping_add(1);
438        id
439    }
440    async fn establish_request_wise_channel(
441        &mut self,
442    ) -> Result<StreamableHttpMessageReceiver, SessionError> {
443        let http_request_id = self.next_http_request_id();
444        let (tx, rx) = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
445        let starting_index = usize::from(self.session_config.sse_retry.is_some());
446        self.tx_router.insert(
447            http_request_id,
448            HttpRequestWise {
449                resources: Default::default(),
450                tx: CachedTx::new(tx, Some(http_request_id), starting_index),
451                completed_at: None,
452            },
453        );
454        tracing::debug!(http_request_id, "establish new request wise channel");
455        Ok(StreamableHttpMessageReceiver {
456            http_request_id: Some(http_request_id),
457            inner: rx,
458        })
459    }
460    fn resolve_outbound_channel(&self, message: &ServerJsonRpcMessage) -> OutboundChannel {
461        match &message {
462            ServerJsonRpcMessage::Request(_) => OutboundChannel::Common,
463            ServerJsonRpcMessage::Notification(JsonRpcNotification {
464                notification:
465                    ServerNotification::ProgressNotification(Notification {
466                        params: ProgressNotificationParam { progress_token, .. },
467                        ..
468                    }),
469                ..
470            }) => {
471                let id = self
472                    .resource_router
473                    .get(&ResourceKey::ProgressToken(progress_token.clone()));
474
475                if let Some(id) = id {
476                    OutboundChannel::RequestWise {
477                        id: *id,
478                        close: false,
479                    }
480                } else {
481                    OutboundChannel::Common
482                }
483            }
484            ServerJsonRpcMessage::Notification(JsonRpcNotification {
485                notification:
486                    ServerNotification::CancelledNotification(Notification {
487                        params: CancelledNotificationParam { request_id, .. },
488                        ..
489                    }),
490                ..
491            }) => {
492                if let Some(id) = self
493                    .resource_router
494                    .get(&ResourceKey::McpRequestId(request_id.clone()))
495                {
496                    OutboundChannel::RequestWise {
497                        id: *id,
498                        close: false,
499                    }
500                } else {
501                    OutboundChannel::Common
502                }
503            }
504            ServerJsonRpcMessage::Notification(_) => OutboundChannel::Common,
505            ServerJsonRpcMessage::Response(json_rpc_response) => {
506                if let Some(id) = self
507                    .resource_router
508                    .get(&ResourceKey::McpRequestId(json_rpc_response.id.clone()))
509                {
510                    OutboundChannel::RequestWise {
511                        id: *id,
512                        close: true,
513                    }
514                } else {
515                    OutboundChannel::Common
516                }
517            }
518            ServerJsonRpcMessage::Error(json_rpc_error) => {
519                if let Some(id) = self
520                    .resource_router
521                    .get(&ResourceKey::McpRequestId(json_rpc_error.id.clone()))
522                {
523                    OutboundChannel::RequestWise {
524                        id: *id,
525                        close: true,
526                    }
527                } else {
528                    OutboundChannel::Common
529                }
530            }
531        }
532    }
533    async fn handle_server_message(
534        &mut self,
535        message: ServerJsonRpcMessage,
536    ) -> Result<(), SessionError> {
537        let outbound_channel = self.resolve_outbound_channel(&message);
538        match outbound_channel {
539            OutboundChannel::RequestWise { id, close } => {
540                if let Some(request_wise) = self.tx_router.get_mut(&id) {
541                    request_wise.tx.send(message).await;
542                    if close {
543                        if let Some(channel) = self.tx_router.remove(&id) {
544                            for resource in channel.resources {
545                                self.resource_router.remove(&resource);
546                            }
547                        }
548                    }
549                } else {
550                    return Err(SessionError::ChannelClosed(Some(id)));
551                }
552            }
553            OutboundChannel::Common => self.common.send(message).await,
554        }
555        Ok(())
556    }
557    async fn resume(
558        &mut self,
559        last_event_id: EventId,
560    ) -> Result<StreamableHttpMessageReceiver, SessionError> {
561        // Clean up closed shadow senders before processing
562        self.shadow_txs.retain(|tx| !tx.is_closed());
563
564        match last_event_id.http_request_id {
565            Some(http_request_id) => {
566                let request_wise = self
567                    .tx_router
568                    .get_mut(&http_request_id)
569                    .ok_or(SessionError::ChannelClosed(Some(http_request_id)))?;
570                let is_completed = request_wise.completed_at.is_some();
571                let (tx, rx) = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
572                request_wise.tx.tx = tx;
573                let index = last_event_id.index;
574                request_wise.tx.sync(index).await?;
575                if is_completed {
576                    // Drop the sender after replaying so the stream ends
577                    // instead of hanging indefinitely.
578                    let (closed_tx, _) = tokio::sync::mpsc::channel(1);
579                    request_wise.tx.tx = closed_tx;
580                }
581                Ok(StreamableHttpMessageReceiver {
582                    http_request_id: Some(http_request_id),
583                    inner: rx,
584                })
585            }
586            None => self.resume_or_shadow_common(last_event_id.index).await,
587        }
588    }
589
590    /// Resume the common channel, or create a shadow stream if the primary is
591    /// still active.
592    ///
593    /// When the primary common channel is dead (receiver dropped), replace it
594    /// so this stream becomes the new primary notification channel. Cached
595    /// messages are replayed from `last_event_index` so the client receives
596    /// any events it missed (including server-initiated requests).
597    ///
598    /// When the primary is still active, create a "shadow" stream — an idle SSE
599    /// connection kept alive by keep-alive pings. This prevents multiple
600    /// EventSource connections (e.g. from POST response reconnections) from
601    /// killing each other by repeatedly replacing the common channel sender.
602    async fn resume_or_shadow_common(
603        &mut self,
604        last_event_index: usize,
605    ) -> Result<StreamableHttpMessageReceiver, SessionError> {
606        let is_replacing_dead_primary = self.common.tx.is_closed();
607        let capacity = if is_replacing_dead_primary {
608            self.session_config.channel_capacity
609        } else {
610            1 // Shadow streams only need keep-alive pings
611        };
612        let (tx, rx) = tokio::sync::mpsc::channel(capacity);
613        if is_replacing_dead_primary {
614            // Primary common channel is dead — replace it.
615            tracing::debug!("Replacing dead common channel with new primary");
616            self.common.tx = tx;
617            // Replay cached messages from where the client left off so
618            // server-initiated requests and notifications are not lost.
619            self.common.sync(last_event_index).await?;
620        } else {
621            // Primary common channel is still active. Create a shadow stream
622            // that stays alive via SSE keep-alive but doesn't receive
623            // notifications. This prevents competing EventSource connections
624            // from killing each other's channels.
625            const MAX_SHADOW_STREAMS: usize = 32;
626
627            if self.shadow_txs.len() >= MAX_SHADOW_STREAMS {
628                tracing::warn!(
629                    shadow_count = self.shadow_txs.len(),
630                    "Shadow stream limit reached, dropping oldest"
631                );
632                self.shadow_txs.remove(0);
633            }
634            tracing::debug!(
635                shadow_count = self.shadow_txs.len(),
636                "Common channel active, creating shadow stream"
637            );
638            self.shadow_txs.push(tx);
639        }
640        Ok(StreamableHttpMessageReceiver {
641            http_request_id: None,
642            inner: rx,
643        })
644    }
645
646    async fn close_sse_stream(
647        &mut self,
648        http_request_id: Option<HttpRequestId>,
649        retry_interval: Option<Duration>,
650    ) -> Result<(), SessionError> {
651        match http_request_id {
652            // Close a request-wise stream
653            Some(id) => {
654                let request_wise = self
655                    .tx_router
656                    .get_mut(&id)
657                    .ok_or(SessionError::ChannelClosed(Some(id)))?;
658
659                // Send priming event if retry interval is specified
660                if let Some(interval) = retry_interval {
661                    request_wise.tx.send_priming(interval).await;
662                }
663
664                // Close the stream by dropping the sender
665                let (tx, _rx) = tokio::sync::mpsc::channel(1);
666                request_wise.tx.tx = tx;
667
668                tracing::debug!(
669                    http_request_id = id,
670                    "closed SSE stream for server-initiated disconnection"
671                );
672                Ok(())
673            }
674            // Close the standalone (common) stream
675            None => {
676                // Send priming event if retry interval is specified
677                if let Some(interval) = retry_interval {
678                    self.common.send_priming(interval).await;
679                }
680
681                // Close the stream by dropping the sender
682                let (tx, _rx) = tokio::sync::mpsc::channel(1);
683                self.common.tx = tx;
684
685                // Also close all shadow streams
686                self.shadow_txs.clear();
687
688                tracing::debug!("closed standalone SSE stream for server-initiated disconnection");
689                Ok(())
690            }
691        }
692    }
693}
694
695#[derive(Debug)]
696#[non_exhaustive]
697pub enum SessionEvent {
698    ClientMessage {
699        message: ClientJsonRpcMessage,
700        http_request_id: Option<HttpRequestId>,
701    },
702    EstablishRequestWiseChannel {
703        responder: oneshot::Sender<Result<StreamableHttpMessageReceiver, SessionError>>,
704    },
705    CloseRequestWiseChannel {
706        id: HttpRequestId,
707        responder: oneshot::Sender<Result<(), SessionError>>,
708    },
709    Resume {
710        last_event_id: EventId,
711        responder: oneshot::Sender<Result<StreamableHttpMessageReceiver, SessionError>>,
712    },
713    InitializeRequest {
714        request: ClientJsonRpcMessage,
715        responder: oneshot::Sender<Result<ServerJsonRpcMessage, SessionError>>,
716    },
717    Close,
718    CloseSseStream {
719        /// The HTTP request ID to close. If `None`, closes the standalone (common) stream.
720        http_request_id: Option<HttpRequestId>,
721        /// Optional retry interval. If provided, a priming event is sent before closing.
722        retry_interval: Option<Duration>,
723        responder: oneshot::Sender<Result<(), SessionError>>,
724    },
725}
726
727#[derive(Debug, Clone)]
728#[non_exhaustive]
729pub enum SessionQuitReason {
730    ServiceTerminated,
731    ClientTerminated,
732    ExpectInitializeRequest,
733    ExpectInitializeResponse,
734    Cancelled,
735}
736
737#[derive(Debug, Clone)]
738pub struct LocalSessionHandle {
739    id: SessionId,
740    // after all event_tx drop, inner task will be terminated
741    event_tx: Sender<SessionEvent>,
742}
743
744impl LocalSessionHandle {
745    /// Get the session id
746    pub fn id(&self) -> &SessionId {
747        &self.id
748    }
749
750    /// Close the session
751    pub async fn close(&self) -> Result<(), SessionError> {
752        self.event_tx
753            .send(SessionEvent::Close)
754            .await
755            .map_err(|_| SessionError::SessionServiceTerminated)?;
756        Ok(())
757    }
758
759    /// Send a message to the session
760    pub async fn push_message(
761        &self,
762        message: ClientJsonRpcMessage,
763        http_request_id: Option<HttpRequestId>,
764    ) -> Result<(), SessionError> {
765        self.event_tx
766            .send(SessionEvent::ClientMessage {
767                message,
768                http_request_id,
769            })
770            .await
771            .map_err(|_| SessionError::SessionServiceTerminated)?;
772        Ok(())
773    }
774
775    /// establish a channel for a http-request, the corresponded message from server will be
776    /// sent through this channel. The channel will be closed when the request is completed,
777    /// or you can close it manually by calling [`LocalSessionHandle::close_request_wise_channel`].
778    pub async fn establish_request_wise_channel(
779        &self,
780    ) -> Result<StreamableHttpMessageReceiver, SessionError> {
781        let (tx, rx) = tokio::sync::oneshot::channel();
782        self.event_tx
783            .send(SessionEvent::EstablishRequestWiseChannel { responder: tx })
784            .await
785            .map_err(|_| SessionError::SessionServiceTerminated)?;
786        rx.await
787            .map_err(|_| SessionError::SessionServiceTerminated)?
788    }
789
790    /// close the http-request wise channel.
791    pub async fn close_request_wise_channel(
792        &self,
793        request_id: HttpRequestId,
794    ) -> Result<(), SessionError> {
795        let (tx, rx) = tokio::sync::oneshot::channel();
796        self.event_tx
797            .send(SessionEvent::CloseRequestWiseChannel {
798                id: request_id,
799                responder: tx,
800            })
801            .await
802            .map_err(|_| SessionError::SessionServiceTerminated)?;
803        rx.await
804            .map_err(|_| SessionError::SessionServiceTerminated)?
805    }
806
807    /// Establish a common channel for general purpose messages.
808    pub async fn establish_common_channel(
809        &self,
810    ) -> Result<StreamableHttpMessageReceiver, SessionError> {
811        let (tx, rx) = tokio::sync::oneshot::channel();
812        self.event_tx
813            .send(SessionEvent::Resume {
814                last_event_id: EventId {
815                    http_request_id: None,
816                    index: 0,
817                },
818                responder: tx,
819            })
820            .await
821            .map_err(|_| SessionError::SessionServiceTerminated)?;
822        rx.await
823            .map_err(|_| SessionError::SessionServiceTerminated)?
824    }
825
826    /// Resume streaming response by the last event id. This is suitable for both request wise and common channel.
827    pub async fn resume(
828        &self,
829        last_event_id: EventId,
830    ) -> Result<StreamableHttpMessageReceiver, SessionError> {
831        let (tx, rx) = tokio::sync::oneshot::channel();
832        self.event_tx
833            .send(SessionEvent::Resume {
834                last_event_id,
835                responder: tx,
836            })
837            .await
838            .map_err(|_| SessionError::SessionServiceTerminated)?;
839        rx.await
840            .map_err(|_| SessionError::SessionServiceTerminated)?
841    }
842
843    /// Send an initialize request to the session. And wait for the initialized response.
844    ///
845    /// This is used to establish a session with the server.
846    pub async fn initialize(
847        &self,
848        request: ClientJsonRpcMessage,
849    ) -> Result<ServerJsonRpcMessage, SessionError> {
850        let (tx, rx) = tokio::sync::oneshot::channel();
851        self.event_tx
852            .send(SessionEvent::InitializeRequest {
853                request,
854                responder: tx,
855            })
856            .await
857            .map_err(|_| SessionError::SessionServiceTerminated)?;
858        rx.await
859            .map_err(|_| SessionError::SessionServiceTerminated)?
860    }
861
862    /// Close an SSE stream for a specific request.
863    ///
864    /// This closes the SSE connection for a POST request stream, but keeps the session
865    /// and message cache active. Clients can reconnect using the `Last-Event-ID` header
866    /// via a GET request to resume receiving messages.
867    ///
868    /// # Arguments
869    ///
870    /// * `http_request_id` - The HTTP request ID of the stream to close
871    /// * `retry_interval` - Optional retry interval. If provided, a priming event is sent
872    pub async fn close_sse_stream(
873        &self,
874        http_request_id: HttpRequestId,
875        retry_interval: Option<Duration>,
876    ) -> Result<(), SessionError> {
877        let (tx, rx) = tokio::sync::oneshot::channel();
878        self.event_tx
879            .send(SessionEvent::CloseSseStream {
880                http_request_id: Some(http_request_id),
881                retry_interval,
882                responder: tx,
883            })
884            .await
885            .map_err(|_| SessionError::SessionServiceTerminated)?;
886        rx.await
887            .map_err(|_| SessionError::SessionServiceTerminated)?
888    }
889
890    /// Close the standalone SSE stream.
891    ///
892    /// This closes the standalone SSE connection (established via GET request),
893    /// but keeps the session and message cache active. Clients can reconnect using
894    /// the `Last-Event-ID` header via a GET request to resume receiving messages.
895    ///
896    /// # Arguments
897    ///
898    /// * `retry_interval` - Optional retry interval. If provided, a priming event is sent
899    pub async fn close_standalone_sse_stream(
900        &self,
901        retry_interval: Option<Duration>,
902    ) -> Result<(), SessionError> {
903        let (tx, rx) = tokio::sync::oneshot::channel();
904        self.event_tx
905            .send(SessionEvent::CloseSseStream {
906                http_request_id: None,
907                retry_interval,
908                responder: tx,
909            })
910            .await
911            .map_err(|_| SessionError::SessionServiceTerminated)?;
912        rx.await
913            .map_err(|_| SessionError::SessionServiceTerminated)?
914    }
915}
916
917pub type SessionTransport = WorkerTransport<LocalSessionWorker>;
918
919#[allow(clippy::large_enum_variant)]
920#[derive(Debug, Error)]
921#[non_exhaustive]
922pub enum LocalSessionWorkerError {
923    #[error("transport terminated")]
924    TransportTerminated,
925    #[error("unexpected message: {0:?}")]
926    UnexpectedEvent(SessionEvent),
927    #[error("fail to send initialize request {0}")]
928    FailToSendInitializeRequest(SessionError),
929    #[error("fail to handle message: {0}")]
930    FailToHandleMessage(SessionError),
931    #[error("keep alive timeout after {}ms", _0.as_millis())]
932    KeepAliveTimeout(Duration),
933    #[error("Transport closed")]
934    TransportClosed,
935    #[error("Tokio join error {0}")]
936    TokioJoinError(#[from] tokio::task::JoinError),
937}
938impl Worker for LocalSessionWorker {
939    type Error = LocalSessionWorkerError;
940    type Role = RoleServer;
941    fn err_closed() -> Self::Error {
942        LocalSessionWorkerError::TransportClosed
943    }
944    fn err_join(e: tokio::task::JoinError) -> Self::Error {
945        LocalSessionWorkerError::TokioJoinError(e)
946    }
947    fn config(&self) -> crate::transport::worker::WorkerConfig {
948        crate::transport::worker::WorkerConfig {
949            name: Some(format!("streamable-http-session-{}", self.id)),
950            channel_buffer_capacity: self.session_config.channel_capacity,
951        }
952    }
953    #[instrument(name = "streamable_http_session", skip_all, fields(id = self.id.as_ref()))]
954    async fn run(
955        mut self,
956        mut context: WorkerContext<Self>,
957    ) -> Result<(), WorkerQuitReason<Self::Error>> {
958        enum InnerEvent {
959            FromHttpService(SessionEvent),
960            FromHandler(WorkerSendRequest<LocalSessionWorker>),
961        }
962        // waiting for initialize request
963        let evt = self.event_rx.recv().await.ok_or_else(|| {
964            WorkerQuitReason::fatal(
965                LocalSessionWorkerError::TransportTerminated,
966                "get initialize request",
967            )
968        })?;
969        let SessionEvent::InitializeRequest { request, responder } = evt else {
970            return Err(WorkerQuitReason::fatal(
971                LocalSessionWorkerError::UnexpectedEvent(evt),
972                "get initialize request",
973            ));
974        };
975        context.send_to_handler(request).await?;
976        let send_initialize_response = context.recv_from_handler().await?;
977        responder
978            .send(Ok(send_initialize_response.message))
979            .map_err(|_| {
980                WorkerQuitReason::fatal(
981                    LocalSessionWorkerError::FailToSendInitializeRequest(
982                        SessionError::SessionServiceTerminated,
983                    ),
984                    "send initialize response",
985                )
986            })?;
987        send_initialize_response
988            .responder
989            .send(Ok(()))
990            .map_err(|_| WorkerQuitReason::HandlerTerminated)?;
991        let ct = context.cancellation_token.clone();
992        let keep_alive = self.session_config.keep_alive.unwrap_or(Duration::MAX);
993        loop {
994            self.evict_expired_channels();
995            let keep_alive_timeout = tokio::time::sleep(keep_alive);
996            let event = tokio::select! {
997                event = self.event_rx.recv() => {
998                    if let Some(event) = event {
999                        InnerEvent::FromHttpService(event)
1000                    } else {
1001                        return Err(WorkerQuitReason::fatal(LocalSessionWorkerError::TransportTerminated, "waiting next session event"))
1002                    }
1003                },
1004                from_handler = context.recv_from_handler() => {
1005                    InnerEvent::FromHandler(from_handler?)
1006                }
1007                _ = ct.cancelled() => {
1008                    return Err(WorkerQuitReason::Cancelled)
1009                }
1010                _ = keep_alive_timeout => {
1011                    return Err(WorkerQuitReason::fatal(LocalSessionWorkerError::KeepAliveTimeout(keep_alive), "poll next session event"))
1012                }
1013            };
1014            match event {
1015                InnerEvent::FromHandler(WorkerSendRequest { message, responder }) => {
1016                    // catch response
1017                    let to_unregister = match &message {
1018                        crate::model::JsonRpcMessage::Response(json_rpc_response) => {
1019                            let request_id = json_rpc_response.id.clone();
1020                            Some(ResourceKey::McpRequestId(request_id))
1021                        }
1022                        crate::model::JsonRpcMessage::Error(json_rpc_error) => {
1023                            let request_id = json_rpc_error.id.clone();
1024                            Some(ResourceKey::McpRequestId(request_id))
1025                        }
1026                        _ => {
1027                            None
1028                            // no need to unregister resource
1029                        }
1030                    };
1031                    let handle_result = self
1032                        .handle_server_message(message)
1033                        .await
1034                        .map_err(LocalSessionWorkerError::FailToHandleMessage);
1035                    let _ = responder.send(handle_result).inspect_err(|error| {
1036                        tracing::warn!(?error, "failed to send message to http service handler");
1037                    });
1038                    if let Some(to_unregister) = to_unregister {
1039                        self.unregister_resource(&to_unregister);
1040                    }
1041                }
1042                InnerEvent::FromHttpService(SessionEvent::ClientMessage {
1043                    message: json_rpc_message,
1044                    http_request_id,
1045                }) => {
1046                    match &json_rpc_message {
1047                        crate::model::JsonRpcMessage::Request(request) => {
1048                            if let Some(http_request_id) = http_request_id {
1049                                self.register_request(request, http_request_id)
1050                            }
1051                        }
1052                        crate::model::JsonRpcMessage::Notification(notification) => {
1053                            self.catch_cancellation_notification(notification)
1054                        }
1055                        _ => {}
1056                    }
1057                    context.send_to_handler(json_rpc_message).await?;
1058                }
1059                InnerEvent::FromHttpService(SessionEvent::EstablishRequestWiseChannel {
1060                    responder,
1061                }) => {
1062                    let handle_result = self.establish_request_wise_channel().await;
1063                    let _ = responder.send(handle_result);
1064                }
1065                InnerEvent::FromHttpService(SessionEvent::CloseRequestWiseChannel {
1066                    id,
1067                    responder,
1068                }) => {
1069                    let _handle_result = self.tx_router.remove(&id);
1070                    let _ = responder.send(Ok(()));
1071                }
1072                InnerEvent::FromHttpService(SessionEvent::Resume {
1073                    last_event_id,
1074                    responder,
1075                }) => {
1076                    let handle_result = self.resume(last_event_id).await;
1077                    let _ = responder.send(handle_result);
1078                }
1079                InnerEvent::FromHttpService(SessionEvent::Close) => {
1080                    return Err(WorkerQuitReason::TransportClosed);
1081                }
1082                InnerEvent::FromHttpService(SessionEvent::CloseSseStream {
1083                    http_request_id,
1084                    retry_interval,
1085                    responder,
1086                }) => {
1087                    let handle_result =
1088                        self.close_sse_stream(http_request_id, retry_interval).await;
1089                    let _ = responder.send(handle_result);
1090                }
1091                _ => {
1092                    // ignore
1093                }
1094            }
1095        }
1096    }
1097}
1098
1099#[derive(Debug, Clone)]
1100#[non_exhaustive]
1101pub struct SessionConfig {
1102    /// the capacity of the channel for the session. Default is 16.
1103    pub channel_capacity: usize,
1104    /// The session will be closed after this duration of inactivity.
1105    ///
1106    /// This serves as a safety net for cleaning up sessions whose HTTP
1107    /// connections have silently dropped (e.g., due to an HTTP/2
1108    /// `RST_STREAM`). Without a timeout, such sessions become zombies:
1109    /// the session worker keeps running indefinitely because the session
1110    /// handle's sender is still held in the session manager, preventing
1111    /// the worker's event channel from closing.
1112    ///
1113    /// Defaults to 5 minutes. Set to `None` to disable (not recommended
1114    /// for long-running servers behind proxies).
1115    pub keep_alive: Option<Duration>,
1116    /// SSE retry interval for priming events on request-wise streams.
1117    /// When set, the session layer prepends a priming event with the correct
1118    /// stream-identifying event ID to each request-wise SSE stream.
1119    /// Default is 3 seconds, matching `StreamableHttpServerConfig::default()`.
1120    pub sse_retry: Option<Duration>,
1121    /// How long to retain completed request-wise channel caches for late
1122    /// resume requests. After this duration, completed entries are evicted
1123    /// and resume will return an error. Default is 60 seconds.
1124    pub completed_cache_ttl: Duration,
1125}
1126
1127impl SessionConfig {
1128    pub const DEFAULT_CHANNEL_CAPACITY: usize = 16;
1129    pub const DEFAULT_KEEP_ALIVE: Duration = Duration::from_secs(300);
1130    pub const DEFAULT_SSE_RETRY: Duration = Duration::from_secs(3);
1131    pub const DEFAULT_COMPLETED_CACHE_TTL: Duration = Duration::from_secs(60);
1132}
1133
1134impl Default for SessionConfig {
1135    fn default() -> Self {
1136        Self {
1137            channel_capacity: Self::DEFAULT_CHANNEL_CAPACITY,
1138            keep_alive: Some(Self::DEFAULT_KEEP_ALIVE),
1139            sse_retry: Some(Self::DEFAULT_SSE_RETRY),
1140            completed_cache_ttl: Self::DEFAULT_COMPLETED_CACHE_TTL,
1141        }
1142    }
1143}
1144
1145/// Create a new session with the given id and configuration.
1146///
1147/// This function will return a pair of [`LocalSessionHandle`] and [`LocalSessionWorker`].
1148///
1149/// You can run the [`LocalSessionWorker`] as a transport for mcp server. And use the [`LocalSessionHandle`] operate the session.
1150pub fn create_local_session(
1151    id: impl Into<SessionId>,
1152    config: SessionConfig,
1153) -> (LocalSessionHandle, LocalSessionWorker) {
1154    let id = id.into();
1155    let (event_tx, event_rx) = tokio::sync::mpsc::channel(config.channel_capacity);
1156    let (common_tx, _) = tokio::sync::mpsc::channel(config.channel_capacity);
1157    let common = CachedTx::new_common(common_tx);
1158    tracing::info!(session_id = ?id, "create new session");
1159    let handle = LocalSessionHandle {
1160        event_tx,
1161        id: id.clone(),
1162    };
1163    let session_worker = LocalSessionWorker {
1164        next_http_request_id: 0,
1165        id,
1166        tx_router: HashMap::new(),
1167        resource_router: HashMap::new(),
1168        common,
1169        shadow_txs: Vec::new(),
1170        event_rx,
1171        session_config: config.clone(),
1172    };
1173    (handle, session_worker)
1174}