Skip to main content

rmcp/
service.rs

1use futures::FutureExt;
2#[cfg(not(feature = "local"))]
3use futures::future::BoxFuture;
4#[cfg(feature = "local")]
5use futures::future::LocalBoxFuture;
6use thiserror::Error;
7
8// ---------------------------------------------------------------------------
9// Conditional Send helpers
10//
11// `MaybeSend`       – supertrait alias: `Send + Sync` without `local`, empty with `local`
12// `MaybeSendFuture` – future bound alias: `Send` without `local`, empty with `local`
13// `MaybeBoxFuture`  – boxed future type: `BoxFuture` without `local`, `LocalBoxFuture` with `local`
14// ---------------------------------------------------------------------------
15
16#[cfg(not(feature = "local"))]
17#[doc(hidden)]
18pub trait MaybeSend: Send + Sync {}
19#[cfg(not(feature = "local"))]
20impl<T: Send + Sync> MaybeSend for T {}
21
22#[cfg(feature = "local")]
23#[doc(hidden)]
24pub trait MaybeSend {}
25#[cfg(feature = "local")]
26impl<T> MaybeSend for T {}
27
28#[cfg(not(feature = "local"))]
29#[doc(hidden)]
30pub trait MaybeSendFuture: Send {}
31#[cfg(not(feature = "local"))]
32impl<T: Send> MaybeSendFuture for T {}
33
34#[cfg(feature = "local")]
35#[doc(hidden)]
36pub trait MaybeSendFuture {}
37#[cfg(feature = "local")]
38impl<T> MaybeSendFuture for T {}
39
40#[cfg(not(feature = "local"))]
41pub(crate) type MaybeBoxFuture<'a, T> = BoxFuture<'a, T>;
42#[cfg(feature = "local")]
43pub(crate) type MaybeBoxFuture<'a, T> = LocalBoxFuture<'a, T>;
44
45#[cfg(feature = "server")]
46use crate::model::ClientNotification;
47#[cfg(feature = "server")]
48use crate::model::ServerJsonRpcMessage;
49#[cfg(feature = "client")]
50use crate::model::ServerNotification;
51use crate::{
52    error::ErrorData as McpError,
53    model::{
54        CancelledNotification, CancelledNotificationParam, Extensions, GetExtensions, GetMeta,
55        JsonRpcError, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, Meta,
56        NumberOrString, ProgressToken, RequestId,
57    },
58    transport::{DynamicTransportError, IntoTransport, Transport},
59};
60#[cfg(feature = "client")]
61mod client;
62#[cfg(feature = "client")]
63pub use client::*;
64#[cfg(feature = "server")]
65mod server;
66#[cfg(feature = "server")]
67pub use server::*;
68#[cfg(feature = "tower")]
69mod tower;
70use tokio_util::sync::{CancellationToken, DropGuard};
71#[cfg(feature = "tower")]
72pub use tower::*;
73use tracing::{Instrument as _, instrument};
74#[derive(Error, Debug)]
75#[non_exhaustive]
76pub enum ServiceError {
77    #[error("Mcp error: {0}")]
78    McpError(McpError),
79    #[error("Transport send error: {0}")]
80    TransportSend(DynamicTransportError),
81    #[error("Transport closed")]
82    TransportClosed,
83    #[error("Unexpected response type")]
84    UnexpectedResponse,
85    #[error("task cancelled for reason {}", reason.as_deref().unwrap_or("<unknown>"))]
86    Cancelled { reason: Option<String> },
87    #[error("request timeout after {}", chrono::Duration::from_std(*timeout).unwrap_or_default())]
88    Timeout { timeout: Duration },
89}
90
91trait TransferObject:
92    std::fmt::Debug + Clone + serde::Serialize + serde::de::DeserializeOwned + Send + Sync + 'static
93{
94}
95
96impl<T> TransferObject for T where
97    T: std::fmt::Debug
98        + serde::Serialize
99        + serde::de::DeserializeOwned
100        + Send
101        + Sync
102        + 'static
103        + Clone
104{
105}
106
107#[allow(private_bounds, reason = "there's no the third implementation")]
108pub trait ServiceRole: std::fmt::Debug + Send + Sync + 'static + Copy + Clone {
109    type Req: TransferObject + GetMeta + GetExtensions;
110    type Resp: TransferObject;
111    type Not: TryInto<CancelledNotification, Error = Self::Not>
112        + From<CancelledNotification>
113        + TransferObject;
114    type PeerReq: TransferObject + GetMeta + GetExtensions;
115    type PeerResp: TransferObject;
116    type PeerNot: TryInto<CancelledNotification, Error = Self::PeerNot>
117        + From<CancelledNotification>
118        + TransferObject
119        + GetMeta
120        + GetExtensions;
121    type InitializeError;
122    const IS_CLIENT: bool;
123    type Info: TransferObject;
124    type PeerInfo: TransferObject;
125}
126
127pub type TxJsonRpcMessage<R> =
128    JsonRpcMessage<<R as ServiceRole>::Req, <R as ServiceRole>::Resp, <R as ServiceRole>::Not>;
129pub type RxJsonRpcMessage<R> = JsonRpcMessage<
130    <R as ServiceRole>::PeerReq,
131    <R as ServiceRole>::PeerResp,
132    <R as ServiceRole>::PeerNot,
133>;
134
135#[cfg(not(feature = "local"))]
136pub trait Service<R: ServiceRole>: Send + Sync + 'static {
137    fn handle_request(
138        &self,
139        request: R::PeerReq,
140        context: RequestContext<R>,
141    ) -> impl Future<Output = Result<R::Resp, McpError>> + MaybeSendFuture + '_;
142    fn handle_notification(
143        &self,
144        notification: R::PeerNot,
145        context: NotificationContext<R>,
146    ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_;
147    fn get_info(&self) -> R::Info;
148}
149
150#[cfg(feature = "local")]
151pub trait Service<R: ServiceRole>: 'static {
152    fn handle_request(
153        &self,
154        request: R::PeerReq,
155        context: RequestContext<R>,
156    ) -> impl Future<Output = Result<R::Resp, McpError>> + MaybeSendFuture + '_;
157    fn handle_notification(
158        &self,
159        notification: R::PeerNot,
160        context: NotificationContext<R>,
161    ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_;
162    fn get_info(&self) -> R::Info;
163}
164
165pub trait ServiceExt<R: ServiceRole>: Service<R> + Sized {
166    /// Convert this service to a dynamic boxed service
167    ///
168    /// This could be very helpful when you want to store the services in a collection
169    fn into_dyn(self) -> Box<dyn DynService<R>> {
170        Box::new(self)
171    }
172    fn serve<T, E, A>(
173        self,
174        transport: T,
175    ) -> impl Future<Output = Result<RunningService<R, Self>, R::InitializeError>> + MaybeSendFuture
176    where
177        T: IntoTransport<R, E, A>,
178        E: std::error::Error + Send + Sync + 'static,
179        Self: Sized,
180    {
181        Self::serve_with_ct(self, transport, Default::default())
182    }
183    fn serve_with_ct<T, E, A>(
184        self,
185        transport: T,
186        ct: CancellationToken,
187    ) -> impl Future<Output = Result<RunningService<R, Self>, R::InitializeError>> + MaybeSendFuture
188    where
189        T: IntoTransport<R, E, A>,
190        E: std::error::Error + Send + Sync + 'static,
191        Self: Sized;
192}
193
194impl<R: ServiceRole> Service<R> for Box<dyn DynService<R>> {
195    fn handle_request(
196        &self,
197        request: R::PeerReq,
198        context: RequestContext<R>,
199    ) -> impl Future<Output = Result<R::Resp, McpError>> + MaybeSendFuture + '_ {
200        DynService::handle_request(self.as_ref(), request, context)
201    }
202
203    fn handle_notification(
204        &self,
205        notification: R::PeerNot,
206        context: NotificationContext<R>,
207    ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ {
208        DynService::handle_notification(self.as_ref(), notification, context)
209    }
210
211    fn get_info(&self) -> R::Info {
212        DynService::get_info(self.as_ref())
213    }
214}
215
216#[cfg(not(feature = "local"))]
217pub trait DynService<R: ServiceRole>: Send + Sync {
218    fn handle_request(
219        &self,
220        request: R::PeerReq,
221        context: RequestContext<R>,
222    ) -> MaybeBoxFuture<'_, Result<R::Resp, McpError>>;
223    fn handle_notification(
224        &self,
225        notification: R::PeerNot,
226        context: NotificationContext<R>,
227    ) -> MaybeBoxFuture<'_, Result<(), McpError>>;
228    fn get_info(&self) -> R::Info;
229}
230
231#[cfg(feature = "local")]
232pub trait DynService<R: ServiceRole> {
233    fn handle_request(
234        &self,
235        request: R::PeerReq,
236        context: RequestContext<R>,
237    ) -> MaybeBoxFuture<'_, Result<R::Resp, McpError>>;
238    fn handle_notification(
239        &self,
240        notification: R::PeerNot,
241        context: NotificationContext<R>,
242    ) -> MaybeBoxFuture<'_, Result<(), McpError>>;
243    fn get_info(&self) -> R::Info;
244}
245
246impl<R: ServiceRole, S: Service<R>> DynService<R> for S {
247    fn handle_request(
248        &self,
249        request: R::PeerReq,
250        context: RequestContext<R>,
251    ) -> MaybeBoxFuture<'_, Result<R::Resp, McpError>> {
252        Box::pin(self.handle_request(request, context))
253    }
254    fn handle_notification(
255        &self,
256        notification: R::PeerNot,
257        context: NotificationContext<R>,
258    ) -> MaybeBoxFuture<'_, Result<(), McpError>> {
259        Box::pin(self.handle_notification(notification, context))
260    }
261    fn get_info(&self) -> R::Info {
262        self.get_info()
263    }
264}
265
266use std::{
267    collections::{HashMap, VecDeque},
268    ops::Deref,
269    sync::{Arc, atomic::AtomicU64},
270    time::Duration,
271};
272
273use tokio::sync::mpsc;
274
275pub trait RequestIdProvider: Send + Sync + 'static {
276    fn next_request_id(&self) -> RequestId;
277}
278
279pub trait ProgressTokenProvider: Send + Sync + 'static {
280    fn next_progress_token(&self) -> ProgressToken;
281}
282
283pub type AtomicU32RequestIdProvider = AtomicU32Provider;
284pub type AtomicU32ProgressTokenProvider = AtomicU32Provider;
285
286#[derive(Debug, Default)]
287pub struct AtomicU32Provider {
288    id: AtomicU64,
289}
290
291impl RequestIdProvider for AtomicU32Provider {
292    fn next_request_id(&self) -> RequestId {
293        let id = self.id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
294        // Safe conversion: we start at 0 and increment by 1, so we won't overflow i64::MAX in practice
295        RequestId::Number(id as i64)
296    }
297}
298
299impl ProgressTokenProvider for AtomicU32Provider {
300    fn next_progress_token(&self) -> ProgressToken {
301        let id = self.id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
302        ProgressToken(NumberOrString::Number(id as i64))
303    }
304}
305
306#[doc(hidden)]
307pub trait ProgressNotificationToken {
308    fn progress_token(&self) -> Option<&ProgressToken>;
309}
310
311#[cfg(feature = "server")]
312impl ProgressNotificationToken for ClientNotification {
313    fn progress_token(&self) -> Option<&ProgressToken> {
314        match self {
315            ClientNotification::ProgressNotification(notification) => {
316                Some(&notification.params.progress_token)
317            }
318            _ => None,
319        }
320    }
321}
322
323#[cfg(feature = "client")]
324impl ProgressNotificationToken for ServerNotification {
325    fn progress_token(&self) -> Option<&ProgressToken> {
326        match self {
327            ServerNotification::ProgressNotification(notification) => {
328                Some(&notification.params.progress_token)
329            }
330            _ => None,
331        }
332    }
333}
334
335type Responder<T> = tokio::sync::oneshot::Sender<T>;
336type ProgressTimeoutWatchers = Arc<tokio::sync::RwLock<HashMap<ProgressToken, mpsc::Sender<()>>>>;
337
338/// A handle to a remote request
339///
340/// You can cancel it by call [`RequestHandle::cancel`] with a reason,
341///
342/// or wait for response by call [`RequestHandle::await_response`]
343#[derive(Debug)]
344#[non_exhaustive]
345pub struct RequestHandle<R: ServiceRole> {
346    pub rx: tokio::sync::oneshot::Receiver<Result<R::PeerResp, ServiceError>>,
347    pub options: PeerRequestOptions,
348    pub peer: Peer<R>,
349    pub id: RequestId,
350    pub progress_token: ProgressToken,
351    progress_reset_rx: Option<mpsc::Receiver<()>>,
352}
353
354impl<R: ServiceRole> RequestHandle<R> {
355    pub const REQUEST_TIMEOUT_REASON: &str = "request timeout";
356    pub const REQUEST_MAX_TOTAL_TIMEOUT_REASON: &str = "maximum total timeout exceeded";
357
358    pub async fn await_response(mut self) -> Result<R::PeerResp, ServiceError> {
359        let timeout = self.options.timeout;
360        let max_total_timeout = self.options.max_total_timeout;
361        let reset_timeout_on_progress = self.options.reset_timeout_on_progress;
362
363        let has_progress_reset_rx = self.progress_reset_rx.is_some();
364        let progress_token = self.progress_token.clone();
365
366        let result = match (timeout, max_total_timeout, reset_timeout_on_progress) {
367            (Some(timeout), None, false) => match tokio::time::timeout(timeout, &mut self.rx).await
368            {
369                Ok(response) => response.map_err(|_e| ServiceError::TransportClosed)?,
370                Err(_) => {
371                    let error = Err(ServiceError::Timeout { timeout });
372                    // cancel this request
373                    self.send_timeout_cancel_notification(Self::REQUEST_TIMEOUT_REASON)
374                        .await;
375                    error
376                }
377            },
378            (None, None, _) => (&mut self.rx)
379                .await
380                .map_err(|_e| ServiceError::TransportClosed)?,
381            _ => {
382                self.await_response_with_progress_timeout(
383                    timeout,
384                    max_total_timeout,
385                    reset_timeout_on_progress,
386                )
387                .await
388            }
389        };
390
391        Self::cleanup_progress_timeout_watcher(
392            &self.peer.progress_timeout_watchers,
393            &progress_token,
394            has_progress_reset_rx,
395        )
396        .await;
397        result
398    }
399
400    async fn send_timeout_cancel_notification(&self, reason: &str) {
401        let notification = CancelledNotification {
402            params: CancelledNotificationParam {
403                request_id: Some(self.id.clone()),
404                reason: Some(reason.to_owned()),
405                meta: None,
406            },
407            method: crate::model::CancelledNotificationMethod,
408            extensions: Default::default(),
409        };
410        let _ = self.peer.send_notification(notification.into()).await;
411    }
412
413    async fn await_response_with_progress_timeout(
414        &mut self,
415        timeout: Option<Duration>,
416        max_total_timeout: Option<Duration>,
417        reset_timeout_on_progress: bool,
418    ) -> Result<R::PeerResp, ServiceError> {
419        let mut idle_sleep =
420            timeout.map(|timeout| (timeout, Box::pin(tokio::time::sleep(timeout))));
421        let mut max_total_sleep =
422            max_total_timeout.map(|timeout| (timeout, Box::pin(tokio::time::sleep(timeout))));
423
424        loop {
425            tokio::select! {
426                biased;
427
428                response = &mut self.rx => {
429                    return response.map_err(|_e| ServiceError::TransportClosed)?;
430                }
431                _ = async {
432                    if let Some((_, sleep)) = idle_sleep.as_mut() {
433                        sleep.as_mut().await;
434                    }
435                }, if idle_sleep.is_some() => {
436                    if let Some((timeout, _)) = idle_sleep.as_ref() {
437                        self.send_timeout_cancel_notification(Self::REQUEST_TIMEOUT_REASON).await;
438                        return Err(ServiceError::Timeout { timeout: *timeout });
439                    }
440                }
441                _ = async {
442                    if let Some((_, sleep)) = max_total_sleep.as_mut() {
443                        sleep.as_mut().await;
444                    }
445                }, if max_total_sleep.is_some() => {
446                    if let Some((timeout, _)) = max_total_sleep.as_ref() {
447                        self.send_timeout_cancel_notification(Self::REQUEST_MAX_TOTAL_TIMEOUT_REASON).await;
448                        return Err(ServiceError::Timeout { timeout: *timeout });
449                    }
450                }
451                progress = async {
452                    match self.progress_reset_rx.as_mut() {
453                        Some(rx) => rx.recv().await,
454                        None => None,
455                    }
456                }, if reset_timeout_on_progress && idle_sleep.is_some() && self.progress_reset_rx.is_some() => {
457                    if progress.is_some() {
458                        if let Some((timeout, sleep)) = idle_sleep.as_mut() {
459                            sleep.as_mut().reset(tokio::time::Instant::now() + *timeout);
460                        }
461                    }
462                }
463            }
464        }
465    }
466
467    /// Cancel this request
468    pub async fn cancel(self, reason: Option<String>) -> Result<(), ServiceError> {
469        Self::cleanup_progress_timeout_watcher(
470            &self.peer.progress_timeout_watchers,
471            &self.progress_token,
472            self.progress_reset_rx.is_some(),
473        )
474        .await;
475        let notification = CancelledNotification {
476            params: CancelledNotificationParam {
477                request_id: Some(self.id),
478                reason,
479                meta: None,
480            },
481            method: crate::model::CancelledNotificationMethod,
482            extensions: Default::default(),
483        };
484        self.peer.send_notification(notification.into()).await?;
485        Ok(())
486    }
487
488    async fn cleanup_progress_timeout_watcher(
489        progress_timeout_watchers: &ProgressTimeoutWatchers,
490        progress_token: &ProgressToken,
491        has_progress_reset_rx: bool,
492    ) {
493        if has_progress_reset_rx {
494            progress_timeout_watchers
495                .write()
496                .await
497                .remove(progress_token);
498        }
499    }
500}
501
502#[derive(Debug)]
503pub(crate) enum PeerSinkMessage<R: ServiceRole> {
504    Request {
505        request: R::Req,
506        id: RequestId,
507        responder: Responder<Result<R::PeerResp, ServiceError>>,
508    },
509    Notification {
510        notification: R::Not,
511        responder: Responder<Result<(), ServiceError>>,
512    },
513}
514
515/// An interface to fetch the remote client or server
516///
517/// For general purpose, call [`Peer::send_request`] or [`Peer::send_notification`] to send message to remote peer.
518///
519/// To create a cancellable request, call [`Peer::send_request_with_option`].
520#[derive(Clone)]
521pub struct Peer<R: ServiceRole> {
522    tx: mpsc::Sender<PeerSinkMessage<R>>,
523    request_id_provider: Arc<dyn RequestIdProvider>,
524    progress_token_provider: Arc<dyn ProgressTokenProvider>,
525    progress_timeout_watchers: ProgressTimeoutWatchers,
526    info: Arc<std::sync::RwLock<Option<Arc<R::PeerInfo>>>>,
527}
528
529impl<R: ServiceRole> std::fmt::Debug for Peer<R> {
530    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
531        f.debug_struct("PeerSink")
532            .field("tx", &self.tx)
533            .field("is_client", &R::IS_CLIENT)
534            .finish()
535    }
536}
537
538type ProxyOutbound<R> = mpsc::Receiver<PeerSinkMessage<R>>;
539
540#[derive(Debug, Default)]
541#[non_exhaustive]
542pub struct PeerRequestOptions {
543    pub timeout: Option<Duration>,
544    pub meta: Option<Meta>,
545    /// Reset the request timeout when a matching progress notification is received.
546    pub reset_timeout_on_progress: bool,
547    /// Maximum total time to wait for the request, regardless of progress notifications.
548    pub max_total_timeout: Option<Duration>,
549}
550
551impl PeerRequestOptions {
552    pub fn no_options() -> Self {
553        Self::default()
554    }
555
556    pub fn with_timeout(timeout: Duration) -> Self {
557        Self {
558            timeout: Some(timeout),
559            ..Self::default()
560        }
561    }
562
563    pub fn reset_timeout_on_progress(mut self) -> Self {
564        self.reset_timeout_on_progress = true;
565        self
566    }
567
568    pub fn with_max_total_timeout(mut self, timeout: Duration) -> Self {
569        self.max_total_timeout = Some(timeout);
570        self
571    }
572}
573
574impl<R: ServiceRole> Peer<R> {
575    const CLIENT_CHANNEL_BUFFER_SIZE: usize = 1024;
576    pub(crate) fn new(
577        request_id_provider: Arc<dyn RequestIdProvider>,
578        peer_info: Option<R::PeerInfo>,
579    ) -> (Peer<R>, ProxyOutbound<R>) {
580        let (tx, rx) = mpsc::channel(Self::CLIENT_CHANNEL_BUFFER_SIZE);
581        (
582            Self {
583                tx,
584                request_id_provider,
585                progress_token_provider: Arc::new(AtomicU32ProgressTokenProvider::default()),
586                progress_timeout_watchers: Default::default(),
587                info: Arc::new(std::sync::RwLock::new(peer_info.map(Arc::new))),
588            },
589            rx,
590        )
591    }
592    pub async fn send_notification(&self, notification: R::Not) -> Result<(), ServiceError> {
593        let (responder, receiver) = tokio::sync::oneshot::channel();
594        self.tx
595            .send(PeerSinkMessage::Notification {
596                notification,
597                responder,
598            })
599            .await
600            .map_err(|_m| ServiceError::TransportClosed)?;
601        receiver.await.map_err(|_e| ServiceError::TransportClosed)?
602    }
603    pub async fn send_request(&self, request: R::Req) -> Result<R::PeerResp, ServiceError> {
604        self.send_request_with_option(request, PeerRequestOptions::no_options())
605            .await?
606            .await_response()
607            .await
608    }
609
610    pub async fn send_cancellable_request(
611        &self,
612        request: R::Req,
613        options: PeerRequestOptions,
614    ) -> Result<RequestHandle<R>, ServiceError> {
615        self.send_request_with_option(request, options).await
616    }
617
618    pub async fn send_request_with_option(
619        &self,
620        mut request: R::Req,
621        options: PeerRequestOptions,
622    ) -> Result<RequestHandle<R>, ServiceError> {
623        let id = self.request_id_provider.next_request_id();
624        let progress_token = self.progress_token_provider.next_progress_token();
625        if let Some(meta) = options.meta.clone() {
626            request.get_meta_mut().extend(meta);
627        }
628        request
629            .get_meta_mut()
630            .set_progress_token(progress_token.clone());
631        let (responder, receiver) = tokio::sync::oneshot::channel();
632        let progress_reset_rx = if options.reset_timeout_on_progress && options.timeout.is_some() {
633            let (sender, receiver) = mpsc::channel(1);
634            self.progress_timeout_watchers
635                .write()
636                .await
637                .insert(progress_token.clone(), sender);
638            Some(receiver)
639        } else {
640            None
641        };
642        if self
643            .tx
644            .send(PeerSinkMessage::Request {
645                request,
646                id: id.clone(),
647                responder,
648            })
649            .await
650            .is_err()
651        {
652            if progress_reset_rx.is_some() {
653                self.progress_timeout_watchers
654                    .write()
655                    .await
656                    .remove(&progress_token);
657            }
658            return Err(ServiceError::TransportClosed);
659        }
660        Ok(RequestHandle {
661            id,
662            rx: receiver,
663            progress_token,
664            options,
665            peer: self.clone(),
666            progress_reset_rx,
667        })
668    }
669
670    async fn notify_progress_timeout_watcher(&self, progress_token: &ProgressToken) {
671        let sender = self
672            .progress_timeout_watchers
673            .read()
674            .await
675            .get(progress_token)
676            .cloned();
677        if let Some(sender) = sender {
678            match sender.try_send(()) {
679                Ok(()) => {}
680                Err(mpsc::error::TrySendError::Full(_)) => {
681                    tracing::trace!(?progress_token, "progress timeout watcher channel is full");
682                }
683                Err(mpsc::error::TrySendError::Closed(_)) => {
684                    self.progress_timeout_watchers
685                        .write()
686                        .await
687                        .remove(progress_token);
688                }
689            }
690        }
691    }
692
693    /// Snapshot of the peer's handshake info.
694    pub fn peer_info(&self) -> Option<Arc<R::PeerInfo>> {
695        self.info.read().expect("peer info lock poisoned").clone()
696    }
697
698    /// Stores the peer's handshake info, overwriting any previous value.
699    pub fn set_peer_info(&self, info: R::PeerInfo) {
700        *self.info.write().expect("peer info lock poisoned") = Some(Arc::new(info));
701    }
702
703    pub fn is_transport_closed(&self) -> bool {
704        self.tx.is_closed()
705    }
706}
707
708#[derive(Debug)]
709pub struct RunningService<R: ServiceRole, S: Service<R>> {
710    service: Arc<S>,
711    peer: Peer<R>,
712    handle: Option<tokio::task::JoinHandle<QuitReason>>,
713    cancellation_token: CancellationToken,
714    dg: DropGuard,
715}
716impl<R: ServiceRole, S: Service<R>> Deref for RunningService<R, S> {
717    type Target = Peer<R>;
718
719    fn deref(&self) -> &Self::Target {
720        &self.peer
721    }
722}
723
724impl<R: ServiceRole, S: Service<R>> RunningService<R, S> {
725    #[inline]
726    pub fn peer(&self) -> &Peer<R> {
727        &self.peer
728    }
729    #[inline]
730    pub fn service(&self) -> &S {
731        self.service.as_ref()
732    }
733    #[inline]
734    pub fn cancellation_token(&self) -> RunningServiceCancellationToken {
735        RunningServiceCancellationToken(self.cancellation_token.clone())
736    }
737
738    /// Returns true if the service has been closed or cancelled.
739    #[inline]
740    pub fn is_closed(&self) -> bool {
741        self.handle.is_none() || self.cancellation_token.is_cancelled()
742    }
743
744    /// Wait for the service to complete.
745    ///
746    /// This will block until the service loop terminates (either due to
747    /// cancellation, transport closure, or an error).
748    #[inline]
749    pub async fn waiting(mut self) -> Result<QuitReason, tokio::task::JoinError> {
750        match self.handle.take() {
751            Some(handle) => handle.await,
752            None => Ok(QuitReason::Closed),
753        }
754    }
755
756    /// Gracefully close the connection and wait for cleanup to complete.
757    ///
758    /// This method cancels the service, waits for the background task to finish
759    /// (which includes calling `transport.close()`), and ensures all cleanup
760    /// operations complete before returning.
761    ///
762    /// Unlike [`cancel`](Self::cancel), this method takes `&mut self` and can be
763    /// called without consuming the `RunningService`. After calling this method,
764    /// the service is considered closed and subsequent operations will fail.
765    ///
766    /// # Example
767    ///
768    /// ```rust,ignore
769    /// let mut client = ().serve(transport).await?;
770    /// // ... use the client ...
771    /// client.close().await?;
772    /// ```
773    pub async fn close(&mut self) -> Result<QuitReason, tokio::task::JoinError> {
774        if let Some(handle) = self.handle.take() {
775            // Disarm the drop guard so it doesn't try to cancel again
776            // We need to cancel manually and wait for completion
777            self.cancellation_token.cancel();
778            handle.await
779        } else {
780            // Already closed
781            Ok(QuitReason::Closed)
782        }
783    }
784
785    /// Gracefully close the connection with a timeout.
786    ///
787    /// Similar to [`close`](Self::close), but returns after the specified timeout
788    /// if the cleanup doesn't complete in time. This is useful for ensuring
789    /// a bounded shutdown time.
790    ///
791    /// Returns `Ok(Some(reason))` if shutdown completed within the timeout,
792    /// `Ok(None)` if the timeout was reached, or `Err` if there was a join error.
793    pub async fn close_with_timeout(
794        &mut self,
795        timeout: Duration,
796    ) -> Result<Option<QuitReason>, tokio::task::JoinError> {
797        if let Some(handle) = self.handle.take() {
798            self.cancellation_token.cancel();
799            match tokio::time::timeout(timeout, handle).await {
800                Ok(result) => result.map(Some),
801                Err(_elapsed) => {
802                    tracing::warn!(
803                        "close_with_timeout: cleanup did not complete within {:?}",
804                        timeout
805                    );
806                    Ok(None)
807                }
808            }
809        } else {
810            Ok(Some(QuitReason::Closed))
811        }
812    }
813
814    /// Cancel the service and wait for cleanup to complete.
815    ///
816    /// This consumes the `RunningService` and ensures the connection is properly
817    /// closed. For a non-consuming alternative, see [`close`](Self::close).
818    pub async fn cancel(mut self) -> Result<QuitReason, tokio::task::JoinError> {
819        // Disarm the drop guard since we're handling cancellation explicitly
820        let _ = std::mem::replace(&mut self.dg, self.cancellation_token.clone().drop_guard());
821        self.close().await
822    }
823}
824
825impl<R: ServiceRole, S: Service<R>> Drop for RunningService<R, S> {
826    fn drop(&mut self) {
827        if self.handle.is_some() && !self.cancellation_token.is_cancelled() {
828            tracing::debug!(
829                "RunningService dropped without explicit close(). \
830                 The connection will be closed asynchronously. \
831                 For guaranteed cleanup, call close() or cancel() before dropping."
832            );
833        }
834        // The DropGuard will handle cancellation
835    }
836}
837
838// use a wrapper type so we can tweak the implementation if needed
839pub struct RunningServiceCancellationToken(CancellationToken);
840
841impl RunningServiceCancellationToken {
842    pub fn cancel(self) {
843        self.0.cancel();
844    }
845}
846
847#[derive(Debug)]
848#[non_exhaustive]
849pub enum QuitReason {
850    Cancelled,
851    Closed,
852    JoinError(tokio::task::JoinError),
853}
854
855/// Request execution context
856#[derive(Debug, Clone)]
857#[non_exhaustive]
858pub struct RequestContext<R: ServiceRole> {
859    /// this token will be cancelled when the [`CancelledNotification`] is received.
860    pub ct: CancellationToken,
861    pub id: RequestId,
862    pub meta: Meta,
863    pub extensions: Extensions,
864    /// An interface to fetch the remote client or server
865    pub peer: Peer<R>,
866}
867
868impl<R: ServiceRole> RequestContext<R> {
869    /// Create a new RequestContext.
870    pub fn new(id: RequestId, peer: Peer<R>) -> Self {
871        Self {
872            ct: CancellationToken::new(),
873            id,
874            meta: Meta::default(),
875            extensions: Extensions::default(),
876            peer,
877        }
878    }
879}
880
881#[cfg(feature = "server")]
882impl RequestContext<RoleServer> {
883    /// The protocol version the client negotiated, or `None` before peer info is recorded.
884    pub fn protocol_version(&self) -> Option<crate::model::ProtocolVersion> {
885        self.peer
886            .peer_info()
887            .map(|info| info.protocol_version.clone())
888    }
889}
890
891/// Request execution context
892#[derive(Debug, Clone)]
893#[non_exhaustive]
894pub struct NotificationContext<R: ServiceRole> {
895    pub meta: Meta,
896    pub extensions: Extensions,
897    /// An interface to fetch the remote client or server
898    pub peer: Peer<R>,
899}
900
901/// Use this function to skip initialization process
902pub fn serve_directly<R, S, T, E, A>(
903    service: S,
904    transport: T,
905    peer_info: Option<R::PeerInfo>,
906) -> RunningService<R, S>
907where
908    R: ServiceRole,
909    R::PeerNot: ProgressNotificationToken,
910    S: Service<R>,
911    T: IntoTransport<R, E, A>,
912    E: std::error::Error + Send + Sync + 'static,
913{
914    serve_directly_with_ct(service, transport, peer_info, Default::default())
915}
916
917/// Use this function to skip initialization process
918pub fn serve_directly_with_ct<R, S, T, E, A>(
919    service: S,
920    transport: T,
921    peer_info: Option<R::PeerInfo>,
922    ct: CancellationToken,
923) -> RunningService<R, S>
924where
925    R: ServiceRole,
926    R::PeerNot: ProgressNotificationToken,
927    S: Service<R>,
928    T: IntoTransport<R, E, A>,
929    E: std::error::Error + Send + Sync + 'static,
930{
931    let (peer, peer_rx) = Peer::new(Arc::new(AtomicU32RequestIdProvider::default()), peer_info);
932    serve_inner(service, transport.into_transport(), peer, peer_rx, ct)
933}
934
935/// Spawn a task that may hold `!Send` state when the `local` feature is active.
936///
937/// Without the `local` feature this is `tokio::spawn` (requires `Future: Send + 'static`).
938/// With `local` it uses `tokio::task::spawn_local` (requires only `Future: 'static`).
939#[cfg(not(feature = "local"))]
940fn spawn_service_task<F>(future: F) -> tokio::task::JoinHandle<F::Output>
941where
942    F: Future + Send + 'static,
943    F::Output: Send + 'static,
944{
945    tokio::spawn(future)
946}
947
948#[cfg(feature = "local")]
949fn spawn_service_task<F>(future: F) -> tokio::task::JoinHandle<F::Output>
950where
951    F: Future + 'static,
952    F::Output: 'static,
953{
954    tokio::task::spawn_local(future)
955}
956
957#[instrument(skip_all)]
958fn serve_inner<R, S, T>(
959    service: S,
960    transport: T,
961    peer: Peer<R>,
962    mut peer_rx: tokio::sync::mpsc::Receiver<PeerSinkMessage<R>>,
963    ct: CancellationToken,
964) -> RunningService<R, S>
965where
966    R: ServiceRole,
967    R::PeerNot: ProgressNotificationToken,
968    S: Service<R>,
969    T: Transport<R> + 'static,
970{
971    const SINK_PROXY_BUFFER_SIZE: usize = 64;
972    let (sink_proxy_tx, mut sink_proxy_rx) =
973        tokio::sync::mpsc::channel::<TxJsonRpcMessage<R>>(SINK_PROXY_BUFFER_SIZE);
974    let peer_info = peer.peer_info();
975    if R::IS_CLIENT {
976        tracing::info!(?peer_info, "Service initialized as client");
977    } else {
978        tracing::info!(?peer_info, "Service initialized as server");
979    }
980
981    let mut local_responder_pool =
982        HashMap::<RequestId, Responder<Result<R::PeerResp, ServiceError>>>::new();
983    let mut local_ct_pool = HashMap::<RequestId, CancellationToken>::new();
984    let shared_service = Arc::new(service);
985    // for return
986    let service = shared_service.clone();
987
988    // let message_sink = tokio::sync::
989    // let mut stream = std::pin::pin!(stream);
990    let serve_loop_ct = ct.child_token();
991    let peer_return: Peer<R> = peer.clone();
992    let current_span = tracing::Span::current();
993    let handle = spawn_service_task(async move {
994        let mut transport = transport.into_transport();
995        let mut batch_messages = VecDeque::<RxJsonRpcMessage<R>>::new();
996        let mut send_task_set = tokio::task::JoinSet::<SendTaskResult>::new();
997        let mut response_send_tasks = tokio::task::JoinSet::<()>::new();
998        #[derive(Debug)]
999        enum SendTaskResult {
1000            Request {
1001                id: RequestId,
1002                result: Result<(), DynamicTransportError>,
1003            },
1004            Notification {
1005                responder: Responder<Result<(), ServiceError>>,
1006                cancellation_param: Option<CancelledNotificationParam>,
1007                result: Result<(), DynamicTransportError>,
1008            },
1009        }
1010        #[derive(Debug)]
1011        enum Event<R: ServiceRole> {
1012            ProxyMessage(PeerSinkMessage<R>),
1013            PeerMessage(RxJsonRpcMessage<R>),
1014            ToSink(TxJsonRpcMessage<R>),
1015            SendTaskResult(SendTaskResult),
1016        }
1017
1018        let quit_reason = loop {
1019            let evt = if let Some(m) = batch_messages.pop_front() {
1020                Event::PeerMessage(m)
1021            } else {
1022                tokio::select! {
1023                    m = sink_proxy_rx.recv(), if !sink_proxy_rx.is_closed() => {
1024                        if let Some(m) = m {
1025                            Event::ToSink(m)
1026                        } else {
1027                            continue
1028                        }
1029                    }
1030                    m = transport.receive() => {
1031                        if let Some(m) = m {
1032                            Event::PeerMessage(m)
1033                        } else {
1034                            // input stream closed
1035                            tracing::info!("input stream terminated");
1036                            break QuitReason::Closed
1037                        }
1038                    }
1039                    m = peer_rx.recv(), if !peer_rx.is_closed() => {
1040                        if let Some(m) = m {
1041                            Event::ProxyMessage(m)
1042                        } else {
1043                            continue
1044                        }
1045                    }
1046                    m = send_task_set.join_next(), if !send_task_set.is_empty() => {
1047                        let Some(result) = m else {
1048                            continue
1049                        };
1050                        match result {
1051                            Err(e) => {
1052                                // join error, which is serious, we should quit.
1053                                tracing::error!(%e, "send request task encounter a tokio join error");
1054                                break QuitReason::JoinError(e)
1055                            }
1056                            Ok(result) => {
1057                                Event::SendTaskResult(result)
1058                            }
1059                        }
1060                    }
1061                    _ = serve_loop_ct.cancelled() => {
1062                        tracing::info!("task cancelled");
1063                        break QuitReason::Cancelled
1064                    }
1065                }
1066            };
1067
1068            tracing::trace!(?evt, "new event");
1069            match evt {
1070                Event::SendTaskResult(SendTaskResult::Request { id, result }) => {
1071                    if let Err(e) = result {
1072                        if let Some(responder) = local_responder_pool.remove(&id) {
1073                            let _ = responder.send(Err(ServiceError::TransportSend(e)));
1074                        }
1075                    }
1076                }
1077                Event::SendTaskResult(SendTaskResult::Notification {
1078                    responder,
1079                    result,
1080                    cancellation_param,
1081                }) => {
1082                    let response = if let Err(e) = result {
1083                        Err(ServiceError::TransportSend(e))
1084                    } else {
1085                        Ok(())
1086                    };
1087                    let _ = responder.send(response);
1088                    if let Some(param) = cancellation_param {
1089                        if let Some(request_id) = &param.request_id {
1090                            if let Some(responder) = local_responder_pool.remove(request_id) {
1091                                tracing::info!(id = %request_id, reason = param.reason, "cancelled");
1092                                let _response_result = responder.send(Err(ServiceError::Cancelled {
1093                                    reason: param.reason.clone(),
1094                                }));
1095                            }
1096                        }
1097                    }
1098                }
1099                // response and error
1100                Event::ToSink(m) => {
1101                    if let Some(id) = match &m {
1102                        JsonRpcMessage::Response(response) => Some(&response.id),
1103                        JsonRpcMessage::Error(error) => error.id.as_ref(),
1104                        _ => None,
1105                    } {
1106                        if let Some(ct) = local_ct_pool.remove(id) {
1107                            ct.cancel();
1108                        }
1109                        let send = transport.send(m);
1110                        let current_span = tracing::Span::current();
1111                        response_send_tasks.spawn(async move {
1112                            let send_result = send.await;
1113                            if let Err(error) = send_result {
1114                                tracing::error!(%error, "fail to response message");
1115                            }
1116                        }.instrument(current_span));
1117                    }
1118                }
1119                Event::ProxyMessage(PeerSinkMessage::Request {
1120                    request,
1121                    id,
1122                    responder,
1123                }) => {
1124                    local_responder_pool.insert(id.clone(), responder);
1125                    let send = transport.send(JsonRpcMessage::request(request, id.clone()));
1126                    {
1127                        let id = id.clone();
1128                        let current_span = tracing::Span::current();
1129                        send_task_set.spawn(send.map(move |r| SendTaskResult::Request {
1130                            id,
1131                            result: r.map_err(DynamicTransportError::new::<T, R>),
1132                        }).instrument(current_span));
1133                    }
1134                }
1135                Event::ProxyMessage(PeerSinkMessage::Notification {
1136                    notification,
1137                    responder,
1138                }) => {
1139                    // catch cancellation notification
1140                    let mut cancellation_param = None;
1141                    let notification = match notification.try_into() {
1142                        Ok::<CancelledNotification, _>(cancelled) => {
1143                            cancellation_param.replace(cancelled.params.clone());
1144                            cancelled.into()
1145                        }
1146                        Err(notification) => notification,
1147                    };
1148                    let send = transport.send(JsonRpcMessage::notification(notification));
1149                    let current_span = tracing::Span::current();
1150                    send_task_set.spawn(send.map(move |result| SendTaskResult::Notification {
1151                        responder,
1152                        cancellation_param,
1153                        result: result.map_err(DynamicTransportError::new::<T, R>),
1154                    }).instrument(current_span));
1155                }
1156                Event::PeerMessage(JsonRpcMessage::Request(JsonRpcRequest {
1157                    id,
1158                    mut request,
1159                    ..
1160                })) => {
1161                    tracing::debug!(%id, ?request, "received request");
1162                    {
1163                        let service = shared_service.clone();
1164                        let sink = sink_proxy_tx.clone();
1165                        let request_ct = serve_loop_ct.child_token();
1166                        let context_ct = request_ct.child_token();
1167                        local_ct_pool.insert(id.clone(), request_ct);
1168                        let mut extensions = Extensions::new();
1169                        let mut meta = Meta::new();
1170                        // avoid clone
1171                        // swap meta firstly, otherwise progress token will be lost
1172                        std::mem::swap(&mut meta, request.get_meta_mut());
1173                        std::mem::swap(&mut extensions, request.extensions_mut());
1174                        let context = RequestContext {
1175                            ct: context_ct,
1176                            id: id.clone(),
1177                            peer: peer.clone(),
1178                            meta,
1179                            extensions,
1180                        };
1181                        let current_span = tracing::Span::current();
1182                        spawn_service_task(async move {
1183                            let result = service
1184                                .handle_request(request, context)
1185                                .await;
1186                            let response = match result {
1187                                Ok(result) => {
1188                                    tracing::debug!(%id, ?result, "response message");
1189                                    JsonRpcMessage::response(result, id)
1190                                }
1191                                Err(error) => {
1192                                    tracing::warn!(%id, ?error, "response error");
1193                                    JsonRpcMessage::error(error, Some(id))
1194                                }
1195                            };
1196                            let _send_result = sink.send(response).await;
1197                        }.instrument(current_span));
1198                    }
1199                }
1200                Event::PeerMessage(JsonRpcMessage::Notification(JsonRpcNotification {
1201                    notification,
1202                    ..
1203                })) => {
1204                    tracing::info!(?notification, "received notification");
1205                    // catch cancelled notification
1206                    let mut notification = match notification.try_into() {
1207                        Ok::<CancelledNotification, _>(cancelled) => {
1208                            if let Some(request_id) = &cancelled.params.request_id {
1209                                if let Some(ct) = local_ct_pool.remove(request_id) {
1210                                    tracing::info!(id = %request_id, reason = cancelled.params.reason, "cancelled");
1211                                    ct.cancel();
1212                                }
1213                            }
1214                            cancelled.into()
1215                        }
1216                        Err(notification) => notification,
1217                    };
1218                    if let Some(progress_token) = notification.progress_token() {
1219                        peer.notify_progress_timeout_watcher(progress_token).await;
1220                    }
1221                    {
1222                        let service = shared_service.clone();
1223                        let mut extensions = Extensions::new();
1224                        let mut meta = Meta::new();
1225                        // avoid clone
1226                        std::mem::swap(&mut extensions, notification.extensions_mut());
1227                        std::mem::swap(&mut meta, notification.get_meta_mut());
1228                        let context = NotificationContext {
1229                            peer: peer.clone(),
1230                            meta,
1231                            extensions,
1232                        };
1233                        let current_span = tracing::Span::current();
1234                        spawn_service_task(async move {
1235                            let result = service.handle_notification(notification, context).await;
1236                            if let Err(error) = result {
1237                                tracing::warn!(%error, "Error sending notification");
1238                            }
1239                        }.instrument(current_span));
1240                    }
1241                }
1242                Event::PeerMessage(JsonRpcMessage::Response(JsonRpcResponse {
1243                    result,
1244                    id,
1245                    ..
1246                })) => {
1247                    if let Some(responder) = local_responder_pool.remove(&id) {
1248                        let response_result = responder.send(Ok(result));
1249                        if let Err(_error) = response_result {
1250                            tracing::warn!(%id, "Error sending response");
1251                        }
1252                    }
1253                }
1254                Event::PeerMessage(JsonRpcMessage::Error(JsonRpcError { error, id, .. })) => {
1255                    let Some(id) = id else {
1256                        // MCP error responses without an id (e.g. Parse error / Invalid Request)
1257                        // can't be routed back to a pending request — log and drop.
1258                        tracing::debug!(?error, "received id-less peer error");
1259                        continue;
1260                    };
1261                    if let Some(responder) = local_responder_pool.remove(&id) {
1262                        let _response_result = responder.send(Err(ServiceError::McpError(error)));
1263                        if let Err(_error) = _response_result {
1264                            tracing::warn!(%id, "Error sending response");
1265                        }
1266                    }
1267                }
1268            }
1269        };
1270
1271        // Drain in-flight handler responses before closing the transport.
1272        // When stdin EOF or cancellation arrives, spawned handler tasks may still
1273        // be finishing. We need to:
1274        // 1. Wait for response sends that were already spawned in the main loop
1275        // 2. Drain any remaining handler responses from the channel
1276        let drain_timeout = match &quit_reason {
1277            QuitReason::Closed => Some(Duration::from_secs(5)),
1278            QuitReason::Cancelled => Some(Duration::from_secs(2)),
1279            _ => None,
1280        };
1281        if let Some(timeout_duration) = drain_timeout {
1282            // Drop our sender so the channel closes once all handler task
1283            // clones finish sending their responses (or are dropped).
1284            drop(sink_proxy_tx);
1285            let drain_result = tokio::time::timeout(timeout_duration, async {
1286                // First, wait for any response sends already dispatched by the
1287                // main loop (these hold transport write futures).
1288                while let Some(result) = response_send_tasks.join_next().await {
1289                    if let Err(error) = result {
1290                        tracing::error!(%error, "response send task failed during drain");
1291                    }
1292                }
1293                // Then drain any handler responses still in the channel
1294                // (handlers that finished after the loop broke).
1295                while let Some(m) = sink_proxy_rx.recv().await {
1296                    if let Err(error) = transport.send(m).await {
1297                        tracing::error!(%error, "failed to send pending response during drain");
1298                        break;
1299                    }
1300                }
1301            })
1302            .await;
1303            if drain_result.is_err() {
1304                tracing::warn!("timed out draining in-flight responses");
1305            }
1306        }
1307
1308        let sink_close_result = transport.close().await;
1309        if let Err(e) = sink_close_result {
1310            tracing::error!(%e, "fail to close sink");
1311        }
1312        tracing::info!(?quit_reason, "serve finished");
1313        quit_reason
1314    }.instrument(current_span));
1315    RunningService {
1316        service,
1317        peer: peer_return,
1318        handle: Some(handle),
1319        cancellation_token: ct.clone(),
1320        dg: ct.drop_guard(),
1321    }
1322}