1use futures::FutureExt;
2#[cfg(not(feature = "local"))]
3use futures::future::BoxFuture;
4#[cfg(feature = "local")]
5use futures::future::LocalBoxFuture;
6use thiserror::Error;
7
8#[cfg(not(feature = "local"))]
17#[doc(hidden)]
18pub trait MaybeSend: Send + Sync {}
19#[cfg(not(feature = "local"))]
20impl<T: Send + Sync> MaybeSend for T {}
21
22#[cfg(feature = "local")]
23#[doc(hidden)]
24pub trait MaybeSend {}
25#[cfg(feature = "local")]
26impl<T> MaybeSend for T {}
27
28#[cfg(not(feature = "local"))]
29#[doc(hidden)]
30pub trait MaybeSendFuture: Send {}
31#[cfg(not(feature = "local"))]
32impl<T: Send> MaybeSendFuture for T {}
33
34#[cfg(feature = "local")]
35#[doc(hidden)]
36pub trait MaybeSendFuture {}
37#[cfg(feature = "local")]
38impl<T> MaybeSendFuture for T {}
39
40#[cfg(not(feature = "local"))]
41pub(crate) type MaybeBoxFuture<'a, T> = BoxFuture<'a, T>;
42#[cfg(feature = "local")]
43pub(crate) type MaybeBoxFuture<'a, T> = LocalBoxFuture<'a, T>;
44
45#[cfg(feature = "server")]
46use crate::model::ServerJsonRpcMessage;
47use crate::{
48 error::ErrorData as McpError,
49 model::{
50 CancelledNotification, CancelledNotificationParam, Extensions, GetExtensions, GetMeta,
51 JsonRpcError, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, Meta,
52 NumberOrString, ProgressToken, RequestId,
53 },
54 transport::{DynamicTransportError, IntoTransport, Transport},
55};
56#[cfg(feature = "client")]
57mod client;
58#[cfg(feature = "client")]
59pub use client::*;
60#[cfg(feature = "server")]
61mod server;
62#[cfg(feature = "server")]
63pub use server::*;
64#[cfg(feature = "tower")]
65mod tower;
66use tokio_util::sync::{CancellationToken, DropGuard};
67#[cfg(feature = "tower")]
68pub use tower::*;
69use tracing::{Instrument as _, instrument};
70#[derive(Error, Debug)]
71#[non_exhaustive]
72pub enum ServiceError {
73 #[error("Mcp error: {0}")]
74 McpError(McpError),
75 #[error("Transport send error: {0}")]
76 TransportSend(DynamicTransportError),
77 #[error("Transport closed")]
78 TransportClosed,
79 #[error("Unexpected response type")]
80 UnexpectedResponse,
81 #[error("task cancelled for reason {}", reason.as_deref().unwrap_or("<unknown>"))]
82 Cancelled { reason: Option<String> },
83 #[error("request timeout after {}", chrono::Duration::from_std(*timeout).unwrap_or_default())]
84 Timeout { timeout: Duration },
85}
86
87trait TransferObject:
88 std::fmt::Debug + Clone + serde::Serialize + serde::de::DeserializeOwned + Send + Sync + 'static
89{
90}
91
92impl<T> TransferObject for T where
93 T: std::fmt::Debug
94 + serde::Serialize
95 + serde::de::DeserializeOwned
96 + Send
97 + Sync
98 + 'static
99 + Clone
100{
101}
102
103#[allow(private_bounds, reason = "there's no the third implementation")]
104pub trait ServiceRole: std::fmt::Debug + Send + Sync + 'static + Copy + Clone {
105 type Req: TransferObject + GetMeta + GetExtensions;
106 type Resp: TransferObject;
107 type Not: TryInto<CancelledNotification, Error = Self::Not>
108 + From<CancelledNotification>
109 + TransferObject;
110 type PeerReq: TransferObject + GetMeta + GetExtensions;
111 type PeerResp: TransferObject;
112 type PeerNot: TryInto<CancelledNotification, Error = Self::PeerNot>
113 + From<CancelledNotification>
114 + TransferObject
115 + GetMeta
116 + GetExtensions;
117 type InitializeError;
118 const IS_CLIENT: bool;
119 type Info: TransferObject;
120 type PeerInfo: TransferObject;
121}
122
123pub type TxJsonRpcMessage<R> =
124 JsonRpcMessage<<R as ServiceRole>::Req, <R as ServiceRole>::Resp, <R as ServiceRole>::Not>;
125pub type RxJsonRpcMessage<R> = JsonRpcMessage<
126 <R as ServiceRole>::PeerReq,
127 <R as ServiceRole>::PeerResp,
128 <R as ServiceRole>::PeerNot,
129>;
130
131#[cfg(not(feature = "local"))]
132pub trait Service<R: ServiceRole>: Send + Sync + 'static {
133 fn handle_request(
134 &self,
135 request: R::PeerReq,
136 context: RequestContext<R>,
137 ) -> impl Future<Output = Result<R::Resp, McpError>> + MaybeSendFuture + '_;
138 fn handle_notification(
139 &self,
140 notification: R::PeerNot,
141 context: NotificationContext<R>,
142 ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_;
143 fn get_info(&self) -> R::Info;
144}
145
146#[cfg(feature = "local")]
147pub trait Service<R: ServiceRole>: 'static {
148 fn handle_request(
149 &self,
150 request: R::PeerReq,
151 context: RequestContext<R>,
152 ) -> impl Future<Output = Result<R::Resp, McpError>> + MaybeSendFuture + '_;
153 fn handle_notification(
154 &self,
155 notification: R::PeerNot,
156 context: NotificationContext<R>,
157 ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_;
158 fn get_info(&self) -> R::Info;
159}
160
161pub trait ServiceExt<R: ServiceRole>: Service<R> + Sized {
162 fn into_dyn(self) -> Box<dyn DynService<R>> {
166 Box::new(self)
167 }
168 fn serve<T, E, A>(
169 self,
170 transport: T,
171 ) -> impl Future<Output = Result<RunningService<R, Self>, R::InitializeError>> + MaybeSendFuture
172 where
173 T: IntoTransport<R, E, A>,
174 E: std::error::Error + Send + Sync + 'static,
175 Self: Sized,
176 {
177 Self::serve_with_ct(self, transport, Default::default())
178 }
179 fn serve_with_ct<T, E, A>(
180 self,
181 transport: T,
182 ct: CancellationToken,
183 ) -> impl Future<Output = Result<RunningService<R, Self>, R::InitializeError>> + MaybeSendFuture
184 where
185 T: IntoTransport<R, E, A>,
186 E: std::error::Error + Send + Sync + 'static,
187 Self: Sized;
188}
189
190impl<R: ServiceRole> Service<R> for Box<dyn DynService<R>> {
191 fn handle_request(
192 &self,
193 request: R::PeerReq,
194 context: RequestContext<R>,
195 ) -> impl Future<Output = Result<R::Resp, McpError>> + MaybeSendFuture + '_ {
196 DynService::handle_request(self.as_ref(), request, context)
197 }
198
199 fn handle_notification(
200 &self,
201 notification: R::PeerNot,
202 context: NotificationContext<R>,
203 ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ {
204 DynService::handle_notification(self.as_ref(), notification, context)
205 }
206
207 fn get_info(&self) -> R::Info {
208 DynService::get_info(self.as_ref())
209 }
210}
211
212#[cfg(not(feature = "local"))]
213pub trait DynService<R: ServiceRole>: Send + Sync {
214 fn handle_request(
215 &self,
216 request: R::PeerReq,
217 context: RequestContext<R>,
218 ) -> MaybeBoxFuture<'_, Result<R::Resp, McpError>>;
219 fn handle_notification(
220 &self,
221 notification: R::PeerNot,
222 context: NotificationContext<R>,
223 ) -> MaybeBoxFuture<'_, Result<(), McpError>>;
224 fn get_info(&self) -> R::Info;
225}
226
227#[cfg(feature = "local")]
228pub trait DynService<R: ServiceRole> {
229 fn handle_request(
230 &self,
231 request: R::PeerReq,
232 context: RequestContext<R>,
233 ) -> MaybeBoxFuture<'_, Result<R::Resp, McpError>>;
234 fn handle_notification(
235 &self,
236 notification: R::PeerNot,
237 context: NotificationContext<R>,
238 ) -> MaybeBoxFuture<'_, Result<(), McpError>>;
239 fn get_info(&self) -> R::Info;
240}
241
242impl<R: ServiceRole, S: Service<R>> DynService<R> for S {
243 fn handle_request(
244 &self,
245 request: R::PeerReq,
246 context: RequestContext<R>,
247 ) -> MaybeBoxFuture<'_, Result<R::Resp, McpError>> {
248 Box::pin(self.handle_request(request, context))
249 }
250 fn handle_notification(
251 &self,
252 notification: R::PeerNot,
253 context: NotificationContext<R>,
254 ) -> MaybeBoxFuture<'_, Result<(), McpError>> {
255 Box::pin(self.handle_notification(notification, context))
256 }
257 fn get_info(&self) -> R::Info {
258 self.get_info()
259 }
260}
261
262use std::{
263 collections::{HashMap, VecDeque},
264 ops::Deref,
265 sync::{Arc, atomic::AtomicU64},
266 time::Duration,
267};
268
269use tokio::sync::mpsc;
270
271pub trait RequestIdProvider: Send + Sync + 'static {
272 fn next_request_id(&self) -> RequestId;
273}
274
275pub trait ProgressTokenProvider: Send + Sync + 'static {
276 fn next_progress_token(&self) -> ProgressToken;
277}
278
279pub type AtomicU32RequestIdProvider = AtomicU32Provider;
280pub type AtomicU32ProgressTokenProvider = AtomicU32Provider;
281
282#[derive(Debug, Default)]
283pub struct AtomicU32Provider {
284 id: AtomicU64,
285}
286
287impl RequestIdProvider for AtomicU32Provider {
288 fn next_request_id(&self) -> RequestId {
289 let id = self.id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
290 RequestId::Number(id as i64)
292 }
293}
294
295impl ProgressTokenProvider for AtomicU32Provider {
296 fn next_progress_token(&self) -> ProgressToken {
297 let id = self.id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
298 ProgressToken(NumberOrString::Number(id as i64))
299 }
300}
301
302type Responder<T> = tokio::sync::oneshot::Sender<T>;
303
304#[derive(Debug)]
310#[non_exhaustive]
311pub struct RequestHandle<R: ServiceRole> {
312 pub rx: tokio::sync::oneshot::Receiver<Result<R::PeerResp, ServiceError>>,
313 pub options: PeerRequestOptions,
314 pub peer: Peer<R>,
315 pub id: RequestId,
316 pub progress_token: ProgressToken,
317}
318
319impl<R: ServiceRole> RequestHandle<R> {
320 pub const REQUEST_TIMEOUT_REASON: &str = "request timeout";
321 pub async fn await_response(self) -> Result<R::PeerResp, ServiceError> {
322 if let Some(timeout) = self.options.timeout {
323 let timeout_result = tokio::time::timeout(timeout, async move {
324 self.rx.await.map_err(|_e| ServiceError::TransportClosed)?
325 })
326 .await;
327 match timeout_result {
328 Ok(response) => response,
329 Err(_) => {
330 let error = Err(ServiceError::Timeout { timeout });
331 let notification = CancelledNotification {
333 params: CancelledNotificationParam {
334 request_id: self.id,
335 reason: Some(Self::REQUEST_TIMEOUT_REASON.to_owned()),
336 },
337 method: crate::model::CancelledNotificationMethod,
338 extensions: Default::default(),
339 };
340 let _ = self.peer.send_notification(notification.into()).await;
341 error
342 }
343 }
344 } else {
345 self.rx.await.map_err(|_e| ServiceError::TransportClosed)?
346 }
347 }
348
349 pub async fn cancel(self, reason: Option<String>) -> Result<(), ServiceError> {
351 let notification = CancelledNotification {
352 params: CancelledNotificationParam {
353 request_id: self.id,
354 reason,
355 },
356 method: crate::model::CancelledNotificationMethod,
357 extensions: Default::default(),
358 };
359 self.peer.send_notification(notification.into()).await?;
360 Ok(())
361 }
362}
363
364#[derive(Debug)]
365pub(crate) enum PeerSinkMessage<R: ServiceRole> {
366 Request {
367 request: R::Req,
368 id: RequestId,
369 responder: Responder<Result<R::PeerResp, ServiceError>>,
370 },
371 Notification {
372 notification: R::Not,
373 responder: Responder<Result<(), ServiceError>>,
374 },
375}
376
377#[derive(Clone)]
383pub struct Peer<R: ServiceRole> {
384 tx: mpsc::Sender<PeerSinkMessage<R>>,
385 request_id_provider: Arc<dyn RequestIdProvider>,
386 progress_token_provider: Arc<dyn ProgressTokenProvider>,
387 info: Arc<tokio::sync::OnceCell<R::PeerInfo>>,
388}
389
390impl<R: ServiceRole> std::fmt::Debug for Peer<R> {
391 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
392 f.debug_struct("PeerSink")
393 .field("tx", &self.tx)
394 .field("is_client", &R::IS_CLIENT)
395 .finish()
396 }
397}
398
399type ProxyOutbound<R> = mpsc::Receiver<PeerSinkMessage<R>>;
400
401#[derive(Debug, Default)]
402#[non_exhaustive]
403pub struct PeerRequestOptions {
404 pub timeout: Option<Duration>,
405 pub meta: Option<Meta>,
406}
407
408impl PeerRequestOptions {
409 pub fn no_options() -> Self {
410 Self::default()
411 }
412}
413
414impl<R: ServiceRole> Peer<R> {
415 const CLIENT_CHANNEL_BUFFER_SIZE: usize = 1024;
416 pub(crate) fn new(
417 request_id_provider: Arc<dyn RequestIdProvider>,
418 peer_info: Option<R::PeerInfo>,
419 ) -> (Peer<R>, ProxyOutbound<R>) {
420 let (tx, rx) = mpsc::channel(Self::CLIENT_CHANNEL_BUFFER_SIZE);
421 (
422 Self {
423 tx,
424 request_id_provider,
425 progress_token_provider: Arc::new(AtomicU32ProgressTokenProvider::default()),
426 info: Arc::new(tokio::sync::OnceCell::new_with(peer_info)),
427 },
428 rx,
429 )
430 }
431 pub async fn send_notification(&self, notification: R::Not) -> Result<(), ServiceError> {
432 let (responder, receiver) = tokio::sync::oneshot::channel();
433 self.tx
434 .send(PeerSinkMessage::Notification {
435 notification,
436 responder,
437 })
438 .await
439 .map_err(|_m| ServiceError::TransportClosed)?;
440 receiver.await.map_err(|_e| ServiceError::TransportClosed)?
441 }
442 pub async fn send_request(&self, request: R::Req) -> Result<R::PeerResp, ServiceError> {
443 self.send_request_with_option(request, PeerRequestOptions::no_options())
444 .await?
445 .await_response()
446 .await
447 }
448
449 pub async fn send_cancellable_request(
450 &self,
451 request: R::Req,
452 options: PeerRequestOptions,
453 ) -> Result<RequestHandle<R>, ServiceError> {
454 self.send_request_with_option(request, options).await
455 }
456
457 pub async fn send_request_with_option(
458 &self,
459 mut request: R::Req,
460 options: PeerRequestOptions,
461 ) -> Result<RequestHandle<R>, ServiceError> {
462 let id = self.request_id_provider.next_request_id();
463 let progress_token = self.progress_token_provider.next_progress_token();
464 request
465 .get_meta_mut()
466 .set_progress_token(progress_token.clone());
467 if let Some(meta) = options.meta.clone() {
468 request.get_meta_mut().extend(meta);
469 }
470 let (responder, receiver) = tokio::sync::oneshot::channel();
471 self.tx
472 .send(PeerSinkMessage::Request {
473 request,
474 id: id.clone(),
475 responder,
476 })
477 .await
478 .map_err(|_m| ServiceError::TransportClosed)?;
479 Ok(RequestHandle {
480 id,
481 rx: receiver,
482 progress_token,
483 options,
484 peer: self.clone(),
485 })
486 }
487 pub fn peer_info(&self) -> Option<&R::PeerInfo> {
488 self.info.get()
489 }
490
491 pub fn set_peer_info(&self, info: R::PeerInfo) {
492 if self.info.initialized() {
493 tracing::warn!("trying to set peer info, which is already initialized");
494 } else {
495 let _ = self.info.set(info);
496 }
497 }
498
499 pub fn is_transport_closed(&self) -> bool {
500 self.tx.is_closed()
501 }
502}
503
504#[derive(Debug)]
505pub struct RunningService<R: ServiceRole, S: Service<R>> {
506 service: Arc<S>,
507 peer: Peer<R>,
508 handle: Option<tokio::task::JoinHandle<QuitReason>>,
509 cancellation_token: CancellationToken,
510 dg: DropGuard,
511}
512impl<R: ServiceRole, S: Service<R>> Deref for RunningService<R, S> {
513 type Target = Peer<R>;
514
515 fn deref(&self) -> &Self::Target {
516 &self.peer
517 }
518}
519
520impl<R: ServiceRole, S: Service<R>> RunningService<R, S> {
521 #[inline]
522 pub fn peer(&self) -> &Peer<R> {
523 &self.peer
524 }
525 #[inline]
526 pub fn service(&self) -> &S {
527 self.service.as_ref()
528 }
529 #[inline]
530 pub fn cancellation_token(&self) -> RunningServiceCancellationToken {
531 RunningServiceCancellationToken(self.cancellation_token.clone())
532 }
533
534 #[inline]
536 pub fn is_closed(&self) -> bool {
537 self.handle.is_none() || self.cancellation_token.is_cancelled()
538 }
539
540 #[inline]
545 pub async fn waiting(mut self) -> Result<QuitReason, tokio::task::JoinError> {
546 match self.handle.take() {
547 Some(handle) => handle.await,
548 None => Ok(QuitReason::Closed),
549 }
550 }
551
552 pub async fn close(&mut self) -> Result<QuitReason, tokio::task::JoinError> {
570 if let Some(handle) = self.handle.take() {
571 self.cancellation_token.cancel();
574 handle.await
575 } else {
576 Ok(QuitReason::Closed)
578 }
579 }
580
581 pub async fn close_with_timeout(
590 &mut self,
591 timeout: Duration,
592 ) -> Result<Option<QuitReason>, tokio::task::JoinError> {
593 if let Some(handle) = self.handle.take() {
594 self.cancellation_token.cancel();
595 match tokio::time::timeout(timeout, handle).await {
596 Ok(result) => result.map(Some),
597 Err(_elapsed) => {
598 tracing::warn!(
599 "close_with_timeout: cleanup did not complete within {:?}",
600 timeout
601 );
602 Ok(None)
603 }
604 }
605 } else {
606 Ok(Some(QuitReason::Closed))
607 }
608 }
609
610 pub async fn cancel(mut self) -> Result<QuitReason, tokio::task::JoinError> {
615 let _ = std::mem::replace(&mut self.dg, self.cancellation_token.clone().drop_guard());
617 self.close().await
618 }
619}
620
621impl<R: ServiceRole, S: Service<R>> Drop for RunningService<R, S> {
622 fn drop(&mut self) {
623 if self.handle.is_some() && !self.cancellation_token.is_cancelled() {
624 tracing::debug!(
625 "RunningService dropped without explicit close(). \
626 The connection will be closed asynchronously. \
627 For guaranteed cleanup, call close() or cancel() before dropping."
628 );
629 }
630 }
632}
633
634pub struct RunningServiceCancellationToken(CancellationToken);
636
637impl RunningServiceCancellationToken {
638 pub fn cancel(self) {
639 self.0.cancel();
640 }
641}
642
643#[derive(Debug)]
644#[non_exhaustive]
645pub enum QuitReason {
646 Cancelled,
647 Closed,
648 JoinError(tokio::task::JoinError),
649}
650
651#[derive(Debug, Clone)]
653#[non_exhaustive]
654pub struct RequestContext<R: ServiceRole> {
655 pub ct: CancellationToken,
657 pub id: RequestId,
658 pub meta: Meta,
659 pub extensions: Extensions,
660 pub peer: Peer<R>,
662}
663
664impl<R: ServiceRole> RequestContext<R> {
665 pub fn new(id: RequestId, peer: Peer<R>) -> Self {
667 Self {
668 ct: CancellationToken::new(),
669 id,
670 meta: Meta::default(),
671 extensions: Extensions::default(),
672 peer,
673 }
674 }
675}
676
677#[derive(Debug, Clone)]
679#[non_exhaustive]
680pub struct NotificationContext<R: ServiceRole> {
681 pub meta: Meta,
682 pub extensions: Extensions,
683 pub peer: Peer<R>,
685}
686
687pub fn serve_directly<R, S, T, E, A>(
689 service: S,
690 transport: T,
691 peer_info: Option<R::PeerInfo>,
692) -> RunningService<R, S>
693where
694 R: ServiceRole,
695 S: Service<R>,
696 T: IntoTransport<R, E, A>,
697 E: std::error::Error + Send + Sync + 'static,
698{
699 serve_directly_with_ct(service, transport, peer_info, Default::default())
700}
701
702pub fn serve_directly_with_ct<R, S, T, E, A>(
704 service: S,
705 transport: T,
706 peer_info: Option<R::PeerInfo>,
707 ct: CancellationToken,
708) -> RunningService<R, S>
709where
710 R: ServiceRole,
711 S: Service<R>,
712 T: IntoTransport<R, E, A>,
713 E: std::error::Error + Send + Sync + 'static,
714{
715 let (peer, peer_rx) = Peer::new(Arc::new(AtomicU32RequestIdProvider::default()), peer_info);
716 serve_inner(service, transport.into_transport(), peer, peer_rx, ct)
717}
718
719#[cfg(not(feature = "local"))]
724fn spawn_service_task<F>(future: F) -> tokio::task::JoinHandle<F::Output>
725where
726 F: Future + Send + 'static,
727 F::Output: Send + 'static,
728{
729 tokio::spawn(future)
730}
731
732#[cfg(feature = "local")]
733fn spawn_service_task<F>(future: F) -> tokio::task::JoinHandle<F::Output>
734where
735 F: Future + 'static,
736 F::Output: 'static,
737{
738 tokio::task::spawn_local(future)
739}
740
741#[instrument(skip_all)]
742fn serve_inner<R, S, T>(
743 service: S,
744 transport: T,
745 peer: Peer<R>,
746 mut peer_rx: tokio::sync::mpsc::Receiver<PeerSinkMessage<R>>,
747 ct: CancellationToken,
748) -> RunningService<R, S>
749where
750 R: ServiceRole,
751 S: Service<R>,
752 T: Transport<R> + 'static,
753{
754 const SINK_PROXY_BUFFER_SIZE: usize = 64;
755 let (sink_proxy_tx, mut sink_proxy_rx) =
756 tokio::sync::mpsc::channel::<TxJsonRpcMessage<R>>(SINK_PROXY_BUFFER_SIZE);
757 let peer_info = peer.peer_info();
758 if R::IS_CLIENT {
759 tracing::info!(?peer_info, "Service initialized as client");
760 } else {
761 tracing::info!(?peer_info, "Service initialized as server");
762 }
763
764 let mut local_responder_pool =
765 HashMap::<RequestId, Responder<Result<R::PeerResp, ServiceError>>>::new();
766 let mut local_ct_pool = HashMap::<RequestId, CancellationToken>::new();
767 let shared_service = Arc::new(service);
768 let service = shared_service.clone();
770
771 let serve_loop_ct = ct.child_token();
774 let peer_return: Peer<R> = peer.clone();
775 let current_span = tracing::Span::current();
776 let handle = spawn_service_task(async move {
777 let mut transport = transport.into_transport();
778 let mut batch_messages = VecDeque::<RxJsonRpcMessage<R>>::new();
779 let mut send_task_set = tokio::task::JoinSet::<SendTaskResult>::new();
780 let mut response_send_tasks = tokio::task::JoinSet::<()>::new();
781 #[derive(Debug)]
782 enum SendTaskResult {
783 Request {
784 id: RequestId,
785 result: Result<(), DynamicTransportError>,
786 },
787 Notification {
788 responder: Responder<Result<(), ServiceError>>,
789 cancellation_param: Option<CancelledNotificationParam>,
790 result: Result<(), DynamicTransportError>,
791 },
792 }
793 #[derive(Debug)]
794 enum Event<R: ServiceRole> {
795 ProxyMessage(PeerSinkMessage<R>),
796 PeerMessage(RxJsonRpcMessage<R>),
797 ToSink(TxJsonRpcMessage<R>),
798 SendTaskResult(SendTaskResult),
799 }
800
801 let quit_reason = loop {
802 let evt = if let Some(m) = batch_messages.pop_front() {
803 Event::PeerMessage(m)
804 } else {
805 tokio::select! {
806 m = sink_proxy_rx.recv(), if !sink_proxy_rx.is_closed() => {
807 if let Some(m) = m {
808 Event::ToSink(m)
809 } else {
810 continue
811 }
812 }
813 m = transport.receive() => {
814 if let Some(m) = m {
815 Event::PeerMessage(m)
816 } else {
817 tracing::info!("input stream terminated");
819 break QuitReason::Closed
820 }
821 }
822 m = peer_rx.recv(), if !peer_rx.is_closed() => {
823 if let Some(m) = m {
824 Event::ProxyMessage(m)
825 } else {
826 continue
827 }
828 }
829 m = send_task_set.join_next(), if !send_task_set.is_empty() => {
830 let Some(result) = m else {
831 continue
832 };
833 match result {
834 Err(e) => {
835 tracing::error!(%e, "send request task encounter a tokio join error");
837 break QuitReason::JoinError(e)
838 }
839 Ok(result) => {
840 Event::SendTaskResult(result)
841 }
842 }
843 }
844 _ = serve_loop_ct.cancelled() => {
845 tracing::info!("task cancelled");
846 break QuitReason::Cancelled
847 }
848 }
849 };
850
851 tracing::trace!(?evt, "new event");
852 match evt {
853 Event::SendTaskResult(SendTaskResult::Request { id, result }) => {
854 if let Err(e) = result {
855 if let Some(responder) = local_responder_pool.remove(&id) {
856 let _ = responder.send(Err(ServiceError::TransportSend(e)));
857 }
858 }
859 }
860 Event::SendTaskResult(SendTaskResult::Notification {
861 responder,
862 result,
863 cancellation_param,
864 }) => {
865 let response = if let Err(e) = result {
866 Err(ServiceError::TransportSend(e))
867 } else {
868 Ok(())
869 };
870 let _ = responder.send(response);
871 if let Some(param) = cancellation_param {
872 if let Some(responder) = local_responder_pool.remove(¶m.request_id) {
873 tracing::info!(id = %param.request_id, reason = param.reason, "cancelled");
874 let _response_result = responder.send(Err(ServiceError::Cancelled {
875 reason: param.reason.clone(),
876 }));
877 }
878 }
879 }
880 Event::ToSink(m) => {
882 if let Some(id) = match &m {
883 JsonRpcMessage::Response(response) => Some(&response.id),
884 JsonRpcMessage::Error(error) => Some(&error.id),
885 _ => None,
886 } {
887 if let Some(ct) = local_ct_pool.remove(id) {
888 ct.cancel();
889 }
890 let send = transport.send(m);
891 let current_span = tracing::Span::current();
892 response_send_tasks.spawn(async move {
893 let send_result = send.await;
894 if let Err(error) = send_result {
895 tracing::error!(%error, "fail to response message");
896 }
897 }.instrument(current_span));
898 }
899 }
900 Event::ProxyMessage(PeerSinkMessage::Request {
901 request,
902 id,
903 responder,
904 }) => {
905 local_responder_pool.insert(id.clone(), responder);
906 let send = transport.send(JsonRpcMessage::request(request, id.clone()));
907 {
908 let id = id.clone();
909 let current_span = tracing::Span::current();
910 send_task_set.spawn(send.map(move |r| SendTaskResult::Request {
911 id,
912 result: r.map_err(DynamicTransportError::new::<T, R>),
913 }).instrument(current_span));
914 }
915 }
916 Event::ProxyMessage(PeerSinkMessage::Notification {
917 notification,
918 responder,
919 }) => {
920 let mut cancellation_param = None;
922 let notification = match notification.try_into() {
923 Ok::<CancelledNotification, _>(cancelled) => {
924 cancellation_param.replace(cancelled.params.clone());
925 cancelled.into()
926 }
927 Err(notification) => notification,
928 };
929 let send = transport.send(JsonRpcMessage::notification(notification));
930 let current_span = tracing::Span::current();
931 send_task_set.spawn(send.map(move |result| SendTaskResult::Notification {
932 responder,
933 cancellation_param,
934 result: result.map_err(DynamicTransportError::new::<T, R>),
935 }).instrument(current_span));
936 }
937 Event::PeerMessage(JsonRpcMessage::Request(JsonRpcRequest {
938 id,
939 mut request,
940 ..
941 })) => {
942 tracing::debug!(%id, ?request, "received request");
943 {
944 let service = shared_service.clone();
945 let sink = sink_proxy_tx.clone();
946 let request_ct = serve_loop_ct.child_token();
947 let context_ct = request_ct.child_token();
948 local_ct_pool.insert(id.clone(), request_ct);
949 let mut extensions = Extensions::new();
950 let mut meta = Meta::new();
951 std::mem::swap(&mut meta, request.get_meta_mut());
954 std::mem::swap(&mut extensions, request.extensions_mut());
955 let context = RequestContext {
956 ct: context_ct,
957 id: id.clone(),
958 peer: peer.clone(),
959 meta,
960 extensions,
961 };
962 let current_span = tracing::Span::current();
963 spawn_service_task(async move {
964 let result = service
965 .handle_request(request, context)
966 .await;
967 let response = match result {
968 Ok(result) => {
969 tracing::debug!(%id, ?result, "response message");
970 JsonRpcMessage::response(result, id)
971 }
972 Err(error) => {
973 tracing::warn!(%id, ?error, "response error");
974 JsonRpcMessage::error(error, id)
975 }
976 };
977 let _send_result = sink.send(response).await;
978 }.instrument(current_span));
979 }
980 }
981 Event::PeerMessage(JsonRpcMessage::Notification(JsonRpcNotification {
982 notification,
983 ..
984 })) => {
985 tracing::info!(?notification, "received notification");
986 let mut notification = match notification.try_into() {
988 Ok::<CancelledNotification, _>(cancelled) => {
989 if let Some(ct) = local_ct_pool.remove(&cancelled.params.request_id) {
990 tracing::info!(id = %cancelled.params.request_id, reason = cancelled.params.reason, "cancelled");
991 ct.cancel();
992 }
993 cancelled.into()
994 }
995 Err(notification) => notification,
996 };
997 {
998 let service = shared_service.clone();
999 let mut extensions = Extensions::new();
1000 let mut meta = Meta::new();
1001 std::mem::swap(&mut extensions, notification.extensions_mut());
1003 std::mem::swap(&mut meta, notification.get_meta_mut());
1004 let context = NotificationContext {
1005 peer: peer.clone(),
1006 meta,
1007 extensions,
1008 };
1009 let current_span = tracing::Span::current();
1010 spawn_service_task(async move {
1011 let result = service.handle_notification(notification, context).await;
1012 if let Err(error) = result {
1013 tracing::warn!(%error, "Error sending notification");
1014 }
1015 }.instrument(current_span));
1016 }
1017 }
1018 Event::PeerMessage(JsonRpcMessage::Response(JsonRpcResponse {
1019 result,
1020 id,
1021 ..
1022 })) => {
1023 if let Some(responder) = local_responder_pool.remove(&id) {
1024 let response_result = responder.send(Ok(result));
1025 if let Err(_error) = response_result {
1026 tracing::warn!(%id, "Error sending response");
1027 }
1028 }
1029 }
1030 Event::PeerMessage(JsonRpcMessage::Error(JsonRpcError { error, id, .. })) => {
1031 if let Some(responder) = local_responder_pool.remove(&id) {
1032 let _response_result = responder.send(Err(ServiceError::McpError(error)));
1033 if let Err(_error) = _response_result {
1034 tracing::warn!(%id, "Error sending response");
1035 }
1036 }
1037 }
1038 }
1039 };
1040
1041 let drain_timeout = match &quit_reason {
1047 QuitReason::Closed => Some(Duration::from_secs(5)),
1048 QuitReason::Cancelled => Some(Duration::from_secs(2)),
1049 _ => None,
1050 };
1051 if let Some(timeout_duration) = drain_timeout {
1052 drop(sink_proxy_tx);
1055 let drain_result = tokio::time::timeout(timeout_duration, async {
1056 while let Some(result) = response_send_tasks.join_next().await {
1059 if let Err(error) = result {
1060 tracing::error!(%error, "response send task failed during drain");
1061 }
1062 }
1063 while let Some(m) = sink_proxy_rx.recv().await {
1066 if let Err(error) = transport.send(m).await {
1067 tracing::error!(%error, "failed to send pending response during drain");
1068 break;
1069 }
1070 }
1071 })
1072 .await;
1073 if drain_result.is_err() {
1074 tracing::warn!("timed out draining in-flight responses");
1075 }
1076 }
1077
1078 let sink_close_result = transport.close().await;
1079 if let Err(e) = sink_close_result {
1080 tracing::error!(%e, "fail to close sink");
1081 }
1082 tracing::info!(?quit_reason, "serve finished");
1083 quit_reason
1084 }.instrument(current_span));
1085 RunningService {
1086 service,
1087 peer: peer_return,
1088 handle: Some(handle),
1089 cancellation_token: ct.clone(),
1090 dg: ct.drop_guard(),
1091 }
1092}