Skip to main content

rmcp_soddygo/transport/streamable_http_server/session/
local.rs

1use std::{
2    collections::{HashMap, HashSet, VecDeque},
3    num::ParseIntError,
4    sync::Arc,
5    time::Duration,
6};
7
8use futures::Stream;
9use thiserror::Error;
10use tokio::sync::{
11    mpsc::{Receiver, Sender},
12    oneshot,
13};
14use tokio_stream::wrappers::ReceiverStream;
15use tracing::instrument;
16
17use crate::{
18    RoleServer,
19    model::{
20        CancelledNotificationParam, ClientJsonRpcMessage, ClientNotification, ClientRequest,
21        JsonRpcNotification, JsonRpcRequest, Notification, ProgressNotificationParam,
22        ProgressToken, RequestId, ServerJsonRpcMessage, ServerNotification,
23    },
24    transport::{
25        WorkerTransport,
26        common::server_side_http::{SessionId, session_id},
27        worker::{Worker, WorkerContext, WorkerQuitReason, WorkerSendRequest},
28    },
29};
30
31#[derive(Debug, Default)]
32pub struct LocalSessionManager {
33    pub sessions: tokio::sync::RwLock<HashMap<SessionId, LocalSessionHandle>>,
34    pub session_config: SessionConfig,
35}
36
37#[derive(Debug, Error)]
38pub enum LocalSessionManagerError {
39    #[error("Session not found: {0}")]
40    SessionNotFound(SessionId),
41    #[error("Session error: {0}")]
42    SessionError(#[from] SessionError),
43    #[error("Invalid event id: {0}")]
44    InvalidEventId(#[from] EventIdParseError),
45}
46impl SessionManager for LocalSessionManager {
47    type Error = LocalSessionManagerError;
48    type Transport = WorkerTransport<LocalSessionWorker>;
49    async fn create_session(&self) -> Result<(SessionId, Self::Transport), Self::Error> {
50        let id = session_id();
51        let (handle, worker) = create_local_session(id.clone(), self.session_config.clone());
52        self.sessions.write().await.insert(id.clone(), handle);
53        Ok((id, WorkerTransport::spawn(worker)))
54    }
55    async fn initialize_session(
56        &self,
57        id: &SessionId,
58        message: ClientJsonRpcMessage,
59    ) -> Result<ServerJsonRpcMessage, Self::Error> {
60        let sessions = self.sessions.read().await;
61        let handle = sessions
62            .get(id)
63            .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
64        let response = handle.initialize(message).await?;
65        Ok(response)
66    }
67    async fn close_session(&self, id: &SessionId) -> Result<(), Self::Error> {
68        let mut sessions = self.sessions.write().await;
69        if let Some(handle) = sessions.remove(id) {
70            handle.close().await?;
71        }
72        Ok(())
73    }
74    async fn has_session(&self, id: &SessionId) -> Result<bool, Self::Error> {
75        let sessions = self.sessions.read().await;
76        Ok(sessions.contains_key(id))
77    }
78    async fn create_stream(
79        &self,
80        id: &SessionId,
81        message: ClientJsonRpcMessage,
82    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
83        let sessions = self.sessions.read().await;
84        let handle = sessions
85            .get(id)
86            .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
87        let receiver = handle.establish_request_wise_channel().await?;
88        handle
89            .push_message(message, receiver.http_request_id)
90            .await?;
91        Ok(ReceiverStream::new(receiver.inner))
92    }
93
94    async fn create_standalone_stream(
95        &self,
96        id: &SessionId,
97    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
98        let sessions = self.sessions.read().await;
99        let handle = sessions
100            .get(id)
101            .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
102        let receiver = handle.establish_common_channel().await?;
103        Ok(ReceiverStream::new(receiver.inner))
104    }
105
106    async fn resume(
107        &self,
108        id: &SessionId,
109        last_event_id: String,
110    ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
111        let sessions = self.sessions.read().await;
112        let handle = sessions
113            .get(id)
114            .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
115        let receiver = handle.resume(last_event_id.parse()?).await?;
116        Ok(ReceiverStream::new(receiver.inner))
117    }
118
119    async fn accept_message(
120        &self,
121        id: &SessionId,
122        message: ClientJsonRpcMessage,
123    ) -> Result<(), Self::Error> {
124        let sessions = self.sessions.read().await;
125        let handle = sessions
126            .get(id)
127            .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
128        handle.push_message(message, None).await?;
129        Ok(())
130    }
131}
132
133/// `<index>/request_id>`
134#[derive(Debug, Clone, PartialEq, Eq, Hash)]
135pub struct EventId {
136    http_request_id: Option<HttpRequestId>,
137    index: usize,
138}
139
140impl std::fmt::Display for EventId {
141    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142        write!(f, "{}", self.index)?;
143        match &self.http_request_id {
144            Some(http_request_id) => write!(f, "/{http_request_id}"),
145            None => write!(f, ""),
146        }
147    }
148}
149
150#[derive(Debug, Clone, Error)]
151pub enum EventIdParseError {
152    #[error("Invalid index: {0}")]
153    InvalidIndex(ParseIntError),
154    #[error("Invalid numeric request id: {0}")]
155    InvalidNumericRequestId(ParseIntError),
156    #[error("Missing request id type")]
157    InvalidRequestIdType,
158    #[error("Missing request id")]
159    MissingRequestId,
160}
161
162impl std::str::FromStr for EventId {
163    type Err = EventIdParseError;
164    fn from_str(s: &str) -> Result<Self, Self::Err> {
165        if let Some((index, request_id)) = s.split_once("/") {
166            let index = usize::from_str(index).map_err(EventIdParseError::InvalidIndex)?;
167            let request_id = u64::from_str(request_id).map_err(EventIdParseError::InvalidIndex)?;
168            Ok(EventId {
169                http_request_id: Some(request_id),
170                index,
171            })
172        } else {
173            let index = usize::from_str(s).map_err(EventIdParseError::InvalidIndex)?;
174            Ok(EventId {
175                http_request_id: None,
176                index,
177            })
178        }
179    }
180}
181
182use super::{ServerSseMessage, SessionManager};
183
184struct CachedTx {
185    tx: Sender<ServerSseMessage>,
186    cache: VecDeque<ServerSseMessage>,
187    http_request_id: Option<HttpRequestId>,
188    capacity: usize,
189}
190
191impl CachedTx {
192    fn new(tx: Sender<ServerSseMessage>, http_request_id: Option<HttpRequestId>) -> Self {
193        Self {
194            cache: VecDeque::with_capacity(tx.capacity()),
195            capacity: tx.capacity(),
196            tx,
197            http_request_id,
198        }
199    }
200    fn new_common(tx: Sender<ServerSseMessage>) -> Self {
201        Self::new(tx, None)
202    }
203
204    fn next_event_id(&self) -> EventId {
205        let index = self.cache.back().map_or(0, |m| {
206            m.event_id
207                .as_deref()
208                .unwrap_or_default()
209                .parse::<EventId>()
210                .expect("valid event id")
211                .index
212                + 1
213        });
214        EventId {
215            http_request_id: self.http_request_id,
216            index,
217        }
218    }
219
220    async fn send(&mut self, message: ServerJsonRpcMessage) {
221        let event_id = self.next_event_id();
222        let message = ServerSseMessage {
223            event_id: Some(event_id.to_string()),
224            message: Some(Arc::new(message)),
225            retry: None,
226        };
227        self.cache_and_send(message).await;
228    }
229
230    async fn send_priming(&mut self, retry: Duration) {
231        let event_id = self.next_event_id();
232        let message = ServerSseMessage {
233            event_id: Some(event_id.to_string()),
234            message: None,
235            retry: Some(retry),
236        };
237        self.cache_and_send(message).await;
238    }
239
240    async fn cache_and_send(&mut self, message: ServerSseMessage) {
241        if self.cache.len() >= self.capacity {
242            self.cache.pop_front();
243            self.cache.push_back(message.clone());
244        } else {
245            self.cache.push_back(message.clone());
246        }
247        let _ = self.tx.send(message).await.inspect_err(|e| {
248            let event_id = &e.0.event_id;
249            tracing::trace!(?event_id, "trying to send message in a closed session")
250        });
251    }
252
253    async fn sync(&mut self, index: usize) -> Result<(), SessionError> {
254        let Some(front) = self.cache.front() else {
255            return Ok(());
256        };
257        let front_event_id = front
258            .event_id
259            .as_deref()
260            .unwrap_or_default()
261            .parse::<EventId>()?;
262        let sync_index = index.saturating_sub(front_event_id.index);
263        if sync_index > self.cache.len() {
264            // invalid index
265            return Err(SessionError::InvalidEventId);
266        }
267        for message in self.cache.iter().skip(sync_index) {
268            let send_result = self.tx.send(message.clone()).await;
269            if send_result.is_err() {
270                let event_id: EventId = message.event_id.as_deref().unwrap_or_default().parse()?;
271                return Err(SessionError::ChannelClosed(Some(event_id.index as u64)));
272            }
273        }
274        Ok(())
275    }
276}
277
278struct HttpRequestWise {
279    resources: HashSet<ResourceKey>,
280    tx: CachedTx,
281}
282
283type HttpRequestId = u64;
284#[derive(Debug, Clone, Hash, PartialEq, Eq)]
285enum ResourceKey {
286    McpRequestId(RequestId),
287    ProgressToken(ProgressToken),
288}
289
290pub struct LocalSessionWorker {
291    id: SessionId,
292    next_http_request_id: HttpRequestId,
293    tx_router: HashMap<HttpRequestId, HttpRequestWise>,
294    resource_router: HashMap<ResourceKey, HttpRequestId>,
295    common: CachedTx,
296    /// Shadow senders for secondary SSE streams (e.g. from POST EventSource
297    /// reconnections). These keep the HTTP connections alive via SSE keep-alive
298    /// without receiving notifications, preventing MCP clients from entering
299    /// infinite reconnect loops when multiple EventSource connections compete
300    /// to replace the common channel.
301    shadow_txs: Vec<Sender<ServerSseMessage>>,
302    event_rx: Receiver<SessionEvent>,
303    session_config: SessionConfig,
304}
305
306impl LocalSessionWorker {
307    pub fn id(&self) -> &SessionId {
308        &self.id
309    }
310}
311
312#[derive(Debug, Error)]
313pub enum SessionError {
314    #[error("Invalid request id: {0}")]
315    DuplicatedRequestId(HttpRequestId),
316    #[error("Channel closed: {0:?}")]
317    ChannelClosed(Option<HttpRequestId>),
318    #[error("Cannot parse event id: {0}")]
319    EventIdParseError(#[from] EventIdParseError),
320    #[error("Session service terminated")]
321    SessionServiceTerminated,
322    #[error("Invalid event id")]
323    InvalidEventId,
324    #[error("IO error: {0}")]
325    Io(#[from] std::io::Error),
326}
327
328impl From<SessionError> for std::io::Error {
329    fn from(value: SessionError) -> Self {
330        match value {
331            SessionError::Io(io) => io,
332            _ => std::io::Error::other(format!("Session error: {value}")),
333        }
334    }
335}
336
337enum OutboundChannel {
338    RequestWise { id: HttpRequestId, close: bool },
339    Common,
340}
341#[derive(Debug)]
342pub struct StreamableHttpMessageReceiver {
343    pub http_request_id: Option<HttpRequestId>,
344    pub inner: Receiver<ServerSseMessage>,
345}
346
347impl LocalSessionWorker {
348    fn unregister_resource(&mut self, resource: &ResourceKey) {
349        if let Some(http_request_id) = self.resource_router.remove(resource) {
350            tracing::trace!(?resource, http_request_id, "unregister resource");
351            if let Some(channel) = self.tx_router.get_mut(&http_request_id) {
352                // It's okey to do so, since we don't handle batch json rpc request anymore
353                // and this can be refactored after the batch request is removed in the coming version.
354                if channel.resources.is_empty() || matches!(resource, ResourceKey::McpRequestId(_))
355                {
356                    tracing::debug!(http_request_id, "close http request wise channel");
357                    if let Some(channel) = self.tx_router.remove(&http_request_id) {
358                        for resource in channel.resources {
359                            self.resource_router.remove(&resource);
360                        }
361                    }
362                }
363            } else {
364                tracing::warn!(http_request_id, "http request wise channel not found");
365            }
366        }
367    }
368    fn register_resource(&mut self, resource: ResourceKey, http_request_id: HttpRequestId) {
369        tracing::trace!(?resource, http_request_id, "register resource");
370        if let Some(channel) = self.tx_router.get_mut(&http_request_id) {
371            channel.resources.insert(resource.clone());
372            self.resource_router.insert(resource, http_request_id);
373        }
374    }
375    fn register_request(
376        &mut self,
377        request: &JsonRpcRequest<ClientRequest>,
378        http_request_id: HttpRequestId,
379    ) {
380        use crate::model::GetMeta;
381        self.register_resource(
382            ResourceKey::McpRequestId(request.id.clone()),
383            http_request_id,
384        );
385        if let Some(progress_token) = request.request.get_meta().get_progress_token() {
386            self.register_resource(
387                ResourceKey::ProgressToken(progress_token.clone()),
388                http_request_id,
389            );
390        }
391    }
392    fn catch_cancellation_notification(
393        &mut self,
394        notification: &JsonRpcNotification<ClientNotification>,
395    ) {
396        if let ClientNotification::CancelledNotification(n) = &notification.notification {
397            let request_id = n.params.request_id.clone();
398            let resource = ResourceKey::McpRequestId(request_id);
399            self.unregister_resource(&resource);
400        }
401    }
402    fn next_http_request_id(&mut self) -> HttpRequestId {
403        let id = self.next_http_request_id;
404        self.next_http_request_id = self.next_http_request_id.wrapping_add(1);
405        id
406    }
407    async fn establish_request_wise_channel(
408        &mut self,
409    ) -> Result<StreamableHttpMessageReceiver, SessionError> {
410        let http_request_id = self.next_http_request_id();
411        let (tx, rx) = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
412        self.tx_router.insert(
413            http_request_id,
414            HttpRequestWise {
415                resources: Default::default(),
416                tx: CachedTx::new(tx, Some(http_request_id)),
417            },
418        );
419        tracing::debug!(http_request_id, "establish new request wise channel");
420        Ok(StreamableHttpMessageReceiver {
421            http_request_id: Some(http_request_id),
422            inner: rx,
423        })
424    }
425    fn resolve_outbound_channel(&self, message: &ServerJsonRpcMessage) -> OutboundChannel {
426        match &message {
427            ServerJsonRpcMessage::Request(_) => OutboundChannel::Common,
428            ServerJsonRpcMessage::Notification(JsonRpcNotification {
429                notification:
430                    ServerNotification::ProgressNotification(Notification {
431                        params: ProgressNotificationParam { progress_token, .. },
432                        ..
433                    }),
434                ..
435            }) => {
436                let id = self
437                    .resource_router
438                    .get(&ResourceKey::ProgressToken(progress_token.clone()));
439
440                if let Some(id) = id {
441                    OutboundChannel::RequestWise {
442                        id: *id,
443                        close: false,
444                    }
445                } else {
446                    OutboundChannel::Common
447                }
448            }
449            ServerJsonRpcMessage::Notification(JsonRpcNotification {
450                notification:
451                    ServerNotification::CancelledNotification(Notification {
452                        params: CancelledNotificationParam { request_id, .. },
453                        ..
454                    }),
455                ..
456            }) => {
457                if let Some(id) = self
458                    .resource_router
459                    .get(&ResourceKey::McpRequestId(request_id.clone()))
460                {
461                    OutboundChannel::RequestWise {
462                        id: *id,
463                        close: false,
464                    }
465                } else {
466                    OutboundChannel::Common
467                }
468            }
469            ServerJsonRpcMessage::Notification(_) => OutboundChannel::Common,
470            ServerJsonRpcMessage::Response(json_rpc_response) => {
471                if let Some(id) = self
472                    .resource_router
473                    .get(&ResourceKey::McpRequestId(json_rpc_response.id.clone()))
474                {
475                    OutboundChannel::RequestWise {
476                        id: *id,
477                        close: false,
478                    }
479                } else {
480                    OutboundChannel::Common
481                }
482            }
483            ServerJsonRpcMessage::Error(json_rpc_error) => {
484                if let Some(id) = self
485                    .resource_router
486                    .get(&ResourceKey::McpRequestId(json_rpc_error.id.clone()))
487                {
488                    OutboundChannel::RequestWise {
489                        id: *id,
490                        close: false,
491                    }
492                } else {
493                    OutboundChannel::Common
494                }
495            }
496        }
497    }
498    async fn handle_server_message(
499        &mut self,
500        message: ServerJsonRpcMessage,
501    ) -> Result<(), SessionError> {
502        let outbound_channel = self.resolve_outbound_channel(&message);
503        match outbound_channel {
504            OutboundChannel::RequestWise { id, close } => {
505                if let Some(request_wise) = self.tx_router.get_mut(&id) {
506                    request_wise.tx.send(message).await;
507                    if close {
508                        self.tx_router.remove(&id);
509                    }
510                } else {
511                    return Err(SessionError::ChannelClosed(Some(id)));
512                }
513            }
514            OutboundChannel::Common => self.common.send(message).await,
515        }
516        Ok(())
517    }
518    async fn resume(
519        &mut self,
520        last_event_id: EventId,
521    ) -> Result<StreamableHttpMessageReceiver, SessionError> {
522        // Clean up closed shadow senders before processing
523        self.shadow_txs.retain(|tx| !tx.is_closed());
524
525        match last_event_id.http_request_id {
526            Some(http_request_id) => {
527                if let Some(request_wise) = self.tx_router.get_mut(&http_request_id) {
528                    // Resume existing request-wise channel
529                    let channel = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
530                    let (tx, rx) = channel;
531                    request_wise.tx.tx = tx;
532                    let index = last_event_id.index;
533                    // sync messages after index
534                    request_wise.tx.sync(index).await?;
535                    Ok(StreamableHttpMessageReceiver {
536                        http_request_id: Some(http_request_id),
537                        inner: rx,
538                    })
539                } else {
540                    // Request-wise channel completed (POST response already delivered).
541                    // The client's EventSource is reconnecting after the POST SSE stream
542                    // ended. Fall through to common channel handling below.
543                    tracing::debug!(
544                        http_request_id,
545                        "Request-wise channel completed, falling back to common channel"
546                    );
547                    self.resume_or_shadow_common(last_event_id.index).await
548                }
549            }
550            None => self.resume_or_shadow_common(last_event_id.index).await,
551        }
552    }
553
554    /// Resume the common channel, or create a shadow stream if the primary is
555    /// still active.
556    ///
557    /// When the primary common channel is dead (receiver dropped), replace it
558    /// so this stream becomes the new primary notification channel. Cached
559    /// messages are replayed from `last_event_index` so the client receives
560    /// any events it missed (including server-initiated requests).
561    ///
562    /// When the primary is still active, create a "shadow" stream — an idle SSE
563    /// connection kept alive by keep-alive pings. This prevents multiple
564    /// EventSource connections (e.g. from POST response reconnections) from
565    /// killing each other by repeatedly replacing the common channel sender.
566    async fn resume_or_shadow_common(
567        &mut self,
568        last_event_index: usize,
569    ) -> Result<StreamableHttpMessageReceiver, SessionError> {
570        let is_replacing_dead_primary = self.common.tx.is_closed();
571        let capacity = if is_replacing_dead_primary {
572            self.session_config.channel_capacity
573        } else {
574            1 // Shadow streams only need keep-alive pings
575        };
576        let (tx, rx) = tokio::sync::mpsc::channel(capacity);
577        if is_replacing_dead_primary {
578            // Primary common channel is dead — replace it.
579            tracing::debug!("Replacing dead common channel with new primary");
580            self.common.tx = tx;
581            // Replay cached messages from where the client left off so
582            // server-initiated requests and notifications are not lost.
583            self.common.sync(last_event_index).await?;
584        } else {
585            // Primary common channel is still active. Create a shadow stream
586            // that stays alive via SSE keep-alive but doesn't receive
587            // notifications. This prevents competing EventSource connections
588            // from killing each other's channels.
589            const MAX_SHADOW_STREAMS: usize = 32;
590
591            if self.shadow_txs.len() >= MAX_SHADOW_STREAMS {
592                tracing::warn!(
593                    shadow_count = self.shadow_txs.len(),
594                    "Shadow stream limit reached, dropping oldest"
595                );
596                self.shadow_txs.remove(0);
597            }
598            tracing::debug!(
599                shadow_count = self.shadow_txs.len(),
600                "Common channel active, creating shadow stream"
601            );
602            self.shadow_txs.push(tx);
603        }
604        Ok(StreamableHttpMessageReceiver {
605            http_request_id: None,
606            inner: rx,
607        })
608    }
609
610    async fn close_sse_stream(
611        &mut self,
612        http_request_id: Option<HttpRequestId>,
613        retry_interval: Option<Duration>,
614    ) -> Result<(), SessionError> {
615        match http_request_id {
616            // Close a request-wise stream
617            Some(id) => {
618                let request_wise = self
619                    .tx_router
620                    .get_mut(&id)
621                    .ok_or(SessionError::ChannelClosed(Some(id)))?;
622
623                // Send priming event if retry interval is specified
624                if let Some(interval) = retry_interval {
625                    request_wise.tx.send_priming(interval).await;
626                }
627
628                // Close the stream by dropping the sender
629                let (tx, _rx) = tokio::sync::mpsc::channel(1);
630                request_wise.tx.tx = tx;
631
632                tracing::debug!(
633                    http_request_id = id,
634                    "closed SSE stream for server-initiated disconnection"
635                );
636                Ok(())
637            }
638            // Close the standalone (common) stream
639            None => {
640                // Send priming event if retry interval is specified
641                if let Some(interval) = retry_interval {
642                    self.common.send_priming(interval).await;
643                }
644
645                // Close the stream by dropping the sender
646                let (tx, _rx) = tokio::sync::mpsc::channel(1);
647                self.common.tx = tx;
648
649                // Also close all shadow streams
650                self.shadow_txs.clear();
651
652                tracing::debug!("closed standalone SSE stream for server-initiated disconnection");
653                Ok(())
654            }
655        }
656    }
657}
658
659#[derive(Debug)]
660pub enum SessionEvent {
661    ClientMessage {
662        message: ClientJsonRpcMessage,
663        http_request_id: Option<HttpRequestId>,
664    },
665    EstablishRequestWiseChannel {
666        responder: oneshot::Sender<Result<StreamableHttpMessageReceiver, SessionError>>,
667    },
668    CloseRequestWiseChannel {
669        id: HttpRequestId,
670        responder: oneshot::Sender<Result<(), SessionError>>,
671    },
672    Resume {
673        last_event_id: EventId,
674        responder: oneshot::Sender<Result<StreamableHttpMessageReceiver, SessionError>>,
675    },
676    InitializeRequest {
677        request: ClientJsonRpcMessage,
678        responder: oneshot::Sender<Result<ServerJsonRpcMessage, SessionError>>,
679    },
680    Close,
681    CloseSseStream {
682        /// The HTTP request ID to close. If `None`, closes the standalone (common) stream.
683        http_request_id: Option<HttpRequestId>,
684        /// Optional retry interval. If provided, a priming event is sent before closing.
685        retry_interval: Option<Duration>,
686        responder: oneshot::Sender<Result<(), SessionError>>,
687    },
688}
689
690#[derive(Debug, Clone)]
691#[non_exhaustive]
692pub enum SessionQuitReason {
693    ServiceTerminated,
694    ClientTerminated,
695    ExpectInitializeRequest,
696    ExpectInitializeResponse,
697    Cancelled,
698}
699
700#[derive(Debug, Clone)]
701pub struct LocalSessionHandle {
702    id: SessionId,
703    // after all event_tx drop, inner task will be terminated
704    event_tx: Sender<SessionEvent>,
705}
706
707impl LocalSessionHandle {
708    /// Get the session id
709    pub fn id(&self) -> &SessionId {
710        &self.id
711    }
712
713    /// Close the session
714    pub async fn close(&self) -> Result<(), SessionError> {
715        self.event_tx
716            .send(SessionEvent::Close)
717            .await
718            .map_err(|_| SessionError::SessionServiceTerminated)?;
719        Ok(())
720    }
721
722    /// Send a message to the session
723    pub async fn push_message(
724        &self,
725        message: ClientJsonRpcMessage,
726        http_request_id: Option<HttpRequestId>,
727    ) -> Result<(), SessionError> {
728        self.event_tx
729            .send(SessionEvent::ClientMessage {
730                message,
731                http_request_id,
732            })
733            .await
734            .map_err(|_| SessionError::SessionServiceTerminated)?;
735        Ok(())
736    }
737
738    /// establish a channel for a http-request, the corresponded message from server will be
739    /// sent through this channel. The channel will be closed when the request is completed,
740    /// or you can close it manually by calling [`LocalSessionHandle::close_request_wise_channel`].
741    pub async fn establish_request_wise_channel(
742        &self,
743    ) -> Result<StreamableHttpMessageReceiver, SessionError> {
744        let (tx, rx) = tokio::sync::oneshot::channel();
745        self.event_tx
746            .send(SessionEvent::EstablishRequestWiseChannel { responder: tx })
747            .await
748            .map_err(|_| SessionError::SessionServiceTerminated)?;
749        rx.await
750            .map_err(|_| SessionError::SessionServiceTerminated)?
751    }
752
753    /// close the http-request wise channel.
754    pub async fn close_request_wise_channel(
755        &self,
756        request_id: HttpRequestId,
757    ) -> Result<(), SessionError> {
758        let (tx, rx) = tokio::sync::oneshot::channel();
759        self.event_tx
760            .send(SessionEvent::CloseRequestWiseChannel {
761                id: request_id,
762                responder: tx,
763            })
764            .await
765            .map_err(|_| SessionError::SessionServiceTerminated)?;
766        rx.await
767            .map_err(|_| SessionError::SessionServiceTerminated)?
768    }
769
770    /// Establish a common channel for general purpose messages.
771    pub async fn establish_common_channel(
772        &self,
773    ) -> Result<StreamableHttpMessageReceiver, SessionError> {
774        let (tx, rx) = tokio::sync::oneshot::channel();
775        self.event_tx
776            .send(SessionEvent::Resume {
777                last_event_id: EventId {
778                    http_request_id: None,
779                    index: 0,
780                },
781                responder: tx,
782            })
783            .await
784            .map_err(|_| SessionError::SessionServiceTerminated)?;
785        rx.await
786            .map_err(|_| SessionError::SessionServiceTerminated)?
787    }
788
789    /// Resume streaming response by the last event id. This is suitable for both request wise and common channel.
790    pub async fn resume(
791        &self,
792        last_event_id: EventId,
793    ) -> Result<StreamableHttpMessageReceiver, SessionError> {
794        let (tx, rx) = tokio::sync::oneshot::channel();
795        self.event_tx
796            .send(SessionEvent::Resume {
797                last_event_id,
798                responder: tx,
799            })
800            .await
801            .map_err(|_| SessionError::SessionServiceTerminated)?;
802        rx.await
803            .map_err(|_| SessionError::SessionServiceTerminated)?
804    }
805
806    /// Send an initialize request to the session. And wait for the initialized response.
807    ///
808    /// This is used to establish a session with the server.
809    pub async fn initialize(
810        &self,
811        request: ClientJsonRpcMessage,
812    ) -> Result<ServerJsonRpcMessage, SessionError> {
813        let (tx, rx) = tokio::sync::oneshot::channel();
814        self.event_tx
815            .send(SessionEvent::InitializeRequest {
816                request,
817                responder: tx,
818            })
819            .await
820            .map_err(|_| SessionError::SessionServiceTerminated)?;
821        rx.await
822            .map_err(|_| SessionError::SessionServiceTerminated)?
823    }
824
825    /// Close an SSE stream for a specific request.
826    ///
827    /// This closes the SSE connection for a POST request stream, but keeps the session
828    /// and message cache active. Clients can reconnect using the `Last-Event-ID` header
829    /// via a GET request to resume receiving messages.
830    ///
831    /// # Arguments
832    ///
833    /// * `http_request_id` - The HTTP request ID of the stream to close
834    /// * `retry_interval` - Optional retry interval. If provided, a priming event is sent
835    pub async fn close_sse_stream(
836        &self,
837        http_request_id: HttpRequestId,
838        retry_interval: Option<Duration>,
839    ) -> Result<(), SessionError> {
840        let (tx, rx) = tokio::sync::oneshot::channel();
841        self.event_tx
842            .send(SessionEvent::CloseSseStream {
843                http_request_id: Some(http_request_id),
844                retry_interval,
845                responder: tx,
846            })
847            .await
848            .map_err(|_| SessionError::SessionServiceTerminated)?;
849        rx.await
850            .map_err(|_| SessionError::SessionServiceTerminated)?
851    }
852
853    /// Close the standalone SSE stream.
854    ///
855    /// This closes the standalone SSE connection (established via GET request),
856    /// but keeps the session and message cache active. Clients can reconnect using
857    /// the `Last-Event-ID` header via a GET request to resume receiving messages.
858    ///
859    /// # Arguments
860    ///
861    /// * `retry_interval` - Optional retry interval. If provided, a priming event is sent
862    pub async fn close_standalone_sse_stream(
863        &self,
864        retry_interval: Option<Duration>,
865    ) -> Result<(), SessionError> {
866        let (tx, rx) = tokio::sync::oneshot::channel();
867        self.event_tx
868            .send(SessionEvent::CloseSseStream {
869                http_request_id: None,
870                retry_interval,
871                responder: tx,
872            })
873            .await
874            .map_err(|_| SessionError::SessionServiceTerminated)?;
875        rx.await
876            .map_err(|_| SessionError::SessionServiceTerminated)?
877    }
878}
879
880pub type SessionTransport = WorkerTransport<LocalSessionWorker>;
881
882#[allow(clippy::large_enum_variant)]
883#[derive(Debug, Error)]
884#[non_exhaustive]
885pub enum LocalSessionWorkerError {
886    #[error("transport terminated")]
887    TransportTerminated,
888    #[error("unexpected message: {0:?}")]
889    UnexpectedEvent(SessionEvent),
890    #[error("fail to send initialize request {0}")]
891    FailToSendInitializeRequest(SessionError),
892    #[error("fail to handle message: {0}")]
893    FailToHandleMessage(SessionError),
894    #[error("keep alive timeout after {}ms", _0.as_millis())]
895    KeepAliveTimeout(Duration),
896    #[error("Transport closed")]
897    TransportClosed,
898    #[error("Tokio join error {0}")]
899    TokioJoinError(#[from] tokio::task::JoinError),
900}
901impl Worker for LocalSessionWorker {
902    type Error = LocalSessionWorkerError;
903    type Role = RoleServer;
904    fn err_closed() -> Self::Error {
905        LocalSessionWorkerError::TransportClosed
906    }
907    fn err_join(e: tokio::task::JoinError) -> Self::Error {
908        LocalSessionWorkerError::TokioJoinError(e)
909    }
910    fn config(&self) -> crate::transport::worker::WorkerConfig {
911        crate::transport::worker::WorkerConfig {
912            name: Some(format!("streamable-http-session-{}", self.id)),
913            channel_buffer_capacity: self.session_config.channel_capacity,
914        }
915    }
916    #[instrument(name = "streamable_http_session", skip_all, fields(id = self.id.as_ref()))]
917    async fn run(
918        mut self,
919        mut context: WorkerContext<Self>,
920    ) -> Result<(), WorkerQuitReason<Self::Error>> {
921        enum InnerEvent {
922            FromHttpService(SessionEvent),
923            FromHandler(WorkerSendRequest<LocalSessionWorker>),
924        }
925        // waiting for initialize request
926        let evt = self.event_rx.recv().await.ok_or_else(|| {
927            WorkerQuitReason::fatal(
928                LocalSessionWorkerError::TransportTerminated,
929                "get initialize request",
930            )
931        })?;
932        let SessionEvent::InitializeRequest { request, responder } = evt else {
933            return Err(WorkerQuitReason::fatal(
934                LocalSessionWorkerError::UnexpectedEvent(evt),
935                "get initialize request",
936            ));
937        };
938        context.send_to_handler(request).await?;
939        let send_initialize_response = context.recv_from_handler().await?;
940        responder
941            .send(Ok(send_initialize_response.message))
942            .map_err(|_| {
943                WorkerQuitReason::fatal(
944                    LocalSessionWorkerError::FailToSendInitializeRequest(
945                        SessionError::SessionServiceTerminated,
946                    ),
947                    "send initialize response",
948                )
949            })?;
950        send_initialize_response
951            .responder
952            .send(Ok(()))
953            .map_err(|_| WorkerQuitReason::HandlerTerminated)?;
954        let ct = context.cancellation_token.clone();
955        let keep_alive = self.session_config.keep_alive.unwrap_or(Duration::MAX);
956        loop {
957            let keep_alive_timeout = tokio::time::sleep(keep_alive);
958            let event = tokio::select! {
959                event = self.event_rx.recv() => {
960                    if let Some(event) = event {
961                        InnerEvent::FromHttpService(event)
962                    } else {
963                        return Err(WorkerQuitReason::fatal(LocalSessionWorkerError::TransportTerminated, "waiting next session event"))
964                    }
965                },
966                from_handler = context.recv_from_handler() => {
967                    InnerEvent::FromHandler(from_handler?)
968                }
969                _ = ct.cancelled() => {
970                    return Err(WorkerQuitReason::Cancelled)
971                }
972                _ = keep_alive_timeout => {
973                    return Err(WorkerQuitReason::fatal(LocalSessionWorkerError::KeepAliveTimeout(keep_alive), "poll next session event"))
974                }
975            };
976            match event {
977                InnerEvent::FromHandler(WorkerSendRequest { message, responder }) => {
978                    // catch response
979                    let to_unregister = match &message {
980                        crate::model::JsonRpcMessage::Response(json_rpc_response) => {
981                            let request_id = json_rpc_response.id.clone();
982                            Some(ResourceKey::McpRequestId(request_id))
983                        }
984                        crate::model::JsonRpcMessage::Error(json_rpc_error) => {
985                            let request_id = json_rpc_error.id.clone();
986                            Some(ResourceKey::McpRequestId(request_id))
987                        }
988                        _ => {
989                            None
990                            // no need to unregister resource
991                        }
992                    };
993                    let handle_result = self
994                        .handle_server_message(message)
995                        .await
996                        .map_err(LocalSessionWorkerError::FailToHandleMessage);
997                    let _ = responder.send(handle_result).inspect_err(|error| {
998                        tracing::warn!(?error, "failed to send message to http service handler");
999                    });
1000                    if let Some(to_unregister) = to_unregister {
1001                        self.unregister_resource(&to_unregister);
1002                    }
1003                }
1004                InnerEvent::FromHttpService(SessionEvent::ClientMessage {
1005                    message: json_rpc_message,
1006                    http_request_id,
1007                }) => {
1008                    match &json_rpc_message {
1009                        crate::model::JsonRpcMessage::Request(request) => {
1010                            if let Some(http_request_id) = http_request_id {
1011                                self.register_request(request, http_request_id)
1012                            }
1013                        }
1014                        crate::model::JsonRpcMessage::Notification(notification) => {
1015                            self.catch_cancellation_notification(notification)
1016                        }
1017                        _ => {}
1018                    }
1019                    context.send_to_handler(json_rpc_message).await?;
1020                }
1021                InnerEvent::FromHttpService(SessionEvent::EstablishRequestWiseChannel {
1022                    responder,
1023                }) => {
1024                    let handle_result = self.establish_request_wise_channel().await;
1025                    let _ = responder.send(handle_result);
1026                }
1027                InnerEvent::FromHttpService(SessionEvent::CloseRequestWiseChannel {
1028                    id,
1029                    responder,
1030                }) => {
1031                    let _handle_result = self.tx_router.remove(&id);
1032                    let _ = responder.send(Ok(()));
1033                }
1034                InnerEvent::FromHttpService(SessionEvent::Resume {
1035                    last_event_id,
1036                    responder,
1037                }) => {
1038                    let handle_result = self.resume(last_event_id).await;
1039                    let _ = responder.send(handle_result);
1040                }
1041                InnerEvent::FromHttpService(SessionEvent::Close) => {
1042                    return Err(WorkerQuitReason::TransportClosed);
1043                }
1044                InnerEvent::FromHttpService(SessionEvent::CloseSseStream {
1045                    http_request_id,
1046                    retry_interval,
1047                    responder,
1048                }) => {
1049                    let handle_result =
1050                        self.close_sse_stream(http_request_id, retry_interval).await;
1051                    let _ = responder.send(handle_result);
1052                }
1053                _ => {
1054                    // ignore
1055                }
1056            }
1057        }
1058    }
1059}
1060
1061#[derive(Debug, Clone)]
1062pub struct SessionConfig {
1063    /// the capacity of the channel for the session. Default is 16.
1064    pub channel_capacity: usize,
1065    /// if set, the session will be closed after this duration of inactivity.
1066    pub keep_alive: Option<Duration>,
1067}
1068
1069impl SessionConfig {
1070    pub const DEFAULT_CHANNEL_CAPACITY: usize = 16;
1071}
1072
1073impl Default for SessionConfig {
1074    fn default() -> Self {
1075        Self {
1076            channel_capacity: Self::DEFAULT_CHANNEL_CAPACITY,
1077            keep_alive: None,
1078        }
1079    }
1080}
1081
1082/// Create a new session with the given id and configuration.
1083///
1084/// This function will return a pair of [`LocalSessionHandle`] and [`LocalSessionWorker`].
1085///
1086/// You can run the [`LocalSessionWorker`] as a transport for mcp server. And use the [`LocalSessionHandle`] operate the session.
1087pub fn create_local_session(
1088    id: impl Into<SessionId>,
1089    config: SessionConfig,
1090) -> (LocalSessionHandle, LocalSessionWorker) {
1091    let id = id.into();
1092    let (event_tx, event_rx) = tokio::sync::mpsc::channel(config.channel_capacity);
1093    let (common_tx, _) = tokio::sync::mpsc::channel(config.channel_capacity);
1094    let common = CachedTx::new_common(common_tx);
1095    tracing::info!(session_id = ?id, "create new session");
1096    let handle = LocalSessionHandle {
1097        event_tx,
1098        id: id.clone(),
1099    };
1100    let session_worker = LocalSessionWorker {
1101        next_http_request_id: 0,
1102        id,
1103        tx_router: HashMap::new(),
1104        resource_router: HashMap::new(),
1105        common,
1106        shadow_txs: Vec::new(),
1107        event_rx,
1108        session_config: config.clone(),
1109    };
1110    (handle, session_worker)
1111}