1use std::{
2 collections::{HashMap, HashSet, VecDeque},
3 num::ParseIntError,
4 time::Duration,
5};
6
7use futures::Stream;
8use thiserror::Error;
9use tokio::sync::{
10 mpsc::{Receiver, Sender},
11 oneshot,
12};
13use tokio_stream::wrappers::ReceiverStream;
14use tracing::instrument;
15
16use crate::{
17 RoleServer,
18 model::{
19 CancelledNotificationParam, ClientJsonRpcMessage, ClientNotification, ClientRequest,
20 JsonRpcNotification, JsonRpcRequest, Notification, ProgressNotificationParam,
21 ProgressToken, RequestId, ServerJsonRpcMessage, ServerNotification,
22 },
23 transport::{
24 WorkerTransport,
25 common::server_side_http::{SessionId, session_id},
26 worker::{Worker, WorkerContext, WorkerQuitReason, WorkerSendRequest},
27 },
28};
29
30#[derive(Debug, Default)]
31#[non_exhaustive]
32pub struct LocalSessionManager {
33 pub sessions: tokio::sync::RwLock<HashMap<SessionId, LocalSessionHandle>>,
34 pub session_config: SessionConfig,
35}
36
37#[derive(Debug, Error)]
38#[non_exhaustive]
39pub enum LocalSessionManagerError {
40 #[error("Session not found: {0}")]
41 SessionNotFound(SessionId),
42 #[error("Session error: {0}")]
43 SessionError(#[from] SessionError),
44 #[error("Invalid event id: {0}")]
45 InvalidEventId(#[from] EventIdParseError),
46}
47impl SessionManager for LocalSessionManager {
48 type Error = LocalSessionManagerError;
49 type Transport = WorkerTransport<LocalSessionWorker>;
50 async fn create_session(&self) -> Result<(SessionId, Self::Transport), Self::Error> {
51 let id = session_id();
52 let (handle, worker) = create_local_session(id.clone(), self.session_config.clone());
53 self.sessions.write().await.insert(id.clone(), handle);
54 Ok((id, WorkerTransport::spawn(worker)))
55 }
56 async fn initialize_session(
57 &self,
58 id: &SessionId,
59 message: ClientJsonRpcMessage,
60 ) -> Result<ServerJsonRpcMessage, Self::Error> {
61 let sessions = self.sessions.read().await;
62 let handle = sessions
63 .get(id)
64 .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
65 let response = handle.initialize(message).await?;
66 Ok(response)
67 }
68 async fn close_session(&self, id: &SessionId) -> Result<(), Self::Error> {
69 let mut sessions = self.sessions.write().await;
70 if let Some(handle) = sessions.remove(id) {
71 handle.close().await?;
72 }
73 Ok(())
74 }
75 async fn has_session(&self, id: &SessionId) -> Result<bool, Self::Error> {
76 let sessions = self.sessions.read().await;
77 Ok(sessions.contains_key(id))
78 }
79 async fn create_stream(
80 &self,
81 id: &SessionId,
82 message: ClientJsonRpcMessage,
83 ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
84 let sessions = self.sessions.read().await;
85 let handle = sessions
86 .get(id)
87 .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
88 let receiver = handle.establish_request_wise_channel().await?;
89 handle
90 .push_message(message, receiver.http_request_id)
91 .await?;
92 Ok(ReceiverStream::new(receiver.inner))
93 }
94
95 async fn create_standalone_stream(
96 &self,
97 id: &SessionId,
98 ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
99 let sessions = self.sessions.read().await;
100 let handle = sessions
101 .get(id)
102 .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
103 let receiver = handle.establish_common_channel().await?;
104 Ok(ReceiverStream::new(receiver.inner))
105 }
106
107 async fn resume(
108 &self,
109 id: &SessionId,
110 last_event_id: String,
111 ) -> Result<impl Stream<Item = ServerSseMessage> + Send + 'static, Self::Error> {
112 let sessions = self.sessions.read().await;
113 let handle = sessions
114 .get(id)
115 .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
116 let receiver = handle.resume(last_event_id.parse()?).await?;
117 Ok(ReceiverStream::new(receiver.inner))
118 }
119
120 async fn accept_message(
121 &self,
122 id: &SessionId,
123 message: ClientJsonRpcMessage,
124 ) -> Result<(), Self::Error> {
125 let sessions = self.sessions.read().await;
126 let handle = sessions
127 .get(id)
128 .ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
129 handle.push_message(message, None).await?;
130 Ok(())
131 }
132}
133
134#[derive(Debug, Clone, PartialEq, Eq, Hash)]
136pub struct EventId {
137 http_request_id: Option<HttpRequestId>,
138 index: usize,
139}
140
141impl std::fmt::Display for EventId {
142 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143 write!(f, "{}", self.index)?;
144 match &self.http_request_id {
145 Some(http_request_id) => write!(f, "/{http_request_id}"),
146 None => write!(f, ""),
147 }
148 }
149}
150
151#[derive(Debug, Clone, Error)]
152#[non_exhaustive]
153pub enum EventIdParseError {
154 #[error("Invalid index: {0}")]
155 InvalidIndex(ParseIntError),
156 #[error("Invalid numeric request id: {0}")]
157 InvalidNumericRequestId(ParseIntError),
158 #[error("Missing request id type")]
159 InvalidRequestIdType,
160 #[error("Missing request id")]
161 MissingRequestId,
162}
163
164impl std::str::FromStr for EventId {
165 type Err = EventIdParseError;
166 fn from_str(s: &str) -> Result<Self, Self::Err> {
167 if let Some((index, request_id)) = s.split_once("/") {
168 let index = usize::from_str(index).map_err(EventIdParseError::InvalidIndex)?;
169 let request_id = u64::from_str(request_id).map_err(EventIdParseError::InvalidIndex)?;
170 Ok(EventId {
171 http_request_id: Some(request_id),
172 index,
173 })
174 } else {
175 let index = usize::from_str(s).map_err(EventIdParseError::InvalidIndex)?;
176 Ok(EventId {
177 http_request_id: None,
178 index,
179 })
180 }
181 }
182}
183
184use super::{ServerSseMessage, SessionManager};
185
186struct CachedTx {
187 tx: Sender<ServerSseMessage>,
188 cache: VecDeque<ServerSseMessage>,
189 http_request_id: Option<HttpRequestId>,
190 capacity: usize,
191}
192
193impl CachedTx {
194 fn new(tx: Sender<ServerSseMessage>, http_request_id: Option<HttpRequestId>) -> Self {
195 Self {
196 cache: VecDeque::with_capacity(tx.capacity()),
197 capacity: tx.capacity(),
198 tx,
199 http_request_id,
200 }
201 }
202 fn new_common(tx: Sender<ServerSseMessage>) -> Self {
203 Self::new(tx, None)
204 }
205
206 fn next_event_id(&self) -> EventId {
207 let index = self.cache.back().map_or(0, |m| {
208 m.event_id
209 .as_deref()
210 .unwrap_or_default()
211 .parse::<EventId>()
212 .expect("valid event id")
213 .index
214 + 1
215 });
216 EventId {
217 http_request_id: self.http_request_id,
218 index,
219 }
220 }
221
222 async fn send(&mut self, message: ServerJsonRpcMessage) {
223 let event_id = self.next_event_id();
224 let message = ServerSseMessage::new(event_id.to_string(), message);
225 self.cache_and_send(message).await;
226 }
227
228 async fn send_priming(&mut self, retry: Duration) {
229 let event_id = self.next_event_id();
230 let message = ServerSseMessage::priming(event_id.to_string(), retry);
231 self.cache_and_send(message).await;
232 }
233
234 async fn cache_and_send(&mut self, message: ServerSseMessage) {
235 if self.cache.len() >= self.capacity {
236 self.cache.pop_front();
237 self.cache.push_back(message.clone());
238 } else {
239 self.cache.push_back(message.clone());
240 }
241 let _ = self.tx.send(message).await.inspect_err(|e| {
242 let event_id = &e.0.event_id;
243 tracing::trace!(?event_id, "trying to send message in a closed session")
244 });
245 }
246
247 async fn sync(&mut self, index: usize) -> Result<(), SessionError> {
248 let Some(front) = self.cache.front() else {
249 return Ok(());
250 };
251 let front_event_id = front
252 .event_id
253 .as_deref()
254 .unwrap_or_default()
255 .parse::<EventId>()?;
256 let sync_index = index.saturating_sub(front_event_id.index);
257 if sync_index > self.cache.len() {
258 return Err(SessionError::InvalidEventId);
260 }
261 for message in self.cache.iter().skip(sync_index) {
262 let send_result = self.tx.send(message.clone()).await;
263 if send_result.is_err() {
264 let event_id: EventId = message.event_id.as_deref().unwrap_or_default().parse()?;
265 return Err(SessionError::ChannelClosed(Some(event_id.index as u64)));
266 }
267 }
268 Ok(())
269 }
270}
271
272struct HttpRequestWise {
273 resources: HashSet<ResourceKey>,
274 tx: CachedTx,
275}
276
277type HttpRequestId = u64;
278#[derive(Debug, Clone, Hash, PartialEq, Eq)]
279enum ResourceKey {
280 McpRequestId(RequestId),
281 ProgressToken(ProgressToken),
282}
283
284pub struct LocalSessionWorker {
285 id: SessionId,
286 next_http_request_id: HttpRequestId,
287 tx_router: HashMap<HttpRequestId, HttpRequestWise>,
288 resource_router: HashMap<ResourceKey, HttpRequestId>,
289 common: CachedTx,
290 shadow_txs: Vec<Sender<ServerSseMessage>>,
296 event_rx: Receiver<SessionEvent>,
297 session_config: SessionConfig,
298}
299
300impl LocalSessionWorker {
301 pub fn id(&self) -> &SessionId {
302 &self.id
303 }
304}
305
306#[derive(Debug, Error)]
307#[non_exhaustive]
308pub enum SessionError {
309 #[error("Invalid request id: {0}")]
310 DuplicatedRequestId(HttpRequestId),
311 #[error("Channel closed: {0:?}")]
312 ChannelClosed(Option<HttpRequestId>),
313 #[error("Cannot parse event id: {0}")]
314 EventIdParseError(#[from] EventIdParseError),
315 #[error("Session service terminated")]
316 SessionServiceTerminated,
317 #[error("Invalid event id")]
318 InvalidEventId,
319 #[error("IO error: {0}")]
320 Io(#[from] std::io::Error),
321}
322
323impl From<SessionError> for std::io::Error {
324 fn from(value: SessionError) -> Self {
325 match value {
326 SessionError::Io(io) => io,
327 _ => std::io::Error::other(format!("Session error: {value}")),
328 }
329 }
330}
331
332enum OutboundChannel {
333 RequestWise { id: HttpRequestId, close: bool },
334 Common,
335}
336#[derive(Debug)]
337#[non_exhaustive]
338pub struct StreamableHttpMessageReceiver {
339 pub http_request_id: Option<HttpRequestId>,
340 pub inner: Receiver<ServerSseMessage>,
341}
342
343impl LocalSessionWorker {
344 fn unregister_resource(&mut self, resource: &ResourceKey) {
345 if let Some(http_request_id) = self.resource_router.remove(resource) {
346 tracing::trace!(?resource, http_request_id, "unregister resource");
347 if let Some(channel) = self.tx_router.get_mut(&http_request_id) {
348 if channel.resources.is_empty() || matches!(resource, ResourceKey::McpRequestId(_))
351 {
352 tracing::debug!(http_request_id, "close http request wise channel");
353 if let Some(channel) = self.tx_router.remove(&http_request_id) {
354 for resource in channel.resources {
355 self.resource_router.remove(&resource);
356 }
357 }
358 }
359 } else {
360 tracing::warn!(http_request_id, "http request wise channel not found");
361 }
362 }
363 }
364 fn register_resource(&mut self, resource: ResourceKey, http_request_id: HttpRequestId) {
365 tracing::trace!(?resource, http_request_id, "register resource");
366 if let Some(channel) = self.tx_router.get_mut(&http_request_id) {
367 channel.resources.insert(resource.clone());
368 self.resource_router.insert(resource, http_request_id);
369 }
370 }
371 fn register_request(
372 &mut self,
373 request: &JsonRpcRequest<ClientRequest>,
374 http_request_id: HttpRequestId,
375 ) {
376 use crate::model::GetMeta;
377 self.register_resource(
378 ResourceKey::McpRequestId(request.id.clone()),
379 http_request_id,
380 );
381 if let Some(progress_token) = request.request.get_meta().get_progress_token() {
382 self.register_resource(
383 ResourceKey::ProgressToken(progress_token.clone()),
384 http_request_id,
385 );
386 }
387 }
388 fn catch_cancellation_notification(
389 &mut self,
390 notification: &JsonRpcNotification<ClientNotification>,
391 ) {
392 if let ClientNotification::CancelledNotification(n) = ¬ification.notification {
393 let request_id = n.params.request_id.clone();
394 let resource = ResourceKey::McpRequestId(request_id);
395 self.unregister_resource(&resource);
396 }
397 }
398 fn next_http_request_id(&mut self) -> HttpRequestId {
399 let id = self.next_http_request_id;
400 self.next_http_request_id = self.next_http_request_id.wrapping_add(1);
401 id
402 }
403 async fn establish_request_wise_channel(
404 &mut self,
405 ) -> Result<StreamableHttpMessageReceiver, SessionError> {
406 let http_request_id = self.next_http_request_id();
407 let (tx, rx) = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
408 self.tx_router.insert(
409 http_request_id,
410 HttpRequestWise {
411 resources: Default::default(),
412 tx: CachedTx::new(tx, Some(http_request_id)),
413 },
414 );
415 tracing::debug!(http_request_id, "establish new request wise channel");
416 Ok(StreamableHttpMessageReceiver {
417 http_request_id: Some(http_request_id),
418 inner: rx,
419 })
420 }
421 fn resolve_outbound_channel(&self, message: &ServerJsonRpcMessage) -> OutboundChannel {
422 match &message {
423 ServerJsonRpcMessage::Request(_) => OutboundChannel::Common,
424 ServerJsonRpcMessage::Notification(JsonRpcNotification {
425 notification:
426 ServerNotification::ProgressNotification(Notification {
427 params: ProgressNotificationParam { progress_token, .. },
428 ..
429 }),
430 ..
431 }) => {
432 let id = self
433 .resource_router
434 .get(&ResourceKey::ProgressToken(progress_token.clone()));
435
436 if let Some(id) = id {
437 OutboundChannel::RequestWise {
438 id: *id,
439 close: false,
440 }
441 } else {
442 OutboundChannel::Common
443 }
444 }
445 ServerJsonRpcMessage::Notification(JsonRpcNotification {
446 notification:
447 ServerNotification::CancelledNotification(Notification {
448 params: CancelledNotificationParam { request_id, .. },
449 ..
450 }),
451 ..
452 }) => {
453 if let Some(id) = self
454 .resource_router
455 .get(&ResourceKey::McpRequestId(request_id.clone()))
456 {
457 OutboundChannel::RequestWise {
458 id: *id,
459 close: false,
460 }
461 } else {
462 OutboundChannel::Common
463 }
464 }
465 ServerJsonRpcMessage::Notification(_) => OutboundChannel::Common,
466 ServerJsonRpcMessage::Response(json_rpc_response) => {
467 if let Some(id) = self
468 .resource_router
469 .get(&ResourceKey::McpRequestId(json_rpc_response.id.clone()))
470 {
471 OutboundChannel::RequestWise {
472 id: *id,
473 close: false,
474 }
475 } else {
476 OutboundChannel::Common
477 }
478 }
479 ServerJsonRpcMessage::Error(json_rpc_error) => {
480 if let Some(id) = self
481 .resource_router
482 .get(&ResourceKey::McpRequestId(json_rpc_error.id.clone()))
483 {
484 OutboundChannel::RequestWise {
485 id: *id,
486 close: false,
487 }
488 } else {
489 OutboundChannel::Common
490 }
491 }
492 }
493 }
494 async fn handle_server_message(
495 &mut self,
496 message: ServerJsonRpcMessage,
497 ) -> Result<(), SessionError> {
498 let outbound_channel = self.resolve_outbound_channel(&message);
499 match outbound_channel {
500 OutboundChannel::RequestWise { id, close } => {
501 if let Some(request_wise) = self.tx_router.get_mut(&id) {
502 request_wise.tx.send(message).await;
503 if close {
504 self.tx_router.remove(&id);
505 }
506 } else {
507 return Err(SessionError::ChannelClosed(Some(id)));
508 }
509 }
510 OutboundChannel::Common => self.common.send(message).await,
511 }
512 Ok(())
513 }
514 async fn resume(
515 &mut self,
516 last_event_id: EventId,
517 ) -> Result<StreamableHttpMessageReceiver, SessionError> {
518 self.shadow_txs.retain(|tx| !tx.is_closed());
520
521 match last_event_id.http_request_id {
522 Some(http_request_id) => {
523 if let Some(request_wise) = self.tx_router.get_mut(&http_request_id) {
524 let channel = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
526 let (tx, rx) = channel;
527 request_wise.tx.tx = tx;
528 let index = last_event_id.index;
529 request_wise.tx.sync(index).await?;
531 Ok(StreamableHttpMessageReceiver {
532 http_request_id: Some(http_request_id),
533 inner: rx,
534 })
535 } else {
536 tracing::debug!(
540 http_request_id,
541 "Request-wise channel completed, falling back to common channel"
542 );
543 self.resume_or_shadow_common(last_event_id.index).await
544 }
545 }
546 None => self.resume_or_shadow_common(last_event_id.index).await,
547 }
548 }
549
550 async fn resume_or_shadow_common(
563 &mut self,
564 last_event_index: usize,
565 ) -> Result<StreamableHttpMessageReceiver, SessionError> {
566 let is_replacing_dead_primary = self.common.tx.is_closed();
567 let capacity = if is_replacing_dead_primary {
568 self.session_config.channel_capacity
569 } else {
570 1 };
572 let (tx, rx) = tokio::sync::mpsc::channel(capacity);
573 if is_replacing_dead_primary {
574 tracing::debug!("Replacing dead common channel with new primary");
576 self.common.tx = tx;
577 self.common.sync(last_event_index).await?;
580 } else {
581 const MAX_SHADOW_STREAMS: usize = 32;
586
587 if self.shadow_txs.len() >= MAX_SHADOW_STREAMS {
588 tracing::warn!(
589 shadow_count = self.shadow_txs.len(),
590 "Shadow stream limit reached, dropping oldest"
591 );
592 self.shadow_txs.remove(0);
593 }
594 tracing::debug!(
595 shadow_count = self.shadow_txs.len(),
596 "Common channel active, creating shadow stream"
597 );
598 self.shadow_txs.push(tx);
599 }
600 Ok(StreamableHttpMessageReceiver {
601 http_request_id: None,
602 inner: rx,
603 })
604 }
605
606 async fn close_sse_stream(
607 &mut self,
608 http_request_id: Option<HttpRequestId>,
609 retry_interval: Option<Duration>,
610 ) -> Result<(), SessionError> {
611 match http_request_id {
612 Some(id) => {
614 let request_wise = self
615 .tx_router
616 .get_mut(&id)
617 .ok_or(SessionError::ChannelClosed(Some(id)))?;
618
619 if let Some(interval) = retry_interval {
621 request_wise.tx.send_priming(interval).await;
622 }
623
624 let (tx, _rx) = tokio::sync::mpsc::channel(1);
626 request_wise.tx.tx = tx;
627
628 tracing::debug!(
629 http_request_id = id,
630 "closed SSE stream for server-initiated disconnection"
631 );
632 Ok(())
633 }
634 None => {
636 if let Some(interval) = retry_interval {
638 self.common.send_priming(interval).await;
639 }
640
641 let (tx, _rx) = tokio::sync::mpsc::channel(1);
643 self.common.tx = tx;
644
645 self.shadow_txs.clear();
647
648 tracing::debug!("closed standalone SSE stream for server-initiated disconnection");
649 Ok(())
650 }
651 }
652 }
653}
654
655#[derive(Debug)]
656#[non_exhaustive]
657pub enum SessionEvent {
658 ClientMessage {
659 message: ClientJsonRpcMessage,
660 http_request_id: Option<HttpRequestId>,
661 },
662 EstablishRequestWiseChannel {
663 responder: oneshot::Sender<Result<StreamableHttpMessageReceiver, SessionError>>,
664 },
665 CloseRequestWiseChannel {
666 id: HttpRequestId,
667 responder: oneshot::Sender<Result<(), SessionError>>,
668 },
669 Resume {
670 last_event_id: EventId,
671 responder: oneshot::Sender<Result<StreamableHttpMessageReceiver, SessionError>>,
672 },
673 InitializeRequest {
674 request: ClientJsonRpcMessage,
675 responder: oneshot::Sender<Result<ServerJsonRpcMessage, SessionError>>,
676 },
677 Close,
678 CloseSseStream {
679 http_request_id: Option<HttpRequestId>,
681 retry_interval: Option<Duration>,
683 responder: oneshot::Sender<Result<(), SessionError>>,
684 },
685}
686
687#[derive(Debug, Clone)]
688#[non_exhaustive]
689pub enum SessionQuitReason {
690 ServiceTerminated,
691 ClientTerminated,
692 ExpectInitializeRequest,
693 ExpectInitializeResponse,
694 Cancelled,
695}
696
697#[derive(Debug, Clone)]
698pub struct LocalSessionHandle {
699 id: SessionId,
700 event_tx: Sender<SessionEvent>,
702}
703
704impl LocalSessionHandle {
705 pub fn id(&self) -> &SessionId {
707 &self.id
708 }
709
710 pub async fn close(&self) -> Result<(), SessionError> {
712 self.event_tx
713 .send(SessionEvent::Close)
714 .await
715 .map_err(|_| SessionError::SessionServiceTerminated)?;
716 Ok(())
717 }
718
719 pub async fn push_message(
721 &self,
722 message: ClientJsonRpcMessage,
723 http_request_id: Option<HttpRequestId>,
724 ) -> Result<(), SessionError> {
725 self.event_tx
726 .send(SessionEvent::ClientMessage {
727 message,
728 http_request_id,
729 })
730 .await
731 .map_err(|_| SessionError::SessionServiceTerminated)?;
732 Ok(())
733 }
734
735 pub async fn establish_request_wise_channel(
739 &self,
740 ) -> Result<StreamableHttpMessageReceiver, SessionError> {
741 let (tx, rx) = tokio::sync::oneshot::channel();
742 self.event_tx
743 .send(SessionEvent::EstablishRequestWiseChannel { responder: tx })
744 .await
745 .map_err(|_| SessionError::SessionServiceTerminated)?;
746 rx.await
747 .map_err(|_| SessionError::SessionServiceTerminated)?
748 }
749
750 pub async fn close_request_wise_channel(
752 &self,
753 request_id: HttpRequestId,
754 ) -> Result<(), SessionError> {
755 let (tx, rx) = tokio::sync::oneshot::channel();
756 self.event_tx
757 .send(SessionEvent::CloseRequestWiseChannel {
758 id: request_id,
759 responder: tx,
760 })
761 .await
762 .map_err(|_| SessionError::SessionServiceTerminated)?;
763 rx.await
764 .map_err(|_| SessionError::SessionServiceTerminated)?
765 }
766
767 pub async fn establish_common_channel(
769 &self,
770 ) -> Result<StreamableHttpMessageReceiver, SessionError> {
771 let (tx, rx) = tokio::sync::oneshot::channel();
772 self.event_tx
773 .send(SessionEvent::Resume {
774 last_event_id: EventId {
775 http_request_id: None,
776 index: 0,
777 },
778 responder: tx,
779 })
780 .await
781 .map_err(|_| SessionError::SessionServiceTerminated)?;
782 rx.await
783 .map_err(|_| SessionError::SessionServiceTerminated)?
784 }
785
786 pub async fn resume(
788 &self,
789 last_event_id: EventId,
790 ) -> Result<StreamableHttpMessageReceiver, SessionError> {
791 let (tx, rx) = tokio::sync::oneshot::channel();
792 self.event_tx
793 .send(SessionEvent::Resume {
794 last_event_id,
795 responder: tx,
796 })
797 .await
798 .map_err(|_| SessionError::SessionServiceTerminated)?;
799 rx.await
800 .map_err(|_| SessionError::SessionServiceTerminated)?
801 }
802
803 pub async fn initialize(
807 &self,
808 request: ClientJsonRpcMessage,
809 ) -> Result<ServerJsonRpcMessage, SessionError> {
810 let (tx, rx) = tokio::sync::oneshot::channel();
811 self.event_tx
812 .send(SessionEvent::InitializeRequest {
813 request,
814 responder: tx,
815 })
816 .await
817 .map_err(|_| SessionError::SessionServiceTerminated)?;
818 rx.await
819 .map_err(|_| SessionError::SessionServiceTerminated)?
820 }
821
822 pub async fn close_sse_stream(
833 &self,
834 http_request_id: HttpRequestId,
835 retry_interval: Option<Duration>,
836 ) -> Result<(), SessionError> {
837 let (tx, rx) = tokio::sync::oneshot::channel();
838 self.event_tx
839 .send(SessionEvent::CloseSseStream {
840 http_request_id: Some(http_request_id),
841 retry_interval,
842 responder: tx,
843 })
844 .await
845 .map_err(|_| SessionError::SessionServiceTerminated)?;
846 rx.await
847 .map_err(|_| SessionError::SessionServiceTerminated)?
848 }
849
850 pub async fn close_standalone_sse_stream(
860 &self,
861 retry_interval: Option<Duration>,
862 ) -> Result<(), SessionError> {
863 let (tx, rx) = tokio::sync::oneshot::channel();
864 self.event_tx
865 .send(SessionEvent::CloseSseStream {
866 http_request_id: None,
867 retry_interval,
868 responder: tx,
869 })
870 .await
871 .map_err(|_| SessionError::SessionServiceTerminated)?;
872 rx.await
873 .map_err(|_| SessionError::SessionServiceTerminated)?
874 }
875}
876
877pub type SessionTransport = WorkerTransport<LocalSessionWorker>;
878
879#[allow(clippy::large_enum_variant)]
880#[derive(Debug, Error)]
881#[non_exhaustive]
882pub enum LocalSessionWorkerError {
883 #[error("transport terminated")]
884 TransportTerminated,
885 #[error("unexpected message: {0:?}")]
886 UnexpectedEvent(SessionEvent),
887 #[error("fail to send initialize request {0}")]
888 FailToSendInitializeRequest(SessionError),
889 #[error("fail to handle message: {0}")]
890 FailToHandleMessage(SessionError),
891 #[error("keep alive timeout after {}ms", _0.as_millis())]
892 KeepAliveTimeout(Duration),
893 #[error("Transport closed")]
894 TransportClosed,
895 #[error("Tokio join error {0}")]
896 TokioJoinError(#[from] tokio::task::JoinError),
897}
898impl Worker for LocalSessionWorker {
899 type Error = LocalSessionWorkerError;
900 type Role = RoleServer;
901 fn err_closed() -> Self::Error {
902 LocalSessionWorkerError::TransportClosed
903 }
904 fn err_join(e: tokio::task::JoinError) -> Self::Error {
905 LocalSessionWorkerError::TokioJoinError(e)
906 }
907 fn config(&self) -> crate::transport::worker::WorkerConfig {
908 crate::transport::worker::WorkerConfig {
909 name: Some(format!("streamable-http-session-{}", self.id)),
910 channel_buffer_capacity: self.session_config.channel_capacity,
911 }
912 }
913 #[instrument(name = "streamable_http_session", skip_all, fields(id = self.id.as_ref()))]
914 async fn run(
915 mut self,
916 mut context: WorkerContext<Self>,
917 ) -> Result<(), WorkerQuitReason<Self::Error>> {
918 enum InnerEvent {
919 FromHttpService(SessionEvent),
920 FromHandler(WorkerSendRequest<LocalSessionWorker>),
921 }
922 let evt = self.event_rx.recv().await.ok_or_else(|| {
924 WorkerQuitReason::fatal(
925 LocalSessionWorkerError::TransportTerminated,
926 "get initialize request",
927 )
928 })?;
929 let SessionEvent::InitializeRequest { request, responder } = evt else {
930 return Err(WorkerQuitReason::fatal(
931 LocalSessionWorkerError::UnexpectedEvent(evt),
932 "get initialize request",
933 ));
934 };
935 context.send_to_handler(request).await?;
936 let send_initialize_response = context.recv_from_handler().await?;
937 responder
938 .send(Ok(send_initialize_response.message))
939 .map_err(|_| {
940 WorkerQuitReason::fatal(
941 LocalSessionWorkerError::FailToSendInitializeRequest(
942 SessionError::SessionServiceTerminated,
943 ),
944 "send initialize response",
945 )
946 })?;
947 send_initialize_response
948 .responder
949 .send(Ok(()))
950 .map_err(|_| WorkerQuitReason::HandlerTerminated)?;
951 let ct = context.cancellation_token.clone();
952 let keep_alive = self.session_config.keep_alive.unwrap_or(Duration::MAX);
953 loop {
954 let keep_alive_timeout = tokio::time::sleep(keep_alive);
955 let event = tokio::select! {
956 event = self.event_rx.recv() => {
957 if let Some(event) = event {
958 InnerEvent::FromHttpService(event)
959 } else {
960 return Err(WorkerQuitReason::fatal(LocalSessionWorkerError::TransportTerminated, "waiting next session event"))
961 }
962 },
963 from_handler = context.recv_from_handler() => {
964 InnerEvent::FromHandler(from_handler?)
965 }
966 _ = ct.cancelled() => {
967 return Err(WorkerQuitReason::Cancelled)
968 }
969 _ = keep_alive_timeout => {
970 return Err(WorkerQuitReason::fatal(LocalSessionWorkerError::KeepAliveTimeout(keep_alive), "poll next session event"))
971 }
972 };
973 match event {
974 InnerEvent::FromHandler(WorkerSendRequest { message, responder }) => {
975 let to_unregister = match &message {
977 crate::model::JsonRpcMessage::Response(json_rpc_response) => {
978 let request_id = json_rpc_response.id.clone();
979 Some(ResourceKey::McpRequestId(request_id))
980 }
981 crate::model::JsonRpcMessage::Error(json_rpc_error) => {
982 let request_id = json_rpc_error.id.clone();
983 Some(ResourceKey::McpRequestId(request_id))
984 }
985 _ => {
986 None
987 }
989 };
990 let handle_result = self
991 .handle_server_message(message)
992 .await
993 .map_err(LocalSessionWorkerError::FailToHandleMessage);
994 let _ = responder.send(handle_result).inspect_err(|error| {
995 tracing::warn!(?error, "failed to send message to http service handler");
996 });
997 if let Some(to_unregister) = to_unregister {
998 self.unregister_resource(&to_unregister);
999 }
1000 }
1001 InnerEvent::FromHttpService(SessionEvent::ClientMessage {
1002 message: json_rpc_message,
1003 http_request_id,
1004 }) => {
1005 match &json_rpc_message {
1006 crate::model::JsonRpcMessage::Request(request) => {
1007 if let Some(http_request_id) = http_request_id {
1008 self.register_request(request, http_request_id)
1009 }
1010 }
1011 crate::model::JsonRpcMessage::Notification(notification) => {
1012 self.catch_cancellation_notification(notification)
1013 }
1014 _ => {}
1015 }
1016 context.send_to_handler(json_rpc_message).await?;
1017 }
1018 InnerEvent::FromHttpService(SessionEvent::EstablishRequestWiseChannel {
1019 responder,
1020 }) => {
1021 let handle_result = self.establish_request_wise_channel().await;
1022 let _ = responder.send(handle_result);
1023 }
1024 InnerEvent::FromHttpService(SessionEvent::CloseRequestWiseChannel {
1025 id,
1026 responder,
1027 }) => {
1028 let _handle_result = self.tx_router.remove(&id);
1029 let _ = responder.send(Ok(()));
1030 }
1031 InnerEvent::FromHttpService(SessionEvent::Resume {
1032 last_event_id,
1033 responder,
1034 }) => {
1035 let handle_result = self.resume(last_event_id).await;
1036 let _ = responder.send(handle_result);
1037 }
1038 InnerEvent::FromHttpService(SessionEvent::Close) => {
1039 return Err(WorkerQuitReason::TransportClosed);
1040 }
1041 InnerEvent::FromHttpService(SessionEvent::CloseSseStream {
1042 http_request_id,
1043 retry_interval,
1044 responder,
1045 }) => {
1046 let handle_result =
1047 self.close_sse_stream(http_request_id, retry_interval).await;
1048 let _ = responder.send(handle_result);
1049 }
1050 _ => {
1051 }
1053 }
1054 }
1055 }
1056}
1057
1058#[derive(Debug, Clone)]
1059#[non_exhaustive]
1060pub struct SessionConfig {
1061 pub channel_capacity: usize,
1063 pub keep_alive: Option<Duration>,
1075}
1076
1077impl SessionConfig {
1078 pub const DEFAULT_CHANNEL_CAPACITY: usize = 16;
1079 pub const DEFAULT_KEEP_ALIVE: Duration = Duration::from_secs(300);
1080}
1081
1082impl Default for SessionConfig {
1083 fn default() -> Self {
1084 Self {
1085 channel_capacity: Self::DEFAULT_CHANNEL_CAPACITY,
1086 keep_alive: Some(Self::DEFAULT_KEEP_ALIVE),
1087 }
1088 }
1089}
1090
1091pub fn create_local_session(
1097 id: impl Into<SessionId>,
1098 config: SessionConfig,
1099) -> (LocalSessionHandle, LocalSessionWorker) {
1100 let id = id.into();
1101 let (event_tx, event_rx) = tokio::sync::mpsc::channel(config.channel_capacity);
1102 let (common_tx, _) = tokio::sync::mpsc::channel(config.channel_capacity);
1103 let common = CachedTx::new_common(common_tx);
1104 tracing::info!(session_id = ?id, "create new session");
1105 let handle = LocalSessionHandle {
1106 event_tx,
1107 id: id.clone(),
1108 };
1109 let session_worker = LocalSessionWorker {
1110 next_http_request_id: 0,
1111 id,
1112 tx_router: HashMap::new(),
1113 resource_router: HashMap::new(),
1114 common,
1115 shadow_txs: Vec::new(),
1116 event_rx,
1117 session_config: config.clone(),
1118 };
1119 (handle, session_worker)
1120}