agenterra_rmcp/service/
client.rs

1use std::borrow::Cow;
2
3use thiserror::Error;
4
5use super::*;
6use crate::model::{
7    CallToolRequest, CallToolRequestParam, CallToolResult, CancelledNotification,
8    CancelledNotificationParam, ClientInfo, ClientJsonRpcMessage, ClientNotification,
9    ClientRequest, ClientResult, CompleteRequest, CompleteRequestParam, CompleteResult,
10    GetPromptRequest, GetPromptRequestParam, GetPromptResult, InitializeRequest,
11    InitializedNotification, JsonRpcResponse, ListPromptsRequest, ListPromptsResult,
12    ListResourceTemplatesRequest, ListResourceTemplatesResult, ListResourcesRequest,
13    ListResourcesResult, ListToolsRequest, ListToolsResult, PaginatedRequestParam,
14    ProgressNotification, ProgressNotificationParam, ReadResourceRequest, ReadResourceRequestParam,
15    ReadResourceResult, RequestId, RootsListChangedNotification, ServerInfo, ServerJsonRpcMessage,
16    ServerNotification, ServerRequest, ServerResult, SetLevelRequest, SetLevelRequestParam,
17    SubscribeRequest, SubscribeRequestParam, UnsubscribeRequest, UnsubscribeRequestParam,
18};
19
20/// It represents the error that may occur when serving the client.
21///
22/// if you want to handle the error, you can use `serve_client_with_ct` or `serve_client` with `Result<RunningService<RoleClient, S>, ClientError>`
23#[derive(Error, Debug)]
24pub enum ClientInitializeError<E> {
25    #[error("expect initialized response, but received: {0:?}")]
26    ExpectedInitResponse(Option<ServerJsonRpcMessage>),
27
28    #[error("expect initialized result, but received: {0:?}")]
29    ExpectedInitResult(Option<ServerResult>),
30
31    #[error("conflict initialized response id: expected {0}, got {1}")]
32    ConflictInitResponseId(RequestId, RequestId),
33
34    #[error("connection closed: {0}")]
35    ConnectionClosed(String),
36
37    #[error("Send message error {error}, when {context}")]
38    TransportError {
39        error: E,
40        context: Cow<'static, str>,
41    },
42
43    #[error("Cancelled")]
44    Cancelled,
45}
46
47/// Helper function to get the next message from the stream
48async fn expect_next_message<T, E>(
49    transport: &mut T,
50    context: &str,
51) -> Result<ServerJsonRpcMessage, ClientInitializeError<E>>
52where
53    T: Transport<RoleClient>,
54{
55    transport
56        .receive()
57        .await
58        .ok_or_else(|| ClientInitializeError::ConnectionClosed(context.to_string()))
59}
60
61/// Helper function to expect a response from the stream
62async fn expect_response<T, E>(
63    transport: &mut T,
64    context: &str,
65) -> Result<(ServerResult, RequestId), ClientInitializeError<E>>
66where
67    T: Transport<RoleClient>,
68{
69    let msg = expect_next_message(transport, context).await?;
70
71    match msg {
72        ServerJsonRpcMessage::Response(JsonRpcResponse { id, result, .. }) => Ok((result, id)),
73        _ => Err(ClientInitializeError::ExpectedInitResponse(Some(msg))),
74    }
75}
76
77#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
78pub struct RoleClient;
79
80impl ServiceRole for RoleClient {
81    type Req = ClientRequest;
82    type Resp = ClientResult;
83    type Not = ClientNotification;
84    type PeerReq = ServerRequest;
85    type PeerResp = ServerResult;
86    type PeerNot = ServerNotification;
87    type Info = ClientInfo;
88    type PeerInfo = ServerInfo;
89    type InitializeError<E> = ClientInitializeError<E>;
90    const IS_CLIENT: bool = true;
91}
92
93pub type ServerSink = Peer<RoleClient>;
94
95impl<S: Service<RoleClient>> ServiceExt<RoleClient> for S {
96    fn serve_with_ct<T, E, A>(
97        self,
98        transport: T,
99        ct: CancellationToken,
100    ) -> impl Future<Output = Result<RunningService<RoleClient, Self>, ClientInitializeError<E>>> + Send
101    where
102        T: IntoTransport<RoleClient, E, A>,
103        E: std::error::Error + From<std::io::Error> + Send + Sync + 'static,
104        Self: Sized,
105    {
106        serve_client_with_ct(self, transport, ct)
107    }
108}
109
110pub async fn serve_client<S, T, E, A>(
111    service: S,
112    transport: T,
113) -> Result<RunningService<RoleClient, S>, ClientInitializeError<E>>
114where
115    S: Service<RoleClient>,
116    T: IntoTransport<RoleClient, E, A>,
117    E: std::error::Error + Send + Sync + 'static,
118{
119    serve_client_with_ct(service, transport, Default::default()).await
120}
121
122pub async fn serve_client_with_ct<S, T, E, A>(
123    service: S,
124    transport: T,
125    ct: CancellationToken,
126) -> Result<RunningService<RoleClient, S>, ClientInitializeError<E>>
127where
128    S: Service<RoleClient>,
129    T: IntoTransport<RoleClient, E, A>,
130    E: std::error::Error + Send + Sync + 'static,
131{
132    tokio::select! {
133        result = serve_client_with_ct_inner(service, transport, ct.clone()) => { result }
134        _ = ct.cancelled() => {
135            Err(ClientInitializeError::Cancelled)
136        }
137    }
138}
139
140async fn serve_client_with_ct_inner<S, T, E, A>(
141    service: S,
142    transport: T,
143    ct: CancellationToken,
144) -> Result<RunningService<RoleClient, S>, ClientInitializeError<E>>
145where
146    S: Service<RoleClient>,
147    T: IntoTransport<RoleClient, E, A>,
148    E: std::error::Error + Send + Sync + 'static,
149{
150    let mut transport = transport.into_transport();
151    let id_provider = <Arc<AtomicU32RequestIdProvider>>::default();
152
153    // service
154    let id = id_provider.next_request_id();
155    let init_request = InitializeRequest {
156        method: Default::default(),
157        params: service.get_info(),
158        extensions: Default::default(),
159    };
160    transport
161        .send(ClientJsonRpcMessage::request(
162            ClientRequest::InitializeRequest(init_request),
163            id.clone(),
164        ))
165        .await
166        .map_err(|error| ClientInitializeError::TransportError {
167            error,
168            context: "send initialize request".into(),
169        })?;
170
171    let (response, response_id) = expect_response(&mut transport, "initialize response").await?;
172
173    if id != response_id {
174        return Err(ClientInitializeError::ConflictInitResponseId(
175            id,
176            response_id,
177        ));
178    }
179
180    let ServerResult::InitializeResult(initialize_result) = response else {
181        return Err(ClientInitializeError::ExpectedInitResult(Some(response)));
182    };
183
184    // send notification
185    let notification = ClientJsonRpcMessage::notification(
186        ClientNotification::InitializedNotification(InitializedNotification {
187            method: Default::default(),
188            extensions: Default::default(),
189        }),
190    );
191    transport
192        .send(notification)
193        .await
194        .map_err(|error| ClientInitializeError::TransportError {
195            error,
196            context: "send initialized notification".into(),
197        })?;
198    let (peer, peer_rx) = Peer::new(id_provider, Some(initialize_result));
199    Ok(serve_inner(service, transport, peer, peer_rx, ct))
200}
201
202macro_rules! method {
203    (peer_req $method:ident $Req:ident() => $Resp: ident ) => {
204        pub async fn $method(&self) -> Result<$Resp, ServiceError> {
205            let result = self
206                .send_request(ClientRequest::$Req($Req {
207                    method: Default::default(),
208                }))
209                .await?;
210            match result {
211                ServerResult::$Resp(result) => Ok(result),
212                _ => Err(ServiceError::UnexpectedResponse),
213            }
214        }
215    };
216    (peer_req $method:ident $Req:ident($Param: ident) => $Resp: ident ) => {
217        pub async fn $method(&self, params: $Param) -> Result<$Resp, ServiceError> {
218            let result = self
219                .send_request(ClientRequest::$Req($Req {
220                    method: Default::default(),
221                    params,
222                    extensions: Default::default(),
223                }))
224                .await?;
225            match result {
226                ServerResult::$Resp(result) => Ok(result),
227                _ => Err(ServiceError::UnexpectedResponse),
228            }
229        }
230    };
231    (peer_req $method:ident $Req:ident($Param: ident)? => $Resp: ident ) => {
232        pub async fn $method(&self, params: Option<$Param>) -> Result<$Resp, ServiceError> {
233            let result = self
234                .send_request(ClientRequest::$Req($Req {
235                    method: Default::default(),
236                    params,
237                    extensions: Default::default(),
238                }))
239                .await?;
240            match result {
241                ServerResult::$Resp(result) => Ok(result),
242                _ => Err(ServiceError::UnexpectedResponse),
243            }
244        }
245    };
246    (peer_req $method:ident $Req:ident($Param: ident)) => {
247        pub async fn $method(&self, params: $Param) -> Result<(), ServiceError> {
248            let result = self
249                .send_request(ClientRequest::$Req($Req {
250                    method: Default::default(),
251                    params,
252                    extensions: Default::default(),
253                }))
254                .await?;
255            match result {
256                ServerResult::EmptyResult(_) => Ok(()),
257                _ => Err(ServiceError::UnexpectedResponse),
258            }
259        }
260    };
261
262    (peer_not $method:ident $Not:ident($Param: ident)) => {
263        pub async fn $method(&self, params: $Param) -> Result<(), ServiceError> {
264            self.send_notification(ClientNotification::$Not($Not {
265                method: Default::default(),
266                params,
267                extensions: Default::default(),
268            }))
269            .await?;
270            Ok(())
271        }
272    };
273    (peer_not $method:ident $Not:ident) => {
274        pub async fn $method(&self) -> Result<(), ServiceError> {
275            self.send_notification(ClientNotification::$Not($Not {
276                method: Default::default(),
277                extensions: Default::default(),
278            }))
279            .await?;
280            Ok(())
281        }
282    };
283}
284
285impl Peer<RoleClient> {
286    method!(peer_req complete CompleteRequest(CompleteRequestParam) => CompleteResult);
287    method!(peer_req set_level SetLevelRequest(SetLevelRequestParam));
288    method!(peer_req get_prompt GetPromptRequest(GetPromptRequestParam) => GetPromptResult);
289    method!(peer_req list_prompts ListPromptsRequest(PaginatedRequestParam)? => ListPromptsResult);
290    method!(peer_req list_resources ListResourcesRequest(PaginatedRequestParam)? => ListResourcesResult);
291    method!(peer_req list_resource_templates ListResourceTemplatesRequest(PaginatedRequestParam)? => ListResourceTemplatesResult);
292    method!(peer_req read_resource ReadResourceRequest(ReadResourceRequestParam) => ReadResourceResult);
293    method!(peer_req subscribe SubscribeRequest(SubscribeRequestParam) );
294    method!(peer_req unsubscribe UnsubscribeRequest(UnsubscribeRequestParam));
295    method!(peer_req call_tool CallToolRequest(CallToolRequestParam) => CallToolResult);
296    method!(peer_req list_tools ListToolsRequest(PaginatedRequestParam)? => ListToolsResult);
297
298    method!(peer_not notify_cancelled CancelledNotification(CancelledNotificationParam));
299    method!(peer_not notify_progress ProgressNotification(ProgressNotificationParam));
300    method!(peer_not notify_initialized InitializedNotification);
301    method!(peer_not notify_roots_list_changed RootsListChangedNotification);
302}
303
304impl Peer<RoleClient> {
305    /// A wrapper method for [`Peer<RoleClient>::list_tools`].
306    ///
307    /// This function will call [`Peer<RoleClient>::list_tools`] multiple times until all tools are listed.
308    pub async fn list_all_tools(&self) -> Result<Vec<crate::model::Tool>, ServiceError> {
309        let mut tools = Vec::new();
310        let mut cursor = None;
311        loop {
312            let result = self
313                .list_tools(Some(PaginatedRequestParam { cursor }))
314                .await?;
315            tools.extend(result.tools);
316            cursor = result.next_cursor;
317            if cursor.is_none() {
318                break;
319            }
320        }
321        Ok(tools)
322    }
323
324    /// A wrapper method for [`Peer<RoleClient>::list_prompts`].
325    ///
326    /// This function will call [`Peer<RoleClient>::list_prompts`] multiple times until all prompts are listed.
327    pub async fn list_all_prompts(&self) -> Result<Vec<crate::model::Prompt>, ServiceError> {
328        let mut prompts = Vec::new();
329        let mut cursor = None;
330        loop {
331            let result = self
332                .list_prompts(Some(PaginatedRequestParam { cursor }))
333                .await?;
334            prompts.extend(result.prompts);
335            cursor = result.next_cursor;
336            if cursor.is_none() {
337                break;
338            }
339        }
340        Ok(prompts)
341    }
342
343    /// A wrapper method for [`Peer<RoleClient>::list_resources`].
344    ///
345    /// This function will call [`Peer<RoleClient>::list_resources`] multiple times until all resources are listed.
346    pub async fn list_all_resources(&self) -> Result<Vec<crate::model::Resource>, ServiceError> {
347        let mut resources = Vec::new();
348        let mut cursor = None;
349        loop {
350            let result = self
351                .list_resources(Some(PaginatedRequestParam { cursor }))
352                .await?;
353            resources.extend(result.resources);
354            cursor = result.next_cursor;
355            if cursor.is_none() {
356                break;
357            }
358        }
359        Ok(resources)
360    }
361
362    /// A wrapper method for [`Peer<RoleClient>::list_resource_templates`].
363    ///
364    /// This function will call [`Peer<RoleClient>::list_resource_templates`] multiple times until all resource templates are listed.
365    pub async fn list_all_resource_templates(
366        &self,
367    ) -> Result<Vec<crate::model::ResourceTemplate>, ServiceError> {
368        let mut resource_templates = Vec::new();
369        let mut cursor = None;
370        loop {
371            let result = self
372                .list_resource_templates(Some(PaginatedRequestParam { cursor }))
373                .await?;
374            resource_templates.extend(result.resource_templates);
375            cursor = result.next_cursor;
376            if cursor.is_none() {
377                break;
378            }
379        }
380        Ok(resource_templates)
381    }
382}