Skip to main content

rmcp_soddygo/service/
client.rs

1use std::borrow::Cow;
2
3use thiserror::Error;
4
5use super::*;
6use crate::{
7    model::{
8        ArgumentInfo, CallToolRequest, CallToolRequestParams, CallToolResult,
9        CancelledNotification, CancelledNotificationParam, ClientInfo, ClientJsonRpcMessage,
10        ClientNotification, ClientRequest, ClientResult, CompleteRequest, CompleteRequestParams,
11        CompleteResult, CompletionContext, CompletionInfo, ErrorData, GetPromptRequest,
12        GetPromptRequestParams, GetPromptResult, InitializeRequest, InitializedNotification,
13        JsonRpcResponse, ListPromptsRequest, ListPromptsResult, ListResourceTemplatesRequest,
14        ListResourceTemplatesResult, ListResourcesRequest, ListResourcesResult, ListToolsRequest,
15        ListToolsResult, PaginatedRequestParams, ProgressNotification, ProgressNotificationParam,
16        ReadResourceRequest, ReadResourceRequestParams, ReadResourceResult, Reference, RequestId,
17        RootsListChangedNotification, ServerInfo, ServerJsonRpcMessage, ServerNotification,
18        ServerRequest, ServerResult, SetLevelRequest, SetLevelRequestParams, SubscribeRequest,
19        SubscribeRequestParams, UnsubscribeRequest, UnsubscribeRequestParams,
20    },
21    transport::DynamicTransportError,
22};
23
24/// It represents the error that may occur when serving the client.
25///
26/// if you want to handle the error, you can use `serve_client_with_ct` or `serve_client` with `Result<RunningService<RoleClient, S>, ClientError>`
27#[derive(Error, Debug)]
28#[non_exhaustive]
29pub enum ClientInitializeError {
30    #[error("expect initialized response, but received: {0:?}")]
31    ExpectedInitResponse(Option<ServerJsonRpcMessage>),
32
33    #[error("expect initialized result, but received: {0:?}")]
34    ExpectedInitResult(Option<ServerResult>),
35
36    #[error("conflict initialized response id: expected {0}, got {1}")]
37    ConflictInitResponseId(RequestId, RequestId),
38
39    #[error("connection closed: {0}")]
40    ConnectionClosed(String),
41
42    #[error("Send message error {error}, when {context}")]
43    TransportError {
44        error: DynamicTransportError,
45        context: Cow<'static, str>,
46    },
47
48    #[error("JSON-RPC error: {0}")]
49    JsonRpcError(ErrorData),
50
51    #[error("Cancelled")]
52    Cancelled,
53}
54
55impl ClientInitializeError {
56    pub fn transport<T: Transport<RoleClient> + 'static>(
57        error: T::Error,
58        context: impl Into<Cow<'static, str>>,
59    ) -> Self {
60        Self::TransportError {
61            error: DynamicTransportError::new::<T, _>(error),
62            context: context.into(),
63        }
64    }
65}
66
67/// Helper function to get the next message from the stream
68async fn expect_next_message<T>(
69    transport: &mut T,
70    context: &str,
71) -> Result<ServerJsonRpcMessage, ClientInitializeError>
72where
73    T: Transport<RoleClient>,
74{
75    transport
76        .receive()
77        .await
78        .ok_or_else(|| ClientInitializeError::ConnectionClosed(context.to_string()))
79}
80
81/// Helper function to expect a response from the stream
82async fn expect_response<T, S>(
83    transport: &mut T,
84    context: &str,
85    service: &S,
86    peer: Peer<RoleClient>,
87) -> Result<(ServerResult, RequestId), ClientInitializeError>
88where
89    T: Transport<RoleClient>,
90    S: Service<RoleClient>,
91{
92    loop {
93        let message = expect_next_message(transport, context).await?;
94        match message {
95            // Expected message to complete the initialization
96            ServerJsonRpcMessage::Response(JsonRpcResponse { id, result, .. }) => {
97                break Ok((result, id));
98            }
99            // Handle JSON-RPC error responses
100            ServerJsonRpcMessage::Error(error) => {
101                break Err(ClientInitializeError::JsonRpcError(error.error));
102            }
103            // Server could send logging messages before handshake
104            ServerJsonRpcMessage::Notification(mut notification) => {
105                let ServerNotification::LoggingMessageNotification(logging) =
106                    &mut notification.notification
107                else {
108                    tracing::warn!(?notification, "Received unexpected message");
109                    continue;
110                };
111
112                let mut context = NotificationContext {
113                    peer: peer.clone(),
114                    meta: Meta::default(),
115                    extensions: Extensions::default(),
116                };
117
118                if let Some(meta) = logging.extensions.get_mut::<Meta>() {
119                    std::mem::swap(&mut context.meta, meta);
120                }
121                std::mem::swap(&mut context.extensions, &mut logging.extensions);
122
123                if let Err(error) = service
124                    .handle_notification(notification.notification, context)
125                    .await
126                {
127                    tracing::warn!(?error, "Handle logging before handshake failed.");
128                }
129            }
130            // Server could send pings before handshake
131            ServerJsonRpcMessage::Request(ref request)
132                if matches!(request.request, ServerRequest::PingRequest(_)) =>
133            {
134                tracing::trace!("Received ping request. Ignored.")
135            }
136            // Server SHOULD NOT send any other messages before handshake. We ignore them anyway
137            _ => tracing::warn!(?message, "Received unexpected message"),
138        }
139    }
140}
141
142#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
143#[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")]
144pub struct RoleClient;
145
146impl ServiceRole for RoleClient {
147    type Req = ClientRequest;
148    type Resp = ClientResult;
149    type Not = ClientNotification;
150    type PeerReq = ServerRequest;
151    type PeerResp = ServerResult;
152    type PeerNot = ServerNotification;
153    type Info = ClientInfo;
154    type PeerInfo = ServerInfo;
155    type InitializeError = ClientInitializeError;
156    const IS_CLIENT: bool = true;
157}
158
159pub type ServerSink = Peer<RoleClient>;
160
161impl<S: Service<RoleClient>> ServiceExt<RoleClient> for S {
162    fn serve_with_ct<T, E, A>(
163        self,
164        transport: T,
165        ct: CancellationToken,
166    ) -> impl Future<Output = Result<RunningService<RoleClient, Self>, ClientInitializeError>>
167    + MaybeSendFuture
168    where
169        T: IntoTransport<RoleClient, E, A>,
170        E: std::error::Error + Send + Sync + 'static,
171        Self: Sized,
172    {
173        serve_client_with_ct(self, transport, ct)
174    }
175}
176
177pub async fn serve_client<S, T, E, A>(
178    service: S,
179    transport: T,
180) -> Result<RunningService<RoleClient, S>, ClientInitializeError>
181where
182    S: Service<RoleClient>,
183    T: IntoTransport<RoleClient, E, A>,
184    E: std::error::Error + Send + Sync + 'static,
185{
186    serve_client_with_ct(service, transport, Default::default()).await
187}
188
189pub async fn serve_client_with_ct<S, T, E, A>(
190    service: S,
191    transport: T,
192    ct: CancellationToken,
193) -> Result<RunningService<RoleClient, S>, ClientInitializeError>
194where
195    S: Service<RoleClient>,
196    T: IntoTransport<RoleClient, E, A>,
197    E: std::error::Error + Send + Sync + 'static,
198{
199    tokio::select! {
200        result = serve_client_with_ct_inner(service, transport.into_transport(), ct.clone()) => { result }
201        _ = ct.cancelled() => {
202            Err(ClientInitializeError::Cancelled)
203        }
204    }
205}
206
207async fn serve_client_with_ct_inner<S, T>(
208    service: S,
209    transport: T,
210    ct: CancellationToken,
211) -> Result<RunningService<RoleClient, S>, ClientInitializeError>
212where
213    S: Service<RoleClient>,
214    T: Transport<RoleClient> + 'static,
215{
216    let mut transport = transport.into_transport();
217    let id_provider = <Arc<AtomicU32RequestIdProvider>>::default();
218
219    // service
220    let id = id_provider.next_request_id();
221    let init_request = InitializeRequest {
222        method: Default::default(),
223        params: service.get_info(),
224        extensions: Default::default(),
225    };
226    transport
227        .send(ClientJsonRpcMessage::request(
228            ClientRequest::InitializeRequest(init_request),
229            id.clone(),
230        ))
231        .await
232        .map_err(|error| ClientInitializeError::TransportError {
233            error: DynamicTransportError::new::<T, _>(error),
234            context: "send initialize request".into(),
235        })?;
236
237    let (peer, peer_rx) = Peer::new(id_provider, None);
238
239    let (response, response_id) = expect_response(
240        &mut transport,
241        "initialize response",
242        &service,
243        peer.clone(),
244    )
245    .await?;
246
247    if id != response_id {
248        return Err(ClientInitializeError::ConflictInitResponseId(
249            id,
250            response_id,
251        ));
252    }
253
254    let ServerResult::InitializeResult(initialize_result) = response else {
255        return Err(ClientInitializeError::ExpectedInitResult(Some(response)));
256    };
257    peer.set_peer_info(initialize_result);
258
259    // send notification
260    let notification = ClientJsonRpcMessage::notification(
261        ClientNotification::InitializedNotification(InitializedNotification {
262            method: Default::default(),
263            extensions: Default::default(),
264        }),
265    );
266    transport.send(notification).await.map_err(|error| {
267        ClientInitializeError::transport::<T>(error, "send initialized notification")
268    })?;
269    Ok(serve_inner(service, transport, peer, peer_rx, ct))
270}
271
272macro_rules! method {
273    (peer_req $method:ident $Req:ident() => $Resp: ident ) => {
274        pub async fn $method(&self) -> Result<$Resp, ServiceError> {
275            let result = self
276                .send_request(ClientRequest::$Req($Req {
277                    method: Default::default(),
278                }))
279                .await?;
280            match result {
281                ServerResult::$Resp(result) => Ok(result),
282                _ => Err(ServiceError::UnexpectedResponse),
283            }
284        }
285    };
286    (peer_req $method:ident $Req:ident($Param: ident) => $Resp: ident ) => {
287        pub async fn $method(&self, params: $Param) -> Result<$Resp, ServiceError> {
288            let result = self
289                .send_request(ClientRequest::$Req($Req {
290                    method: Default::default(),
291                    params,
292                    extensions: Default::default(),
293                }))
294                .await?;
295            match result {
296                ServerResult::$Resp(result) => Ok(result),
297                _ => Err(ServiceError::UnexpectedResponse),
298            }
299        }
300    };
301    (peer_req $method:ident $Req:ident($Param: ident)? => $Resp: ident ) => {
302        pub async fn $method(&self, params: Option<$Param>) -> Result<$Resp, ServiceError> {
303            let result = self
304                .send_request(ClientRequest::$Req($Req {
305                    method: Default::default(),
306                    params,
307                    extensions: Default::default(),
308                }))
309                .await?;
310            match result {
311                ServerResult::$Resp(result) => Ok(result),
312                _ => Err(ServiceError::UnexpectedResponse),
313            }
314        }
315    };
316    (peer_req $method:ident $Req:ident($Param: ident)) => {
317        pub async fn $method(&self, params: $Param) -> Result<(), ServiceError> {
318            let result = self
319                .send_request(ClientRequest::$Req($Req {
320                    method: Default::default(),
321                    params,
322                    extensions: Default::default(),
323                }))
324                .await?;
325            match result {
326                ServerResult::EmptyResult(_) => Ok(()),
327                _ => Err(ServiceError::UnexpectedResponse),
328            }
329        }
330    };
331
332    (peer_not $method:ident $Not:ident($Param: ident)) => {
333        pub async fn $method(&self, params: $Param) -> Result<(), ServiceError> {
334            self.send_notification(ClientNotification::$Not($Not {
335                method: Default::default(),
336                params,
337                extensions: Default::default(),
338            }))
339            .await?;
340            Ok(())
341        }
342    };
343    (peer_not $method:ident $Not:ident) => {
344        pub async fn $method(&self) -> Result<(), ServiceError> {
345            self.send_notification(ClientNotification::$Not($Not {
346                method: Default::default(),
347                extensions: Default::default(),
348            }))
349            .await?;
350            Ok(())
351        }
352    };
353}
354
355impl Peer<RoleClient> {
356    method!(peer_req complete CompleteRequest(CompleteRequestParams) => CompleteResult);
357    method!(peer_req set_level SetLevelRequest(SetLevelRequestParams));
358    method!(peer_req get_prompt GetPromptRequest(GetPromptRequestParams) => GetPromptResult);
359    method!(peer_req list_prompts ListPromptsRequest(PaginatedRequestParams)? => ListPromptsResult);
360    method!(peer_req list_resources ListResourcesRequest(PaginatedRequestParams)? => ListResourcesResult);
361    method!(peer_req list_resource_templates ListResourceTemplatesRequest(PaginatedRequestParams)? => ListResourceTemplatesResult);
362    method!(peer_req read_resource ReadResourceRequest(ReadResourceRequestParams) => ReadResourceResult);
363    method!(peer_req subscribe SubscribeRequest(SubscribeRequestParams) );
364    method!(peer_req unsubscribe UnsubscribeRequest(UnsubscribeRequestParams));
365    method!(peer_req call_tool CallToolRequest(CallToolRequestParams) => CallToolResult);
366    method!(peer_req list_tools ListToolsRequest(PaginatedRequestParams)? => ListToolsResult);
367
368    method!(peer_not notify_cancelled CancelledNotification(CancelledNotificationParam));
369    method!(peer_not notify_progress ProgressNotification(ProgressNotificationParam));
370    method!(peer_not notify_initialized InitializedNotification);
371    method!(peer_not notify_roots_list_changed RootsListChangedNotification);
372}
373
374impl Peer<RoleClient> {
375    /// A wrapper method for [`Peer<RoleClient>::list_tools`].
376    ///
377    /// This function will call [`Peer<RoleClient>::list_tools`] multiple times until all tools are listed.
378    pub async fn list_all_tools(&self) -> Result<Vec<crate::model::Tool>, ServiceError> {
379        let mut tools = Vec::new();
380        let mut cursor = None;
381        loop {
382            let result = self
383                .list_tools(Some(PaginatedRequestParams { meta: None, cursor }))
384                .await?;
385            tools.extend(result.tools);
386            cursor = result.next_cursor;
387            if cursor.is_none() {
388                break;
389            }
390        }
391        Ok(tools)
392    }
393
394    /// A wrapper method for [`Peer<RoleClient>::list_prompts`].
395    ///
396    /// This function will call [`Peer<RoleClient>::list_prompts`] multiple times until all prompts are listed.
397    pub async fn list_all_prompts(&self) -> Result<Vec<crate::model::Prompt>, ServiceError> {
398        let mut prompts = Vec::new();
399        let mut cursor = None;
400        loop {
401            let result = self
402                .list_prompts(Some(PaginatedRequestParams { meta: None, cursor }))
403                .await?;
404            prompts.extend(result.prompts);
405            cursor = result.next_cursor;
406            if cursor.is_none() {
407                break;
408            }
409        }
410        Ok(prompts)
411    }
412
413    /// A wrapper method for [`Peer<RoleClient>::list_resources`].
414    ///
415    /// This function will call [`Peer<RoleClient>::list_resources`] multiple times until all resources are listed.
416    pub async fn list_all_resources(&self) -> Result<Vec<crate::model::Resource>, ServiceError> {
417        let mut resources = Vec::new();
418        let mut cursor = None;
419        loop {
420            let result = self
421                .list_resources(Some(PaginatedRequestParams { meta: None, cursor }))
422                .await?;
423            resources.extend(result.resources);
424            cursor = result.next_cursor;
425            if cursor.is_none() {
426                break;
427            }
428        }
429        Ok(resources)
430    }
431
432    /// A wrapper method for [`Peer<RoleClient>::list_resource_templates`].
433    ///
434    /// This function will call [`Peer<RoleClient>::list_resource_templates`] multiple times until all resource templates are listed.
435    pub async fn list_all_resource_templates(
436        &self,
437    ) -> Result<Vec<crate::model::ResourceTemplate>, ServiceError> {
438        let mut resource_templates = Vec::new();
439        let mut cursor = None;
440        loop {
441            let result = self
442                .list_resource_templates(Some(PaginatedRequestParams { meta: None, cursor }))
443                .await?;
444            resource_templates.extend(result.resource_templates);
445            cursor = result.next_cursor;
446            if cursor.is_none() {
447                break;
448            }
449        }
450        Ok(resource_templates)
451    }
452
453    /// Convenient method to get completion suggestions for a prompt argument
454    ///
455    /// # Arguments
456    /// * `prompt_name` - Name of the prompt being completed
457    /// * `argument_name` - Name of the argument being completed  
458    /// * `current_value` - Current partial value of the argument
459    /// * `context` - Optional context with previously resolved arguments
460    ///
461    /// # Returns
462    /// CompletionInfo with suggestions for the specified prompt argument
463    pub async fn complete_prompt_argument(
464        &self,
465        prompt_name: impl Into<String>,
466        argument_name: impl Into<String>,
467        current_value: impl Into<String>,
468        context: Option<CompletionContext>,
469    ) -> Result<CompletionInfo, ServiceError> {
470        let request = CompleteRequestParams {
471            meta: None,
472            r#ref: Reference::for_prompt(prompt_name),
473            argument: ArgumentInfo {
474                name: argument_name.into(),
475                value: current_value.into(),
476            },
477            context,
478        };
479
480        let result = self.complete(request).await?;
481        Ok(result.completion)
482    }
483
484    /// Convenient method to get completion suggestions for a resource URI argument
485    ///
486    /// # Arguments
487    /// * `uri_template` - URI template pattern being completed
488    /// * `argument_name` - Name of the URI parameter being completed
489    /// * `current_value` - Current partial value of the parameter
490    /// * `context` - Optional context with previously resolved arguments
491    ///
492    /// # Returns
493    /// CompletionInfo with suggestions for the specified resource URI argument
494    pub async fn complete_resource_argument(
495        &self,
496        uri_template: impl Into<String>,
497        argument_name: impl Into<String>,
498        current_value: impl Into<String>,
499        context: Option<CompletionContext>,
500    ) -> Result<CompletionInfo, ServiceError> {
501        let request = CompleteRequestParams {
502            meta: None,
503            r#ref: Reference::for_resource(uri_template),
504            argument: ArgumentInfo {
505                name: argument_name.into(),
506                value: current_value.into(),
507            },
508            context,
509        };
510
511        let result = self.complete(request).await?;
512        Ok(result.completion)
513    }
514
515    /// Simple completion for a prompt argument without context
516    ///
517    /// This is a convenience wrapper around `complete_prompt_argument` for
518    /// simple completion scenarios that don't require context awareness.
519    pub async fn complete_prompt_simple(
520        &self,
521        prompt_name: impl Into<String>,
522        argument_name: impl Into<String>,
523        current_value: impl Into<String>,
524    ) -> Result<Vec<String>, ServiceError> {
525        let completion = self
526            .complete_prompt_argument(prompt_name, argument_name, current_value, None)
527            .await?;
528        Ok(completion.values)
529    }
530
531    /// Simple completion for a resource URI argument without context
532    ///
533    /// This is a convenience wrapper around `complete_resource_argument` for
534    /// simple completion scenarios that don't require context awareness.
535    pub async fn complete_resource_simple(
536        &self,
537        uri_template: impl Into<String>,
538        argument_name: impl Into<String>,
539        current_value: impl Into<String>,
540    ) -> Result<Vec<String>, ServiceError> {
541        let completion = self
542            .complete_resource_argument(uri_template, argument_name, current_value, None)
543            .await?;
544        Ok(completion.values)
545    }
546}