1use futures::FutureExt;
2#[cfg(not(feature = "local"))]
3use futures::future::BoxFuture;
4#[cfg(feature = "local")]
5use futures::future::LocalBoxFuture;
6use thiserror::Error;
7
8#[cfg(not(feature = "local"))]
17#[doc(hidden)]
18pub trait MaybeSend: Send + Sync {}
19#[cfg(not(feature = "local"))]
20impl<T: Send + Sync> MaybeSend for T {}
21
22#[cfg(feature = "local")]
23#[doc(hidden)]
24pub trait MaybeSend {}
25#[cfg(feature = "local")]
26impl<T> MaybeSend for T {}
27
28#[cfg(not(feature = "local"))]
29#[doc(hidden)]
30pub trait MaybeSendFuture: Send {}
31#[cfg(not(feature = "local"))]
32impl<T: Send> MaybeSendFuture for T {}
33
34#[cfg(feature = "local")]
35#[doc(hidden)]
36pub trait MaybeSendFuture {}
37#[cfg(feature = "local")]
38impl<T> MaybeSendFuture for T {}
39
40#[cfg(not(feature = "local"))]
41pub(crate) type MaybeBoxFuture<'a, T> = BoxFuture<'a, T>;
42#[cfg(feature = "local")]
43pub(crate) type MaybeBoxFuture<'a, T> = LocalBoxFuture<'a, T>;
44
45#[cfg(feature = "server")]
46use crate::model::ClientNotification;
47#[cfg(feature = "server")]
48use crate::model::ServerJsonRpcMessage;
49#[cfg(feature = "client")]
50use crate::model::ServerNotification;
51use crate::{
52 error::ErrorData as McpError,
53 model::{
54 CancelledNotification, CancelledNotificationParam, Extensions, GetExtensions, GetMeta,
55 JsonRpcError, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, Meta,
56 NumberOrString, ProgressToken, RequestId,
57 },
58 transport::{DynamicTransportError, IntoTransport, Transport},
59};
60#[cfg(feature = "client")]
61mod client;
62#[cfg(feature = "client")]
63pub use client::*;
64#[cfg(feature = "server")]
65mod server;
66#[cfg(feature = "server")]
67pub use server::*;
68#[cfg(feature = "tower")]
69mod tower;
70use tokio_util::sync::{CancellationToken, DropGuard};
71#[cfg(feature = "tower")]
72pub use tower::*;
73use tracing::{Instrument as _, instrument};
74#[derive(Error, Debug)]
75#[non_exhaustive]
76pub enum ServiceError {
77 #[error("Mcp error: {0}")]
78 McpError(McpError),
79 #[error("Transport send error: {0}")]
80 TransportSend(DynamicTransportError),
81 #[error("Transport closed")]
82 TransportClosed,
83 #[error("Unexpected response type")]
84 UnexpectedResponse,
85 #[error("task cancelled for reason {}", reason.as_deref().unwrap_or("<unknown>"))]
86 Cancelled { reason: Option<String> },
87 #[error("request timeout after {}", chrono::Duration::from_std(*timeout).unwrap_or_default())]
88 Timeout { timeout: Duration },
89}
90
91trait TransferObject:
92 std::fmt::Debug + Clone + serde::Serialize + serde::de::DeserializeOwned + Send + Sync + 'static
93{
94}
95
96impl<T> TransferObject for T where
97 T: std::fmt::Debug
98 + serde::Serialize
99 + serde::de::DeserializeOwned
100 + Send
101 + Sync
102 + 'static
103 + Clone
104{
105}
106
107#[allow(private_bounds, reason = "there's no the third implementation")]
108pub trait ServiceRole: std::fmt::Debug + Send + Sync + 'static + Copy + Clone {
109 type Req: TransferObject + GetMeta + GetExtensions;
110 type Resp: TransferObject;
111 type Not: TryInto<CancelledNotification, Error = Self::Not>
112 + From<CancelledNotification>
113 + TransferObject;
114 type PeerReq: TransferObject + GetMeta + GetExtensions;
115 type PeerResp: TransferObject;
116 type PeerNot: TryInto<CancelledNotification, Error = Self::PeerNot>
117 + From<CancelledNotification>
118 + TransferObject
119 + GetMeta
120 + GetExtensions;
121 type InitializeError;
122 const IS_CLIENT: bool;
123 type Info: TransferObject;
124 type PeerInfo: TransferObject;
125}
126
127pub type TxJsonRpcMessage<R> =
128 JsonRpcMessage<<R as ServiceRole>::Req, <R as ServiceRole>::Resp, <R as ServiceRole>::Not>;
129pub type RxJsonRpcMessage<R> = JsonRpcMessage<
130 <R as ServiceRole>::PeerReq,
131 <R as ServiceRole>::PeerResp,
132 <R as ServiceRole>::PeerNot,
133>;
134
135#[cfg(not(feature = "local"))]
136pub trait Service<R: ServiceRole>: Send + Sync + 'static {
137 fn handle_request(
138 &self,
139 request: R::PeerReq,
140 context: RequestContext<R>,
141 ) -> impl Future<Output = Result<R::Resp, McpError>> + MaybeSendFuture + '_;
142 fn handle_notification(
143 &self,
144 notification: R::PeerNot,
145 context: NotificationContext<R>,
146 ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_;
147 fn get_info(&self) -> R::Info;
148}
149
150#[cfg(feature = "local")]
151pub trait Service<R: ServiceRole>: 'static {
152 fn handle_request(
153 &self,
154 request: R::PeerReq,
155 context: RequestContext<R>,
156 ) -> impl Future<Output = Result<R::Resp, McpError>> + MaybeSendFuture + '_;
157 fn handle_notification(
158 &self,
159 notification: R::PeerNot,
160 context: NotificationContext<R>,
161 ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_;
162 fn get_info(&self) -> R::Info;
163}
164
165pub trait ServiceExt<R: ServiceRole>: Service<R> + Sized {
166 fn into_dyn(self) -> Box<dyn DynService<R>> {
170 Box::new(self)
171 }
172 fn serve<T, E, A>(
173 self,
174 transport: T,
175 ) -> impl Future<Output = Result<RunningService<R, Self>, R::InitializeError>> + MaybeSendFuture
176 where
177 T: IntoTransport<R, E, A>,
178 E: std::error::Error + Send + Sync + 'static,
179 Self: Sized,
180 {
181 Self::serve_with_ct(self, transport, Default::default())
182 }
183 fn serve_with_ct<T, E, A>(
184 self,
185 transport: T,
186 ct: CancellationToken,
187 ) -> impl Future<Output = Result<RunningService<R, Self>, R::InitializeError>> + MaybeSendFuture
188 where
189 T: IntoTransport<R, E, A>,
190 E: std::error::Error + Send + Sync + 'static,
191 Self: Sized;
192}
193
194impl<R: ServiceRole> Service<R> for Box<dyn DynService<R>> {
195 fn handle_request(
196 &self,
197 request: R::PeerReq,
198 context: RequestContext<R>,
199 ) -> impl Future<Output = Result<R::Resp, McpError>> + MaybeSendFuture + '_ {
200 DynService::handle_request(self.as_ref(), request, context)
201 }
202
203 fn handle_notification(
204 &self,
205 notification: R::PeerNot,
206 context: NotificationContext<R>,
207 ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ {
208 DynService::handle_notification(self.as_ref(), notification, context)
209 }
210
211 fn get_info(&self) -> R::Info {
212 DynService::get_info(self.as_ref())
213 }
214}
215
216#[cfg(not(feature = "local"))]
217pub trait DynService<R: ServiceRole>: Send + Sync {
218 fn handle_request(
219 &self,
220 request: R::PeerReq,
221 context: RequestContext<R>,
222 ) -> MaybeBoxFuture<'_, Result<R::Resp, McpError>>;
223 fn handle_notification(
224 &self,
225 notification: R::PeerNot,
226 context: NotificationContext<R>,
227 ) -> MaybeBoxFuture<'_, Result<(), McpError>>;
228 fn get_info(&self) -> R::Info;
229}
230
231#[cfg(feature = "local")]
232pub trait DynService<R: ServiceRole> {
233 fn handle_request(
234 &self,
235 request: R::PeerReq,
236 context: RequestContext<R>,
237 ) -> MaybeBoxFuture<'_, Result<R::Resp, McpError>>;
238 fn handle_notification(
239 &self,
240 notification: R::PeerNot,
241 context: NotificationContext<R>,
242 ) -> MaybeBoxFuture<'_, Result<(), McpError>>;
243 fn get_info(&self) -> R::Info;
244}
245
246impl<R: ServiceRole, S: Service<R>> DynService<R> for S {
247 fn handle_request(
248 &self,
249 request: R::PeerReq,
250 context: RequestContext<R>,
251 ) -> MaybeBoxFuture<'_, Result<R::Resp, McpError>> {
252 Box::pin(self.handle_request(request, context))
253 }
254 fn handle_notification(
255 &self,
256 notification: R::PeerNot,
257 context: NotificationContext<R>,
258 ) -> MaybeBoxFuture<'_, Result<(), McpError>> {
259 Box::pin(self.handle_notification(notification, context))
260 }
261 fn get_info(&self) -> R::Info {
262 self.get_info()
263 }
264}
265
266use std::{
267 collections::{HashMap, VecDeque},
268 ops::Deref,
269 sync::{Arc, atomic::AtomicU64},
270 time::Duration,
271};
272
273use tokio::sync::mpsc;
274
275pub trait RequestIdProvider: Send + Sync + 'static {
276 fn next_request_id(&self) -> RequestId;
277}
278
279pub trait ProgressTokenProvider: Send + Sync + 'static {
280 fn next_progress_token(&self) -> ProgressToken;
281}
282
283pub type AtomicU32RequestIdProvider = AtomicU32Provider;
284pub type AtomicU32ProgressTokenProvider = AtomicU32Provider;
285
286#[derive(Debug, Default)]
287pub struct AtomicU32Provider {
288 id: AtomicU64,
289}
290
291impl RequestIdProvider for AtomicU32Provider {
292 fn next_request_id(&self) -> RequestId {
293 let id = self.id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
294 RequestId::Number(id as i64)
296 }
297}
298
299impl ProgressTokenProvider for AtomicU32Provider {
300 fn next_progress_token(&self) -> ProgressToken {
301 let id = self.id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
302 ProgressToken(NumberOrString::Number(id as i64))
303 }
304}
305
306#[doc(hidden)]
307pub trait ProgressNotificationToken {
308 fn progress_token(&self) -> Option<&ProgressToken>;
309}
310
311#[cfg(feature = "server")]
312impl ProgressNotificationToken for ClientNotification {
313 fn progress_token(&self) -> Option<&ProgressToken> {
314 match self {
315 ClientNotification::ProgressNotification(notification) => {
316 Some(¬ification.params.progress_token)
317 }
318 _ => None,
319 }
320 }
321}
322
323#[cfg(feature = "client")]
324impl ProgressNotificationToken for ServerNotification {
325 fn progress_token(&self) -> Option<&ProgressToken> {
326 match self {
327 ServerNotification::ProgressNotification(notification) => {
328 Some(¬ification.params.progress_token)
329 }
330 _ => None,
331 }
332 }
333}
334
335type Responder<T> = tokio::sync::oneshot::Sender<T>;
336type ProgressTimeoutWatchers = Arc<tokio::sync::RwLock<HashMap<ProgressToken, mpsc::Sender<()>>>>;
337
338#[derive(Debug)]
344#[non_exhaustive]
345pub struct RequestHandle<R: ServiceRole> {
346 pub rx: tokio::sync::oneshot::Receiver<Result<R::PeerResp, ServiceError>>,
347 pub options: PeerRequestOptions,
348 pub peer: Peer<R>,
349 pub id: RequestId,
350 pub progress_token: ProgressToken,
351 progress_reset_rx: Option<mpsc::Receiver<()>>,
352}
353
354impl<R: ServiceRole> RequestHandle<R> {
355 pub const REQUEST_TIMEOUT_REASON: &str = "request timeout";
356 pub const REQUEST_MAX_TOTAL_TIMEOUT_REASON: &str = "maximum total timeout exceeded";
357
358 pub async fn await_response(mut self) -> Result<R::PeerResp, ServiceError> {
359 let timeout = self.options.timeout;
360 let max_total_timeout = self.options.max_total_timeout;
361 let reset_timeout_on_progress = self.options.reset_timeout_on_progress;
362
363 let has_progress_reset_rx = self.progress_reset_rx.is_some();
364 let progress_token = self.progress_token.clone();
365
366 let result = match (timeout, max_total_timeout, reset_timeout_on_progress) {
367 (Some(timeout), None, false) => match tokio::time::timeout(timeout, &mut self.rx).await
368 {
369 Ok(response) => response.map_err(|_e| ServiceError::TransportClosed)?,
370 Err(_) => {
371 let error = Err(ServiceError::Timeout { timeout });
372 self.send_timeout_cancel_notification(Self::REQUEST_TIMEOUT_REASON)
374 .await;
375 error
376 }
377 },
378 (None, None, _) => (&mut self.rx)
379 .await
380 .map_err(|_e| ServiceError::TransportClosed)?,
381 _ => {
382 self.await_response_with_progress_timeout(
383 timeout,
384 max_total_timeout,
385 reset_timeout_on_progress,
386 )
387 .await
388 }
389 };
390
391 Self::cleanup_progress_timeout_watcher(
392 &self.peer.progress_timeout_watchers,
393 &progress_token,
394 has_progress_reset_rx,
395 )
396 .await;
397 result
398 }
399
400 async fn send_timeout_cancel_notification(&self, reason: &str) {
401 let notification = CancelledNotification {
402 params: CancelledNotificationParam {
403 request_id: Some(self.id.clone()),
404 reason: Some(reason.to_owned()),
405 meta: None,
406 },
407 method: crate::model::CancelledNotificationMethod,
408 extensions: Default::default(),
409 };
410 let _ = self.peer.send_notification(notification.into()).await;
411 }
412
413 async fn await_response_with_progress_timeout(
414 &mut self,
415 timeout: Option<Duration>,
416 max_total_timeout: Option<Duration>,
417 reset_timeout_on_progress: bool,
418 ) -> Result<R::PeerResp, ServiceError> {
419 let mut idle_sleep =
420 timeout.map(|timeout| (timeout, Box::pin(tokio::time::sleep(timeout))));
421 let mut max_total_sleep =
422 max_total_timeout.map(|timeout| (timeout, Box::pin(tokio::time::sleep(timeout))));
423
424 loop {
425 tokio::select! {
426 biased;
427
428 response = &mut self.rx => {
429 return response.map_err(|_e| ServiceError::TransportClosed)?;
430 }
431 _ = async {
432 if let Some((_, sleep)) = idle_sleep.as_mut() {
433 sleep.as_mut().await;
434 }
435 }, if idle_sleep.is_some() => {
436 if let Some((timeout, _)) = idle_sleep.as_ref() {
437 self.send_timeout_cancel_notification(Self::REQUEST_TIMEOUT_REASON).await;
438 return Err(ServiceError::Timeout { timeout: *timeout });
439 }
440 }
441 _ = async {
442 if let Some((_, sleep)) = max_total_sleep.as_mut() {
443 sleep.as_mut().await;
444 }
445 }, if max_total_sleep.is_some() => {
446 if let Some((timeout, _)) = max_total_sleep.as_ref() {
447 self.send_timeout_cancel_notification(Self::REQUEST_MAX_TOTAL_TIMEOUT_REASON).await;
448 return Err(ServiceError::Timeout { timeout: *timeout });
449 }
450 }
451 progress = async {
452 match self.progress_reset_rx.as_mut() {
453 Some(rx) => rx.recv().await,
454 None => None,
455 }
456 }, if reset_timeout_on_progress && idle_sleep.is_some() && self.progress_reset_rx.is_some() => {
457 if progress.is_some() {
458 if let Some((timeout, sleep)) = idle_sleep.as_mut() {
459 sleep.as_mut().reset(tokio::time::Instant::now() + *timeout);
460 }
461 }
462 }
463 }
464 }
465 }
466
467 pub async fn cancel(self, reason: Option<String>) -> Result<(), ServiceError> {
469 Self::cleanup_progress_timeout_watcher(
470 &self.peer.progress_timeout_watchers,
471 &self.progress_token,
472 self.progress_reset_rx.is_some(),
473 )
474 .await;
475 let notification = CancelledNotification {
476 params: CancelledNotificationParam {
477 request_id: Some(self.id),
478 reason,
479 meta: None,
480 },
481 method: crate::model::CancelledNotificationMethod,
482 extensions: Default::default(),
483 };
484 self.peer.send_notification(notification.into()).await?;
485 Ok(())
486 }
487
488 async fn cleanup_progress_timeout_watcher(
489 progress_timeout_watchers: &ProgressTimeoutWatchers,
490 progress_token: &ProgressToken,
491 has_progress_reset_rx: bool,
492 ) {
493 if has_progress_reset_rx {
494 progress_timeout_watchers
495 .write()
496 .await
497 .remove(progress_token);
498 }
499 }
500}
501
502#[derive(Debug)]
503pub(crate) enum PeerSinkMessage<R: ServiceRole> {
504 Request {
505 request: R::Req,
506 id: RequestId,
507 responder: Responder<Result<R::PeerResp, ServiceError>>,
508 },
509 Notification {
510 notification: R::Not,
511 responder: Responder<Result<(), ServiceError>>,
512 },
513}
514
515#[derive(Clone)]
521pub struct Peer<R: ServiceRole> {
522 tx: mpsc::Sender<PeerSinkMessage<R>>,
523 request_id_provider: Arc<dyn RequestIdProvider>,
524 progress_token_provider: Arc<dyn ProgressTokenProvider>,
525 progress_timeout_watchers: ProgressTimeoutWatchers,
526 info: Arc<std::sync::RwLock<Option<Arc<R::PeerInfo>>>>,
527}
528
529impl<R: ServiceRole> std::fmt::Debug for Peer<R> {
530 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
531 f.debug_struct("PeerSink")
532 .field("tx", &self.tx)
533 .field("is_client", &R::IS_CLIENT)
534 .finish()
535 }
536}
537
538type ProxyOutbound<R> = mpsc::Receiver<PeerSinkMessage<R>>;
539
540#[derive(Debug, Default)]
541#[non_exhaustive]
542pub struct PeerRequestOptions {
543 pub timeout: Option<Duration>,
544 pub meta: Option<Meta>,
545 pub reset_timeout_on_progress: bool,
547 pub max_total_timeout: Option<Duration>,
549}
550
551impl PeerRequestOptions {
552 pub fn no_options() -> Self {
553 Self::default()
554 }
555
556 pub fn with_timeout(timeout: Duration) -> Self {
557 Self {
558 timeout: Some(timeout),
559 ..Self::default()
560 }
561 }
562
563 pub fn reset_timeout_on_progress(mut self) -> Self {
564 self.reset_timeout_on_progress = true;
565 self
566 }
567
568 pub fn with_max_total_timeout(mut self, timeout: Duration) -> Self {
569 self.max_total_timeout = Some(timeout);
570 self
571 }
572}
573
574impl<R: ServiceRole> Peer<R> {
575 const CLIENT_CHANNEL_BUFFER_SIZE: usize = 1024;
576 pub(crate) fn new(
577 request_id_provider: Arc<dyn RequestIdProvider>,
578 peer_info: Option<R::PeerInfo>,
579 ) -> (Peer<R>, ProxyOutbound<R>) {
580 let (tx, rx) = mpsc::channel(Self::CLIENT_CHANNEL_BUFFER_SIZE);
581 (
582 Self {
583 tx,
584 request_id_provider,
585 progress_token_provider: Arc::new(AtomicU32ProgressTokenProvider::default()),
586 progress_timeout_watchers: Default::default(),
587 info: Arc::new(std::sync::RwLock::new(peer_info.map(Arc::new))),
588 },
589 rx,
590 )
591 }
592 pub async fn send_notification(&self, notification: R::Not) -> Result<(), ServiceError> {
593 let (responder, receiver) = tokio::sync::oneshot::channel();
594 self.tx
595 .send(PeerSinkMessage::Notification {
596 notification,
597 responder,
598 })
599 .await
600 .map_err(|_m| ServiceError::TransportClosed)?;
601 receiver.await.map_err(|_e| ServiceError::TransportClosed)?
602 }
603 pub async fn send_request(&self, request: R::Req) -> Result<R::PeerResp, ServiceError> {
604 self.send_request_with_option(request, PeerRequestOptions::no_options())
605 .await?
606 .await_response()
607 .await
608 }
609
610 pub async fn send_cancellable_request(
611 &self,
612 request: R::Req,
613 options: PeerRequestOptions,
614 ) -> Result<RequestHandle<R>, ServiceError> {
615 self.send_request_with_option(request, options).await
616 }
617
618 pub async fn send_request_with_option(
619 &self,
620 mut request: R::Req,
621 options: PeerRequestOptions,
622 ) -> Result<RequestHandle<R>, ServiceError> {
623 let id = self.request_id_provider.next_request_id();
624 let progress_token = self.progress_token_provider.next_progress_token();
625 if let Some(meta) = options.meta.clone() {
626 request.get_meta_mut().extend(meta);
627 }
628 request
629 .get_meta_mut()
630 .set_progress_token(progress_token.clone());
631 let (responder, receiver) = tokio::sync::oneshot::channel();
632 let progress_reset_rx = if options.reset_timeout_on_progress && options.timeout.is_some() {
633 let (sender, receiver) = mpsc::channel(1);
634 self.progress_timeout_watchers
635 .write()
636 .await
637 .insert(progress_token.clone(), sender);
638 Some(receiver)
639 } else {
640 None
641 };
642 if self
643 .tx
644 .send(PeerSinkMessage::Request {
645 request,
646 id: id.clone(),
647 responder,
648 })
649 .await
650 .is_err()
651 {
652 if progress_reset_rx.is_some() {
653 self.progress_timeout_watchers
654 .write()
655 .await
656 .remove(&progress_token);
657 }
658 return Err(ServiceError::TransportClosed);
659 }
660 Ok(RequestHandle {
661 id,
662 rx: receiver,
663 progress_token,
664 options,
665 peer: self.clone(),
666 progress_reset_rx,
667 })
668 }
669
670 async fn notify_progress_timeout_watcher(&self, progress_token: &ProgressToken) {
671 let sender = self
672 .progress_timeout_watchers
673 .read()
674 .await
675 .get(progress_token)
676 .cloned();
677 if let Some(sender) = sender {
678 match sender.try_send(()) {
679 Ok(()) => {}
680 Err(mpsc::error::TrySendError::Full(_)) => {
681 tracing::trace!(?progress_token, "progress timeout watcher channel is full");
682 }
683 Err(mpsc::error::TrySendError::Closed(_)) => {
684 self.progress_timeout_watchers
685 .write()
686 .await
687 .remove(progress_token);
688 }
689 }
690 }
691 }
692
693 pub fn peer_info(&self) -> Option<Arc<R::PeerInfo>> {
695 self.info.read().expect("peer info lock poisoned").clone()
696 }
697
698 pub fn set_peer_info(&self, info: R::PeerInfo) {
700 *self.info.write().expect("peer info lock poisoned") = Some(Arc::new(info));
701 }
702
703 pub fn is_transport_closed(&self) -> bool {
704 self.tx.is_closed()
705 }
706}
707
708#[derive(Debug)]
709pub struct RunningService<R: ServiceRole, S: Service<R>> {
710 service: Arc<S>,
711 peer: Peer<R>,
712 handle: Option<tokio::task::JoinHandle<QuitReason>>,
713 cancellation_token: CancellationToken,
714 dg: DropGuard,
715}
716impl<R: ServiceRole, S: Service<R>> Deref for RunningService<R, S> {
717 type Target = Peer<R>;
718
719 fn deref(&self) -> &Self::Target {
720 &self.peer
721 }
722}
723
724impl<R: ServiceRole, S: Service<R>> RunningService<R, S> {
725 #[inline]
726 pub fn peer(&self) -> &Peer<R> {
727 &self.peer
728 }
729 #[inline]
730 pub fn service(&self) -> &S {
731 self.service.as_ref()
732 }
733 #[inline]
734 pub fn cancellation_token(&self) -> RunningServiceCancellationToken {
735 RunningServiceCancellationToken(self.cancellation_token.clone())
736 }
737
738 #[inline]
740 pub fn is_closed(&self) -> bool {
741 self.handle.is_none() || self.cancellation_token.is_cancelled()
742 }
743
744 #[inline]
749 pub async fn waiting(mut self) -> Result<QuitReason, tokio::task::JoinError> {
750 match self.handle.take() {
751 Some(handle) => handle.await,
752 None => Ok(QuitReason::Closed),
753 }
754 }
755
756 pub async fn close(&mut self) -> Result<QuitReason, tokio::task::JoinError> {
774 if let Some(handle) = self.handle.take() {
775 self.cancellation_token.cancel();
778 handle.await
779 } else {
780 Ok(QuitReason::Closed)
782 }
783 }
784
785 pub async fn close_with_timeout(
794 &mut self,
795 timeout: Duration,
796 ) -> Result<Option<QuitReason>, tokio::task::JoinError> {
797 if let Some(handle) = self.handle.take() {
798 self.cancellation_token.cancel();
799 match tokio::time::timeout(timeout, handle).await {
800 Ok(result) => result.map(Some),
801 Err(_elapsed) => {
802 tracing::warn!(
803 "close_with_timeout: cleanup did not complete within {:?}",
804 timeout
805 );
806 Ok(None)
807 }
808 }
809 } else {
810 Ok(Some(QuitReason::Closed))
811 }
812 }
813
814 pub async fn cancel(mut self) -> Result<QuitReason, tokio::task::JoinError> {
819 let _ = std::mem::replace(&mut self.dg, self.cancellation_token.clone().drop_guard());
821 self.close().await
822 }
823}
824
825impl<R: ServiceRole, S: Service<R>> Drop for RunningService<R, S> {
826 fn drop(&mut self) {
827 if self.handle.is_some() && !self.cancellation_token.is_cancelled() {
828 tracing::debug!(
829 "RunningService dropped without explicit close(). \
830 The connection will be closed asynchronously. \
831 For guaranteed cleanup, call close() or cancel() before dropping."
832 );
833 }
834 }
836}
837
838pub struct RunningServiceCancellationToken(CancellationToken);
840
841impl RunningServiceCancellationToken {
842 pub fn cancel(self) {
843 self.0.cancel();
844 }
845}
846
847#[derive(Debug)]
848#[non_exhaustive]
849pub enum QuitReason {
850 Cancelled,
851 Closed,
852 JoinError(tokio::task::JoinError),
853}
854
855#[derive(Debug, Clone)]
857#[non_exhaustive]
858pub struct RequestContext<R: ServiceRole> {
859 pub ct: CancellationToken,
861 pub id: RequestId,
862 pub meta: Meta,
863 pub extensions: Extensions,
864 pub peer: Peer<R>,
866}
867
868impl<R: ServiceRole> RequestContext<R> {
869 pub fn new(id: RequestId, peer: Peer<R>) -> Self {
871 Self {
872 ct: CancellationToken::new(),
873 id,
874 meta: Meta::default(),
875 extensions: Extensions::default(),
876 peer,
877 }
878 }
879}
880
881#[cfg(feature = "server")]
882impl RequestContext<RoleServer> {
883 pub fn protocol_version(&self) -> Option<crate::model::ProtocolVersion> {
885 self.peer
886 .peer_info()
887 .map(|info| info.protocol_version.clone())
888 }
889}
890
891#[derive(Debug, Clone)]
893#[non_exhaustive]
894pub struct NotificationContext<R: ServiceRole> {
895 pub meta: Meta,
896 pub extensions: Extensions,
897 pub peer: Peer<R>,
899}
900
901pub fn serve_directly<R, S, T, E, A>(
903 service: S,
904 transport: T,
905 peer_info: Option<R::PeerInfo>,
906) -> RunningService<R, S>
907where
908 R: ServiceRole,
909 R::PeerNot: ProgressNotificationToken,
910 S: Service<R>,
911 T: IntoTransport<R, E, A>,
912 E: std::error::Error + Send + Sync + 'static,
913{
914 serve_directly_with_ct(service, transport, peer_info, Default::default())
915}
916
917pub fn serve_directly_with_ct<R, S, T, E, A>(
919 service: S,
920 transport: T,
921 peer_info: Option<R::PeerInfo>,
922 ct: CancellationToken,
923) -> RunningService<R, S>
924where
925 R: ServiceRole,
926 R::PeerNot: ProgressNotificationToken,
927 S: Service<R>,
928 T: IntoTransport<R, E, A>,
929 E: std::error::Error + Send + Sync + 'static,
930{
931 let (peer, peer_rx) = Peer::new(Arc::new(AtomicU32RequestIdProvider::default()), peer_info);
932 serve_inner(service, transport.into_transport(), peer, peer_rx, ct)
933}
934
935#[cfg(not(feature = "local"))]
940fn spawn_service_task<F>(future: F) -> tokio::task::JoinHandle<F::Output>
941where
942 F: Future + Send + 'static,
943 F::Output: Send + 'static,
944{
945 tokio::spawn(future)
946}
947
948#[cfg(feature = "local")]
949fn spawn_service_task<F>(future: F) -> tokio::task::JoinHandle<F::Output>
950where
951 F: Future + 'static,
952 F::Output: 'static,
953{
954 tokio::task::spawn_local(future)
955}
956
957#[instrument(skip_all)]
958fn serve_inner<R, S, T>(
959 service: S,
960 transport: T,
961 peer: Peer<R>,
962 mut peer_rx: tokio::sync::mpsc::Receiver<PeerSinkMessage<R>>,
963 ct: CancellationToken,
964) -> RunningService<R, S>
965where
966 R: ServiceRole,
967 R::PeerNot: ProgressNotificationToken,
968 S: Service<R>,
969 T: Transport<R> + 'static,
970{
971 const SINK_PROXY_BUFFER_SIZE: usize = 64;
972 let (sink_proxy_tx, mut sink_proxy_rx) =
973 tokio::sync::mpsc::channel::<TxJsonRpcMessage<R>>(SINK_PROXY_BUFFER_SIZE);
974 let peer_info = peer.peer_info();
975 if R::IS_CLIENT {
976 tracing::info!(?peer_info, "Service initialized as client");
977 } else {
978 tracing::info!(?peer_info, "Service initialized as server");
979 }
980
981 let mut local_responder_pool =
982 HashMap::<RequestId, Responder<Result<R::PeerResp, ServiceError>>>::new();
983 let mut local_ct_pool = HashMap::<RequestId, CancellationToken>::new();
984 let shared_service = Arc::new(service);
985 let service = shared_service.clone();
987
988 let serve_loop_ct = ct.child_token();
991 let peer_return: Peer<R> = peer.clone();
992 let current_span = tracing::Span::current();
993 let handle = spawn_service_task(async move {
994 let mut transport = transport.into_transport();
995 let mut batch_messages = VecDeque::<RxJsonRpcMessage<R>>::new();
996 let mut send_task_set = tokio::task::JoinSet::<SendTaskResult>::new();
997 let mut response_send_tasks = tokio::task::JoinSet::<()>::new();
998 #[derive(Debug)]
999 enum SendTaskResult {
1000 Request {
1001 id: RequestId,
1002 result: Result<(), DynamicTransportError>,
1003 },
1004 Notification {
1005 responder: Responder<Result<(), ServiceError>>,
1006 cancellation_param: Option<CancelledNotificationParam>,
1007 result: Result<(), DynamicTransportError>,
1008 },
1009 }
1010 #[derive(Debug)]
1011 enum Event<R: ServiceRole> {
1012 ProxyMessage(PeerSinkMessage<R>),
1013 PeerMessage(RxJsonRpcMessage<R>),
1014 ToSink(TxJsonRpcMessage<R>),
1015 SendTaskResult(SendTaskResult),
1016 }
1017
1018 let quit_reason = loop {
1019 let evt = if let Some(m) = batch_messages.pop_front() {
1020 Event::PeerMessage(m)
1021 } else {
1022 tokio::select! {
1023 m = sink_proxy_rx.recv(), if !sink_proxy_rx.is_closed() => {
1024 if let Some(m) = m {
1025 Event::ToSink(m)
1026 } else {
1027 continue
1028 }
1029 }
1030 m = transport.receive() => {
1031 if let Some(m) = m {
1032 Event::PeerMessage(m)
1033 } else {
1034 tracing::info!("input stream terminated");
1036 break QuitReason::Closed
1037 }
1038 }
1039 m = peer_rx.recv(), if !peer_rx.is_closed() => {
1040 if let Some(m) = m {
1041 Event::ProxyMessage(m)
1042 } else {
1043 continue
1044 }
1045 }
1046 m = send_task_set.join_next(), if !send_task_set.is_empty() => {
1047 let Some(result) = m else {
1048 continue
1049 };
1050 match result {
1051 Err(e) => {
1052 tracing::error!(%e, "send request task encounter a tokio join error");
1054 break QuitReason::JoinError(e)
1055 }
1056 Ok(result) => {
1057 Event::SendTaskResult(result)
1058 }
1059 }
1060 }
1061 _ = serve_loop_ct.cancelled() => {
1062 tracing::info!("task cancelled");
1063 break QuitReason::Cancelled
1064 }
1065 }
1066 };
1067
1068 tracing::trace!(?evt, "new event");
1069 match evt {
1070 Event::SendTaskResult(SendTaskResult::Request { id, result }) => {
1071 if let Err(e) = result {
1072 if let Some(responder) = local_responder_pool.remove(&id) {
1073 let _ = responder.send(Err(ServiceError::TransportSend(e)));
1074 }
1075 }
1076 }
1077 Event::SendTaskResult(SendTaskResult::Notification {
1078 responder,
1079 result,
1080 cancellation_param,
1081 }) => {
1082 let response = if let Err(e) = result {
1083 Err(ServiceError::TransportSend(e))
1084 } else {
1085 Ok(())
1086 };
1087 let _ = responder.send(response);
1088 if let Some(param) = cancellation_param {
1089 if let Some(request_id) = ¶m.request_id {
1090 if let Some(responder) = local_responder_pool.remove(request_id) {
1091 tracing::info!(id = %request_id, reason = param.reason, "cancelled");
1092 let _response_result = responder.send(Err(ServiceError::Cancelled {
1093 reason: param.reason.clone(),
1094 }));
1095 }
1096 }
1097 }
1098 }
1099 Event::ToSink(m) => {
1101 if let Some(id) = match &m {
1102 JsonRpcMessage::Response(response) => Some(&response.id),
1103 JsonRpcMessage::Error(error) => error.id.as_ref(),
1104 _ => None,
1105 } {
1106 if let Some(ct) = local_ct_pool.remove(id) {
1107 ct.cancel();
1108 }
1109 let send = transport.send(m);
1110 let current_span = tracing::Span::current();
1111 response_send_tasks.spawn(async move {
1112 let send_result = send.await;
1113 if let Err(error) = send_result {
1114 tracing::error!(%error, "fail to response message");
1115 }
1116 }.instrument(current_span));
1117 }
1118 }
1119 Event::ProxyMessage(PeerSinkMessage::Request {
1120 request,
1121 id,
1122 responder,
1123 }) => {
1124 local_responder_pool.insert(id.clone(), responder);
1125 let send = transport.send(JsonRpcMessage::request(request, id.clone()));
1126 {
1127 let id = id.clone();
1128 let current_span = tracing::Span::current();
1129 send_task_set.spawn(send.map(move |r| SendTaskResult::Request {
1130 id,
1131 result: r.map_err(DynamicTransportError::new::<T, R>),
1132 }).instrument(current_span));
1133 }
1134 }
1135 Event::ProxyMessage(PeerSinkMessage::Notification {
1136 notification,
1137 responder,
1138 }) => {
1139 let mut cancellation_param = None;
1141 let notification = match notification.try_into() {
1142 Ok::<CancelledNotification, _>(cancelled) => {
1143 cancellation_param.replace(cancelled.params.clone());
1144 cancelled.into()
1145 }
1146 Err(notification) => notification,
1147 };
1148 let send = transport.send(JsonRpcMessage::notification(notification));
1149 let current_span = tracing::Span::current();
1150 send_task_set.spawn(send.map(move |result| SendTaskResult::Notification {
1151 responder,
1152 cancellation_param,
1153 result: result.map_err(DynamicTransportError::new::<T, R>),
1154 }).instrument(current_span));
1155 }
1156 Event::PeerMessage(JsonRpcMessage::Request(JsonRpcRequest {
1157 id,
1158 mut request,
1159 ..
1160 })) => {
1161 tracing::debug!(%id, ?request, "received request");
1162 {
1163 let service = shared_service.clone();
1164 let sink = sink_proxy_tx.clone();
1165 let request_ct = serve_loop_ct.child_token();
1166 let context_ct = request_ct.child_token();
1167 local_ct_pool.insert(id.clone(), request_ct);
1168 let mut extensions = Extensions::new();
1169 let mut meta = Meta::new();
1170 std::mem::swap(&mut meta, request.get_meta_mut());
1173 std::mem::swap(&mut extensions, request.extensions_mut());
1174 let context = RequestContext {
1175 ct: context_ct,
1176 id: id.clone(),
1177 peer: peer.clone(),
1178 meta,
1179 extensions,
1180 };
1181 let current_span = tracing::Span::current();
1182 spawn_service_task(async move {
1183 let result = service
1184 .handle_request(request, context)
1185 .await;
1186 let response = match result {
1187 Ok(result) => {
1188 tracing::debug!(%id, ?result, "response message");
1189 JsonRpcMessage::response(result, id)
1190 }
1191 Err(error) => {
1192 tracing::warn!(%id, ?error, "response error");
1193 JsonRpcMessage::error(error, Some(id))
1194 }
1195 };
1196 let _send_result = sink.send(response).await;
1197 }.instrument(current_span));
1198 }
1199 }
1200 Event::PeerMessage(JsonRpcMessage::Notification(JsonRpcNotification {
1201 notification,
1202 ..
1203 })) => {
1204 tracing::info!(?notification, "received notification");
1205 let mut notification = match notification.try_into() {
1207 Ok::<CancelledNotification, _>(cancelled) => {
1208 if let Some(request_id) = &cancelled.params.request_id {
1209 if let Some(ct) = local_ct_pool.remove(request_id) {
1210 tracing::info!(id = %request_id, reason = cancelled.params.reason, "cancelled");
1211 ct.cancel();
1212 }
1213 }
1214 cancelled.into()
1215 }
1216 Err(notification) => notification,
1217 };
1218 if let Some(progress_token) = notification.progress_token() {
1219 peer.notify_progress_timeout_watcher(progress_token).await;
1220 }
1221 {
1222 let service = shared_service.clone();
1223 let mut extensions = Extensions::new();
1224 let mut meta = Meta::new();
1225 std::mem::swap(&mut extensions, notification.extensions_mut());
1227 std::mem::swap(&mut meta, notification.get_meta_mut());
1228 let context = NotificationContext {
1229 peer: peer.clone(),
1230 meta,
1231 extensions,
1232 };
1233 let current_span = tracing::Span::current();
1234 spawn_service_task(async move {
1235 let result = service.handle_notification(notification, context).await;
1236 if let Err(error) = result {
1237 tracing::warn!(%error, "Error sending notification");
1238 }
1239 }.instrument(current_span));
1240 }
1241 }
1242 Event::PeerMessage(JsonRpcMessage::Response(JsonRpcResponse {
1243 result,
1244 id,
1245 ..
1246 })) => {
1247 if let Some(responder) = local_responder_pool.remove(&id) {
1248 let response_result = responder.send(Ok(result));
1249 if let Err(_error) = response_result {
1250 tracing::warn!(%id, "Error sending response");
1251 }
1252 }
1253 }
1254 Event::PeerMessage(JsonRpcMessage::Error(JsonRpcError { error, id, .. })) => {
1255 let Some(id) = id else {
1256 tracing::debug!(?error, "received id-less peer error");
1259 continue;
1260 };
1261 if let Some(responder) = local_responder_pool.remove(&id) {
1262 let _response_result = responder.send(Err(ServiceError::McpError(error)));
1263 if let Err(_error) = _response_result {
1264 tracing::warn!(%id, "Error sending response");
1265 }
1266 }
1267 }
1268 }
1269 };
1270
1271 let drain_timeout = match &quit_reason {
1277 QuitReason::Closed => Some(Duration::from_secs(5)),
1278 QuitReason::Cancelled => Some(Duration::from_secs(2)),
1279 _ => None,
1280 };
1281 if let Some(timeout_duration) = drain_timeout {
1282 drop(sink_proxy_tx);
1285 let drain_result = tokio::time::timeout(timeout_duration, async {
1286 while let Some(result) = response_send_tasks.join_next().await {
1289 if let Err(error) = result {
1290 tracing::error!(%error, "response send task failed during drain");
1291 }
1292 }
1293 while let Some(m) = sink_proxy_rx.recv().await {
1296 if let Err(error) = transport.send(m).await {
1297 tracing::error!(%error, "failed to send pending response during drain");
1298 break;
1299 }
1300 }
1301 })
1302 .await;
1303 if drain_result.is_err() {
1304 tracing::warn!("timed out draining in-flight responses");
1305 }
1306 }
1307
1308 let sink_close_result = transport.close().await;
1309 if let Err(e) = sink_close_result {
1310 tracing::error!(%e, "fail to close sink");
1311 }
1312 tracing::info!(?quit_reason, "serve finished");
1313 quit_reason
1314 }.instrument(current_span));
1315 RunningService {
1316 service,
1317 peer: peer_return,
1318 handle: Some(handle),
1319 cancellation_token: ct.clone(),
1320 dg: ct.drop_guard(),
1321 }
1322}