Skip to main content

rmcp/
service.rs

1use futures::FutureExt;
2#[cfg(not(feature = "local"))]
3use futures::future::BoxFuture;
4#[cfg(feature = "local")]
5use futures::future::LocalBoxFuture;
6use thiserror::Error;
7
8// ---------------------------------------------------------------------------
9// Conditional Send helpers
10//
11// `MaybeSend`       – supertrait alias: `Send + Sync` without `local`, empty with `local`
12// `MaybeSendFuture` – future bound alias: `Send` without `local`, empty with `local`
13// `MaybeBoxFuture`  – boxed future type: `BoxFuture` without `local`, `LocalBoxFuture` with `local`
14// ---------------------------------------------------------------------------
15
16#[cfg(not(feature = "local"))]
17#[doc(hidden)]
18pub trait MaybeSend: Send + Sync {}
19#[cfg(not(feature = "local"))]
20impl<T: Send + Sync> MaybeSend for T {}
21
22#[cfg(feature = "local")]
23#[doc(hidden)]
24pub trait MaybeSend {}
25#[cfg(feature = "local")]
26impl<T> MaybeSend for T {}
27
28#[cfg(not(feature = "local"))]
29#[doc(hidden)]
30pub trait MaybeSendFuture: Send {}
31#[cfg(not(feature = "local"))]
32impl<T: Send> MaybeSendFuture for T {}
33
34#[cfg(feature = "local")]
35#[doc(hidden)]
36pub trait MaybeSendFuture {}
37#[cfg(feature = "local")]
38impl<T> MaybeSendFuture for T {}
39
40#[cfg(not(feature = "local"))]
41pub(crate) type MaybeBoxFuture<'a, T> = BoxFuture<'a, T>;
42#[cfg(feature = "local")]
43pub(crate) type MaybeBoxFuture<'a, T> = LocalBoxFuture<'a, T>;
44
45#[cfg(feature = "server")]
46use crate::model::ServerJsonRpcMessage;
47use crate::{
48    error::ErrorData as McpError,
49    model::{
50        CancelledNotification, CancelledNotificationParam, Extensions, GetExtensions, GetMeta,
51        JsonRpcError, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, Meta,
52        NumberOrString, ProgressToken, RequestId,
53    },
54    transport::{DynamicTransportError, IntoTransport, Transport},
55};
56#[cfg(feature = "client")]
57mod client;
58#[cfg(feature = "client")]
59pub use client::*;
60#[cfg(feature = "server")]
61mod server;
62#[cfg(feature = "server")]
63pub use server::*;
64#[cfg(feature = "tower")]
65mod tower;
66use tokio_util::sync::{CancellationToken, DropGuard};
67#[cfg(feature = "tower")]
68pub use tower::*;
69use tracing::{Instrument as _, instrument};
70#[derive(Error, Debug)]
71#[non_exhaustive]
72pub enum ServiceError {
73    #[error("Mcp error: {0}")]
74    McpError(McpError),
75    #[error("Transport send error: {0}")]
76    TransportSend(DynamicTransportError),
77    #[error("Transport closed")]
78    TransportClosed,
79    #[error("Unexpected response type")]
80    UnexpectedResponse,
81    #[error("task cancelled for reason {}", reason.as_deref().unwrap_or("<unknown>"))]
82    Cancelled { reason: Option<String> },
83    #[error("request timeout after {}", chrono::Duration::from_std(*timeout).unwrap_or_default())]
84    Timeout { timeout: Duration },
85}
86
87trait TransferObject:
88    std::fmt::Debug + Clone + serde::Serialize + serde::de::DeserializeOwned + Send + Sync + 'static
89{
90}
91
92impl<T> TransferObject for T where
93    T: std::fmt::Debug
94        + serde::Serialize
95        + serde::de::DeserializeOwned
96        + Send
97        + Sync
98        + 'static
99        + Clone
100{
101}
102
103#[allow(private_bounds, reason = "there's no the third implementation")]
104pub trait ServiceRole: std::fmt::Debug + Send + Sync + 'static + Copy + Clone {
105    type Req: TransferObject + GetMeta + GetExtensions;
106    type Resp: TransferObject;
107    type Not: TryInto<CancelledNotification, Error = Self::Not>
108        + From<CancelledNotification>
109        + TransferObject;
110    type PeerReq: TransferObject + GetMeta + GetExtensions;
111    type PeerResp: TransferObject;
112    type PeerNot: TryInto<CancelledNotification, Error = Self::PeerNot>
113        + From<CancelledNotification>
114        + TransferObject
115        + GetMeta
116        + GetExtensions;
117    type InitializeError;
118    const IS_CLIENT: bool;
119    type Info: TransferObject;
120    type PeerInfo: TransferObject;
121}
122
123pub type TxJsonRpcMessage<R> =
124    JsonRpcMessage<<R as ServiceRole>::Req, <R as ServiceRole>::Resp, <R as ServiceRole>::Not>;
125pub type RxJsonRpcMessage<R> = JsonRpcMessage<
126    <R as ServiceRole>::PeerReq,
127    <R as ServiceRole>::PeerResp,
128    <R as ServiceRole>::PeerNot,
129>;
130
131#[cfg(not(feature = "local"))]
132pub trait Service<R: ServiceRole>: Send + Sync + 'static {
133    fn handle_request(
134        &self,
135        request: R::PeerReq,
136        context: RequestContext<R>,
137    ) -> impl Future<Output = Result<R::Resp, McpError>> + MaybeSendFuture + '_;
138    fn handle_notification(
139        &self,
140        notification: R::PeerNot,
141        context: NotificationContext<R>,
142    ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_;
143    fn get_info(&self) -> R::Info;
144}
145
146#[cfg(feature = "local")]
147pub trait Service<R: ServiceRole>: 'static {
148    fn handle_request(
149        &self,
150        request: R::PeerReq,
151        context: RequestContext<R>,
152    ) -> impl Future<Output = Result<R::Resp, McpError>> + MaybeSendFuture + '_;
153    fn handle_notification(
154        &self,
155        notification: R::PeerNot,
156        context: NotificationContext<R>,
157    ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_;
158    fn get_info(&self) -> R::Info;
159}
160
161pub trait ServiceExt<R: ServiceRole>: Service<R> + Sized {
162    /// Convert this service to a dynamic boxed service
163    ///
164    /// This could be very helpful when you want to store the services in a collection
165    fn into_dyn(self) -> Box<dyn DynService<R>> {
166        Box::new(self)
167    }
168    fn serve<T, E, A>(
169        self,
170        transport: T,
171    ) -> impl Future<Output = Result<RunningService<R, Self>, R::InitializeError>> + MaybeSendFuture
172    where
173        T: IntoTransport<R, E, A>,
174        E: std::error::Error + Send + Sync + 'static,
175        Self: Sized,
176    {
177        Self::serve_with_ct(self, transport, Default::default())
178    }
179    fn serve_with_ct<T, E, A>(
180        self,
181        transport: T,
182        ct: CancellationToken,
183    ) -> impl Future<Output = Result<RunningService<R, Self>, R::InitializeError>> + MaybeSendFuture
184    where
185        T: IntoTransport<R, E, A>,
186        E: std::error::Error + Send + Sync + 'static,
187        Self: Sized;
188}
189
190impl<R: ServiceRole> Service<R> for Box<dyn DynService<R>> {
191    fn handle_request(
192        &self,
193        request: R::PeerReq,
194        context: RequestContext<R>,
195    ) -> impl Future<Output = Result<R::Resp, McpError>> + MaybeSendFuture + '_ {
196        DynService::handle_request(self.as_ref(), request, context)
197    }
198
199    fn handle_notification(
200        &self,
201        notification: R::PeerNot,
202        context: NotificationContext<R>,
203    ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ {
204        DynService::handle_notification(self.as_ref(), notification, context)
205    }
206
207    fn get_info(&self) -> R::Info {
208        DynService::get_info(self.as_ref())
209    }
210}
211
212#[cfg(not(feature = "local"))]
213pub trait DynService<R: ServiceRole>: Send + Sync {
214    fn handle_request(
215        &self,
216        request: R::PeerReq,
217        context: RequestContext<R>,
218    ) -> MaybeBoxFuture<'_, Result<R::Resp, McpError>>;
219    fn handle_notification(
220        &self,
221        notification: R::PeerNot,
222        context: NotificationContext<R>,
223    ) -> MaybeBoxFuture<'_, Result<(), McpError>>;
224    fn get_info(&self) -> R::Info;
225}
226
227#[cfg(feature = "local")]
228pub trait DynService<R: ServiceRole> {
229    fn handle_request(
230        &self,
231        request: R::PeerReq,
232        context: RequestContext<R>,
233    ) -> MaybeBoxFuture<'_, Result<R::Resp, McpError>>;
234    fn handle_notification(
235        &self,
236        notification: R::PeerNot,
237        context: NotificationContext<R>,
238    ) -> MaybeBoxFuture<'_, Result<(), McpError>>;
239    fn get_info(&self) -> R::Info;
240}
241
242impl<R: ServiceRole, S: Service<R>> DynService<R> for S {
243    fn handle_request(
244        &self,
245        request: R::PeerReq,
246        context: RequestContext<R>,
247    ) -> MaybeBoxFuture<'_, Result<R::Resp, McpError>> {
248        Box::pin(self.handle_request(request, context))
249    }
250    fn handle_notification(
251        &self,
252        notification: R::PeerNot,
253        context: NotificationContext<R>,
254    ) -> MaybeBoxFuture<'_, Result<(), McpError>> {
255        Box::pin(self.handle_notification(notification, context))
256    }
257    fn get_info(&self) -> R::Info {
258        self.get_info()
259    }
260}
261
262use std::{
263    collections::{HashMap, VecDeque},
264    ops::Deref,
265    sync::{Arc, atomic::AtomicU64},
266    time::Duration,
267};
268
269use tokio::sync::mpsc;
270
271pub trait RequestIdProvider: Send + Sync + 'static {
272    fn next_request_id(&self) -> RequestId;
273}
274
275pub trait ProgressTokenProvider: Send + Sync + 'static {
276    fn next_progress_token(&self) -> ProgressToken;
277}
278
279pub type AtomicU32RequestIdProvider = AtomicU32Provider;
280pub type AtomicU32ProgressTokenProvider = AtomicU32Provider;
281
282#[derive(Debug, Default)]
283pub struct AtomicU32Provider {
284    id: AtomicU64,
285}
286
287impl RequestIdProvider for AtomicU32Provider {
288    fn next_request_id(&self) -> RequestId {
289        let id = self.id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
290        // Safe conversion: we start at 0 and increment by 1, so we won't overflow i64::MAX in practice
291        RequestId::Number(id as i64)
292    }
293}
294
295impl ProgressTokenProvider for AtomicU32Provider {
296    fn next_progress_token(&self) -> ProgressToken {
297        let id = self.id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
298        ProgressToken(NumberOrString::Number(id as i64))
299    }
300}
301
302type Responder<T> = tokio::sync::oneshot::Sender<T>;
303
304/// A handle to a remote request
305///
306/// You can cancel it by call [`RequestHandle::cancel`] with a reason,
307///
308/// or wait for response by call [`RequestHandle::await_response`]
309#[derive(Debug)]
310#[non_exhaustive]
311pub struct RequestHandle<R: ServiceRole> {
312    pub rx: tokio::sync::oneshot::Receiver<Result<R::PeerResp, ServiceError>>,
313    pub options: PeerRequestOptions,
314    pub peer: Peer<R>,
315    pub id: RequestId,
316    pub progress_token: ProgressToken,
317}
318
319impl<R: ServiceRole> RequestHandle<R> {
320    pub const REQUEST_TIMEOUT_REASON: &str = "request timeout";
321    pub async fn await_response(self) -> Result<R::PeerResp, ServiceError> {
322        if let Some(timeout) = self.options.timeout {
323            let timeout_result = tokio::time::timeout(timeout, async move {
324                self.rx.await.map_err(|_e| ServiceError::TransportClosed)?
325            })
326            .await;
327            match timeout_result {
328                Ok(response) => response,
329                Err(_) => {
330                    let error = Err(ServiceError::Timeout { timeout });
331                    // cancel this request
332                    let notification = CancelledNotification {
333                        params: CancelledNotificationParam {
334                            request_id: self.id,
335                            reason: Some(Self::REQUEST_TIMEOUT_REASON.to_owned()),
336                        },
337                        method: crate::model::CancelledNotificationMethod,
338                        extensions: Default::default(),
339                    };
340                    let _ = self.peer.send_notification(notification.into()).await;
341                    error
342                }
343            }
344        } else {
345            self.rx.await.map_err(|_e| ServiceError::TransportClosed)?
346        }
347    }
348
349    /// Cancel this request
350    pub async fn cancel(self, reason: Option<String>) -> Result<(), ServiceError> {
351        let notification = CancelledNotification {
352            params: CancelledNotificationParam {
353                request_id: self.id,
354                reason,
355            },
356            method: crate::model::CancelledNotificationMethod,
357            extensions: Default::default(),
358        };
359        self.peer.send_notification(notification.into()).await?;
360        Ok(())
361    }
362}
363
364#[derive(Debug)]
365pub(crate) enum PeerSinkMessage<R: ServiceRole> {
366    Request {
367        request: R::Req,
368        id: RequestId,
369        responder: Responder<Result<R::PeerResp, ServiceError>>,
370    },
371    Notification {
372        notification: R::Not,
373        responder: Responder<Result<(), ServiceError>>,
374    },
375}
376
377/// An interface to fetch the remote client or server
378///
379/// For general purpose, call [`Peer::send_request`] or [`Peer::send_notification`] to send message to remote peer.
380///
381/// To create a cancellable request, call [`Peer::send_request_with_option`].
382#[derive(Clone)]
383pub struct Peer<R: ServiceRole> {
384    tx: mpsc::Sender<PeerSinkMessage<R>>,
385    request_id_provider: Arc<dyn RequestIdProvider>,
386    progress_token_provider: Arc<dyn ProgressTokenProvider>,
387    info: Arc<tokio::sync::OnceCell<R::PeerInfo>>,
388}
389
390impl<R: ServiceRole> std::fmt::Debug for Peer<R> {
391    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
392        f.debug_struct("PeerSink")
393            .field("tx", &self.tx)
394            .field("is_client", &R::IS_CLIENT)
395            .finish()
396    }
397}
398
399type ProxyOutbound<R> = mpsc::Receiver<PeerSinkMessage<R>>;
400
401#[derive(Debug, Default)]
402#[non_exhaustive]
403pub struct PeerRequestOptions {
404    pub timeout: Option<Duration>,
405    pub meta: Option<Meta>,
406}
407
408impl PeerRequestOptions {
409    pub fn no_options() -> Self {
410        Self::default()
411    }
412}
413
414impl<R: ServiceRole> Peer<R> {
415    const CLIENT_CHANNEL_BUFFER_SIZE: usize = 1024;
416    pub(crate) fn new(
417        request_id_provider: Arc<dyn RequestIdProvider>,
418        peer_info: Option<R::PeerInfo>,
419    ) -> (Peer<R>, ProxyOutbound<R>) {
420        let (tx, rx) = mpsc::channel(Self::CLIENT_CHANNEL_BUFFER_SIZE);
421        (
422            Self {
423                tx,
424                request_id_provider,
425                progress_token_provider: Arc::new(AtomicU32ProgressTokenProvider::default()),
426                info: Arc::new(tokio::sync::OnceCell::new_with(peer_info)),
427            },
428            rx,
429        )
430    }
431    pub async fn send_notification(&self, notification: R::Not) -> Result<(), ServiceError> {
432        let (responder, receiver) = tokio::sync::oneshot::channel();
433        self.tx
434            .send(PeerSinkMessage::Notification {
435                notification,
436                responder,
437            })
438            .await
439            .map_err(|_m| ServiceError::TransportClosed)?;
440        receiver.await.map_err(|_e| ServiceError::TransportClosed)?
441    }
442    pub async fn send_request(&self, request: R::Req) -> Result<R::PeerResp, ServiceError> {
443        self.send_request_with_option(request, PeerRequestOptions::no_options())
444            .await?
445            .await_response()
446            .await
447    }
448
449    pub async fn send_cancellable_request(
450        &self,
451        request: R::Req,
452        options: PeerRequestOptions,
453    ) -> Result<RequestHandle<R>, ServiceError> {
454        self.send_request_with_option(request, options).await
455    }
456
457    pub async fn send_request_with_option(
458        &self,
459        mut request: R::Req,
460        options: PeerRequestOptions,
461    ) -> Result<RequestHandle<R>, ServiceError> {
462        let id = self.request_id_provider.next_request_id();
463        let progress_token = self.progress_token_provider.next_progress_token();
464        request
465            .get_meta_mut()
466            .set_progress_token(progress_token.clone());
467        if let Some(meta) = options.meta.clone() {
468            request.get_meta_mut().extend(meta);
469        }
470        let (responder, receiver) = tokio::sync::oneshot::channel();
471        self.tx
472            .send(PeerSinkMessage::Request {
473                request,
474                id: id.clone(),
475                responder,
476            })
477            .await
478            .map_err(|_m| ServiceError::TransportClosed)?;
479        Ok(RequestHandle {
480            id,
481            rx: receiver,
482            progress_token,
483            options,
484            peer: self.clone(),
485        })
486    }
487    pub fn peer_info(&self) -> Option<&R::PeerInfo> {
488        self.info.get()
489    }
490
491    pub fn set_peer_info(&self, info: R::PeerInfo) {
492        if self.info.initialized() {
493            tracing::warn!("trying to set peer info, which is already initialized");
494        } else {
495            let _ = self.info.set(info);
496        }
497    }
498
499    pub fn is_transport_closed(&self) -> bool {
500        self.tx.is_closed()
501    }
502}
503
504#[derive(Debug)]
505pub struct RunningService<R: ServiceRole, S: Service<R>> {
506    service: Arc<S>,
507    peer: Peer<R>,
508    handle: Option<tokio::task::JoinHandle<QuitReason>>,
509    cancellation_token: CancellationToken,
510    dg: DropGuard,
511}
512impl<R: ServiceRole, S: Service<R>> Deref for RunningService<R, S> {
513    type Target = Peer<R>;
514
515    fn deref(&self) -> &Self::Target {
516        &self.peer
517    }
518}
519
520impl<R: ServiceRole, S: Service<R>> RunningService<R, S> {
521    #[inline]
522    pub fn peer(&self) -> &Peer<R> {
523        &self.peer
524    }
525    #[inline]
526    pub fn service(&self) -> &S {
527        self.service.as_ref()
528    }
529    #[inline]
530    pub fn cancellation_token(&self) -> RunningServiceCancellationToken {
531        RunningServiceCancellationToken(self.cancellation_token.clone())
532    }
533
534    /// Returns true if the service has been closed or cancelled.
535    #[inline]
536    pub fn is_closed(&self) -> bool {
537        self.handle.is_none() || self.cancellation_token.is_cancelled()
538    }
539
540    /// Wait for the service to complete.
541    ///
542    /// This will block until the service loop terminates (either due to
543    /// cancellation, transport closure, or an error).
544    #[inline]
545    pub async fn waiting(mut self) -> Result<QuitReason, tokio::task::JoinError> {
546        match self.handle.take() {
547            Some(handle) => handle.await,
548            None => Ok(QuitReason::Closed),
549        }
550    }
551
552    /// Gracefully close the connection and wait for cleanup to complete.
553    ///
554    /// This method cancels the service, waits for the background task to finish
555    /// (which includes calling `transport.close()`), and ensures all cleanup
556    /// operations complete before returning.
557    ///
558    /// Unlike [`cancel`](Self::cancel), this method takes `&mut self` and can be
559    /// called without consuming the `RunningService`. After calling this method,
560    /// the service is considered closed and subsequent operations will fail.
561    ///
562    /// # Example
563    ///
564    /// ```rust,ignore
565    /// let mut client = ().serve(transport).await?;
566    /// // ... use the client ...
567    /// client.close().await?;
568    /// ```
569    pub async fn close(&mut self) -> Result<QuitReason, tokio::task::JoinError> {
570        if let Some(handle) = self.handle.take() {
571            // Disarm the drop guard so it doesn't try to cancel again
572            // We need to cancel manually and wait for completion
573            self.cancellation_token.cancel();
574            handle.await
575        } else {
576            // Already closed
577            Ok(QuitReason::Closed)
578        }
579    }
580
581    /// Gracefully close the connection with a timeout.
582    ///
583    /// Similar to [`close`](Self::close), but returns after the specified timeout
584    /// if the cleanup doesn't complete in time. This is useful for ensuring
585    /// a bounded shutdown time.
586    ///
587    /// Returns `Ok(Some(reason))` if shutdown completed within the timeout,
588    /// `Ok(None)` if the timeout was reached, or `Err` if there was a join error.
589    pub async fn close_with_timeout(
590        &mut self,
591        timeout: Duration,
592    ) -> Result<Option<QuitReason>, tokio::task::JoinError> {
593        if let Some(handle) = self.handle.take() {
594            self.cancellation_token.cancel();
595            match tokio::time::timeout(timeout, handle).await {
596                Ok(result) => result.map(Some),
597                Err(_elapsed) => {
598                    tracing::warn!(
599                        "close_with_timeout: cleanup did not complete within {:?}",
600                        timeout
601                    );
602                    Ok(None)
603                }
604            }
605        } else {
606            Ok(Some(QuitReason::Closed))
607        }
608    }
609
610    /// Cancel the service and wait for cleanup to complete.
611    ///
612    /// This consumes the `RunningService` and ensures the connection is properly
613    /// closed. For a non-consuming alternative, see [`close`](Self::close).
614    pub async fn cancel(mut self) -> Result<QuitReason, tokio::task::JoinError> {
615        // Disarm the drop guard since we're handling cancellation explicitly
616        let _ = std::mem::replace(&mut self.dg, self.cancellation_token.clone().drop_guard());
617        self.close().await
618    }
619}
620
621impl<R: ServiceRole, S: Service<R>> Drop for RunningService<R, S> {
622    fn drop(&mut self) {
623        if self.handle.is_some() && !self.cancellation_token.is_cancelled() {
624            tracing::debug!(
625                "RunningService dropped without explicit close(). \
626                 The connection will be closed asynchronously. \
627                 For guaranteed cleanup, call close() or cancel() before dropping."
628            );
629        }
630        // The DropGuard will handle cancellation
631    }
632}
633
634// use a wrapper type so we can tweak the implementation if needed
635pub struct RunningServiceCancellationToken(CancellationToken);
636
637impl RunningServiceCancellationToken {
638    pub fn cancel(self) {
639        self.0.cancel();
640    }
641}
642
643#[derive(Debug)]
644#[non_exhaustive]
645pub enum QuitReason {
646    Cancelled,
647    Closed,
648    JoinError(tokio::task::JoinError),
649}
650
651/// Request execution context
652#[derive(Debug, Clone)]
653#[non_exhaustive]
654pub struct RequestContext<R: ServiceRole> {
655    /// this token will be cancelled when the [`CancelledNotification`] is received.
656    pub ct: CancellationToken,
657    pub id: RequestId,
658    pub meta: Meta,
659    pub extensions: Extensions,
660    /// An interface to fetch the remote client or server
661    pub peer: Peer<R>,
662}
663
664impl<R: ServiceRole> RequestContext<R> {
665    /// Create a new RequestContext.
666    pub fn new(id: RequestId, peer: Peer<R>) -> Self {
667        Self {
668            ct: CancellationToken::new(),
669            id,
670            meta: Meta::default(),
671            extensions: Extensions::default(),
672            peer,
673        }
674    }
675}
676
677/// Request execution context
678#[derive(Debug, Clone)]
679#[non_exhaustive]
680pub struct NotificationContext<R: ServiceRole> {
681    pub meta: Meta,
682    pub extensions: Extensions,
683    /// An interface to fetch the remote client or server
684    pub peer: Peer<R>,
685}
686
687/// Use this function to skip initialization process
688pub fn serve_directly<R, S, T, E, A>(
689    service: S,
690    transport: T,
691    peer_info: Option<R::PeerInfo>,
692) -> RunningService<R, S>
693where
694    R: ServiceRole,
695    S: Service<R>,
696    T: IntoTransport<R, E, A>,
697    E: std::error::Error + Send + Sync + 'static,
698{
699    serve_directly_with_ct(service, transport, peer_info, Default::default())
700}
701
702/// Use this function to skip initialization process
703pub fn serve_directly_with_ct<R, S, T, E, A>(
704    service: S,
705    transport: T,
706    peer_info: Option<R::PeerInfo>,
707    ct: CancellationToken,
708) -> RunningService<R, S>
709where
710    R: ServiceRole,
711    S: Service<R>,
712    T: IntoTransport<R, E, A>,
713    E: std::error::Error + Send + Sync + 'static,
714{
715    let (peer, peer_rx) = Peer::new(Arc::new(AtomicU32RequestIdProvider::default()), peer_info);
716    serve_inner(service, transport.into_transport(), peer, peer_rx, ct)
717}
718
719/// Spawn a task that may hold `!Send` state when the `local` feature is active.
720///
721/// Without the `local` feature this is `tokio::spawn` (requires `Future: Send + 'static`).
722/// With `local` it uses `tokio::task::spawn_local` (requires only `Future: 'static`).
723#[cfg(not(feature = "local"))]
724fn spawn_service_task<F>(future: F) -> tokio::task::JoinHandle<F::Output>
725where
726    F: Future + Send + 'static,
727    F::Output: Send + 'static,
728{
729    tokio::spawn(future)
730}
731
732#[cfg(feature = "local")]
733fn spawn_service_task<F>(future: F) -> tokio::task::JoinHandle<F::Output>
734where
735    F: Future + 'static,
736    F::Output: 'static,
737{
738    tokio::task::spawn_local(future)
739}
740
741#[instrument(skip_all)]
742fn serve_inner<R, S, T>(
743    service: S,
744    transport: T,
745    peer: Peer<R>,
746    mut peer_rx: tokio::sync::mpsc::Receiver<PeerSinkMessage<R>>,
747    ct: CancellationToken,
748) -> RunningService<R, S>
749where
750    R: ServiceRole,
751    S: Service<R>,
752    T: Transport<R> + 'static,
753{
754    const SINK_PROXY_BUFFER_SIZE: usize = 64;
755    let (sink_proxy_tx, mut sink_proxy_rx) =
756        tokio::sync::mpsc::channel::<TxJsonRpcMessage<R>>(SINK_PROXY_BUFFER_SIZE);
757    let peer_info = peer.peer_info();
758    if R::IS_CLIENT {
759        tracing::info!(?peer_info, "Service initialized as client");
760    } else {
761        tracing::info!(?peer_info, "Service initialized as server");
762    }
763
764    let mut local_responder_pool =
765        HashMap::<RequestId, Responder<Result<R::PeerResp, ServiceError>>>::new();
766    let mut local_ct_pool = HashMap::<RequestId, CancellationToken>::new();
767    let shared_service = Arc::new(service);
768    // for return
769    let service = shared_service.clone();
770
771    // let message_sink = tokio::sync::
772    // let mut stream = std::pin::pin!(stream);
773    let serve_loop_ct = ct.child_token();
774    let peer_return: Peer<R> = peer.clone();
775    let current_span = tracing::Span::current();
776    let handle = spawn_service_task(async move {
777        let mut transport = transport.into_transport();
778        let mut batch_messages = VecDeque::<RxJsonRpcMessage<R>>::new();
779        let mut send_task_set = tokio::task::JoinSet::<SendTaskResult>::new();
780        let mut response_send_tasks = tokio::task::JoinSet::<()>::new();
781        #[derive(Debug)]
782        enum SendTaskResult {
783            Request {
784                id: RequestId,
785                result: Result<(), DynamicTransportError>,
786            },
787            Notification {
788                responder: Responder<Result<(), ServiceError>>,
789                cancellation_param: Option<CancelledNotificationParam>,
790                result: Result<(), DynamicTransportError>,
791            },
792        }
793        #[derive(Debug)]
794        enum Event<R: ServiceRole> {
795            ProxyMessage(PeerSinkMessage<R>),
796            PeerMessage(RxJsonRpcMessage<R>),
797            ToSink(TxJsonRpcMessage<R>),
798            SendTaskResult(SendTaskResult),
799        }
800
801        let quit_reason = loop {
802            let evt = if let Some(m) = batch_messages.pop_front() {
803                Event::PeerMessage(m)
804            } else {
805                tokio::select! {
806                    m = sink_proxy_rx.recv(), if !sink_proxy_rx.is_closed() => {
807                        if let Some(m) = m {
808                            Event::ToSink(m)
809                        } else {
810                            continue
811                        }
812                    }
813                    m = transport.receive() => {
814                        if let Some(m) = m {
815                            Event::PeerMessage(m)
816                        } else {
817                            // input stream closed
818                            tracing::info!("input stream terminated");
819                            break QuitReason::Closed
820                        }
821                    }
822                    m = peer_rx.recv(), if !peer_rx.is_closed() => {
823                        if let Some(m) = m {
824                            Event::ProxyMessage(m)
825                        } else {
826                            continue
827                        }
828                    }
829                    m = send_task_set.join_next(), if !send_task_set.is_empty() => {
830                        let Some(result) = m else {
831                            continue
832                        };
833                        match result {
834                            Err(e) => {
835                                // join error, which is serious, we should quit.
836                                tracing::error!(%e, "send request task encounter a tokio join error");
837                                break QuitReason::JoinError(e)
838                            }
839                            Ok(result) => {
840                                Event::SendTaskResult(result)
841                            }
842                        }
843                    }
844                    _ = serve_loop_ct.cancelled() => {
845                        tracing::info!("task cancelled");
846                        break QuitReason::Cancelled
847                    }
848                }
849            };
850
851            tracing::trace!(?evt, "new event");
852            match evt {
853                Event::SendTaskResult(SendTaskResult::Request { id, result }) => {
854                    if let Err(e) = result {
855                        if let Some(responder) = local_responder_pool.remove(&id) {
856                            let _ = responder.send(Err(ServiceError::TransportSend(e)));
857                        }
858                    }
859                }
860                Event::SendTaskResult(SendTaskResult::Notification {
861                    responder,
862                    result,
863                    cancellation_param,
864                }) => {
865                    let response = if let Err(e) = result {
866                        Err(ServiceError::TransportSend(e))
867                    } else {
868                        Ok(())
869                    };
870                    let _ = responder.send(response);
871                    if let Some(param) = cancellation_param {
872                        if let Some(responder) = local_responder_pool.remove(&param.request_id) {
873                            tracing::info!(id = %param.request_id, reason = param.reason, "cancelled");
874                            let _response_result = responder.send(Err(ServiceError::Cancelled {
875                                reason: param.reason.clone(),
876                            }));
877                        }
878                    }
879                }
880                // response and error
881                Event::ToSink(m) => {
882                    if let Some(id) = match &m {
883                        JsonRpcMessage::Response(response) => Some(&response.id),
884                        JsonRpcMessage::Error(error) => Some(&error.id),
885                        _ => None,
886                    } {
887                        if let Some(ct) = local_ct_pool.remove(id) {
888                            ct.cancel();
889                        }
890                        let send = transport.send(m);
891                        let current_span = tracing::Span::current();
892                        response_send_tasks.spawn(async move {
893                            let send_result = send.await;
894                            if let Err(error) = send_result {
895                                tracing::error!(%error, "fail to response message");
896                            }
897                        }.instrument(current_span));
898                    }
899                }
900                Event::ProxyMessage(PeerSinkMessage::Request {
901                    request,
902                    id,
903                    responder,
904                }) => {
905                    local_responder_pool.insert(id.clone(), responder);
906                    let send = transport.send(JsonRpcMessage::request(request, id.clone()));
907                    {
908                        let id = id.clone();
909                        let current_span = tracing::Span::current();
910                        send_task_set.spawn(send.map(move |r| SendTaskResult::Request {
911                            id,
912                            result: r.map_err(DynamicTransportError::new::<T, R>),
913                        }).instrument(current_span));
914                    }
915                }
916                Event::ProxyMessage(PeerSinkMessage::Notification {
917                    notification,
918                    responder,
919                }) => {
920                    // catch cancellation notification
921                    let mut cancellation_param = None;
922                    let notification = match notification.try_into() {
923                        Ok::<CancelledNotification, _>(cancelled) => {
924                            cancellation_param.replace(cancelled.params.clone());
925                            cancelled.into()
926                        }
927                        Err(notification) => notification,
928                    };
929                    let send = transport.send(JsonRpcMessage::notification(notification));
930                    let current_span = tracing::Span::current();
931                    send_task_set.spawn(send.map(move |result| SendTaskResult::Notification {
932                        responder,
933                        cancellation_param,
934                        result: result.map_err(DynamicTransportError::new::<T, R>),
935                    }).instrument(current_span));
936                }
937                Event::PeerMessage(JsonRpcMessage::Request(JsonRpcRequest {
938                    id,
939                    mut request,
940                    ..
941                })) => {
942                    tracing::debug!(%id, ?request, "received request");
943                    {
944                        let service = shared_service.clone();
945                        let sink = sink_proxy_tx.clone();
946                        let request_ct = serve_loop_ct.child_token();
947                        let context_ct = request_ct.child_token();
948                        local_ct_pool.insert(id.clone(), request_ct);
949                        let mut extensions = Extensions::new();
950                        let mut meta = Meta::new();
951                        // avoid clone
952                        // swap meta firstly, otherwise progress token will be lost
953                        std::mem::swap(&mut meta, request.get_meta_mut());
954                        std::mem::swap(&mut extensions, request.extensions_mut());
955                        let context = RequestContext {
956                            ct: context_ct,
957                            id: id.clone(),
958                            peer: peer.clone(),
959                            meta,
960                            extensions,
961                        };
962                        let current_span = tracing::Span::current();
963                        spawn_service_task(async move {
964                            let result = service
965                                .handle_request(request, context)
966                                .await;
967                            let response = match result {
968                                Ok(result) => {
969                                    tracing::debug!(%id, ?result, "response message");
970                                    JsonRpcMessage::response(result, id)
971                                }
972                                Err(error) => {
973                                    tracing::warn!(%id, ?error, "response error");
974                                    JsonRpcMessage::error(error, id)
975                                }
976                            };
977                            let _send_result = sink.send(response).await;
978                        }.instrument(current_span));
979                    }
980                }
981                Event::PeerMessage(JsonRpcMessage::Notification(JsonRpcNotification {
982                    notification,
983                    ..
984                })) => {
985                    tracing::info!(?notification, "received notification");
986                    // catch cancelled notification
987                    let mut notification = match notification.try_into() {
988                        Ok::<CancelledNotification, _>(cancelled) => {
989                            if let Some(ct) = local_ct_pool.remove(&cancelled.params.request_id) {
990                                tracing::info!(id = %cancelled.params.request_id, reason = cancelled.params.reason, "cancelled");
991                                ct.cancel();
992                            }
993                            cancelled.into()
994                        }
995                        Err(notification) => notification,
996                    };
997                    {
998                        let service = shared_service.clone();
999                        let mut extensions = Extensions::new();
1000                        let mut meta = Meta::new();
1001                        // avoid clone
1002                        std::mem::swap(&mut extensions, notification.extensions_mut());
1003                        std::mem::swap(&mut meta, notification.get_meta_mut());
1004                        let context = NotificationContext {
1005                            peer: peer.clone(),
1006                            meta,
1007                            extensions,
1008                        };
1009                        let current_span = tracing::Span::current();
1010                        spawn_service_task(async move {
1011                            let result = service.handle_notification(notification, context).await;
1012                            if let Err(error) = result {
1013                                tracing::warn!(%error, "Error sending notification");
1014                            }
1015                        }.instrument(current_span));
1016                    }
1017                }
1018                Event::PeerMessage(JsonRpcMessage::Response(JsonRpcResponse {
1019                    result,
1020                    id,
1021                    ..
1022                })) => {
1023                    if let Some(responder) = local_responder_pool.remove(&id) {
1024                        let response_result = responder.send(Ok(result));
1025                        if let Err(_error) = response_result {
1026                            tracing::warn!(%id, "Error sending response");
1027                        }
1028                    }
1029                }
1030                Event::PeerMessage(JsonRpcMessage::Error(JsonRpcError { error, id, .. })) => {
1031                    if let Some(responder) = local_responder_pool.remove(&id) {
1032                        let _response_result = responder.send(Err(ServiceError::McpError(error)));
1033                        if let Err(_error) = _response_result {
1034                            tracing::warn!(%id, "Error sending response");
1035                        }
1036                    }
1037                }
1038            }
1039        };
1040
1041        // Drain in-flight handler responses before closing the transport.
1042        // When stdin EOF or cancellation arrives, spawned handler tasks may still
1043        // be finishing. We need to:
1044        // 1. Wait for response sends that were already spawned in the main loop
1045        // 2. Drain any remaining handler responses from the channel
1046        let drain_timeout = match &quit_reason {
1047            QuitReason::Closed => Some(Duration::from_secs(5)),
1048            QuitReason::Cancelled => Some(Duration::from_secs(2)),
1049            _ => None,
1050        };
1051        if let Some(timeout_duration) = drain_timeout {
1052            // Drop our sender so the channel closes once all handler task
1053            // clones finish sending their responses (or are dropped).
1054            drop(sink_proxy_tx);
1055            let drain_result = tokio::time::timeout(timeout_duration, async {
1056                // First, wait for any response sends already dispatched by the
1057                // main loop (these hold transport write futures).
1058                while let Some(result) = response_send_tasks.join_next().await {
1059                    if let Err(error) = result {
1060                        tracing::error!(%error, "response send task failed during drain");
1061                    }
1062                }
1063                // Then drain any handler responses still in the channel
1064                // (handlers that finished after the loop broke).
1065                while let Some(m) = sink_proxy_rx.recv().await {
1066                    if let Err(error) = transport.send(m).await {
1067                        tracing::error!(%error, "failed to send pending response during drain");
1068                        break;
1069                    }
1070                }
1071            })
1072            .await;
1073            if drain_result.is_err() {
1074                tracing::warn!("timed out draining in-flight responses");
1075            }
1076        }
1077
1078        let sink_close_result = transport.close().await;
1079        if let Err(e) = sink_close_result {
1080            tracing::error!(%e, "fail to close sink");
1081        }
1082        tracing::info!(?quit_reason, "serve finished");
1083        quit_reason
1084    }.instrument(current_span));
1085    RunningService {
1086        service,
1087        peer: peer_return,
1088        handle: Some(handle),
1089        cancellation_token: ct.clone(),
1090        dg: ct.drop_guard(),
1091    }
1092}