Skip to main content

rmcp_soddygo/transport/streamable_http_server/session/
local.rs

1use std::{
2    collections::{HashMap, HashSet, VecDeque},
3    num::ParseIntError,
4    time::Duration,
5};
6
7use futures::Stream;
8use thiserror::Error;
9use tokio::sync::{
10    mpsc::{Receiver, Sender},
11    oneshot,
12};
13use tokio_stream::wrappers::ReceiverStream;
14use tracing::instrument;
15
16use crate::{
17    RoleServer,
18    model::{
19        CancelledNotificationParam, ClientJsonRpcMessage, ClientNotification, ClientRequest,
20        JsonRpcNotification, JsonRpcRequest, Notification, ProgressNotificationParam,
21        ProgressToken, RequestId, ServerJsonRpcMessage, ServerNotification,
22    },
23    transport::{
24        WorkerTransport,
25        common::server_side_http::{SessionId, session_id},
26        worker::{Worker, WorkerContext, WorkerQuitReason, WorkerSendRequest},
27    },
28};
29
30#[derive(Debug, Default)]
31#[non_exhaustive]
32pub struct LocalSessionManager {
33    pub sessions: tokio::sync::RwLock<HashMap<SessionId, LocalSessionHandle>>,
34    pub session_config: SessionConfig,
35}
36
37#[derive(Debug, Error)]
38#[non_exhaustive]
39pub enum LocalSessionManagerError {
40    #[error("Session not found: {0}")]
41    SessionNotFound(SessionId),
42    #[error("Session error: {0}")]
43    SessionError(#[from] SessionError),
44    #[error("Invalid event id: {0}")]
45    InvalidEventId(#[from] EventIdParseError),
46}
47impl SessionManager for LocalSessionManager {
48    type Error = LocalSessionManagerError;
49    type Transport = WorkerTransport<LocalSessionWorker>;
50    async fn create_session(&self) -> Result<(SessionId, Self::Transport), Self::Error> {
51        let id = session_id();
52        let (handle, worker) = create_local_session(id.clone(), self.session_config.clone());
53        self.sessions.write().await.insert(id.clone(), handle);
54        Ok((id, WorkerTransport::spawn(worker)))
55    }
56    async fn initialize_session(
57        &self,
58        id: &SessionId,
59        message: ClientJsonRpcMessage,
60    ) -> Result<ServerJsonRpcMessage, Self::Error> {
61        let sessions = self.sessions.read().await;
62        let handle = sessions
63            .get(id)
64            .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
65        let response = handle.initialize(message).await?;
66        Ok(response)
67    }
68    async fn close_session(&self, id: &SessionId) -> Result<(), Self::Error> {
69        let mut sessions = self.sessions.write().await;
70        if let Some(handle) = sessions.remove(id) {
71            handle.close().await?;
72        }
73        Ok(())
74    }
75    async fn has_session(&self, id: &SessionId) -> Result<bool, Self::Error> {
76        let sessions = self.sessions.read().await;
77        Ok(sessions.contains_key(id))
78    }
79    async fn create_stream(
80        &self,
81        id: &SessionId,
82        message: ClientJsonRpcMessage,
83    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
84        let sessions = self.sessions.read().await;
85        let handle = sessions
86            .get(id)
87            .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
88        let receiver = handle.establish_request_wise_channel().await?;
89        handle
90            .push_message(message, receiver.http_request_id)
91            .await?;
92        Ok(ReceiverStream::new(receiver.inner))
93    }
94
95    async fn create_standalone_stream(
96        &self,
97        id: &SessionId,
98    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
99        let sessions = self.sessions.read().await;
100        let handle = sessions
101            .get(id)
102            .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
103        let receiver = handle.establish_common_channel().await?;
104        Ok(ReceiverStream::new(receiver.inner))
105    }
106
107    async fn resume(
108        &self,
109        id: &SessionId,
110        last_event_id: String,
111    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
112        let sessions = self.sessions.read().await;
113        let handle = sessions
114            .get(id)
115            .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
116        let receiver = handle.resume(last_event_id.parse()?).await?;
117        Ok(ReceiverStream::new(receiver.inner))
118    }
119
120    async fn accept_message(
121        &self,
122        id: &SessionId,
123        message: ClientJsonRpcMessage,
124    ) -> Result<(), Self::Error> {
125        let sessions = self.sessions.read().await;
126        let handle = sessions
127            .get(id)
128            .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
129        handle.push_message(message, None).await?;
130        Ok(())
131    }
132}
133
134/// `<index>/request_id>`
135#[derive(Debug, Clone, PartialEq, Eq, Hash)]
136pub struct EventId {
137    http_request_id: Option<HttpRequestId>,
138    index: usize,
139}
140
141impl std::fmt::Display for EventId {
142    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143        write!(f, "{}", self.index)?;
144        match &self.http_request_id {
145            Some(http_request_id) => write!(f, "/{http_request_id}"),
146            None => write!(f, ""),
147        }
148    }
149}
150
151#[derive(Debug, Clone, Error)]
152#[non_exhaustive]
153pub enum EventIdParseError {
154    #[error("Invalid index: {0}")]
155    InvalidIndex(ParseIntError),
156    #[error("Invalid numeric request id: {0}")]
157    InvalidNumericRequestId(ParseIntError),
158    #[error("Missing request id type")]
159    InvalidRequestIdType,
160    #[error("Missing request id")]
161    MissingRequestId,
162}
163
164impl std::str::FromStr for EventId {
165    type Err = EventIdParseError;
166    fn from_str(s: &str) -> Result<Self, Self::Err> {
167        if let Some((index, request_id)) = s.split_once("/") {
168            let index = usize::from_str(index).map_err(EventIdParseError::InvalidIndex)?;
169            let request_id = u64::from_str(request_id).map_err(EventIdParseError::InvalidIndex)?;
170            Ok(EventId {
171                http_request_id: Some(request_id),
172                index,
173            })
174        } else {
175            let index = usize::from_str(s).map_err(EventIdParseError::InvalidIndex)?;
176            Ok(EventId {
177                http_request_id: None,
178                index,
179            })
180        }
181    }
182}
183
184use super::{ServerSseMessage, SessionManager};
185
186struct CachedTx {
187    tx: Sender<ServerSseMessage>,
188    cache: VecDeque<ServerSseMessage>,
189    http_request_id: Option<HttpRequestId>,
190    capacity: usize,
191}
192
193impl CachedTx {
194    fn new(tx: Sender<ServerSseMessage>, http_request_id: Option<HttpRequestId>) -> Self {
195        Self {
196            cache: VecDeque::with_capacity(tx.capacity()),
197            capacity: tx.capacity(),
198            tx,
199            http_request_id,
200        }
201    }
202    fn new_common(tx: Sender<ServerSseMessage>) -> Self {
203        Self::new(tx, None)
204    }
205
206    fn next_event_id(&self) -> EventId {
207        let index = self.cache.back().map_or(0, |m| {
208            m.event_id
209                .as_deref()
210                .unwrap_or_default()
211                .parse::<EventId>()
212                .expect("valid event id")
213                .index
214                + 1
215        });
216        EventId {
217            http_request_id: self.http_request_id,
218            index,
219        }
220    }
221
222    async fn send(&mut self, message: ServerJsonRpcMessage) {
223        let event_id = self.next_event_id();
224        let message = ServerSseMessage::new(event_id.to_string(), message);
225        self.cache_and_send(message).await;
226    }
227
228    async fn send_priming(&mut self, retry: Duration) {
229        let event_id = self.next_event_id();
230        let message = ServerSseMessage::priming(event_id.to_string(), retry);
231        self.cache_and_send(message).await;
232    }
233
234    async fn cache_and_send(&mut self, message: ServerSseMessage) {
235        if self.cache.len() >= self.capacity {
236            self.cache.pop_front();
237            self.cache.push_back(message.clone());
238        } else {
239            self.cache.push_back(message.clone());
240        }
241        let _ = self.tx.send(message).await.inspect_err(|e| {
242            let event_id = &e.0.event_id;
243            tracing::trace!(?event_id, "trying to send message in a closed session")
244        });
245    }
246
247    async fn sync(&mut self, index: usize) -> Result<(), SessionError> {
248        let Some(front) = self.cache.front() else {
249            return Ok(());
250        };
251        let front_event_id = front
252            .event_id
253            .as_deref()
254            .unwrap_or_default()
255            .parse::<EventId>()?;
256        let sync_index = index.saturating_sub(front_event_id.index);
257        if sync_index > self.cache.len() {
258            // invalid index
259            return Err(SessionError::InvalidEventId);
260        }
261        for message in self.cache.iter().skip(sync_index) {
262            let send_result = self.tx.send(message.clone()).await;
263            if send_result.is_err() {
264                let event_id: EventId = message.event_id.as_deref().unwrap_or_default().parse()?;
265                return Err(SessionError::ChannelClosed(Some(event_id.index as u64)));
266            }
267        }
268        Ok(())
269    }
270}
271
272struct HttpRequestWise {
273    resources: HashSet<ResourceKey>,
274    tx: CachedTx,
275}
276
277type HttpRequestId = u64;
278#[derive(Debug, Clone, Hash, PartialEq, Eq)]
279enum ResourceKey {
280    McpRequestId(RequestId),
281    ProgressToken(ProgressToken),
282}
283
284pub struct LocalSessionWorker {
285    id: SessionId,
286    next_http_request_id: HttpRequestId,
287    tx_router: HashMap<HttpRequestId, HttpRequestWise>,
288    resource_router: HashMap<ResourceKey, HttpRequestId>,
289    common: CachedTx,
290    /// Shadow senders for secondary SSE streams (e.g. from POST EventSource
291    /// reconnections). These keep the HTTP connections alive via SSE keep-alive
292    /// without receiving notifications, preventing MCP clients from entering
293    /// infinite reconnect loops when multiple EventSource connections compete
294    /// to replace the common channel.
295    shadow_txs: Vec<Sender<ServerSseMessage>>,
296    event_rx: Receiver<SessionEvent>,
297    session_config: SessionConfig,
298}
299
300impl LocalSessionWorker {
301    pub fn id(&self) -> &SessionId {
302        &self.id
303    }
304}
305
306#[derive(Debug, Error)]
307#[non_exhaustive]
308pub enum SessionError {
309    #[error("Invalid request id: {0}")]
310    DuplicatedRequestId(HttpRequestId),
311    #[error("Channel closed: {0:?}")]
312    ChannelClosed(Option<HttpRequestId>),
313    #[error("Cannot parse event id: {0}")]
314    EventIdParseError(#[from] EventIdParseError),
315    #[error("Session service terminated")]
316    SessionServiceTerminated,
317    #[error("Invalid event id")]
318    InvalidEventId,
319    #[error("IO error: {0}")]
320    Io(#[from] std::io::Error),
321}
322
323impl From<SessionError> for std::io::Error {
324    fn from(value: SessionError) -> Self {
325        match value {
326            SessionError::Io(io) => io,
327            _ => std::io::Error::other(format!("Session error: {value}")),
328        }
329    }
330}
331
332enum OutboundChannel {
333    RequestWise { id: HttpRequestId, close: bool },
334    Common,
335}
336#[derive(Debug)]
337#[non_exhaustive]
338pub struct StreamableHttpMessageReceiver {
339    pub http_request_id: Option<HttpRequestId>,
340    pub inner: Receiver<ServerSseMessage>,
341}
342
343impl LocalSessionWorker {
344    fn unregister_resource(&mut self, resource: &ResourceKey) {
345        if let Some(http_request_id) = self.resource_router.remove(resource) {
346            tracing::trace!(?resource, http_request_id, "unregister resource");
347            if let Some(channel) = self.tx_router.get_mut(&http_request_id) {
348                // It's okey to do so, since we don't handle batch json rpc request anymore
349                // and this can be refactored after the batch request is removed in the coming version.
350                if channel.resources.is_empty() || matches!(resource, ResourceKey::McpRequestId(_))
351                {
352                    tracing::debug!(http_request_id, "close http request wise channel");
353                    if let Some(channel) = self.tx_router.remove(&http_request_id) {
354                        for resource in channel.resources {
355                            self.resource_router.remove(&resource);
356                        }
357                    }
358                }
359            } else {
360                tracing::warn!(http_request_id, "http request wise channel not found");
361            }
362        }
363    }
364    fn register_resource(&mut self, resource: ResourceKey, http_request_id: HttpRequestId) {
365        tracing::trace!(?resource, http_request_id, "register resource");
366        if let Some(channel) = self.tx_router.get_mut(&http_request_id) {
367            channel.resources.insert(resource.clone());
368            self.resource_router.insert(resource, http_request_id);
369        }
370    }
371    fn register_request(
372        &mut self,
373        request: &JsonRpcRequest<ClientRequest>,
374        http_request_id: HttpRequestId,
375    ) {
376        use crate::model::GetMeta;
377        self.register_resource(
378            ResourceKey::McpRequestId(request.id.clone()),
379            http_request_id,
380        );
381        if let Some(progress_token) = request.request.get_meta().get_progress_token() {
382            self.register_resource(
383                ResourceKey::ProgressToken(progress_token.clone()),
384                http_request_id,
385            );
386        }
387    }
388    fn catch_cancellation_notification(
389        &mut self,
390        notification: &JsonRpcNotification<ClientNotification>,
391    ) {
392        if let ClientNotification::CancelledNotification(n) = &notification.notification {
393            let request_id = n.params.request_id.clone();
394            let resource = ResourceKey::McpRequestId(request_id);
395            self.unregister_resource(&resource);
396        }
397    }
398    fn next_http_request_id(&mut self) -> HttpRequestId {
399        let id = self.next_http_request_id;
400        self.next_http_request_id = self.next_http_request_id.wrapping_add(1);
401        id
402    }
403    async fn establish_request_wise_channel(
404        &mut self,
405    ) -> Result<StreamableHttpMessageReceiver, SessionError> {
406        let http_request_id = self.next_http_request_id();
407        let (tx, rx) = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
408        self.tx_router.insert(
409            http_request_id,
410            HttpRequestWise {
411                resources: Default::default(),
412                tx: CachedTx::new(tx, Some(http_request_id)),
413            },
414        );
415        tracing::debug!(http_request_id, "establish new request wise channel");
416        Ok(StreamableHttpMessageReceiver {
417            http_request_id: Some(http_request_id),
418            inner: rx,
419        })
420    }
421    fn resolve_outbound_channel(&self, message: &ServerJsonRpcMessage) -> OutboundChannel {
422        match &message {
423            ServerJsonRpcMessage::Request(_) => OutboundChannel::Common,
424            ServerJsonRpcMessage::Notification(JsonRpcNotification {
425                notification:
426                    ServerNotification::ProgressNotification(Notification {
427                        params: ProgressNotificationParam { progress_token, .. },
428                        ..
429                    }),
430                ..
431            }) => {
432                let id = self
433                    .resource_router
434                    .get(&ResourceKey::ProgressToken(progress_token.clone()));
435
436                if let Some(id) = id {
437                    OutboundChannel::RequestWise {
438                        id: *id,
439                        close: false,
440                    }
441                } else {
442                    OutboundChannel::Common
443                }
444            }
445            ServerJsonRpcMessage::Notification(JsonRpcNotification {
446                notification:
447                    ServerNotification::CancelledNotification(Notification {
448                        params: CancelledNotificationParam { request_id, .. },
449                        ..
450                    }),
451                ..
452            }) => {
453                if let Some(id) = self
454                    .resource_router
455                    .get(&ResourceKey::McpRequestId(request_id.clone()))
456                {
457                    OutboundChannel::RequestWise {
458                        id: *id,
459                        close: false,
460                    }
461                } else {
462                    OutboundChannel::Common
463                }
464            }
465            ServerJsonRpcMessage::Notification(_) => OutboundChannel::Common,
466            ServerJsonRpcMessage::Response(json_rpc_response) => {
467                if let Some(id) = self
468                    .resource_router
469                    .get(&ResourceKey::McpRequestId(json_rpc_response.id.clone()))
470                {
471                    OutboundChannel::RequestWise {
472                        id: *id,
473                        close: true,
474                    }
475                } else {
476                    OutboundChannel::Common
477                }
478            }
479            ServerJsonRpcMessage::Error(json_rpc_error) => {
480                if let Some(id) = self
481                    .resource_router
482                    .get(&ResourceKey::McpRequestId(json_rpc_error.id.clone()))
483                {
484                    OutboundChannel::RequestWise {
485                        id: *id,
486                        close: true,
487                    }
488                } else {
489                    OutboundChannel::Common
490                }
491            }
492        }
493    }
494    async fn handle_server_message(
495        &mut self,
496        message: ServerJsonRpcMessage,
497    ) -> Result<(), SessionError> {
498        let outbound_channel = self.resolve_outbound_channel(&message);
499        match outbound_channel {
500            OutboundChannel::RequestWise { id, close } => {
501                if let Some(request_wise) = self.tx_router.get_mut(&id) {
502                    request_wise.tx.send(message).await;
503                    if close {
504                        if let Some(channel) = self.tx_router.remove(&id) {
505                            for resource in channel.resources {
506                                self.resource_router.remove(&resource);
507                            }
508                        }
509                    }
510                } else {
511                    return Err(SessionError::ChannelClosed(Some(id)));
512                }
513            }
514            OutboundChannel::Common => self.common.send(message).await,
515        }
516        Ok(())
517    }
518    async fn resume(
519        &mut self,
520        last_event_id: EventId,
521    ) -> Result<StreamableHttpMessageReceiver, SessionError> {
522        // Clean up closed shadow senders before processing
523        self.shadow_txs.retain(|tx| !tx.is_closed());
524
525        match last_event_id.http_request_id {
526            Some(http_request_id) => {
527                if let Some(request_wise) = self.tx_router.get_mut(&http_request_id) {
528                    // Resume existing request-wise channel
529                    let channel = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
530                    let (tx, rx) = channel;
531                    request_wise.tx.tx = tx;
532                    let index = last_event_id.index;
533                    // sync messages after index
534                    request_wise.tx.sync(index).await?;
535                    Ok(StreamableHttpMessageReceiver {
536                        http_request_id: Some(http_request_id),
537                        inner: rx,
538                    })
539                } else {
540                    // Request-wise channel completed (POST response already delivered).
541                    // The client's EventSource is reconnecting after the POST SSE stream
542                    // ended. Fall through to common channel handling below.
543                    tracing::debug!(
544                        http_request_id,
545                        "Request-wise channel completed, falling back to common channel"
546                    );
547                    self.resume_or_shadow_common(last_event_id.index).await
548                }
549            }
550            None => self.resume_or_shadow_common(last_event_id.index).await,
551        }
552    }
553
554    /// Resume the common channel, or create a shadow stream if the primary is
555    /// still active.
556    ///
557    /// When the primary common channel is dead (receiver dropped), replace it
558    /// so this stream becomes the new primary notification channel. Cached
559    /// messages are replayed from `last_event_index` so the client receives
560    /// any events it missed (including server-initiated requests).
561    ///
562    /// When the primary is still active, create a "shadow" stream — an idle SSE
563    /// connection kept alive by keep-alive pings. This prevents multiple
564    /// EventSource connections (e.g. from POST response reconnections) from
565    /// killing each other by repeatedly replacing the common channel sender.
566    async fn resume_or_shadow_common(
567        &mut self,
568        last_event_index: usize,
569    ) -> Result<StreamableHttpMessageReceiver, SessionError> {
570        let is_replacing_dead_primary = self.common.tx.is_closed();
571        let capacity = if is_replacing_dead_primary {
572            self.session_config.channel_capacity
573        } else {
574            1 // Shadow streams only need keep-alive pings
575        };
576        let (tx, rx) = tokio::sync::mpsc::channel(capacity);
577        if is_replacing_dead_primary {
578            // Primary common channel is dead — replace it.
579            tracing::debug!("Replacing dead common channel with new primary");
580            self.common.tx = tx;
581            // Replay cached messages from where the client left off so
582            // server-initiated requests and notifications are not lost.
583            self.common.sync(last_event_index).await?;
584        } else {
585            // Primary common channel is still active. Create a shadow stream
586            // that stays alive via SSE keep-alive but doesn't receive
587            // notifications. This prevents competing EventSource connections
588            // from killing each other's channels.
589            const MAX_SHADOW_STREAMS: usize = 32;
590
591            if self.shadow_txs.len() >= MAX_SHADOW_STREAMS {
592                tracing::warn!(
593                    shadow_count = self.shadow_txs.len(),
594                    "Shadow stream limit reached, dropping oldest"
595                );
596                self.shadow_txs.remove(0);
597            }
598            tracing::debug!(
599                shadow_count = self.shadow_txs.len(),
600                "Common channel active, creating shadow stream"
601            );
602            self.shadow_txs.push(tx);
603        }
604        Ok(StreamableHttpMessageReceiver {
605            http_request_id: None,
606            inner: rx,
607        })
608    }
609
610    async fn close_sse_stream(
611        &mut self,
612        http_request_id: Option<HttpRequestId>,
613        retry_interval: Option<Duration>,
614    ) -> Result<(), SessionError> {
615        match http_request_id {
616            // Close a request-wise stream
617            Some(id) => {
618                let request_wise = self
619                    .tx_router
620                    .get_mut(&id)
621                    .ok_or(SessionError::ChannelClosed(Some(id)))?;
622
623                // Send priming event if retry interval is specified
624                if let Some(interval) = retry_interval {
625                    request_wise.tx.send_priming(interval).await;
626                }
627
628                // Close the stream by dropping the sender
629                let (tx, _rx) = tokio::sync::mpsc::channel(1);
630                request_wise.tx.tx = tx;
631
632                tracing::debug!(
633                    http_request_id = id,
634                    "closed SSE stream for server-initiated disconnection"
635                );
636                Ok(())
637            }
638            // Close the standalone (common) stream
639            None => {
640                // Send priming event if retry interval is specified
641                if let Some(interval) = retry_interval {
642                    self.common.send_priming(interval).await;
643                }
644
645                // Close the stream by dropping the sender
646                let (tx, _rx) = tokio::sync::mpsc::channel(1);
647                self.common.tx = tx;
648
649                // Also close all shadow streams
650                self.shadow_txs.clear();
651
652                tracing::debug!("closed standalone SSE stream for server-initiated disconnection");
653                Ok(())
654            }
655        }
656    }
657}
658
659#[derive(Debug)]
660#[non_exhaustive]
661pub enum SessionEvent {
662    ClientMessage {
663        message: ClientJsonRpcMessage,
664        http_request_id: Option<HttpRequestId>,
665    },
666    EstablishRequestWiseChannel {
667        responder: oneshot::Sender<Result<StreamableHttpMessageReceiver, SessionError>>,
668    },
669    CloseRequestWiseChannel {
670        id: HttpRequestId,
671        responder: oneshot::Sender<Result<(), SessionError>>,
672    },
673    Resume {
674        last_event_id: EventId,
675        responder: oneshot::Sender<Result<StreamableHttpMessageReceiver, SessionError>>,
676    },
677    InitializeRequest {
678        request: ClientJsonRpcMessage,
679        responder: oneshot::Sender<Result<ServerJsonRpcMessage, SessionError>>,
680    },
681    Close,
682    CloseSseStream {
683        /// The HTTP request ID to close. If `None`, closes the standalone (common) stream.
684        http_request_id: Option<HttpRequestId>,
685        /// Optional retry interval. If provided, a priming event is sent before closing.
686        retry_interval: Option<Duration>,
687        responder: oneshot::Sender<Result<(), SessionError>>,
688    },
689}
690
691#[derive(Debug, Clone)]
692#[non_exhaustive]
693pub enum SessionQuitReason {
694    ServiceTerminated,
695    ClientTerminated,
696    ExpectInitializeRequest,
697    ExpectInitializeResponse,
698    Cancelled,
699}
700
701#[derive(Debug, Clone)]
702pub struct LocalSessionHandle {
703    id: SessionId,
704    // after all event_tx drop, inner task will be terminated
705    event_tx: Sender<SessionEvent>,
706}
707
708impl LocalSessionHandle {
709    /// Get the session id
710    pub fn id(&self) -> &SessionId {
711        &self.id
712    }
713
714    /// Close the session
715    pub async fn close(&self) -> Result<(), SessionError> {
716        self.event_tx
717            .send(SessionEvent::Close)
718            .await
719            .map_err(|_| SessionError::SessionServiceTerminated)?;
720        Ok(())
721    }
722
723    /// Send a message to the session
724    pub async fn push_message(
725        &self,
726        message: ClientJsonRpcMessage,
727        http_request_id: Option<HttpRequestId>,
728    ) -> Result<(), SessionError> {
729        self.event_tx
730            .send(SessionEvent::ClientMessage {
731                message,
732                http_request_id,
733            })
734            .await
735            .map_err(|_| SessionError::SessionServiceTerminated)?;
736        Ok(())
737    }
738
739    /// establish a channel for a http-request, the corresponded message from server will be
740    /// sent through this channel. The channel will be closed when the request is completed,
741    /// or you can close it manually by calling [`LocalSessionHandle::close_request_wise_channel`].
742    pub async fn establish_request_wise_channel(
743        &self,
744    ) -> Result<StreamableHttpMessageReceiver, SessionError> {
745        let (tx, rx) = tokio::sync::oneshot::channel();
746        self.event_tx
747            .send(SessionEvent::EstablishRequestWiseChannel { responder: tx })
748            .await
749            .map_err(|_| SessionError::SessionServiceTerminated)?;
750        rx.await
751            .map_err(|_| SessionError::SessionServiceTerminated)?
752    }
753
754    /// close the http-request wise channel.
755    pub async fn close_request_wise_channel(
756        &self,
757        request_id: HttpRequestId,
758    ) -> Result<(), SessionError> {
759        let (tx, rx) = tokio::sync::oneshot::channel();
760        self.event_tx
761            .send(SessionEvent::CloseRequestWiseChannel {
762                id: request_id,
763                responder: tx,
764            })
765            .await
766            .map_err(|_| SessionError::SessionServiceTerminated)?;
767        rx.await
768            .map_err(|_| SessionError::SessionServiceTerminated)?
769    }
770
771    /// Establish a common channel for general purpose messages.
772    pub async fn establish_common_channel(
773        &self,
774    ) -> Result<StreamableHttpMessageReceiver, SessionError> {
775        let (tx, rx) = tokio::sync::oneshot::channel();
776        self.event_tx
777            .send(SessionEvent::Resume {
778                last_event_id: EventId {
779                    http_request_id: None,
780                    index: 0,
781                },
782                responder: tx,
783            })
784            .await
785            .map_err(|_| SessionError::SessionServiceTerminated)?;
786        rx.await
787            .map_err(|_| SessionError::SessionServiceTerminated)?
788    }
789
790    /// Resume streaming response by the last event id. This is suitable for both request wise and common channel.
791    pub async fn resume(
792        &self,
793        last_event_id: EventId,
794    ) -> Result<StreamableHttpMessageReceiver, SessionError> {
795        let (tx, rx) = tokio::sync::oneshot::channel();
796        self.event_tx
797            .send(SessionEvent::Resume {
798                last_event_id,
799                responder: tx,
800            })
801            .await
802            .map_err(|_| SessionError::SessionServiceTerminated)?;
803        rx.await
804            .map_err(|_| SessionError::SessionServiceTerminated)?
805    }
806
807    /// Send an initialize request to the session. And wait for the initialized response.
808    ///
809    /// This is used to establish a session with the server.
810    pub async fn initialize(
811        &self,
812        request: ClientJsonRpcMessage,
813    ) -> Result<ServerJsonRpcMessage, SessionError> {
814        let (tx, rx) = tokio::sync::oneshot::channel();
815        self.event_tx
816            .send(SessionEvent::InitializeRequest {
817                request,
818                responder: tx,
819            })
820            .await
821            .map_err(|_| SessionError::SessionServiceTerminated)?;
822        rx.await
823            .map_err(|_| SessionError::SessionServiceTerminated)?
824    }
825
826    /// Close an SSE stream for a specific request.
827    ///
828    /// This closes the SSE connection for a POST request stream, but keeps the session
829    /// and message cache active. Clients can reconnect using the `Last-Event-ID` header
830    /// via a GET request to resume receiving messages.
831    ///
832    /// # Arguments
833    ///
834    /// * `http_request_id` - The HTTP request ID of the stream to close
835    /// * `retry_interval` - Optional retry interval. If provided, a priming event is sent
836    pub async fn close_sse_stream(
837        &self,
838        http_request_id: HttpRequestId,
839        retry_interval: Option<Duration>,
840    ) -> Result<(), SessionError> {
841        let (tx, rx) = tokio::sync::oneshot::channel();
842        self.event_tx
843            .send(SessionEvent::CloseSseStream {
844                http_request_id: Some(http_request_id),
845                retry_interval,
846                responder: tx,
847            })
848            .await
849            .map_err(|_| SessionError::SessionServiceTerminated)?;
850        rx.await
851            .map_err(|_| SessionError::SessionServiceTerminated)?
852    }
853
854    /// Close the standalone SSE stream.
855    ///
856    /// This closes the standalone SSE connection (established via GET request),
857    /// but keeps the session and message cache active. Clients can reconnect using
858    /// the `Last-Event-ID` header via a GET request to resume receiving messages.
859    ///
860    /// # Arguments
861    ///
862    /// * `retry_interval` - Optional retry interval. If provided, a priming event is sent
863    pub async fn close_standalone_sse_stream(
864        &self,
865        retry_interval: Option<Duration>,
866    ) -> Result<(), SessionError> {
867        let (tx, rx) = tokio::sync::oneshot::channel();
868        self.event_tx
869            .send(SessionEvent::CloseSseStream {
870                http_request_id: None,
871                retry_interval,
872                responder: tx,
873            })
874            .await
875            .map_err(|_| SessionError::SessionServiceTerminated)?;
876        rx.await
877            .map_err(|_| SessionError::SessionServiceTerminated)?
878    }
879}
880
881pub type SessionTransport = WorkerTransport<LocalSessionWorker>;
882
883#[allow(clippy::large_enum_variant)]
884#[derive(Debug, Error)]
885#[non_exhaustive]
886pub enum LocalSessionWorkerError {
887    #[error("transport terminated")]
888    TransportTerminated,
889    #[error("unexpected message: {0:?}")]
890    UnexpectedEvent(SessionEvent),
891    #[error("fail to send initialize request {0}")]
892    FailToSendInitializeRequest(SessionError),
893    #[error("fail to handle message: {0}")]
894    FailToHandleMessage(SessionError),
895    #[error("keep alive timeout after {}ms", _0.as_millis())]
896    KeepAliveTimeout(Duration),
897    #[error("Transport closed")]
898    TransportClosed,
899    #[error("Tokio join error {0}")]
900    TokioJoinError(#[from] tokio::task::JoinError),
901}
902impl Worker for LocalSessionWorker {
903    type Error = LocalSessionWorkerError;
904    type Role = RoleServer;
905    fn err_closed() -> Self::Error {
906        LocalSessionWorkerError::TransportClosed
907    }
908    fn err_join(e: tokio::task::JoinError) -> Self::Error {
909        LocalSessionWorkerError::TokioJoinError(e)
910    }
911    fn config(&self) -> crate::transport::worker::WorkerConfig {
912        crate::transport::worker::WorkerConfig {
913            name: Some(format!("streamable-http-session-{}", self.id)),
914            channel_buffer_capacity: self.session_config.channel_capacity,
915        }
916    }
917    #[instrument(name = "streamable_http_session", skip_all, fields(id = self.id.as_ref()))]
918    async fn run(
919        mut self,
920        mut context: WorkerContext<Self>,
921    ) -> Result<(), WorkerQuitReason<Self::Error>> {
922        enum InnerEvent {
923            FromHttpService(SessionEvent),
924            FromHandler(WorkerSendRequest<LocalSessionWorker>),
925        }
926        // waiting for initialize request
927        let evt = self.event_rx.recv().await.ok_or_else(|| {
928            WorkerQuitReason::fatal(
929                LocalSessionWorkerError::TransportTerminated,
930                "get initialize request",
931            )
932        })?;
933        let SessionEvent::InitializeRequest { request, responder } = evt else {
934            return Err(WorkerQuitReason::fatal(
935                LocalSessionWorkerError::UnexpectedEvent(evt),
936                "get initialize request",
937            ));
938        };
939        context.send_to_handler(request).await?;
940        let send_initialize_response = context.recv_from_handler().await?;
941        responder
942            .send(Ok(send_initialize_response.message))
943            .map_err(|_| {
944                WorkerQuitReason::fatal(
945                    LocalSessionWorkerError::FailToSendInitializeRequest(
946                        SessionError::SessionServiceTerminated,
947                    ),
948                    "send initialize response",
949                )
950            })?;
951        send_initialize_response
952            .responder
953            .send(Ok(()))
954            .map_err(|_| WorkerQuitReason::HandlerTerminated)?;
955        let ct = context.cancellation_token.clone();
956        let keep_alive = self.session_config.keep_alive.unwrap_or(Duration::MAX);
957        loop {
958            let keep_alive_timeout = tokio::time::sleep(keep_alive);
959            let event = tokio::select! {
960                event = self.event_rx.recv() => {
961                    if let Some(event) = event {
962                        InnerEvent::FromHttpService(event)
963                    } else {
964                        return Err(WorkerQuitReason::fatal(LocalSessionWorkerError::TransportTerminated, "waiting next session event"))
965                    }
966                },
967                from_handler = context.recv_from_handler() => {
968                    InnerEvent::FromHandler(from_handler?)
969                }
970                _ = ct.cancelled() => {
971                    return Err(WorkerQuitReason::Cancelled)
972                }
973                _ = keep_alive_timeout => {
974                    return Err(WorkerQuitReason::fatal(LocalSessionWorkerError::KeepAliveTimeout(keep_alive), "poll next session event"))
975                }
976            };
977            match event {
978                InnerEvent::FromHandler(WorkerSendRequest { message, responder }) => {
979                    // catch response
980                    let to_unregister = match &message {
981                        crate::model::JsonRpcMessage::Response(json_rpc_response) => {
982                            let request_id = json_rpc_response.id.clone();
983                            Some(ResourceKey::McpRequestId(request_id))
984                        }
985                        crate::model::JsonRpcMessage::Error(json_rpc_error) => {
986                            let request_id = json_rpc_error.id.clone();
987                            Some(ResourceKey::McpRequestId(request_id))
988                        }
989                        _ => {
990                            None
991                            // no need to unregister resource
992                        }
993                    };
994                    let handle_result = self
995                        .handle_server_message(message)
996                        .await
997                        .map_err(LocalSessionWorkerError::FailToHandleMessage);
998                    let _ = responder.send(handle_result).inspect_err(|error| {
999                        tracing::warn!(?error, "failed to send message to http service handler");
1000                    });
1001                    if let Some(to_unregister) = to_unregister {
1002                        self.unregister_resource(&to_unregister);
1003                    }
1004                }
1005                InnerEvent::FromHttpService(SessionEvent::ClientMessage {
1006                    message: json_rpc_message,
1007                    http_request_id,
1008                }) => {
1009                    match &json_rpc_message {
1010                        crate::model::JsonRpcMessage::Request(request) => {
1011                            if let Some(http_request_id) = http_request_id {
1012                                self.register_request(request, http_request_id)
1013                            }
1014                        }
1015                        crate::model::JsonRpcMessage::Notification(notification) => {
1016                            self.catch_cancellation_notification(notification)
1017                        }
1018                        _ => {}
1019                    }
1020                    context.send_to_handler(json_rpc_message).await?;
1021                }
1022                InnerEvent::FromHttpService(SessionEvent::EstablishRequestWiseChannel {
1023                    responder,
1024                }) => {
1025                    let handle_result = self.establish_request_wise_channel().await;
1026                    let _ = responder.send(handle_result);
1027                }
1028                InnerEvent::FromHttpService(SessionEvent::CloseRequestWiseChannel {
1029                    id,
1030                    responder,
1031                }) => {
1032                    let _handle_result = self.tx_router.remove(&id);
1033                    let _ = responder.send(Ok(()));
1034                }
1035                InnerEvent::FromHttpService(SessionEvent::Resume {
1036                    last_event_id,
1037                    responder,
1038                }) => {
1039                    let handle_result = self.resume(last_event_id).await;
1040                    let _ = responder.send(handle_result);
1041                }
1042                InnerEvent::FromHttpService(SessionEvent::Close) => {
1043                    return Err(WorkerQuitReason::TransportClosed);
1044                }
1045                InnerEvent::FromHttpService(SessionEvent::CloseSseStream {
1046                    http_request_id,
1047                    retry_interval,
1048                    responder,
1049                }) => {
1050                    let handle_result =
1051                        self.close_sse_stream(http_request_id, retry_interval).await;
1052                    let _ = responder.send(handle_result);
1053                }
1054                _ => {
1055                    // ignore
1056                }
1057            }
1058        }
1059    }
1060}
1061
1062#[derive(Debug, Clone)]
1063#[non_exhaustive]
1064pub struct SessionConfig {
1065    /// the capacity of the channel for the session. Default is 16.
1066    pub channel_capacity: usize,
1067    /// The session will be closed after this duration of inactivity.
1068    ///
1069    /// This serves as a safety net for cleaning up sessions whose HTTP
1070    /// connections have silently dropped (e.g., due to an HTTP/2
1071    /// `RST_STREAM`). Without a timeout, such sessions become zombies:
1072    /// the session worker keeps running indefinitely because the session
1073    /// handle's sender is still held in the session manager, preventing
1074    /// the worker's event channel from closing.
1075    ///
1076    /// Defaults to 5 minutes. Set to `None` to disable (not recommended
1077    /// for long-running servers behind proxies).
1078    pub keep_alive: Option<Duration>,
1079}
1080
1081impl SessionConfig {
1082    pub const DEFAULT_CHANNEL_CAPACITY: usize = 16;
1083    pub const DEFAULT_KEEP_ALIVE: Duration = Duration::from_secs(300);
1084}
1085
1086impl Default for SessionConfig {
1087    fn default() -> Self {
1088        Self {
1089            channel_capacity: Self::DEFAULT_CHANNEL_CAPACITY,
1090            keep_alive: Some(Self::DEFAULT_KEEP_ALIVE),
1091        }
1092    }
1093}
1094
1095/// Create a new session with the given id and configuration.
1096///
1097/// This function will return a pair of [`LocalSessionHandle`] and [`LocalSessionWorker`].
1098///
1099/// You can run the [`LocalSessionWorker`] as a transport for mcp server. And use the [`LocalSessionHandle`] operate the session.
1100pub fn create_local_session(
1101    id: impl Into<SessionId>,
1102    config: SessionConfig,
1103) -> (LocalSessionHandle, LocalSessionWorker) {
1104    let id = id.into();
1105    let (event_tx, event_rx) = tokio::sync::mpsc::channel(config.channel_capacity);
1106    let (common_tx, _) = tokio::sync::mpsc::channel(config.channel_capacity);
1107    let common = CachedTx::new_common(common_tx);
1108    tracing::info!(session_id = ?id, "create new session");
1109    let handle = LocalSessionHandle {
1110        event_tx,
1111        id: id.clone(),
1112    };
1113    let session_worker = LocalSessionWorker {
1114        next_http_request_id: 0,
1115        id,
1116        tx_router: HashMap::new(),
1117        resource_router: HashMap::new(),
1118        common,
1119        shadow_txs: Vec::new(),
1120        event_rx,
1121        session_config: config.clone(),
1122    };
1123    (handle, session_worker)
1124}