Skip to main content

rmcp/transport/streamable_http_server/session/
local.rs

1use std::{
2    collections::{HashMap, HashSet, VecDeque},
3    num::ParseIntError,
4    time::Duration,
5};
6
7use futures::Stream;
8use thiserror::Error;
9use tokio::sync::{
10    mpsc::{Receiver, Sender},
11    oneshot,
12};
13use tokio_stream::wrappers::ReceiverStream;
14use tracing::instrument;
15
16use crate::{
17    RoleServer,
18    model::{
19        CancelledNotificationParam, ClientJsonRpcMessage, ClientNotification, ClientRequest,
20        JsonRpcNotification, JsonRpcRequest, Notification, ProgressNotificationParam,
21        ProgressToken, RequestId, ServerJsonRpcMessage, ServerNotification,
22    },
23    transport::{
24        WorkerTransport,
25        common::server_side_http::{SessionId, session_id},
26        worker::{Worker, WorkerContext, WorkerQuitReason, WorkerSendRequest},
27    },
28};
29
30#[derive(Debug, Default)]
31#[non_exhaustive]
32pub struct LocalSessionManager {
33    pub sessions: tokio::sync::RwLock<HashMap<SessionId, LocalSessionHandle>>,
34    pub session_config: SessionConfig,
35}
36
37#[derive(Debug, Error)]
38#[non_exhaustive]
39pub enum LocalSessionManagerError {
40    #[error("Session not found: {0}")]
41    SessionNotFound(SessionId),
42    #[error("Session error: {0}")]
43    SessionError(#[from] SessionError),
44    #[error("Invalid event id: {0}")]
45    InvalidEventId(#[from] EventIdParseError),
46}
47impl SessionManager for LocalSessionManager {
48    type Error = LocalSessionManagerError;
49    type Transport = WorkerTransport<LocalSessionWorker>;
50    async fn create_session(&self) -> Result<(SessionId, Self::Transport), Self::Error> {
51        let id = session_id();
52        let (handle, worker) = create_local_session(id.clone(), self.session_config.clone());
53        self.sessions.write().await.insert(id.clone(), handle);
54        Ok((id, WorkerTransport::spawn(worker)))
55    }
56    async fn initialize_session(
57        &self,
58        id: &SessionId,
59        message: ClientJsonRpcMessage,
60    ) -> Result<ServerJsonRpcMessage, Self::Error> {
61        let sessions = self.sessions.read().await;
62        let handle = sessions
63            .get(id)
64            .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
65        let response = handle.initialize(message).await?;
66        Ok(response)
67    }
68    async fn close_session(&self, id: &SessionId) -> Result<(), Self::Error> {
69        let mut sessions = self.sessions.write().await;
70        if let Some(handle) = sessions.remove(id) {
71            handle.close().await?;
72        }
73        Ok(())
74    }
75    async fn has_session(&self, id: &SessionId) -> Result<bool, Self::Error> {
76        let sessions = self.sessions.read().await;
77        Ok(sessions.contains_key(id))
78    }
79    async fn create_stream(
80        &self,
81        id: &SessionId,
82        message: ClientJsonRpcMessage,
83    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
84        let sessions = self.sessions.read().await;
85        let handle = sessions
86            .get(id)
87            .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
88        let receiver = handle.establish_request_wise_channel().await?;
89        handle
90            .push_message(message, receiver.http_request_id)
91            .await?;
92        Ok(ReceiverStream::new(receiver.inner))
93    }
94
95    async fn create_standalone_stream(
96        &self,
97        id: &SessionId,
98    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
99        let sessions = self.sessions.read().await;
100        let handle = sessions
101            .get(id)
102            .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
103        let receiver = handle.establish_common_channel().await?;
104        Ok(ReceiverStream::new(receiver.inner))
105    }
106
107    async fn resume(
108        &self,
109        id: &SessionId,
110        last_event_id: String,
111    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
112        let sessions = self.sessions.read().await;
113        let handle = sessions
114            .get(id)
115            .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
116        let receiver = handle.resume(last_event_id.parse()?).await?;
117        Ok(ReceiverStream::new(receiver.inner))
118    }
119
120    async fn accept_message(
121        &self,
122        id: &SessionId,
123        message: ClientJsonRpcMessage,
124    ) -> Result<(), Self::Error> {
125        let sessions = self.sessions.read().await;
126        let handle = sessions
127            .get(id)
128            .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
129        handle.push_message(message, None).await?;
130        Ok(())
131    }
132}
133
134/// `<index>/request_id>`
135#[derive(Debug, Clone, PartialEq, Eq, Hash)]
136pub struct EventId {
137    http_request_id: Option<HttpRequestId>,
138    index: usize,
139}
140
141impl std::fmt::Display for EventId {
142    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143        write!(f, "{}", self.index)?;
144        match &self.http_request_id {
145            Some(http_request_id) => write!(f, "/{http_request_id}"),
146            None => write!(f, ""),
147        }
148    }
149}
150
151#[derive(Debug, Clone, Error)]
152#[non_exhaustive]
153pub enum EventIdParseError {
154    #[error("Invalid index: {0}")]
155    InvalidIndex(ParseIntError),
156    #[error("Invalid numeric request id: {0}")]
157    InvalidNumericRequestId(ParseIntError),
158    #[error("Missing request id type")]
159    InvalidRequestIdType,
160    #[error("Missing request id")]
161    MissingRequestId,
162}
163
164impl std::str::FromStr for EventId {
165    type Err = EventIdParseError;
166    fn from_str(s: &str) -> Result<Self, Self::Err> {
167        if let Some((index, request_id)) = s.split_once("/") {
168            let index = usize::from_str(index).map_err(EventIdParseError::InvalidIndex)?;
169            let request_id = u64::from_str(request_id).map_err(EventIdParseError::InvalidIndex)?;
170            Ok(EventId {
171                http_request_id: Some(request_id),
172                index,
173            })
174        } else {
175            let index = usize::from_str(s).map_err(EventIdParseError::InvalidIndex)?;
176            Ok(EventId {
177                http_request_id: None,
178                index,
179            })
180        }
181    }
182}
183
184use super::{ServerSseMessage, SessionManager};
185
186struct CachedTx {
187    tx: Sender<ServerSseMessage>,
188    cache: VecDeque<ServerSseMessage>,
189    http_request_id: Option<HttpRequestId>,
190    capacity: usize,
191}
192
193impl CachedTx {
194    fn new(tx: Sender<ServerSseMessage>, http_request_id: Option<HttpRequestId>) -> Self {
195        Self {
196            cache: VecDeque::with_capacity(tx.capacity()),
197            capacity: tx.capacity(),
198            tx,
199            http_request_id,
200        }
201    }
202    fn new_common(tx: Sender<ServerSseMessage>) -> Self {
203        Self::new(tx, None)
204    }
205
206    fn next_event_id(&self) -> EventId {
207        let index = self.cache.back().map_or(0, |m| {
208            m.event_id
209                .as_deref()
210                .unwrap_or_default()
211                .parse::<EventId>()
212                .expect("valid event id")
213                .index
214                + 1
215        });
216        EventId {
217            http_request_id: self.http_request_id,
218            index,
219        }
220    }
221
222    async fn send(&mut self, message: ServerJsonRpcMessage) {
223        let event_id = self.next_event_id();
224        let message = ServerSseMessage::new(event_id.to_string(), message);
225        self.cache_and_send(message).await;
226    }
227
228    async fn send_priming(&mut self, retry: Duration) {
229        let event_id = self.next_event_id();
230        let message = ServerSseMessage::priming(event_id.to_string(), retry);
231        self.cache_and_send(message).await;
232    }
233
234    async fn cache_and_send(&mut self, message: ServerSseMessage) {
235        if self.cache.len() >= self.capacity {
236            self.cache.pop_front();
237            self.cache.push_back(message.clone());
238        } else {
239            self.cache.push_back(message.clone());
240        }
241        let _ = self.tx.send(message).await.inspect_err(|e| {
242            let event_id = &e.0.event_id;
243            tracing::trace!(?event_id, "trying to send message in a closed session")
244        });
245    }
246
247    async fn sync(&mut self, index: usize) -> Result<(), SessionError> {
248        let Some(front) = self.cache.front() else {
249            return Ok(());
250        };
251        let front_event_id = front
252            .event_id
253            .as_deref()
254            .unwrap_or_default()
255            .parse::<EventId>()?;
256        let sync_index = index.saturating_sub(front_event_id.index);
257        if sync_index > self.cache.len() {
258            // invalid index
259            return Err(SessionError::InvalidEventId);
260        }
261        for message in self.cache.iter().skip(sync_index) {
262            let send_result = self.tx.send(message.clone()).await;
263            if send_result.is_err() {
264                let event_id: EventId = message.event_id.as_deref().unwrap_or_default().parse()?;
265                return Err(SessionError::ChannelClosed(Some(event_id.index as u64)));
266            }
267        }
268        Ok(())
269    }
270}
271
272struct HttpRequestWise {
273    resources: HashSet<ResourceKey>,
274    tx: CachedTx,
275}
276
277type HttpRequestId = u64;
278#[derive(Debug, Clone, Hash, PartialEq, Eq)]
279enum ResourceKey {
280    McpRequestId(RequestId),
281    ProgressToken(ProgressToken),
282}
283
284pub struct LocalSessionWorker {
285    id: SessionId,
286    next_http_request_id: HttpRequestId,
287    tx_router: HashMap<HttpRequestId, HttpRequestWise>,
288    resource_router: HashMap<ResourceKey, HttpRequestId>,
289    common: CachedTx,
290    /// Shadow senders for secondary SSE streams (e.g. from POST EventSource
291    /// reconnections). These keep the HTTP connections alive via SSE keep-alive
292    /// without receiving notifications, preventing MCP clients from entering
293    /// infinite reconnect loops when multiple EventSource connections compete
294    /// to replace the common channel.
295    shadow_txs: Vec<Sender<ServerSseMessage>>,
296    event_rx: Receiver<SessionEvent>,
297    session_config: SessionConfig,
298}
299
300impl LocalSessionWorker {
301    pub fn id(&self) -> &SessionId {
302        &self.id
303    }
304}
305
306#[derive(Debug, Error)]
307#[non_exhaustive]
308pub enum SessionError {
309    #[error("Invalid request id: {0}")]
310    DuplicatedRequestId(HttpRequestId),
311    #[error("Channel closed: {0:?}")]
312    ChannelClosed(Option<HttpRequestId>),
313    #[error("Cannot parse event id: {0}")]
314    EventIdParseError(#[from] EventIdParseError),
315    #[error("Session service terminated")]
316    SessionServiceTerminated,
317    #[error("Invalid event id")]
318    InvalidEventId,
319    #[error("IO error: {0}")]
320    Io(#[from] std::io::Error),
321}
322
323impl From<SessionError> for std::io::Error {
324    fn from(value: SessionError) -> Self {
325        match value {
326            SessionError::Io(io) => io,
327            _ => std::io::Error::other(format!("Session error: {value}")),
328        }
329    }
330}
331
332enum OutboundChannel {
333    RequestWise { id: HttpRequestId, close: bool },
334    Common,
335}
336#[derive(Debug)]
337#[non_exhaustive]
338pub struct StreamableHttpMessageReceiver {
339    pub http_request_id: Option<HttpRequestId>,
340    pub inner: Receiver<ServerSseMessage>,
341}
342
343impl LocalSessionWorker {
344    fn unregister_resource(&mut self, resource: &ResourceKey) {
345        if let Some(http_request_id) = self.resource_router.remove(resource) {
346            tracing::trace!(?resource, http_request_id, "unregister resource");
347            if let Some(channel) = self.tx_router.get_mut(&http_request_id) {
348                // It's okey to do so, since we don't handle batch json rpc request anymore
349                // and this can be refactored after the batch request is removed in the coming version.
350                if channel.resources.is_empty() || matches!(resource, ResourceKey::McpRequestId(_))
351                {
352                    tracing::debug!(http_request_id, "close http request wise channel");
353                    if let Some(channel) = self.tx_router.remove(&http_request_id) {
354                        for resource in channel.resources {
355                            self.resource_router.remove(&resource);
356                        }
357                    }
358                }
359            } else {
360                tracing::warn!(http_request_id, "http request wise channel not found");
361            }
362        }
363    }
364    fn register_resource(&mut self, resource: ResourceKey, http_request_id: HttpRequestId) {
365        tracing::trace!(?resource, http_request_id, "register resource");
366        if let Some(channel) = self.tx_router.get_mut(&http_request_id) {
367            channel.resources.insert(resource.clone());
368            self.resource_router.insert(resource, http_request_id);
369        }
370    }
371    fn register_request(
372        &mut self,
373        request: &JsonRpcRequest<ClientRequest>,
374        http_request_id: HttpRequestId,
375    ) {
376        use crate::model::GetMeta;
377        self.register_resource(
378            ResourceKey::McpRequestId(request.id.clone()),
379            http_request_id,
380        );
381        if let Some(progress_token) = request.request.get_meta().get_progress_token() {
382            self.register_resource(
383                ResourceKey::ProgressToken(progress_token.clone()),
384                http_request_id,
385            );
386        }
387    }
388    fn catch_cancellation_notification(
389        &mut self,
390        notification: &JsonRpcNotification<ClientNotification>,
391    ) {
392        if let ClientNotification::CancelledNotification(n) = &notification.notification {
393            let request_id = n.params.request_id.clone();
394            let resource = ResourceKey::McpRequestId(request_id);
395            self.unregister_resource(&resource);
396        }
397    }
398    fn next_http_request_id(&mut self) -> HttpRequestId {
399        let id = self.next_http_request_id;
400        self.next_http_request_id = self.next_http_request_id.wrapping_add(1);
401        id
402    }
403    async fn establish_request_wise_channel(
404        &mut self,
405    ) -> Result<StreamableHttpMessageReceiver, SessionError> {
406        let http_request_id = self.next_http_request_id();
407        let (tx, rx) = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
408        self.tx_router.insert(
409            http_request_id,
410            HttpRequestWise {
411                resources: Default::default(),
412                tx: CachedTx::new(tx, Some(http_request_id)),
413            },
414        );
415        tracing::debug!(http_request_id, "establish new request wise channel");
416        Ok(StreamableHttpMessageReceiver {
417            http_request_id: Some(http_request_id),
418            inner: rx,
419        })
420    }
421    fn resolve_outbound_channel(&self, message: &ServerJsonRpcMessage) -> OutboundChannel {
422        match &message {
423            ServerJsonRpcMessage::Request(_) => OutboundChannel::Common,
424            ServerJsonRpcMessage::Notification(JsonRpcNotification {
425                notification:
426                    ServerNotification::ProgressNotification(Notification {
427                        params: ProgressNotificationParam { progress_token, .. },
428                        ..
429                    }),
430                ..
431            }) => {
432                let id = self
433                    .resource_router
434                    .get(&ResourceKey::ProgressToken(progress_token.clone()));
435
436                if let Some(id) = id {
437                    OutboundChannel::RequestWise {
438                        id: *id,
439                        close: false,
440                    }
441                } else {
442                    OutboundChannel::Common
443                }
444            }
445            ServerJsonRpcMessage::Notification(JsonRpcNotification {
446                notification:
447                    ServerNotification::CancelledNotification(Notification {
448                        params: CancelledNotificationParam { request_id, .. },
449                        ..
450                    }),
451                ..
452            }) => {
453                if let Some(id) = self
454                    .resource_router
455                    .get(&ResourceKey::McpRequestId(request_id.clone()))
456                {
457                    OutboundChannel::RequestWise {
458                        id: *id,
459                        close: false,
460                    }
461                } else {
462                    OutboundChannel::Common
463                }
464            }
465            ServerJsonRpcMessage::Notification(_) => OutboundChannel::Common,
466            ServerJsonRpcMessage::Response(json_rpc_response) => {
467                if let Some(id) = self
468                    .resource_router
469                    .get(&ResourceKey::McpRequestId(json_rpc_response.id.clone()))
470                {
471                    OutboundChannel::RequestWise {
472                        id: *id,
473                        close: false,
474                    }
475                } else {
476                    OutboundChannel::Common
477                }
478            }
479            ServerJsonRpcMessage::Error(json_rpc_error) => {
480                if let Some(id) = self
481                    .resource_router
482                    .get(&ResourceKey::McpRequestId(json_rpc_error.id.clone()))
483                {
484                    OutboundChannel::RequestWise {
485                        id: *id,
486                        close: false,
487                    }
488                } else {
489                    OutboundChannel::Common
490                }
491            }
492        }
493    }
494    async fn handle_server_message(
495        &mut self,
496        message: ServerJsonRpcMessage,
497    ) -> Result<(), SessionError> {
498        let outbound_channel = self.resolve_outbound_channel(&message);
499        match outbound_channel {
500            OutboundChannel::RequestWise { id, close } => {
501                if let Some(request_wise) = self.tx_router.get_mut(&id) {
502                    request_wise.tx.send(message).await;
503                    if close {
504                        self.tx_router.remove(&id);
505                    }
506                } else {
507                    return Err(SessionError::ChannelClosed(Some(id)));
508                }
509            }
510            OutboundChannel::Common => self.common.send(message).await,
511        }
512        Ok(())
513    }
514    async fn resume(
515        &mut self,
516        last_event_id: EventId,
517    ) -> Result<StreamableHttpMessageReceiver, SessionError> {
518        // Clean up closed shadow senders before processing
519        self.shadow_txs.retain(|tx| !tx.is_closed());
520
521        match last_event_id.http_request_id {
522            Some(http_request_id) => {
523                if let Some(request_wise) = self.tx_router.get_mut(&http_request_id) {
524                    // Resume existing request-wise channel
525                    let channel = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
526                    let (tx, rx) = channel;
527                    request_wise.tx.tx = tx;
528                    let index = last_event_id.index;
529                    // sync messages after index
530                    request_wise.tx.sync(index).await?;
531                    Ok(StreamableHttpMessageReceiver {
532                        http_request_id: Some(http_request_id),
533                        inner: rx,
534                    })
535                } else {
536                    // Request-wise channel completed (POST response already delivered).
537                    // The client's EventSource is reconnecting after the POST SSE stream
538                    // ended. Fall through to common channel handling below.
539                    tracing::debug!(
540                        http_request_id,
541                        "Request-wise channel completed, falling back to common channel"
542                    );
543                    self.resume_or_shadow_common(last_event_id.index).await
544                }
545            }
546            None => self.resume_or_shadow_common(last_event_id.index).await,
547        }
548    }
549
550    /// Resume the common channel, or create a shadow stream if the primary is
551    /// still active.
552    ///
553    /// When the primary common channel is dead (receiver dropped), replace it
554    /// so this stream becomes the new primary notification channel. Cached
555    /// messages are replayed from `last_event_index` so the client receives
556    /// any events it missed (including server-initiated requests).
557    ///
558    /// When the primary is still active, create a "shadow" stream — an idle SSE
559    /// connection kept alive by keep-alive pings. This prevents multiple
560    /// EventSource connections (e.g. from POST response reconnections) from
561    /// killing each other by repeatedly replacing the common channel sender.
562    async fn resume_or_shadow_common(
563        &mut self,
564        last_event_index: usize,
565    ) -> Result<StreamableHttpMessageReceiver, SessionError> {
566        let is_replacing_dead_primary = self.common.tx.is_closed();
567        let capacity = if is_replacing_dead_primary {
568            self.session_config.channel_capacity
569        } else {
570            1 // Shadow streams only need keep-alive pings
571        };
572        let (tx, rx) = tokio::sync::mpsc::channel(capacity);
573        if is_replacing_dead_primary {
574            // Primary common channel is dead — replace it.
575            tracing::debug!("Replacing dead common channel with new primary");
576            self.common.tx = tx;
577            // Replay cached messages from where the client left off so
578            // server-initiated requests and notifications are not lost.
579            self.common.sync(last_event_index).await?;
580        } else {
581            // Primary common channel is still active. Create a shadow stream
582            // that stays alive via SSE keep-alive but doesn't receive
583            // notifications. This prevents competing EventSource connections
584            // from killing each other's channels.
585            const MAX_SHADOW_STREAMS: usize = 32;
586
587            if self.shadow_txs.len() >= MAX_SHADOW_STREAMS {
588                tracing::warn!(
589                    shadow_count = self.shadow_txs.len(),
590                    "Shadow stream limit reached, dropping oldest"
591                );
592                self.shadow_txs.remove(0);
593            }
594            tracing::debug!(
595                shadow_count = self.shadow_txs.len(),
596                "Common channel active, creating shadow stream"
597            );
598            self.shadow_txs.push(tx);
599        }
600        Ok(StreamableHttpMessageReceiver {
601            http_request_id: None,
602            inner: rx,
603        })
604    }
605
606    async fn close_sse_stream(
607        &mut self,
608        http_request_id: Option<HttpRequestId>,
609        retry_interval: Option<Duration>,
610    ) -> Result<(), SessionError> {
611        match http_request_id {
612            // Close a request-wise stream
613            Some(id) => {
614                let request_wise = self
615                    .tx_router
616                    .get_mut(&id)
617                    .ok_or(SessionError::ChannelClosed(Some(id)))?;
618
619                // Send priming event if retry interval is specified
620                if let Some(interval) = retry_interval {
621                    request_wise.tx.send_priming(interval).await;
622                }
623
624                // Close the stream by dropping the sender
625                let (tx, _rx) = tokio::sync::mpsc::channel(1);
626                request_wise.tx.tx = tx;
627
628                tracing::debug!(
629                    http_request_id = id,
630                    "closed SSE stream for server-initiated disconnection"
631                );
632                Ok(())
633            }
634            // Close the standalone (common) stream
635            None => {
636                // Send priming event if retry interval is specified
637                if let Some(interval) = retry_interval {
638                    self.common.send_priming(interval).await;
639                }
640
641                // Close the stream by dropping the sender
642                let (tx, _rx) = tokio::sync::mpsc::channel(1);
643                self.common.tx = tx;
644
645                // Also close all shadow streams
646                self.shadow_txs.clear();
647
648                tracing::debug!("closed standalone SSE stream for server-initiated disconnection");
649                Ok(())
650            }
651        }
652    }
653}
654
655#[derive(Debug)]
656#[non_exhaustive]
657pub enum SessionEvent {
658    ClientMessage {
659        message: ClientJsonRpcMessage,
660        http_request_id: Option<HttpRequestId>,
661    },
662    EstablishRequestWiseChannel {
663        responder: oneshot::Sender<Result<StreamableHttpMessageReceiver, SessionError>>,
664    },
665    CloseRequestWiseChannel {
666        id: HttpRequestId,
667        responder: oneshot::Sender<Result<(), SessionError>>,
668    },
669    Resume {
670        last_event_id: EventId,
671        responder: oneshot::Sender<Result<StreamableHttpMessageReceiver, SessionError>>,
672    },
673    InitializeRequest {
674        request: ClientJsonRpcMessage,
675        responder: oneshot::Sender<Result<ServerJsonRpcMessage, SessionError>>,
676    },
677    Close,
678    CloseSseStream {
679        /// The HTTP request ID to close. If `None`, closes the standalone (common) stream.
680        http_request_id: Option<HttpRequestId>,
681        /// Optional retry interval. If provided, a priming event is sent before closing.
682        retry_interval: Option<Duration>,
683        responder: oneshot::Sender<Result<(), SessionError>>,
684    },
685}
686
687#[derive(Debug, Clone)]
688#[non_exhaustive]
689pub enum SessionQuitReason {
690    ServiceTerminated,
691    ClientTerminated,
692    ExpectInitializeRequest,
693    ExpectInitializeResponse,
694    Cancelled,
695}
696
697#[derive(Debug, Clone)]
698pub struct LocalSessionHandle {
699    id: SessionId,
700    // after all event_tx drop, inner task will be terminated
701    event_tx: Sender<SessionEvent>,
702}
703
704impl LocalSessionHandle {
705    /// Get the session id
706    pub fn id(&self) -> &SessionId {
707        &self.id
708    }
709
710    /// Close the session
711    pub async fn close(&self) -> Result<(), SessionError> {
712        self.event_tx
713            .send(SessionEvent::Close)
714            .await
715            .map_err(|_| SessionError::SessionServiceTerminated)?;
716        Ok(())
717    }
718
719    /// Send a message to the session
720    pub async fn push_message(
721        &self,
722        message: ClientJsonRpcMessage,
723        http_request_id: Option<HttpRequestId>,
724    ) -> Result<(), SessionError> {
725        self.event_tx
726            .send(SessionEvent::ClientMessage {
727                message,
728                http_request_id,
729            })
730            .await
731            .map_err(|_| SessionError::SessionServiceTerminated)?;
732        Ok(())
733    }
734
735    /// establish a channel for a http-request, the corresponded message from server will be
736    /// sent through this channel. The channel will be closed when the request is completed,
737    /// or you can close it manually by calling [`LocalSessionHandle::close_request_wise_channel`].
738    pub async fn establish_request_wise_channel(
739        &self,
740    ) -> Result<StreamableHttpMessageReceiver, SessionError> {
741        let (tx, rx) = tokio::sync::oneshot::channel();
742        self.event_tx
743            .send(SessionEvent::EstablishRequestWiseChannel { responder: tx })
744            .await
745            .map_err(|_| SessionError::SessionServiceTerminated)?;
746        rx.await
747            .map_err(|_| SessionError::SessionServiceTerminated)?
748    }
749
750    /// close the http-request wise channel.
751    pub async fn close_request_wise_channel(
752        &self,
753        request_id: HttpRequestId,
754    ) -> Result<(), SessionError> {
755        let (tx, rx) = tokio::sync::oneshot::channel();
756        self.event_tx
757            .send(SessionEvent::CloseRequestWiseChannel {
758                id: request_id,
759                responder: tx,
760            })
761            .await
762            .map_err(|_| SessionError::SessionServiceTerminated)?;
763        rx.await
764            .map_err(|_| SessionError::SessionServiceTerminated)?
765    }
766
767    /// Establish a common channel for general purpose messages.
768    pub async fn establish_common_channel(
769        &self,
770    ) -> Result<StreamableHttpMessageReceiver, SessionError> {
771        let (tx, rx) = tokio::sync::oneshot::channel();
772        self.event_tx
773            .send(SessionEvent::Resume {
774                last_event_id: EventId {
775                    http_request_id: None,
776                    index: 0,
777                },
778                responder: tx,
779            })
780            .await
781            .map_err(|_| SessionError::SessionServiceTerminated)?;
782        rx.await
783            .map_err(|_| SessionError::SessionServiceTerminated)?
784    }
785
786    /// Resume streaming response by the last event id. This is suitable for both request wise and common channel.
787    pub async fn resume(
788        &self,
789        last_event_id: EventId,
790    ) -> Result<StreamableHttpMessageReceiver, SessionError> {
791        let (tx, rx) = tokio::sync::oneshot::channel();
792        self.event_tx
793            .send(SessionEvent::Resume {
794                last_event_id,
795                responder: tx,
796            })
797            .await
798            .map_err(|_| SessionError::SessionServiceTerminated)?;
799        rx.await
800            .map_err(|_| SessionError::SessionServiceTerminated)?
801    }
802
803    /// Send an initialize request to the session. And wait for the initialized response.
804    ///
805    /// This is used to establish a session with the server.
806    pub async fn initialize(
807        &self,
808        request: ClientJsonRpcMessage,
809    ) -> Result<ServerJsonRpcMessage, SessionError> {
810        let (tx, rx) = tokio::sync::oneshot::channel();
811        self.event_tx
812            .send(SessionEvent::InitializeRequest {
813                request,
814                responder: tx,
815            })
816            .await
817            .map_err(|_| SessionError::SessionServiceTerminated)?;
818        rx.await
819            .map_err(|_| SessionError::SessionServiceTerminated)?
820    }
821
822    /// Close an SSE stream for a specific request.
823    ///
824    /// This closes the SSE connection for a POST request stream, but keeps the session
825    /// and message cache active. Clients can reconnect using the `Last-Event-ID` header
826    /// via a GET request to resume receiving messages.
827    ///
828    /// # Arguments
829    ///
830    /// * `http_request_id` - The HTTP request ID of the stream to close
831    /// * `retry_interval` - Optional retry interval. If provided, a priming event is sent
832    pub async fn close_sse_stream(
833        &self,
834        http_request_id: HttpRequestId,
835        retry_interval: Option<Duration>,
836    ) -> Result<(), SessionError> {
837        let (tx, rx) = tokio::sync::oneshot::channel();
838        self.event_tx
839            .send(SessionEvent::CloseSseStream {
840                http_request_id: Some(http_request_id),
841                retry_interval,
842                responder: tx,
843            })
844            .await
845            .map_err(|_| SessionError::SessionServiceTerminated)?;
846        rx.await
847            .map_err(|_| SessionError::SessionServiceTerminated)?
848    }
849
850    /// Close the standalone SSE stream.
851    ///
852    /// This closes the standalone SSE connection (established via GET request),
853    /// but keeps the session and message cache active. Clients can reconnect using
854    /// the `Last-Event-ID` header via a GET request to resume receiving messages.
855    ///
856    /// # Arguments
857    ///
858    /// * `retry_interval` - Optional retry interval. If provided, a priming event is sent
859    pub async fn close_standalone_sse_stream(
860        &self,
861        retry_interval: Option<Duration>,
862    ) -> Result<(), SessionError> {
863        let (tx, rx) = tokio::sync::oneshot::channel();
864        self.event_tx
865            .send(SessionEvent::CloseSseStream {
866                http_request_id: None,
867                retry_interval,
868                responder: tx,
869            })
870            .await
871            .map_err(|_| SessionError::SessionServiceTerminated)?;
872        rx.await
873            .map_err(|_| SessionError::SessionServiceTerminated)?
874    }
875}
876
877pub type SessionTransport = WorkerTransport<LocalSessionWorker>;
878
879#[allow(clippy::large_enum_variant)]
880#[derive(Debug, Error)]
881#[non_exhaustive]
882pub enum LocalSessionWorkerError {
883    #[error("transport terminated")]
884    TransportTerminated,
885    #[error("unexpected message: {0:?}")]
886    UnexpectedEvent(SessionEvent),
887    #[error("fail to send initialize request {0}")]
888    FailToSendInitializeRequest(SessionError),
889    #[error("fail to handle message: {0}")]
890    FailToHandleMessage(SessionError),
891    #[error("keep alive timeout after {}ms", _0.as_millis())]
892    KeepAliveTimeout(Duration),
893    #[error("Transport closed")]
894    TransportClosed,
895    #[error("Tokio join error {0}")]
896    TokioJoinError(#[from] tokio::task::JoinError),
897}
898impl Worker for LocalSessionWorker {
899    type Error = LocalSessionWorkerError;
900    type Role = RoleServer;
901    fn err_closed() -> Self::Error {
902        LocalSessionWorkerError::TransportClosed
903    }
904    fn err_join(e: tokio::task::JoinError) -> Self::Error {
905        LocalSessionWorkerError::TokioJoinError(e)
906    }
907    fn config(&self) -> crate::transport::worker::WorkerConfig {
908        crate::transport::worker::WorkerConfig {
909            name: Some(format!("streamable-http-session-{}", self.id)),
910            channel_buffer_capacity: self.session_config.channel_capacity,
911        }
912    }
913    #[instrument(name = "streamable_http_session", skip_all, fields(id = self.id.as_ref()))]
914    async fn run(
915        mut self,
916        mut context: WorkerContext<Self>,
917    ) -> Result<(), WorkerQuitReason<Self::Error>> {
918        enum InnerEvent {
919            FromHttpService(SessionEvent),
920            FromHandler(WorkerSendRequest<LocalSessionWorker>),
921        }
922        // waiting for initialize request
923        let evt = self.event_rx.recv().await.ok_or_else(|| {
924            WorkerQuitReason::fatal(
925                LocalSessionWorkerError::TransportTerminated,
926                "get initialize request",
927            )
928        })?;
929        let SessionEvent::InitializeRequest { request, responder } = evt else {
930            return Err(WorkerQuitReason::fatal(
931                LocalSessionWorkerError::UnexpectedEvent(evt),
932                "get initialize request",
933            ));
934        };
935        context.send_to_handler(request).await?;
936        let send_initialize_response = context.recv_from_handler().await?;
937        responder
938            .send(Ok(send_initialize_response.message))
939            .map_err(|_| {
940                WorkerQuitReason::fatal(
941                    LocalSessionWorkerError::FailToSendInitializeRequest(
942                        SessionError::SessionServiceTerminated,
943                    ),
944                    "send initialize response",
945                )
946            })?;
947        send_initialize_response
948            .responder
949            .send(Ok(()))
950            .map_err(|_| WorkerQuitReason::HandlerTerminated)?;
951        let ct = context.cancellation_token.clone();
952        let keep_alive = self.session_config.keep_alive.unwrap_or(Duration::MAX);
953        loop {
954            let keep_alive_timeout = tokio::time::sleep(keep_alive);
955            let event = tokio::select! {
956                event = self.event_rx.recv() => {
957                    if let Some(event) = event {
958                        InnerEvent::FromHttpService(event)
959                    } else {
960                        return Err(WorkerQuitReason::fatal(LocalSessionWorkerError::TransportTerminated, "waiting next session event"))
961                    }
962                },
963                from_handler = context.recv_from_handler() => {
964                    InnerEvent::FromHandler(from_handler?)
965                }
966                _ = ct.cancelled() => {
967                    return Err(WorkerQuitReason::Cancelled)
968                }
969                _ = keep_alive_timeout => {
970                    return Err(WorkerQuitReason::fatal(LocalSessionWorkerError::KeepAliveTimeout(keep_alive), "poll next session event"))
971                }
972            };
973            match event {
974                InnerEvent::FromHandler(WorkerSendRequest { message, responder }) => {
975                    // catch response
976                    let to_unregister = match &message {
977                        crate::model::JsonRpcMessage::Response(json_rpc_response) => {
978                            let request_id = json_rpc_response.id.clone();
979                            Some(ResourceKey::McpRequestId(request_id))
980                        }
981                        crate::model::JsonRpcMessage::Error(json_rpc_error) => {
982                            let request_id = json_rpc_error.id.clone();
983                            Some(ResourceKey::McpRequestId(request_id))
984                        }
985                        _ => {
986                            None
987                            // no need to unregister resource
988                        }
989                    };
990                    let handle_result = self
991                        .handle_server_message(message)
992                        .await
993                        .map_err(LocalSessionWorkerError::FailToHandleMessage);
994                    let _ = responder.send(handle_result).inspect_err(|error| {
995                        tracing::warn!(?error, "failed to send message to http service handler");
996                    });
997                    if let Some(to_unregister) = to_unregister {
998                        self.unregister_resource(&to_unregister);
999                    }
1000                }
1001                InnerEvent::FromHttpService(SessionEvent::ClientMessage {
1002                    message: json_rpc_message,
1003                    http_request_id,
1004                }) => {
1005                    match &json_rpc_message {
1006                        crate::model::JsonRpcMessage::Request(request) => {
1007                            if let Some(http_request_id) = http_request_id {
1008                                self.register_request(request, http_request_id)
1009                            }
1010                        }
1011                        crate::model::JsonRpcMessage::Notification(notification) => {
1012                            self.catch_cancellation_notification(notification)
1013                        }
1014                        _ => {}
1015                    }
1016                    context.send_to_handler(json_rpc_message).await?;
1017                }
1018                InnerEvent::FromHttpService(SessionEvent::EstablishRequestWiseChannel {
1019                    responder,
1020                }) => {
1021                    let handle_result = self.establish_request_wise_channel().await;
1022                    let _ = responder.send(handle_result);
1023                }
1024                InnerEvent::FromHttpService(SessionEvent::CloseRequestWiseChannel {
1025                    id,
1026                    responder,
1027                }) => {
1028                    let _handle_result = self.tx_router.remove(&id);
1029                    let _ = responder.send(Ok(()));
1030                }
1031                InnerEvent::FromHttpService(SessionEvent::Resume {
1032                    last_event_id,
1033                    responder,
1034                }) => {
1035                    let handle_result = self.resume(last_event_id).await;
1036                    let _ = responder.send(handle_result);
1037                }
1038                InnerEvent::FromHttpService(SessionEvent::Close) => {
1039                    return Err(WorkerQuitReason::TransportClosed);
1040                }
1041                InnerEvent::FromHttpService(SessionEvent::CloseSseStream {
1042                    http_request_id,
1043                    retry_interval,
1044                    responder,
1045                }) => {
1046                    let handle_result =
1047                        self.close_sse_stream(http_request_id, retry_interval).await;
1048                    let _ = responder.send(handle_result);
1049                }
1050                _ => {
1051                    // ignore
1052                }
1053            }
1054        }
1055    }
1056}
1057
1058#[derive(Debug, Clone)]
1059#[non_exhaustive]
1060pub struct SessionConfig {
1061    /// the capacity of the channel for the session. Default is 16.
1062    pub channel_capacity: usize,
1063    /// The session will be closed after this duration of inactivity.
1064    ///
1065    /// This serves as a safety net for cleaning up sessions whose HTTP
1066    /// connections have silently dropped (e.g., due to an HTTP/2
1067    /// `RST_STREAM`). Without a timeout, such sessions become zombies:
1068    /// the session worker keeps running indefinitely because the session
1069    /// handle's sender is still held in the session manager, preventing
1070    /// the worker's event channel from closing.
1071    ///
1072    /// Defaults to 5 minutes. Set to `None` to disable (not recommended
1073    /// for long-running servers behind proxies).
1074    pub keep_alive: Option<Duration>,
1075}
1076
1077impl SessionConfig {
1078    pub const DEFAULT_CHANNEL_CAPACITY: usize = 16;
1079    pub const DEFAULT_KEEP_ALIVE: Duration = Duration::from_secs(300);
1080}
1081
1082impl Default for SessionConfig {
1083    fn default() -> Self {
1084        Self {
1085            channel_capacity: Self::DEFAULT_CHANNEL_CAPACITY,
1086            keep_alive: Some(Self::DEFAULT_KEEP_ALIVE),
1087        }
1088    }
1089}
1090
1091/// Create a new session with the given id and configuration.
1092///
1093/// This function will return a pair of [`LocalSessionHandle`] and [`LocalSessionWorker`].
1094///
1095/// You can run the [`LocalSessionWorker`] as a transport for mcp server. And use the [`LocalSessionHandle`] operate the session.
1096pub fn create_local_session(
1097    id: impl Into<SessionId>,
1098    config: SessionConfig,
1099) -> (LocalSessionHandle, LocalSessionWorker) {
1100    let id = id.into();
1101    let (event_tx, event_rx) = tokio::sync::mpsc::channel(config.channel_capacity);
1102    let (common_tx, _) = tokio::sync::mpsc::channel(config.channel_capacity);
1103    let common = CachedTx::new_common(common_tx);
1104    tracing::info!(session_id = ?id, "create new session");
1105    let handle = LocalSessionHandle {
1106        event_tx,
1107        id: id.clone(),
1108    };
1109    let session_worker = LocalSessionWorker {
1110        next_http_request_id: 0,
1111        id,
1112        tx_router: HashMap::new(),
1113        resource_router: HashMap::new(),
1114        common,
1115        shadow_txs: Vec::new(),
1116        event_rx,
1117        session_config: config.clone(),
1118    };
1119    (handle, session_worker)
1120}