Skip to main content

mcpkit_rs/service/
client.rs

1use std::borrow::Cow;
2
3use thiserror::Error;
4
5use super::*;
6use crate::{
7    model::{
8        ArgumentInfo, CallToolRequest, CallToolRequestParams, CallToolResult,
9        CancelledNotification, CancelledNotificationParam, ClientInfo, ClientJsonRpcMessage,
10        ClientNotification, ClientRequest, ClientResult, CompleteRequest, CompleteRequestParams,
11        CompleteResult, CompletionContext, CompletionInfo, ErrorData, GetPromptRequest,
12        GetPromptRequestParams, GetPromptResult, InitializeRequest, InitializedNotification,
13        JsonRpcResponse, ListPromptsRequest, ListPromptsResult, ListResourceTemplatesRequest,
14        ListResourceTemplatesResult, ListResourcesRequest, ListResourcesResult, ListToolsRequest,
15        ListToolsResult, PaginatedRequestParams, ProgressNotification, ProgressNotificationParam,
16        ReadResourceRequest, ReadResourceRequestParams, ReadResourceResult, Reference, RequestId,
17        RootsListChangedNotification, ServerInfo, ServerJsonRpcMessage, ServerNotification,
18        ServerRequest, ServerResult, SetLevelRequest, SetLevelRequestParams, SubscribeRequest,
19        SubscribeRequestParams, UnsubscribeRequest, UnsubscribeRequestParams,
20    },
21    transport::DynamicTransportError,
22};
23
24/// It represents the error that may occur when serving the client.
25///
26/// if you want to handle the error, you can use `serve_client_with_ct` or `serve_client` with `Result<RunningService<RoleClient, S>, ClientError>`
27#[derive(Error, Debug)]
28pub enum ClientInitializeError {
29    #[error("expect initialized response, but received: {0:?}")]
30    ExpectedInitResponse(Option<ServerJsonRpcMessage>),
31
32    #[error("expect initialized result, but received: {0:?}")]
33    ExpectedInitResult(Option<ServerResult>),
34
35    #[error("conflict initialized response id: expected {0}, got {1}")]
36    ConflictInitResponseId(RequestId, RequestId),
37
38    #[error("connection closed: {0}")]
39    ConnectionClosed(String),
40
41    #[error("Send message error {error}, when {context}")]
42    TransportError {
43        error: DynamicTransportError,
44        context: Cow<'static, str>,
45    },
46
47    #[error("JSON-RPC error: {0}")]
48    JsonRpcError(ErrorData),
49
50    #[error("Cancelled")]
51    Cancelled,
52}
53
54impl ClientInitializeError {
55    pub fn transport<T: Transport<RoleClient> + 'static>(
56        error: T::Error,
57        context: impl Into<Cow<'static, str>>,
58    ) -> Self {
59        Self::TransportError {
60            error: DynamicTransportError::new::<T, _>(error),
61            context: context.into(),
62        }
63    }
64}
65
66/// Helper function to get the next message from the stream
67async fn expect_next_message<T>(
68    transport: &mut T,
69    context: &str,
70) -> Result<ServerJsonRpcMessage, ClientInitializeError>
71where
72    T: Transport<RoleClient>,
73{
74    transport
75        .receive()
76        .await
77        .ok_or_else(|| ClientInitializeError::ConnectionClosed(context.to_string()))
78}
79
80/// Helper function to expect a response from the stream
81async fn expect_response<T, S>(
82    transport: &mut T,
83    context: &str,
84    service: &S,
85    peer: Peer<RoleClient>,
86) -> Result<(ServerResult, RequestId), ClientInitializeError>
87where
88    T: Transport<RoleClient>,
89    S: Service<RoleClient>,
90{
91    loop {
92        let message = expect_next_message(transport, context).await?;
93        match message {
94            // Expected message to complete the initialization
95            ServerJsonRpcMessage::Response(JsonRpcResponse { id, result, .. }) => {
96                break Ok((result, id));
97            }
98            // Handle JSON-RPC error responses
99            ServerJsonRpcMessage::Error(error) => {
100                break Err(ClientInitializeError::JsonRpcError(error.error));
101            }
102            // Server could send logging messages before handshake
103            ServerJsonRpcMessage::Notification(mut notification) => {
104                let ServerNotification::LoggingMessageNotification(logging) =
105                    &mut notification.notification
106                else {
107                    tracing::warn!(?notification, "Received unexpected message");
108                    continue;
109                };
110
111                let mut context = NotificationContext {
112                    peer: peer.clone(),
113                    meta: Meta::default(),
114                    extensions: Extensions::default(),
115                };
116
117                if let Some(meta) = logging.extensions.get_mut::<Meta>() {
118                    std::mem::swap(&mut context.meta, meta);
119                }
120                std::mem::swap(&mut context.extensions, &mut logging.extensions);
121
122                if let Err(error) = service
123                    .handle_notification(notification.notification, context)
124                    .await
125                {
126                    tracing::warn!(?error, "Handle logging before handshake failed.");
127                }
128            }
129            // Server could send pings before handshake
130            ServerJsonRpcMessage::Request(ref request)
131                if matches!(request.request, ServerRequest::PingRequest(_)) =>
132            {
133                tracing::trace!("Received ping request. Ignored.")
134            }
135            // Server SHOULD NOT send any other messages before handshake. We ignore them anyway
136            _ => tracing::warn!(?message, "Received unexpected message"),
137        }
138    }
139}
140
141#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
142pub struct RoleClient;
143
144impl ServiceRole for RoleClient {
145    type Req = ClientRequest;
146    type Resp = ClientResult;
147    type Not = ClientNotification;
148    type PeerReq = ServerRequest;
149    type PeerResp = ServerResult;
150    type PeerNot = ServerNotification;
151    type Info = ClientInfo;
152    type PeerInfo = ServerInfo;
153    type InitializeError = ClientInitializeError;
154    const IS_CLIENT: bool = true;
155}
156
157pub type ServerSink = Peer<RoleClient>;
158
159impl<S: Service<RoleClient>> ServiceExt<RoleClient> for S {
160    fn serve_with_ct<T, E, A>(
161        self,
162        transport: T,
163        ct: CancellationToken,
164    ) -> impl Future<Output = Result<RunningService<RoleClient, Self>, ClientInitializeError>> + Send
165    where
166        T: IntoTransport<RoleClient, E, A>,
167        E: std::error::Error + Send + Sync + 'static,
168        Self: Sized,
169    {
170        serve_client_with_ct(self, transport, ct)
171    }
172}
173
174pub async fn serve_client<S, T, E, A>(
175    service: S,
176    transport: T,
177) -> Result<RunningService<RoleClient, S>, ClientInitializeError>
178where
179    S: Service<RoleClient>,
180    T: IntoTransport<RoleClient, E, A>,
181    E: std::error::Error + Send + Sync + 'static,
182{
183    serve_client_with_ct(service, transport, Default::default()).await
184}
185
186pub async fn serve_client_with_ct<S, T, E, A>(
187    service: S,
188    transport: T,
189    ct: CancellationToken,
190) -> Result<RunningService<RoleClient, S>, ClientInitializeError>
191where
192    S: Service<RoleClient>,
193    T: IntoTransport<RoleClient, E, A>,
194    E: std::error::Error + Send + Sync + 'static,
195{
196    tokio::select! {
197        result = serve_client_with_ct_inner(service, transport.into_transport(), ct.clone()) => { result }
198        _ = ct.cancelled() => {
199            Err(ClientInitializeError::Cancelled)
200        }
201    }
202}
203
204async fn serve_client_with_ct_inner<S, T>(
205    service: S,
206    transport: T,
207    ct: CancellationToken,
208) -> Result<RunningService<RoleClient, S>, ClientInitializeError>
209where
210    S: Service<RoleClient>,
211    T: Transport<RoleClient> + 'static,
212{
213    let mut transport = transport.into_transport();
214    let id_provider = <Arc<AtomicU32RequestIdProvider>>::default();
215
216    // service
217    let id = id_provider.next_request_id();
218    let init_request = InitializeRequest {
219        method: Default::default(),
220        params: service.get_info(),
221        extensions: Default::default(),
222    };
223    transport
224        .send(ClientJsonRpcMessage::request(
225            ClientRequest::InitializeRequest(init_request),
226            id.clone(),
227        ))
228        .await
229        .map_err(|error| ClientInitializeError::TransportError {
230            error: DynamicTransportError::new::<T, _>(error),
231            context: "send initialize request".into(),
232        })?;
233
234    let (peer, peer_rx) = Peer::new(id_provider, None);
235
236    let (response, response_id) = expect_response(
237        &mut transport,
238        "initialize response",
239        &service,
240        peer.clone(),
241    )
242    .await?;
243
244    if id != response_id {
245        return Err(ClientInitializeError::ConflictInitResponseId(
246            id,
247            response_id,
248        ));
249    }
250
251    let ServerResult::InitializeResult(initialize_result) = response else {
252        return Err(ClientInitializeError::ExpectedInitResult(Some(response)));
253    };
254    peer.set_peer_info(initialize_result);
255
256    // send notification
257    let notification = ClientJsonRpcMessage::notification(
258        ClientNotification::InitializedNotification(InitializedNotification {
259            method: Default::default(),
260            extensions: Default::default(),
261        }),
262    );
263    transport.send(notification).await.map_err(|error| {
264        ClientInitializeError::transport::<T>(error, "send initialized notification")
265    })?;
266    Ok(serve_inner(service, transport, peer, peer_rx, ct))
267}
268
269macro_rules! method {
270    (peer_req $method:ident $Req:ident() => $Resp: ident ) => {
271        pub async fn $method(&self) -> Result<$Resp, ServiceError> {
272            let result = self
273                .send_request(ClientRequest::$Req($Req {
274                    method: Default::default(),
275                }))
276                .await?;
277            match result {
278                ServerResult::$Resp(result) => Ok(result),
279                _ => Err(ServiceError::UnexpectedResponse),
280            }
281        }
282    };
283    (peer_req $method:ident $Req:ident($Param: ident) => $Resp: ident ) => {
284        pub async fn $method(&self, params: $Param) -> Result<$Resp, ServiceError> {
285            let result = self
286                .send_request(ClientRequest::$Req($Req {
287                    method: Default::default(),
288                    params,
289                    extensions: Default::default(),
290                }))
291                .await?;
292            match result {
293                ServerResult::$Resp(result) => Ok(result),
294                _ => Err(ServiceError::UnexpectedResponse),
295            }
296        }
297    };
298    (peer_req $method:ident $Req:ident($Param: ident)? => $Resp: ident ) => {
299        pub async fn $method(&self, params: Option<$Param>) -> Result<$Resp, ServiceError> {
300            let result = self
301                .send_request(ClientRequest::$Req($Req {
302                    method: Default::default(),
303                    params,
304                    extensions: Default::default(),
305                }))
306                .await?;
307            match result {
308                ServerResult::$Resp(result) => Ok(result),
309                _ => Err(ServiceError::UnexpectedResponse),
310            }
311        }
312    };
313    (peer_req $method:ident $Req:ident($Param: ident)) => {
314        pub async fn $method(&self, params: $Param) -> Result<(), ServiceError> {
315            let result = self
316                .send_request(ClientRequest::$Req($Req {
317                    method: Default::default(),
318                    params,
319                    extensions: Default::default(),
320                }))
321                .await?;
322            match result {
323                ServerResult::EmptyResult(_) => Ok(()),
324                _ => Err(ServiceError::UnexpectedResponse),
325            }
326        }
327    };
328
329    (peer_not $method:ident $Not:ident($Param: ident)) => {
330        pub async fn $method(&self, params: $Param) -> Result<(), ServiceError> {
331            self.send_notification(ClientNotification::$Not($Not {
332                method: Default::default(),
333                params,
334                extensions: Default::default(),
335            }))
336            .await?;
337            Ok(())
338        }
339    };
340    (peer_not $method:ident $Not:ident) => {
341        pub async fn $method(&self) -> Result<(), ServiceError> {
342            self.send_notification(ClientNotification::$Not($Not {
343                method: Default::default(),
344                extensions: Default::default(),
345            }))
346            .await?;
347            Ok(())
348        }
349    };
350}
351
352impl Peer<RoleClient> {
353    method!(peer_req complete CompleteRequest(CompleteRequestParams) => CompleteResult);
354    method!(peer_req set_level SetLevelRequest(SetLevelRequestParams));
355    method!(peer_req get_prompt GetPromptRequest(GetPromptRequestParams) => GetPromptResult);
356    method!(peer_req list_prompts ListPromptsRequest(PaginatedRequestParams)? => ListPromptsResult);
357    method!(peer_req list_resources ListResourcesRequest(PaginatedRequestParams)? => ListResourcesResult);
358    method!(peer_req list_resource_templates ListResourceTemplatesRequest(PaginatedRequestParams)? => ListResourceTemplatesResult);
359    method!(peer_req read_resource ReadResourceRequest(ReadResourceRequestParams) => ReadResourceResult);
360    method!(peer_req subscribe SubscribeRequest(SubscribeRequestParams) );
361    method!(peer_req unsubscribe UnsubscribeRequest(UnsubscribeRequestParams));
362    method!(peer_req call_tool CallToolRequest(CallToolRequestParams) => CallToolResult);
363    method!(peer_req list_tools ListToolsRequest(PaginatedRequestParams)? => ListToolsResult);
364
365    method!(peer_not notify_cancelled CancelledNotification(CancelledNotificationParam));
366    method!(peer_not notify_progress ProgressNotification(ProgressNotificationParam));
367    method!(peer_not notify_initialized InitializedNotification);
368    method!(peer_not notify_roots_list_changed RootsListChangedNotification);
369}
370
371impl Peer<RoleClient> {
372    /// A wrapper method for [`Peer<RoleClient>::list_tools`].
373    ///
374    /// This function will call [`Peer<RoleClient>::list_tools`] multiple times until all tools are listed.
375    pub async fn list_all_tools(&self) -> Result<Vec<crate::model::Tool>, ServiceError> {
376        let mut tools = Vec::new();
377        let mut cursor = None;
378        loop {
379            let result = self
380                .list_tools(Some(PaginatedRequestParams { meta: None, cursor }))
381                .await?;
382            tools.extend(result.tools);
383            cursor = result.next_cursor;
384            if cursor.is_none() {
385                break;
386            }
387        }
388        Ok(tools)
389    }
390
391    /// A wrapper method for [`Peer<RoleClient>::list_prompts`].
392    ///
393    /// This function will call [`Peer<RoleClient>::list_prompts`] multiple times until all prompts are listed.
394    pub async fn list_all_prompts(&self) -> Result<Vec<crate::model::Prompt>, ServiceError> {
395        let mut prompts = Vec::new();
396        let mut cursor = None;
397        loop {
398            let result = self
399                .list_prompts(Some(PaginatedRequestParams { meta: None, cursor }))
400                .await?;
401            prompts.extend(result.prompts);
402            cursor = result.next_cursor;
403            if cursor.is_none() {
404                break;
405            }
406        }
407        Ok(prompts)
408    }
409
410    /// A wrapper method for [`Peer<RoleClient>::list_resources`].
411    ///
412    /// This function will call [`Peer<RoleClient>::list_resources`] multiple times until all resources are listed.
413    pub async fn list_all_resources(&self) -> Result<Vec<crate::model::Resource>, ServiceError> {
414        let mut resources = Vec::new();
415        let mut cursor = None;
416        loop {
417            let result = self
418                .list_resources(Some(PaginatedRequestParams { meta: None, cursor }))
419                .await?;
420            resources.extend(result.resources);
421            cursor = result.next_cursor;
422            if cursor.is_none() {
423                break;
424            }
425        }
426        Ok(resources)
427    }
428
429    /// A wrapper method for [`Peer<RoleClient>::list_resource_templates`].
430    ///
431    /// This function will call [`Peer<RoleClient>::list_resource_templates`] multiple times until all resource templates are listed.
432    pub async fn list_all_resource_templates(
433        &self,
434    ) -> Result<Vec<crate::model::ResourceTemplate>, ServiceError> {
435        let mut resource_templates = Vec::new();
436        let mut cursor = None;
437        loop {
438            let result = self
439                .list_resource_templates(Some(PaginatedRequestParams { meta: None, cursor }))
440                .await?;
441            resource_templates.extend(result.resource_templates);
442            cursor = result.next_cursor;
443            if cursor.is_none() {
444                break;
445            }
446        }
447        Ok(resource_templates)
448    }
449
450    /// Convenient method to get completion suggestions for a prompt argument
451    ///
452    /// # Arguments
453    /// * `prompt_name` - Name of the prompt being completed
454    /// * `argument_name` - Name of the argument being completed  
455    /// * `current_value` - Current partial value of the argument
456    /// * `context` - Optional context with previously resolved arguments
457    ///
458    /// # Returns
459    /// CompletionInfo with suggestions for the specified prompt argument
460    pub async fn complete_prompt_argument(
461        &self,
462        prompt_name: impl Into<String>,
463        argument_name: impl Into<String>,
464        current_value: impl Into<String>,
465        context: Option<CompletionContext>,
466    ) -> Result<CompletionInfo, ServiceError> {
467        let request = CompleteRequestParams {
468            meta: None,
469            r#ref: Reference::for_prompt(prompt_name),
470            argument: ArgumentInfo {
471                name: argument_name.into(),
472                value: current_value.into(),
473            },
474            context,
475        };
476
477        let result = self.complete(request).await?;
478        Ok(result.completion)
479    }
480
481    /// Convenient method to get completion suggestions for a resource URI argument
482    ///
483    /// # Arguments
484    /// * `uri_template` - URI template pattern being completed
485    /// * `argument_name` - Name of the URI parameter being completed
486    /// * `current_value` - Current partial value of the parameter
487    /// * `context` - Optional context with previously resolved arguments
488    ///
489    /// # Returns
490    /// CompletionInfo with suggestions for the specified resource URI argument
491    pub async fn complete_resource_argument(
492        &self,
493        uri_template: impl Into<String>,
494        argument_name: impl Into<String>,
495        current_value: impl Into<String>,
496        context: Option<CompletionContext>,
497    ) -> Result<CompletionInfo, ServiceError> {
498        let request = CompleteRequestParams {
499            meta: None,
500            r#ref: Reference::for_resource(uri_template),
501            argument: ArgumentInfo {
502                name: argument_name.into(),
503                value: current_value.into(),
504            },
505            context,
506        };
507
508        let result = self.complete(request).await?;
509        Ok(result.completion)
510    }
511
512    /// Simple completion for a prompt argument without context
513    ///
514    /// This is a convenience wrapper around `complete_prompt_argument` for
515    /// simple completion scenarios that don't require context awareness.
516    pub async fn complete_prompt_simple(
517        &self,
518        prompt_name: impl Into<String>,
519        argument_name: impl Into<String>,
520        current_value: impl Into<String>,
521    ) -> Result<Vec<String>, ServiceError> {
522        let completion = self
523            .complete_prompt_argument(prompt_name, argument_name, current_value, None)
524            .await?;
525        Ok(completion.values)
526    }
527
528    /// Simple completion for a resource URI argument without context
529    ///
530    /// This is a convenience wrapper around `complete_resource_argument` for
531    /// simple completion scenarios that don't require context awareness.
532    pub async fn complete_resource_simple(
533        &self,
534        uri_template: impl Into<String>,
535        argument_name: impl Into<String>,
536        current_value: impl Into<String>,
537    ) -> Result<Vec<String>, ServiceError> {
538        let completion = self
539            .complete_resource_argument(uri_template, argument_name, current_value, None)
540            .await?;
541        Ok(completion.values)
542    }
543}