Skip to main content

rmcp_soddygo/service/
client.rs

1use std::borrow::Cow;
2
3use thiserror::Error;
4
5use super::*;
6use crate::{
7    model::{
8        ArgumentInfo, CallToolRequest, CallToolRequestParams, CallToolResult,
9        CancelledNotification, CancelledNotificationParam, ClientInfo, ClientJsonRpcMessage,
10        ClientNotification, ClientRequest, ClientResult, CompleteRequest, CompleteRequestParams,
11        CompleteResult, CompletionContext, CompletionInfo, ErrorData, GetPromptRequest,
12        GetPromptRequestParams, GetPromptResult, InitializeRequest, InitializedNotification,
13        JsonRpcResponse, ListPromptsRequest, ListPromptsResult, ListResourceTemplatesRequest,
14        ListResourceTemplatesResult, ListResourcesRequest, ListResourcesResult, ListToolsRequest,
15        ListToolsResult, PaginatedRequestParams, ProgressNotification, ProgressNotificationParam,
16        ReadResourceRequest, ReadResourceRequestParams, ReadResourceResult, Reference, RequestId,
17        RootsListChangedNotification, ServerInfo, ServerJsonRpcMessage, ServerNotification,
18        ServerRequest, ServerResult, SetLevelRequest, SetLevelRequestParams, SubscribeRequest,
19        SubscribeRequestParams, UnsubscribeRequest, UnsubscribeRequestParams,
20    },
21    transport::DynamicTransportError,
22};
23
24/// It represents the error that may occur when serving the client.
25///
26/// if you want to handle the error, you can use `serve_client_with_ct` or `serve_client` with `Result<RunningService<RoleClient, S>, ClientError>`
27#[derive(Error, Debug)]
28#[non_exhaustive]
29pub enum ClientInitializeError {
30    #[error("expect initialized response, but received: {0:?}")]
31    ExpectedInitResponse(Option<ServerJsonRpcMessage>),
32
33    #[error("expect initialized result, but received: {0:?}")]
34    ExpectedInitResult(Option<ServerResult>),
35
36    #[error("conflict initialized response id: expected {0}, got {1}")]
37    ConflictInitResponseId(RequestId, RequestId),
38
39    #[error("connection closed: {0}")]
40    ConnectionClosed(String),
41
42    #[error("Send message error {error}, when {context}")]
43    TransportError {
44        error: DynamicTransportError,
45        context: Cow<'static, str>,
46    },
47
48    #[error("JSON-RPC error: {0}")]
49    JsonRpcError(ErrorData),
50
51    #[error("Cancelled")]
52    Cancelled,
53}
54
55impl ClientInitializeError {
56    pub fn transport<T: Transport<RoleClient> + 'static>(
57        error: T::Error,
58        context: impl Into<Cow<'static, str>>,
59    ) -> Self {
60        Self::TransportError {
61            error: DynamicTransportError::new::<T, _>(error),
62            context: context.into(),
63        }
64    }
65}
66
67/// Helper function to get the next message from the stream
68async fn expect_next_message<T>(
69    transport: &mut T,
70    context: &str,
71) -> Result<ServerJsonRpcMessage, ClientInitializeError>
72where
73    T: Transport<RoleClient>,
74{
75    transport
76        .receive()
77        .await
78        .ok_or_else(|| ClientInitializeError::ConnectionClosed(context.to_string()))
79}
80
81/// Helper function to expect a response from the stream
82async fn expect_response<T, S>(
83    transport: &mut T,
84    context: &str,
85    service: &S,
86    peer: Peer<RoleClient>,
87) -> Result<(ServerResult, RequestId), ClientInitializeError>
88where
89    T: Transport<RoleClient>,
90    S: Service<RoleClient>,
91{
92    loop {
93        let message = expect_next_message(transport, context).await?;
94        match message {
95            // Expected message to complete the initialization
96            ServerJsonRpcMessage::Response(JsonRpcResponse { id, result, .. }) => {
97                break Ok((result, id));
98            }
99            // Handle JSON-RPC error responses
100            ServerJsonRpcMessage::Error(error) => {
101                break Err(ClientInitializeError::JsonRpcError(error.error));
102            }
103            // Server could send logging messages before handshake
104            ServerJsonRpcMessage::Notification(mut notification) => {
105                let ServerNotification::LoggingMessageNotification(logging) =
106                    &mut notification.notification
107                else {
108                    tracing::warn!(?notification, "Received unexpected message");
109                    continue;
110                };
111
112                let mut context = NotificationContext {
113                    peer: peer.clone(),
114                    meta: Meta::default(),
115                    extensions: Extensions::default(),
116                };
117
118                if let Some(meta) = logging.extensions.get_mut::<Meta>() {
119                    std::mem::swap(&mut context.meta, meta);
120                }
121                std::mem::swap(&mut context.extensions, &mut logging.extensions);
122
123                if let Err(error) = service
124                    .handle_notification(notification.notification, context)
125                    .await
126                {
127                    tracing::warn!(?error, "Handle logging before handshake failed.");
128                }
129            }
130            // Server could send pings before handshake
131            ServerJsonRpcMessage::Request(ref request)
132                if matches!(request.request, ServerRequest::PingRequest(_)) =>
133            {
134                tracing::trace!("Received ping request. Ignored.")
135            }
136            // Server SHOULD NOT send any other messages before handshake. We ignore them anyway
137            _ => tracing::warn!(?message, "Received unexpected message"),
138        }
139    }
140}
141
142#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
143pub struct RoleClient;
144
145impl ServiceRole for RoleClient {
146    type Req = ClientRequest;
147    type Resp = ClientResult;
148    type Not = ClientNotification;
149    type PeerReq = ServerRequest;
150    type PeerResp = ServerResult;
151    type PeerNot = ServerNotification;
152    type Info = ClientInfo;
153    type PeerInfo = ServerInfo;
154    type InitializeError = ClientInitializeError;
155    const IS_CLIENT: bool = true;
156}
157
158pub type ServerSink = Peer<RoleClient>;
159
160impl<S: Service<RoleClient>> ServiceExt<RoleClient> for S {
161    fn serve_with_ct<T, E, A>(
162        self,
163        transport: T,
164        ct: CancellationToken,
165    ) -> impl Future<Output = Result<RunningService<RoleClient, Self>, ClientInitializeError>> + Send
166    where
167        T: IntoTransport<RoleClient, E, A>,
168        E: std::error::Error + Send + Sync + 'static,
169        Self: Sized,
170    {
171        serve_client_with_ct(self, transport, ct)
172    }
173}
174
175pub async fn serve_client<S, T, E, A>(
176    service: S,
177    transport: T,
178) -> Result<RunningService<RoleClient, S>, ClientInitializeError>
179where
180    S: Service<RoleClient>,
181    T: IntoTransport<RoleClient, E, A>,
182    E: std::error::Error + Send + Sync + 'static,
183{
184    serve_client_with_ct(service, transport, Default::default()).await
185}
186
187pub async fn serve_client_with_ct<S, T, E, A>(
188    service: S,
189    transport: T,
190    ct: CancellationToken,
191) -> Result<RunningService<RoleClient, S>, ClientInitializeError>
192where
193    S: Service<RoleClient>,
194    T: IntoTransport<RoleClient, E, A>,
195    E: std::error::Error + Send + Sync + 'static,
196{
197    tokio::select! {
198        result = serve_client_with_ct_inner(service, transport.into_transport(), ct.clone()) => { result }
199        _ = ct.cancelled() => {
200            Err(ClientInitializeError::Cancelled)
201        }
202    }
203}
204
205async fn serve_client_with_ct_inner<S, T>(
206    service: S,
207    transport: T,
208    ct: CancellationToken,
209) -> Result<RunningService<RoleClient, S>, ClientInitializeError>
210where
211    S: Service<RoleClient>,
212    T: Transport<RoleClient> + 'static,
213{
214    let mut transport = transport.into_transport();
215    let id_provider = <Arc<AtomicU32RequestIdProvider>>::default();
216
217    // service
218    let id = id_provider.next_request_id();
219    let init_request = InitializeRequest {
220        method: Default::default(),
221        params: service.get_info(),
222        extensions: Default::default(),
223    };
224    transport
225        .send(ClientJsonRpcMessage::request(
226            ClientRequest::InitializeRequest(init_request),
227            id.clone(),
228        ))
229        .await
230        .map_err(|error| ClientInitializeError::TransportError {
231            error: DynamicTransportError::new::<T, _>(error),
232            context: "send initialize request".into(),
233        })?;
234
235    let (peer, peer_rx) = Peer::new(id_provider, None);
236
237    let (response, response_id) = expect_response(
238        &mut transport,
239        "initialize response",
240        &service,
241        peer.clone(),
242    )
243    .await?;
244
245    if id != response_id {
246        return Err(ClientInitializeError::ConflictInitResponseId(
247            id,
248            response_id,
249        ));
250    }
251
252    let ServerResult::InitializeResult(initialize_result) = response else {
253        return Err(ClientInitializeError::ExpectedInitResult(Some(response)));
254    };
255    peer.set_peer_info(initialize_result);
256
257    // send notification
258    let notification = ClientJsonRpcMessage::notification(
259        ClientNotification::InitializedNotification(InitializedNotification {
260            method: Default::default(),
261            extensions: Default::default(),
262        }),
263    );
264    transport.send(notification).await.map_err(|error| {
265        ClientInitializeError::transport::<T>(error, "send initialized notification")
266    })?;
267    Ok(serve_inner(service, transport, peer, peer_rx, ct))
268}
269
270macro_rules! method {
271    (peer_req $method:ident $Req:ident() => $Resp: ident ) => {
272        pub async fn $method(&self) -> Result<$Resp, ServiceError> {
273            let result = self
274                .send_request(ClientRequest::$Req($Req {
275                    method: Default::default(),
276                }))
277                .await?;
278            match result {
279                ServerResult::$Resp(result) => Ok(result),
280                _ => Err(ServiceError::UnexpectedResponse),
281            }
282        }
283    };
284    (peer_req $method:ident $Req:ident($Param: ident) => $Resp: ident ) => {
285        pub async fn $method(&self, params: $Param) -> Result<$Resp, ServiceError> {
286            let result = self
287                .send_request(ClientRequest::$Req($Req {
288                    method: Default::default(),
289                    params,
290                    extensions: Default::default(),
291                }))
292                .await?;
293            match result {
294                ServerResult::$Resp(result) => Ok(result),
295                _ => Err(ServiceError::UnexpectedResponse),
296            }
297        }
298    };
299    (peer_req $method:ident $Req:ident($Param: ident)? => $Resp: ident ) => {
300        pub async fn $method(&self, params: Option<$Param>) -> Result<$Resp, ServiceError> {
301            let result = self
302                .send_request(ClientRequest::$Req($Req {
303                    method: Default::default(),
304                    params,
305                    extensions: Default::default(),
306                }))
307                .await?;
308            match result {
309                ServerResult::$Resp(result) => Ok(result),
310                _ => Err(ServiceError::UnexpectedResponse),
311            }
312        }
313    };
314    (peer_req $method:ident $Req:ident($Param: ident)) => {
315        pub async fn $method(&self, params: $Param) -> Result<(), ServiceError> {
316            let result = self
317                .send_request(ClientRequest::$Req($Req {
318                    method: Default::default(),
319                    params,
320                    extensions: Default::default(),
321                }))
322                .await?;
323            match result {
324                ServerResult::EmptyResult(_) => Ok(()),
325                _ => Err(ServiceError::UnexpectedResponse),
326            }
327        }
328    };
329
330    (peer_not $method:ident $Not:ident($Param: ident)) => {
331        pub async fn $method(&self, params: $Param) -> Result<(), ServiceError> {
332            self.send_notification(ClientNotification::$Not($Not {
333                method: Default::default(),
334                params,
335                extensions: Default::default(),
336            }))
337            .await?;
338            Ok(())
339        }
340    };
341    (peer_not $method:ident $Not:ident) => {
342        pub async fn $method(&self) -> Result<(), ServiceError> {
343            self.send_notification(ClientNotification::$Not($Not {
344                method: Default::default(),
345                extensions: Default::default(),
346            }))
347            .await?;
348            Ok(())
349        }
350    };
351}
352
353impl Peer<RoleClient> {
354    method!(peer_req complete CompleteRequest(CompleteRequestParams) => CompleteResult);
355    method!(peer_req set_level SetLevelRequest(SetLevelRequestParams));
356    method!(peer_req get_prompt GetPromptRequest(GetPromptRequestParams) => GetPromptResult);
357    method!(peer_req list_prompts ListPromptsRequest(PaginatedRequestParams)? => ListPromptsResult);
358    method!(peer_req list_resources ListResourcesRequest(PaginatedRequestParams)? => ListResourcesResult);
359    method!(peer_req list_resource_templates ListResourceTemplatesRequest(PaginatedRequestParams)? => ListResourceTemplatesResult);
360    method!(peer_req read_resource ReadResourceRequest(ReadResourceRequestParams) => ReadResourceResult);
361    method!(peer_req subscribe SubscribeRequest(SubscribeRequestParams) );
362    method!(peer_req unsubscribe UnsubscribeRequest(UnsubscribeRequestParams));
363    method!(peer_req call_tool CallToolRequest(CallToolRequestParams) => CallToolResult);
364    method!(peer_req list_tools ListToolsRequest(PaginatedRequestParams)? => ListToolsResult);
365
366    method!(peer_not notify_cancelled CancelledNotification(CancelledNotificationParam));
367    method!(peer_not notify_progress ProgressNotification(ProgressNotificationParam));
368    method!(peer_not notify_initialized InitializedNotification);
369    method!(peer_not notify_roots_list_changed RootsListChangedNotification);
370}
371
372impl Peer<RoleClient> {
373    /// A wrapper method for [`Peer<RoleClient>::list_tools`].
374    ///
375    /// This function will call [`Peer<RoleClient>::list_tools`] multiple times until all tools are listed.
376    pub async fn list_all_tools(&self) -> Result<Vec<crate::model::Tool>, ServiceError> {
377        let mut tools = Vec::new();
378        let mut cursor = None;
379        loop {
380            let result = self
381                .list_tools(Some(PaginatedRequestParams { meta: None, cursor }))
382                .await?;
383            tools.extend(result.tools);
384            cursor = result.next_cursor;
385            if cursor.is_none() {
386                break;
387            }
388        }
389        Ok(tools)
390    }
391
392    /// A wrapper method for [`Peer<RoleClient>::list_prompts`].
393    ///
394    /// This function will call [`Peer<RoleClient>::list_prompts`] multiple times until all prompts are listed.
395    pub async fn list_all_prompts(&self) -> Result<Vec<crate::model::Prompt>, ServiceError> {
396        let mut prompts = Vec::new();
397        let mut cursor = None;
398        loop {
399            let result = self
400                .list_prompts(Some(PaginatedRequestParams { meta: None, cursor }))
401                .await?;
402            prompts.extend(result.prompts);
403            cursor = result.next_cursor;
404            if cursor.is_none() {
405                break;
406            }
407        }
408        Ok(prompts)
409    }
410
411    /// A wrapper method for [`Peer<RoleClient>::list_resources`].
412    ///
413    /// This function will call [`Peer<RoleClient>::list_resources`] multiple times until all resources are listed.
414    pub async fn list_all_resources(&self) -> Result<Vec<crate::model::Resource>, ServiceError> {
415        let mut resources = Vec::new();
416        let mut cursor = None;
417        loop {
418            let result = self
419                .list_resources(Some(PaginatedRequestParams { meta: None, cursor }))
420                .await?;
421            resources.extend(result.resources);
422            cursor = result.next_cursor;
423            if cursor.is_none() {
424                break;
425            }
426        }
427        Ok(resources)
428    }
429
430    /// A wrapper method for [`Peer<RoleClient>::list_resource_templates`].
431    ///
432    /// This function will call [`Peer<RoleClient>::list_resource_templates`] multiple times until all resource templates are listed.
433    pub async fn list_all_resource_templates(
434        &self,
435    ) -> Result<Vec<crate::model::ResourceTemplate>, ServiceError> {
436        let mut resource_templates = Vec::new();
437        let mut cursor = None;
438        loop {
439            let result = self
440                .list_resource_templates(Some(PaginatedRequestParams { meta: None, cursor }))
441                .await?;
442            resource_templates.extend(result.resource_templates);
443            cursor = result.next_cursor;
444            if cursor.is_none() {
445                break;
446            }
447        }
448        Ok(resource_templates)
449    }
450
451    /// Convenient method to get completion suggestions for a prompt argument
452    ///
453    /// # Arguments
454    /// * `prompt_name` - Name of the prompt being completed
455    /// * `argument_name` - Name of the argument being completed  
456    /// * `current_value` - Current partial value of the argument
457    /// * `context` - Optional context with previously resolved arguments
458    ///
459    /// # Returns
460    /// CompletionInfo with suggestions for the specified prompt argument
461    pub async fn complete_prompt_argument(
462        &self,
463        prompt_name: impl Into<String>,
464        argument_name: impl Into<String>,
465        current_value: impl Into<String>,
466        context: Option<CompletionContext>,
467    ) -> Result<CompletionInfo, ServiceError> {
468        let request = CompleteRequestParams {
469            meta: None,
470            r#ref: Reference::for_prompt(prompt_name),
471            argument: ArgumentInfo {
472                name: argument_name.into(),
473                value: current_value.into(),
474            },
475            context,
476        };
477
478        let result = self.complete(request).await?;
479        Ok(result.completion)
480    }
481
482    /// Convenient method to get completion suggestions for a resource URI argument
483    ///
484    /// # Arguments
485    /// * `uri_template` - URI template pattern being completed
486    /// * `argument_name` - Name of the URI parameter being completed
487    /// * `current_value` - Current partial value of the parameter
488    /// * `context` - Optional context with previously resolved arguments
489    ///
490    /// # Returns
491    /// CompletionInfo with suggestions for the specified resource URI argument
492    pub async fn complete_resource_argument(
493        &self,
494        uri_template: impl Into<String>,
495        argument_name: impl Into<String>,
496        current_value: impl Into<String>,
497        context: Option<CompletionContext>,
498    ) -> Result<CompletionInfo, ServiceError> {
499        let request = CompleteRequestParams {
500            meta: None,
501            r#ref: Reference::for_resource(uri_template),
502            argument: ArgumentInfo {
503                name: argument_name.into(),
504                value: current_value.into(),
505            },
506            context,
507        };
508
509        let result = self.complete(request).await?;
510        Ok(result.completion)
511    }
512
513    /// Simple completion for a prompt argument without context
514    ///
515    /// This is a convenience wrapper around `complete_prompt_argument` for
516    /// simple completion scenarios that don't require context awareness.
517    pub async fn complete_prompt_simple(
518        &self,
519        prompt_name: impl Into<String>,
520        argument_name: impl Into<String>,
521        current_value: impl Into<String>,
522    ) -> Result<Vec<String>, ServiceError> {
523        let completion = self
524            .complete_prompt_argument(prompt_name, argument_name, current_value, None)
525            .await?;
526        Ok(completion.values)
527    }
528
529    /// Simple completion for a resource URI argument without context
530    ///
531    /// This is a convenience wrapper around `complete_resource_argument` for
532    /// simple completion scenarios that don't require context awareness.
533    pub async fn complete_resource_simple(
534        &self,
535        uri_template: impl Into<String>,
536        argument_name: impl Into<String>,
537        current_value: impl Into<String>,
538    ) -> Result<Vec<String>, ServiceError> {
539        let completion = self
540            .complete_resource_argument(uri_template, argument_name, current_value, None)
541            .await?;
542        Ok(completion.values)
543    }
544}