1use std::{
2 collections::{HashMap, HashSet, VecDeque},
3 num::ParseIntError,
4 sync::Arc,
5 time::Duration,
6};
7
8use futures::Stream;
9use thiserror::Error;
10use tokio::sync::{
11 mpsc::{Receiver, Sender},
12 oneshot,
13};
14use tokio_stream::wrappers::ReceiverStream;
15use tracing::instrument;
16
17use crate::{
18 RoleServer,
19 model::{
20 CancelledNotificationParam, ClientJsonRpcMessage, ClientNotification, ClientRequest,
21 JsonRpcNotification, JsonRpcRequest, Notification, ProgressNotificationParam,
22 ProgressToken, RequestId, ServerJsonRpcMessage, ServerNotification,
23 },
24 transport::{
25 WorkerTransport,
26 common::server_side_http::{SessionId, session_id},
27 worker::{Worker, WorkerContext, WorkerQuitReason, WorkerSendRequest},
28 },
29};
30
31#[derive(Debug, Default)]
32pub struct LocalSessionManager {
33 pub sessions: tokio::sync::RwLock<HashMap<SessionId, LocalSessionHandle>>,
34 pub session_config: SessionConfig,
35}
36
37#[derive(Debug, Error)]
38pub enum LocalSessionManagerError {
39 #[error("Session not found: {0}")]
40 SessionNotFound(SessionId),
41 #[error("Session error: {0}")]
42 SessionError(#[from] SessionError),
43 #[error("Invalid event id: {0}")]
44 InvalidEventId(#[from] EventIdParseError),
45}
46impl SessionManager for LocalSessionManager {
47 type Error = LocalSessionManagerError;
48 type Transport = WorkerTransport<LocalSessionWorker>;
49 async fn create_session(&self) -> Result<(SessionId, Self::Transport), Self::Error> {
50 let id = session_id();
51 let (handle, worker) = create_local_session(id.clone(), self.session_config.clone());
52 self.sessions.write().await.insert(id.clone(), handle);
53 Ok((id, WorkerTransport::spawn(worker)))
54 }
55 async fn initialize_session(
56 &self,
57 id: &SessionId,
58 message: ClientJsonRpcMessage,
59 ) -> Result<ServerJsonRpcMessage, Self::Error> {
60 let sessions = self.sessions.read().await;
61 let handle = sessions
62 .get(id)
63 .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
64 let response = handle.initialize(message).await?;
65 Ok(response)
66 }
67 async fn close_session(&self, id: &SessionId) -> Result<(), Self::Error> {
68 let mut sessions = self.sessions.write().await;
69 if let Some(handle) = sessions.remove(id) {
70 handle.close().await?;
71 }
72 Ok(())
73 }
74 async fn has_session(&self, id: &SessionId) -> Result<bool, Self::Error> {
75 let sessions = self.sessions.read().await;
76 Ok(sessions.contains_key(id))
77 }
78 async fn create_stream(
79 &self,
80 id: &SessionId,
81 message: ClientJsonRpcMessage,
82 ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
83 let sessions = self.sessions.read().await;
84 let handle = sessions
85 .get(id)
86 .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
87 let receiver = handle.establish_request_wise_channel().await?;
88 handle
89 .push_message(message, receiver.http_request_id)
90 .await?;
91 Ok(ReceiverStream::new(receiver.inner))
92 }
93
94 async fn create_standalone_stream(
95 &self,
96 id: &SessionId,
97 ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
98 let sessions = self.sessions.read().await;
99 let handle = sessions
100 .get(id)
101 .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
102 let receiver = handle.establish_common_channel().await?;
103 Ok(ReceiverStream::new(receiver.inner))
104 }
105
106 async fn resume(
107 &self,
108 id: &SessionId,
109 last_event_id: String,
110 ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
111 let sessions = self.sessions.read().await;
112 let handle = sessions
113 .get(id)
114 .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
115 let receiver = handle.resume(last_event_id.parse()?).await?;
116 Ok(ReceiverStream::new(receiver.inner))
117 }
118
119 async fn accept_message(
120 &self,
121 id: &SessionId,
122 message: ClientJsonRpcMessage,
123 ) -> Result<(), Self::Error> {
124 let sessions = self.sessions.read().await;
125 let handle = sessions
126 .get(id)
127 .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
128 handle.push_message(message, None).await?;
129 Ok(())
130 }
131}
132
133#[derive(Debug, Clone, PartialEq, Eq, Hash)]
135pub struct EventId {
136 http_request_id: Option<HttpRequestId>,
137 index: usize,
138}
139
140impl std::fmt::Display for EventId {
141 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142 write!(f, "{}", self.index)?;
143 match &self.http_request_id {
144 Some(http_request_id) => write!(f, "/{http_request_id}"),
145 None => write!(f, ""),
146 }
147 }
148}
149
150#[derive(Debug, Clone, Error)]
151pub enum EventIdParseError {
152 #[error("Invalid index: {0}")]
153 InvalidIndex(ParseIntError),
154 #[error("Invalid numeric request id: {0}")]
155 InvalidNumericRequestId(ParseIntError),
156 #[error("Missing request id type")]
157 InvalidRequestIdType,
158 #[error("Missing request id")]
159 MissingRequestId,
160}
161
162impl std::str::FromStr for EventId {
163 type Err = EventIdParseError;
164 fn from_str(s: &str) -> Result<Self, Self::Err> {
165 if let Some((index, request_id)) = s.split_once("/") {
166 let index = usize::from_str(index).map_err(EventIdParseError::InvalidIndex)?;
167 let request_id = u64::from_str(request_id).map_err(EventIdParseError::InvalidIndex)?;
168 Ok(EventId {
169 http_request_id: Some(request_id),
170 index,
171 })
172 } else {
173 let index = usize::from_str(s).map_err(EventIdParseError::InvalidIndex)?;
174 Ok(EventId {
175 http_request_id: None,
176 index,
177 })
178 }
179 }
180}
181
182use super::{ServerSseMessage, SessionManager};
183
184struct CachedTx {
185 tx: Sender<ServerSseMessage>,
186 cache: VecDeque<ServerSseMessage>,
187 http_request_id: Option<HttpRequestId>,
188 capacity: usize,
189}
190
191impl CachedTx {
192 fn new(tx: Sender<ServerSseMessage>, http_request_id: Option<HttpRequestId>) -> Self {
193 Self {
194 cache: VecDeque::with_capacity(tx.capacity()),
195 capacity: tx.capacity(),
196 tx,
197 http_request_id,
198 }
199 }
200 fn new_common(tx: Sender<ServerSseMessage>) -> Self {
201 Self::new(tx, None)
202 }
203
204 fn next_event_id(&self) -> EventId {
205 let index = self.cache.back().map_or(0, |m| {
206 m.event_id
207 .as_deref()
208 .unwrap_or_default()
209 .parse::<EventId>()
210 .expect("valid event id")
211 .index
212 + 1
213 });
214 EventId {
215 http_request_id: self.http_request_id,
216 index,
217 }
218 }
219
220 async fn send(&mut self, message: ServerJsonRpcMessage) {
221 let event_id = self.next_event_id();
222 let message = ServerSseMessage {
223 event_id: Some(event_id.to_string()),
224 message: Some(Arc::new(message)),
225 retry: None,
226 };
227 self.cache_and_send(message).await;
228 }
229
230 async fn send_priming(&mut self, retry: Duration) {
231 let event_id = self.next_event_id();
232 let message = ServerSseMessage {
233 event_id: Some(event_id.to_string()),
234 message: None,
235 retry: Some(retry),
236 };
237 self.cache_and_send(message).await;
238 }
239
240 async fn cache_and_send(&mut self, message: ServerSseMessage) {
241 if self.cache.len() >= self.capacity {
242 self.cache.pop_front();
243 self.cache.push_back(message.clone());
244 } else {
245 self.cache.push_back(message.clone());
246 }
247 let _ = self.tx.send(message).await.inspect_err(|e| {
248 let event_id = &e.0.event_id;
249 tracing::trace!(?event_id, "trying to send message in a closed session")
250 });
251 }
252
253 async fn sync(&mut self, index: usize) -> Result<(), SessionError> {
254 let Some(front) = self.cache.front() else {
255 return Ok(());
256 };
257 let front_event_id = front
258 .event_id
259 .as_deref()
260 .unwrap_or_default()
261 .parse::<EventId>()?;
262 let sync_index = index.saturating_sub(front_event_id.index);
263 if sync_index > self.cache.len() {
264 return Err(SessionError::InvalidEventId);
266 }
267 for message in self.cache.iter().skip(sync_index) {
268 let send_result = self.tx.send(message.clone()).await;
269 if send_result.is_err() {
270 let event_id: EventId = message.event_id.as_deref().unwrap_or_default().parse()?;
271 return Err(SessionError::ChannelClosed(Some(event_id.index as u64)));
272 }
273 }
274 Ok(())
275 }
276}
277
278struct HttpRequestWise {
279 resources: HashSet<ResourceKey>,
280 tx: CachedTx,
281}
282
283type HttpRequestId = u64;
284#[derive(Debug, Clone, Hash, PartialEq, Eq)]
285enum ResourceKey {
286 McpRequestId(RequestId),
287 ProgressToken(ProgressToken),
288}
289
290pub struct LocalSessionWorker {
291 id: SessionId,
292 next_http_request_id: HttpRequestId,
293 tx_router: HashMap<HttpRequestId, HttpRequestWise>,
294 resource_router: HashMap<ResourceKey, HttpRequestId>,
295 common: CachedTx,
296 shadow_txs: Vec<Sender<ServerSseMessage>>,
302 event_rx: Receiver<SessionEvent>,
303 session_config: SessionConfig,
304}
305
306impl LocalSessionWorker {
307 pub fn id(&self) -> &SessionId {
308 &self.id
309 }
310}
311
312#[derive(Debug, Error)]
313pub enum SessionError {
314 #[error("Invalid request id: {0}")]
315 DuplicatedRequestId(HttpRequestId),
316 #[error("Channel closed: {0:?}")]
317 ChannelClosed(Option<HttpRequestId>),
318 #[error("Cannot parse event id: {0}")]
319 EventIdParseError(#[from] EventIdParseError),
320 #[error("Session service terminated")]
321 SessionServiceTerminated,
322 #[error("Invalid event id")]
323 InvalidEventId,
324 #[error("IO error: {0}")]
325 Io(#[from] std::io::Error),
326}
327
328impl From<SessionError> for std::io::Error {
329 fn from(value: SessionError) -> Self {
330 match value {
331 SessionError::Io(io) => io,
332 _ => std::io::Error::other(format!("Session error: {value}")),
333 }
334 }
335}
336
337enum OutboundChannel {
338 RequestWise { id: HttpRequestId, close: bool },
339 Common,
340}
341#[derive(Debug)]
342pub struct StreamableHttpMessageReceiver {
343 pub http_request_id: Option<HttpRequestId>,
344 pub inner: Receiver<ServerSseMessage>,
345}
346
347impl LocalSessionWorker {
348 fn unregister_resource(&mut self, resource: &ResourceKey) {
349 if let Some(http_request_id) = self.resource_router.remove(resource) {
350 tracing::trace!(?resource, http_request_id, "unregister resource");
351 if let Some(channel) = self.tx_router.get_mut(&http_request_id) {
352 if channel.resources.is_empty() || matches!(resource, ResourceKey::McpRequestId(_))
355 {
356 tracing::debug!(http_request_id, "close http request wise channel");
357 if let Some(channel) = self.tx_router.remove(&http_request_id) {
358 for resource in channel.resources {
359 self.resource_router.remove(&resource);
360 }
361 }
362 }
363 } else {
364 tracing::warn!(http_request_id, "http request wise channel not found");
365 }
366 }
367 }
368 fn register_resource(&mut self, resource: ResourceKey, http_request_id: HttpRequestId) {
369 tracing::trace!(?resource, http_request_id, "register resource");
370 if let Some(channel) = self.tx_router.get_mut(&http_request_id) {
371 channel.resources.insert(resource.clone());
372 self.resource_router.insert(resource, http_request_id);
373 }
374 }
375 fn register_request(
376 &mut self,
377 request: &JsonRpcRequest<ClientRequest>,
378 http_request_id: HttpRequestId,
379 ) {
380 use crate::model::GetMeta;
381 self.register_resource(
382 ResourceKey::McpRequestId(request.id.clone()),
383 http_request_id,
384 );
385 if let Some(progress_token) = request.request.get_meta().get_progress_token() {
386 self.register_resource(
387 ResourceKey::ProgressToken(progress_token.clone()),
388 http_request_id,
389 );
390 }
391 }
392 fn catch_cancellation_notification(
393 &mut self,
394 notification: &JsonRpcNotification<ClientNotification>,
395 ) {
396 if let ClientNotification::CancelledNotification(n) = ¬ification.notification {
397 let request_id = n.params.request_id.clone();
398 let resource = ResourceKey::McpRequestId(request_id);
399 self.unregister_resource(&resource);
400 }
401 }
402 fn next_http_request_id(&mut self) -> HttpRequestId {
403 let id = self.next_http_request_id;
404 self.next_http_request_id = self.next_http_request_id.wrapping_add(1);
405 id
406 }
407 async fn establish_request_wise_channel(
408 &mut self,
409 ) -> Result<StreamableHttpMessageReceiver, SessionError> {
410 let http_request_id = self.next_http_request_id();
411 let (tx, rx) = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
412 self.tx_router.insert(
413 http_request_id,
414 HttpRequestWise {
415 resources: Default::default(),
416 tx: CachedTx::new(tx, Some(http_request_id)),
417 },
418 );
419 tracing::debug!(http_request_id, "establish new request wise channel");
420 Ok(StreamableHttpMessageReceiver {
421 http_request_id: Some(http_request_id),
422 inner: rx,
423 })
424 }
425 fn resolve_outbound_channel(&self, message: &ServerJsonRpcMessage) -> OutboundChannel {
426 match &message {
427 ServerJsonRpcMessage::Request(_) => OutboundChannel::Common,
428 ServerJsonRpcMessage::Notification(JsonRpcNotification {
429 notification:
430 ServerNotification::ProgressNotification(Notification {
431 params: ProgressNotificationParam { progress_token, .. },
432 ..
433 }),
434 ..
435 }) => {
436 let id = self
437 .resource_router
438 .get(&ResourceKey::ProgressToken(progress_token.clone()));
439
440 if let Some(id) = id {
441 OutboundChannel::RequestWise {
442 id: *id,
443 close: false,
444 }
445 } else {
446 OutboundChannel::Common
447 }
448 }
449 ServerJsonRpcMessage::Notification(JsonRpcNotification {
450 notification:
451 ServerNotification::CancelledNotification(Notification {
452 params: CancelledNotificationParam { request_id, .. },
453 ..
454 }),
455 ..
456 }) => {
457 if let Some(id) = self
458 .resource_router
459 .get(&ResourceKey::McpRequestId(request_id.clone()))
460 {
461 OutboundChannel::RequestWise {
462 id: *id,
463 close: false,
464 }
465 } else {
466 OutboundChannel::Common
467 }
468 }
469 ServerJsonRpcMessage::Notification(_) => OutboundChannel::Common,
470 ServerJsonRpcMessage::Response(json_rpc_response) => {
471 if let Some(id) = self
472 .resource_router
473 .get(&ResourceKey::McpRequestId(json_rpc_response.id.clone()))
474 {
475 OutboundChannel::RequestWise {
476 id: *id,
477 close: false,
478 }
479 } else {
480 OutboundChannel::Common
481 }
482 }
483 ServerJsonRpcMessage::Error(json_rpc_error) => {
484 if let Some(id) = self
485 .resource_router
486 .get(&ResourceKey::McpRequestId(json_rpc_error.id.clone()))
487 {
488 OutboundChannel::RequestWise {
489 id: *id,
490 close: false,
491 }
492 } else {
493 OutboundChannel::Common
494 }
495 }
496 }
497 }
498 async fn handle_server_message(
499 &mut self,
500 message: ServerJsonRpcMessage,
501 ) -> Result<(), SessionError> {
502 let outbound_channel = self.resolve_outbound_channel(&message);
503 match outbound_channel {
504 OutboundChannel::RequestWise { id, close } => {
505 if let Some(request_wise) = self.tx_router.get_mut(&id) {
506 request_wise.tx.send(message).await;
507 if close {
508 self.tx_router.remove(&id);
509 }
510 } else {
511 return Err(SessionError::ChannelClosed(Some(id)));
512 }
513 }
514 OutboundChannel::Common => self.common.send(message).await,
515 }
516 Ok(())
517 }
518 async fn resume(
519 &mut self,
520 last_event_id: EventId,
521 ) -> Result<StreamableHttpMessageReceiver, SessionError> {
522 self.shadow_txs.retain(|tx| !tx.is_closed());
524
525 match last_event_id.http_request_id {
526 Some(http_request_id) => {
527 if let Some(request_wise) = self.tx_router.get_mut(&http_request_id) {
528 let channel = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
530 let (tx, rx) = channel;
531 request_wise.tx.tx = tx;
532 let index = last_event_id.index;
533 request_wise.tx.sync(index).await?;
535 Ok(StreamableHttpMessageReceiver {
536 http_request_id: Some(http_request_id),
537 inner: rx,
538 })
539 } else {
540 tracing::debug!(
544 http_request_id,
545 "Request-wise channel completed, falling back to common channel"
546 );
547 self.resume_or_shadow_common(last_event_id.index).await
548 }
549 }
550 None => self.resume_or_shadow_common(last_event_id.index).await,
551 }
552 }
553
554 async fn resume_or_shadow_common(
567 &mut self,
568 last_event_index: usize,
569 ) -> Result<StreamableHttpMessageReceiver, SessionError> {
570 let is_replacing_dead_primary = self.common.tx.is_closed();
571 let capacity = if is_replacing_dead_primary {
572 self.session_config.channel_capacity
573 } else {
574 1 };
576 let (tx, rx) = tokio::sync::mpsc::channel(capacity);
577 if is_replacing_dead_primary {
578 tracing::debug!("Replacing dead common channel with new primary");
580 self.common.tx = tx;
581 self.common.sync(last_event_index).await?;
584 } else {
585 const MAX_SHADOW_STREAMS: usize = 32;
590
591 if self.shadow_txs.len() >= MAX_SHADOW_STREAMS {
592 tracing::warn!(
593 shadow_count = self.shadow_txs.len(),
594 "Shadow stream limit reached, dropping oldest"
595 );
596 self.shadow_txs.remove(0);
597 }
598 tracing::debug!(
599 shadow_count = self.shadow_txs.len(),
600 "Common channel active, creating shadow stream"
601 );
602 self.shadow_txs.push(tx);
603 }
604 Ok(StreamableHttpMessageReceiver {
605 http_request_id: None,
606 inner: rx,
607 })
608 }
609
610 async fn close_sse_stream(
611 &mut self,
612 http_request_id: Option<HttpRequestId>,
613 retry_interval: Option<Duration>,
614 ) -> Result<(), SessionError> {
615 match http_request_id {
616 Some(id) => {
618 let request_wise = self
619 .tx_router
620 .get_mut(&id)
621 .ok_or(SessionError::ChannelClosed(Some(id)))?;
622
623 if let Some(interval) = retry_interval {
625 request_wise.tx.send_priming(interval).await;
626 }
627
628 let (tx, _rx) = tokio::sync::mpsc::channel(1);
630 request_wise.tx.tx = tx;
631
632 tracing::debug!(
633 http_request_id = id,
634 "closed SSE stream for server-initiated disconnection"
635 );
636 Ok(())
637 }
638 None => {
640 if let Some(interval) = retry_interval {
642 self.common.send_priming(interval).await;
643 }
644
645 let (tx, _rx) = tokio::sync::mpsc::channel(1);
647 self.common.tx = tx;
648
649 self.shadow_txs.clear();
651
652 tracing::debug!("closed standalone SSE stream for server-initiated disconnection");
653 Ok(())
654 }
655 }
656 }
657}
658
659#[derive(Debug)]
660pub enum SessionEvent {
661 ClientMessage {
662 message: ClientJsonRpcMessage,
663 http_request_id: Option<HttpRequestId>,
664 },
665 EstablishRequestWiseChannel {
666 responder: oneshot::Sender<Result<StreamableHttpMessageReceiver, SessionError>>,
667 },
668 CloseRequestWiseChannel {
669 id: HttpRequestId,
670 responder: oneshot::Sender<Result<(), SessionError>>,
671 },
672 Resume {
673 last_event_id: EventId,
674 responder: oneshot::Sender<Result<StreamableHttpMessageReceiver, SessionError>>,
675 },
676 InitializeRequest {
677 request: ClientJsonRpcMessage,
678 responder: oneshot::Sender<Result<ServerJsonRpcMessage, SessionError>>,
679 },
680 Close,
681 CloseSseStream {
682 http_request_id: Option<HttpRequestId>,
684 retry_interval: Option<Duration>,
686 responder: oneshot::Sender<Result<(), SessionError>>,
687 },
688}
689
690#[derive(Debug, Clone)]
691#[non_exhaustive]
692pub enum SessionQuitReason {
693 ServiceTerminated,
694 ClientTerminated,
695 ExpectInitializeRequest,
696 ExpectInitializeResponse,
697 Cancelled,
698}
699
700#[derive(Debug, Clone)]
701pub struct LocalSessionHandle {
702 id: SessionId,
703 event_tx: Sender<SessionEvent>,
705}
706
707impl LocalSessionHandle {
708 pub fn id(&self) -> &SessionId {
710 &self.id
711 }
712
713 pub async fn close(&self) -> Result<(), SessionError> {
715 self.event_tx
716 .send(SessionEvent::Close)
717 .await
718 .map_err(|_| SessionError::SessionServiceTerminated)?;
719 Ok(())
720 }
721
722 pub async fn push_message(
724 &self,
725 message: ClientJsonRpcMessage,
726 http_request_id: Option<HttpRequestId>,
727 ) -> Result<(), SessionError> {
728 self.event_tx
729 .send(SessionEvent::ClientMessage {
730 message,
731 http_request_id,
732 })
733 .await
734 .map_err(|_| SessionError::SessionServiceTerminated)?;
735 Ok(())
736 }
737
738 pub async fn establish_request_wise_channel(
742 &self,
743 ) -> Result<StreamableHttpMessageReceiver, SessionError> {
744 let (tx, rx) = tokio::sync::oneshot::channel();
745 self.event_tx
746 .send(SessionEvent::EstablishRequestWiseChannel { responder: tx })
747 .await
748 .map_err(|_| SessionError::SessionServiceTerminated)?;
749 rx.await
750 .map_err(|_| SessionError::SessionServiceTerminated)?
751 }
752
753 pub async fn close_request_wise_channel(
755 &self,
756 request_id: HttpRequestId,
757 ) -> Result<(), SessionError> {
758 let (tx, rx) = tokio::sync::oneshot::channel();
759 self.event_tx
760 .send(SessionEvent::CloseRequestWiseChannel {
761 id: request_id,
762 responder: tx,
763 })
764 .await
765 .map_err(|_| SessionError::SessionServiceTerminated)?;
766 rx.await
767 .map_err(|_| SessionError::SessionServiceTerminated)?
768 }
769
770 pub async fn establish_common_channel(
772 &self,
773 ) -> Result<StreamableHttpMessageReceiver, SessionError> {
774 let (tx, rx) = tokio::sync::oneshot::channel();
775 self.event_tx
776 .send(SessionEvent::Resume {
777 last_event_id: EventId {
778 http_request_id: None,
779 index: 0,
780 },
781 responder: tx,
782 })
783 .await
784 .map_err(|_| SessionError::SessionServiceTerminated)?;
785 rx.await
786 .map_err(|_| SessionError::SessionServiceTerminated)?
787 }
788
789 pub async fn resume(
791 &self,
792 last_event_id: EventId,
793 ) -> Result<StreamableHttpMessageReceiver, SessionError> {
794 let (tx, rx) = tokio::sync::oneshot::channel();
795 self.event_tx
796 .send(SessionEvent::Resume {
797 last_event_id,
798 responder: tx,
799 })
800 .await
801 .map_err(|_| SessionError::SessionServiceTerminated)?;
802 rx.await
803 .map_err(|_| SessionError::SessionServiceTerminated)?
804 }
805
806 pub async fn initialize(
810 &self,
811 request: ClientJsonRpcMessage,
812 ) -> Result<ServerJsonRpcMessage, SessionError> {
813 let (tx, rx) = tokio::sync::oneshot::channel();
814 self.event_tx
815 .send(SessionEvent::InitializeRequest {
816 request,
817 responder: tx,
818 })
819 .await
820 .map_err(|_| SessionError::SessionServiceTerminated)?;
821 rx.await
822 .map_err(|_| SessionError::SessionServiceTerminated)?
823 }
824
825 pub async fn close_sse_stream(
836 &self,
837 http_request_id: HttpRequestId,
838 retry_interval: Option<Duration>,
839 ) -> Result<(), SessionError> {
840 let (tx, rx) = tokio::sync::oneshot::channel();
841 self.event_tx
842 .send(SessionEvent::CloseSseStream {
843 http_request_id: Some(http_request_id),
844 retry_interval,
845 responder: tx,
846 })
847 .await
848 .map_err(|_| SessionError::SessionServiceTerminated)?;
849 rx.await
850 .map_err(|_| SessionError::SessionServiceTerminated)?
851 }
852
853 pub async fn close_standalone_sse_stream(
863 &self,
864 retry_interval: Option<Duration>,
865 ) -> Result<(), SessionError> {
866 let (tx, rx) = tokio::sync::oneshot::channel();
867 self.event_tx
868 .send(SessionEvent::CloseSseStream {
869 http_request_id: None,
870 retry_interval,
871 responder: tx,
872 })
873 .await
874 .map_err(|_| SessionError::SessionServiceTerminated)?;
875 rx.await
876 .map_err(|_| SessionError::SessionServiceTerminated)?
877 }
878}
879
880pub type SessionTransport = WorkerTransport<LocalSessionWorker>;
881
882#[allow(clippy::large_enum_variant)]
883#[derive(Debug, Error)]
884#[non_exhaustive]
885pub enum LocalSessionWorkerError {
886 #[error("transport terminated")]
887 TransportTerminated,
888 #[error("unexpected message: {0:?}")]
889 UnexpectedEvent(SessionEvent),
890 #[error("fail to send initialize request {0}")]
891 FailToSendInitializeRequest(SessionError),
892 #[error("fail to handle message: {0}")]
893 FailToHandleMessage(SessionError),
894 #[error("keep alive timeout after {}ms", _0.as_millis())]
895 KeepAliveTimeout(Duration),
896 #[error("Transport closed")]
897 TransportClosed,
898 #[error("Tokio join error {0}")]
899 TokioJoinError(#[from] tokio::task::JoinError),
900}
901impl Worker for LocalSessionWorker {
902 type Error = LocalSessionWorkerError;
903 type Role = RoleServer;
904 fn err_closed() -> Self::Error {
905 LocalSessionWorkerError::TransportClosed
906 }
907 fn err_join(e: tokio::task::JoinError) -> Self::Error {
908 LocalSessionWorkerError::TokioJoinError(e)
909 }
910 fn config(&self) -> crate::transport::worker::WorkerConfig {
911 crate::transport::worker::WorkerConfig {
912 name: Some(format!("streamable-http-session-{}", self.id)),
913 channel_buffer_capacity: self.session_config.channel_capacity,
914 }
915 }
916 #[instrument(name = "streamable_http_session", skip_all, fields(id = self.id.as_ref()))]
917 async fn run(
918 mut self,
919 mut context: WorkerContext<Self>,
920 ) -> Result<(), WorkerQuitReason<Self::Error>> {
921 enum InnerEvent {
922 FromHttpService(SessionEvent),
923 FromHandler(WorkerSendRequest<LocalSessionWorker>),
924 }
925 let evt = self.event_rx.recv().await.ok_or_else(|| {
927 WorkerQuitReason::fatal(
928 LocalSessionWorkerError::TransportTerminated,
929 "get initialize request",
930 )
931 })?;
932 let SessionEvent::InitializeRequest { request, responder } = evt else {
933 return Err(WorkerQuitReason::fatal(
934 LocalSessionWorkerError::UnexpectedEvent(evt),
935 "get initialize request",
936 ));
937 };
938 context.send_to_handler(request).await?;
939 let send_initialize_response = context.recv_from_handler().await?;
940 responder
941 .send(Ok(send_initialize_response.message))
942 .map_err(|_| {
943 WorkerQuitReason::fatal(
944 LocalSessionWorkerError::FailToSendInitializeRequest(
945 SessionError::SessionServiceTerminated,
946 ),
947 "send initialize response",
948 )
949 })?;
950 send_initialize_response
951 .responder
952 .send(Ok(()))
953 .map_err(|_| WorkerQuitReason::HandlerTerminated)?;
954 let ct = context.cancellation_token.clone();
955 let keep_alive = self.session_config.keep_alive.unwrap_or(Duration::MAX);
956 loop {
957 let keep_alive_timeout = tokio::time::sleep(keep_alive);
958 let event = tokio::select! {
959 event = self.event_rx.recv() => {
960 if let Some(event) = event {
961 InnerEvent::FromHttpService(event)
962 } else {
963 return Err(WorkerQuitReason::fatal(LocalSessionWorkerError::TransportTerminated, "waiting next session event"))
964 }
965 },
966 from_handler = context.recv_from_handler() => {
967 InnerEvent::FromHandler(from_handler?)
968 }
969 _ = ct.cancelled() => {
970 return Err(WorkerQuitReason::Cancelled)
971 }
972 _ = keep_alive_timeout => {
973 return Err(WorkerQuitReason::fatal(LocalSessionWorkerError::KeepAliveTimeout(keep_alive), "poll next session event"))
974 }
975 };
976 match event {
977 InnerEvent::FromHandler(WorkerSendRequest { message, responder }) => {
978 let to_unregister = match &message {
980 crate::model::JsonRpcMessage::Response(json_rpc_response) => {
981 let request_id = json_rpc_response.id.clone();
982 Some(ResourceKey::McpRequestId(request_id))
983 }
984 crate::model::JsonRpcMessage::Error(json_rpc_error) => {
985 let request_id = json_rpc_error.id.clone();
986 Some(ResourceKey::McpRequestId(request_id))
987 }
988 _ => {
989 None
990 }
992 };
993 let handle_result = self
994 .handle_server_message(message)
995 .await
996 .map_err(LocalSessionWorkerError::FailToHandleMessage);
997 let _ = responder.send(handle_result).inspect_err(|error| {
998 tracing::warn!(?error, "failed to send message to http service handler");
999 });
1000 if let Some(to_unregister) = to_unregister {
1001 self.unregister_resource(&to_unregister);
1002 }
1003 }
1004 InnerEvent::FromHttpService(SessionEvent::ClientMessage {
1005 message: json_rpc_message,
1006 http_request_id,
1007 }) => {
1008 match &json_rpc_message {
1009 crate::model::JsonRpcMessage::Request(request) => {
1010 if let Some(http_request_id) = http_request_id {
1011 self.register_request(request, http_request_id)
1012 }
1013 }
1014 crate::model::JsonRpcMessage::Notification(notification) => {
1015 self.catch_cancellation_notification(notification)
1016 }
1017 _ => {}
1018 }
1019 context.send_to_handler(json_rpc_message).await?;
1020 }
1021 InnerEvent::FromHttpService(SessionEvent::EstablishRequestWiseChannel {
1022 responder,
1023 }) => {
1024 let handle_result = self.establish_request_wise_channel().await;
1025 let _ = responder.send(handle_result);
1026 }
1027 InnerEvent::FromHttpService(SessionEvent::CloseRequestWiseChannel {
1028 id,
1029 responder,
1030 }) => {
1031 let _handle_result = self.tx_router.remove(&id);
1032 let _ = responder.send(Ok(()));
1033 }
1034 InnerEvent::FromHttpService(SessionEvent::Resume {
1035 last_event_id,
1036 responder,
1037 }) => {
1038 let handle_result = self.resume(last_event_id).await;
1039 let _ = responder.send(handle_result);
1040 }
1041 InnerEvent::FromHttpService(SessionEvent::Close) => {
1042 return Err(WorkerQuitReason::TransportClosed);
1043 }
1044 InnerEvent::FromHttpService(SessionEvent::CloseSseStream {
1045 http_request_id,
1046 retry_interval,
1047 responder,
1048 }) => {
1049 let handle_result =
1050 self.close_sse_stream(http_request_id, retry_interval).await;
1051 let _ = responder.send(handle_result);
1052 }
1053 _ => {
1054 }
1056 }
1057 }
1058 }
1059}
1060
1061#[derive(Debug, Clone)]
1062pub struct SessionConfig {
1063 pub channel_capacity: usize,
1065 pub keep_alive: Option<Duration>,
1067}
1068
1069impl SessionConfig {
1070 pub const DEFAULT_CHANNEL_CAPACITY: usize = 16;
1071}
1072
1073impl Default for SessionConfig {
1074 fn default() -> Self {
1075 Self {
1076 channel_capacity: Self::DEFAULT_CHANNEL_CAPACITY,
1077 keep_alive: None,
1078 }
1079 }
1080}
1081
1082pub fn create_local_session(
1088 id: impl Into<SessionId>,
1089 config: SessionConfig,
1090) -> (LocalSessionHandle, LocalSessionWorker) {
1091 let id = id.into();
1092 let (event_tx, event_rx) = tokio::sync::mpsc::channel(config.channel_capacity);
1093 let (common_tx, _) = tokio::sync::mpsc::channel(config.channel_capacity);
1094 let common = CachedTx::new_common(common_tx);
1095 tracing::info!(session_id = ?id, "create new session");
1096 let handle = LocalSessionHandle {
1097 event_tx,
1098 id: id.clone(),
1099 };
1100 let session_worker = LocalSessionWorker {
1101 next_http_request_id: 0,
1102 id,
1103 tx_router: HashMap::new(),
1104 resource_router: HashMap::new(),
1105 common,
1106 shadow_txs: Vec::new(),
1107 event_rx,
1108 session_config: config.clone(),
1109 };
1110 (handle, session_worker)
1111}