1use std::borrow::Cow;
2
3use thiserror::Error;
4
5use super::*;
6use crate::model::{
7 CallToolRequest, CallToolRequestParam, CallToolResult, CancelledNotification,
8 CancelledNotificationParam, ClientInfo, ClientJsonRpcMessage, ClientNotification,
9 ClientRequest, ClientResult, CompleteRequest, CompleteRequestParam, CompleteResult,
10 GetPromptRequest, GetPromptRequestParam, GetPromptResult, InitializeRequest,
11 InitializedNotification, JsonRpcResponse, ListPromptsRequest, ListPromptsResult,
12 ListResourceTemplatesRequest, ListResourceTemplatesResult, ListResourcesRequest,
13 ListResourcesResult, ListToolsRequest, ListToolsResult, PaginatedRequestParam,
14 ProgressNotification, ProgressNotificationParam, ReadResourceRequest, ReadResourceRequestParam,
15 ReadResourceResult, RequestId, RootsListChangedNotification, ServerInfo, ServerJsonRpcMessage,
16 ServerNotification, ServerRequest, ServerResult, SetLevelRequest, SetLevelRequestParam,
17 SubscribeRequest, SubscribeRequestParam, UnsubscribeRequest, UnsubscribeRequestParam,
18};
19
20#[derive(Error, Debug)]
24pub enum ClientInitializeError<E> {
25 #[error("expect initialized response, but received: {0:?}")]
26 ExpectedInitResponse(Option<ServerJsonRpcMessage>),
27
28 #[error("expect initialized result, but received: {0:?}")]
29 ExpectedInitResult(Option<ServerResult>),
30
31 #[error("conflict initialized response id: expected {0}, got {1}")]
32 ConflictInitResponseId(RequestId, RequestId),
33
34 #[error("connection closed: {0}")]
35 ConnectionClosed(String),
36
37 #[error("Send message error {error}, when {context}")]
38 TransportError {
39 error: E,
40 context: Cow<'static, str>,
41 },
42
43 #[error("Cancelled")]
44 Cancelled,
45}
46
47async fn expect_next_message<T, E>(
49 transport: &mut T,
50 context: &str,
51) -> Result<ServerJsonRpcMessage, ClientInitializeError<E>>
52where
53 T: Transport<RoleClient>,
54{
55 transport
56 .receive()
57 .await
58 .ok_or_else(|| ClientInitializeError::ConnectionClosed(context.to_string()))
59}
60
61async fn expect_response<T, E>(
63 transport: &mut T,
64 context: &str,
65) -> Result<(ServerResult, RequestId), ClientInitializeError<E>>
66where
67 T: Transport<RoleClient>,
68{
69 let msg = expect_next_message(transport, context).await?;
70
71 match msg {
72 ServerJsonRpcMessage::Response(JsonRpcResponse { id, result, .. }) => Ok((result, id)),
73 _ => Err(ClientInitializeError::ExpectedInitResponse(Some(msg))),
74 }
75}
76
77#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
78pub struct RoleClient;
79
80impl ServiceRole for RoleClient {
81 type Req = ClientRequest;
82 type Resp = ClientResult;
83 type Not = ClientNotification;
84 type PeerReq = ServerRequest;
85 type PeerResp = ServerResult;
86 type PeerNot = ServerNotification;
87 type Info = ClientInfo;
88 type PeerInfo = ServerInfo;
89 type InitializeError<E> = ClientInitializeError<E>;
90 const IS_CLIENT: bool = true;
91}
92
93pub type ServerSink = Peer<RoleClient>;
94
95impl<S: Service<RoleClient>> ServiceExt<RoleClient> for S {
96 fn serve_with_ct<T, E, A>(
97 self,
98 transport: T,
99 ct: CancellationToken,
100 ) -> impl Future<Output = Result<RunningService<RoleClient, Self>, ClientInitializeError<E>>> + Send
101 where
102 T: IntoTransport<RoleClient, E, A>,
103 E: std::error::Error + From<std::io::Error> + Send + Sync + 'static,
104 Self: Sized,
105 {
106 serve_client_with_ct(self, transport, ct)
107 }
108}
109
110pub async fn serve_client<S, T, E, A>(
111 service: S,
112 transport: T,
113) -> Result<RunningService<RoleClient, S>, ClientInitializeError<E>>
114where
115 S: Service<RoleClient>,
116 T: IntoTransport<RoleClient, E, A>,
117 E: std::error::Error + Send + Sync + 'static,
118{
119 serve_client_with_ct(service, transport, Default::default()).await
120}
121
122pub async fn serve_client_with_ct<S, T, E, A>(
123 service: S,
124 transport: T,
125 ct: CancellationToken,
126) -> Result<RunningService<RoleClient, S>, ClientInitializeError<E>>
127where
128 S: Service<RoleClient>,
129 T: IntoTransport<RoleClient, E, A>,
130 E: std::error::Error + Send + Sync + 'static,
131{
132 tokio::select! {
133 result = serve_client_with_ct_inner(service, transport, ct.clone()) => { result }
134 _ = ct.cancelled() => {
135 Err(ClientInitializeError::Cancelled)
136 }
137 }
138}
139
140async fn serve_client_with_ct_inner<S, T, E, A>(
141 service: S,
142 transport: T,
143 ct: CancellationToken,
144) -> Result<RunningService<RoleClient, S>, ClientInitializeError<E>>
145where
146 S: Service<RoleClient>,
147 T: IntoTransport<RoleClient, E, A>,
148 E: std::error::Error + Send + Sync + 'static,
149{
150 let mut transport = transport.into_transport();
151 let id_provider = <Arc<AtomicU32RequestIdProvider>>::default();
152
153 let id = id_provider.next_request_id();
155 let init_request = InitializeRequest {
156 method: Default::default(),
157 params: service.get_info(),
158 extensions: Default::default(),
159 };
160 transport
161 .send(ClientJsonRpcMessage::request(
162 ClientRequest::InitializeRequest(init_request),
163 id.clone(),
164 ))
165 .await
166 .map_err(|error| ClientInitializeError::TransportError {
167 error,
168 context: "send initialize request".into(),
169 })?;
170
171 let (response, response_id) = expect_response(&mut transport, "initialize response").await?;
172
173 if id != response_id {
174 return Err(ClientInitializeError::ConflictInitResponseId(
175 id,
176 response_id,
177 ));
178 }
179
180 let ServerResult::InitializeResult(initialize_result) = response else {
181 return Err(ClientInitializeError::ExpectedInitResult(Some(response)));
182 };
183
184 let notification = ClientJsonRpcMessage::notification(
186 ClientNotification::InitializedNotification(InitializedNotification {
187 method: Default::default(),
188 extensions: Default::default(),
189 }),
190 );
191 transport
192 .send(notification)
193 .await
194 .map_err(|error| ClientInitializeError::TransportError {
195 error,
196 context: "send initialized notification".into(),
197 })?;
198 let (peer, peer_rx) = Peer::new(id_provider, Some(initialize_result));
199 Ok(serve_inner(service, transport, peer, peer_rx, ct))
200}
201
202macro_rules! method {
203 (peer_req $method:ident $Req:ident() => $Resp: ident ) => {
204 pub async fn $method(&self) -> Result<$Resp, ServiceError> {
205 let result = self
206 .send_request(ClientRequest::$Req($Req {
207 method: Default::default(),
208 }))
209 .await?;
210 match result {
211 ServerResult::$Resp(result) => Ok(result),
212 _ => Err(ServiceError::UnexpectedResponse),
213 }
214 }
215 };
216 (peer_req $method:ident $Req:ident($Param: ident) => $Resp: ident ) => {
217 pub async fn $method(&self, params: $Param) -> Result<$Resp, ServiceError> {
218 let result = self
219 .send_request(ClientRequest::$Req($Req {
220 method: Default::default(),
221 params,
222 extensions: Default::default(),
223 }))
224 .await?;
225 match result {
226 ServerResult::$Resp(result) => Ok(result),
227 _ => Err(ServiceError::UnexpectedResponse),
228 }
229 }
230 };
231 (peer_req $method:ident $Req:ident($Param: ident)? => $Resp: ident ) => {
232 pub async fn $method(&self, params: Option<$Param>) -> Result<$Resp, ServiceError> {
233 let result = self
234 .send_request(ClientRequest::$Req($Req {
235 method: Default::default(),
236 params,
237 extensions: Default::default(),
238 }))
239 .await?;
240 match result {
241 ServerResult::$Resp(result) => Ok(result),
242 _ => Err(ServiceError::UnexpectedResponse),
243 }
244 }
245 };
246 (peer_req $method:ident $Req:ident($Param: ident)) => {
247 pub async fn $method(&self, params: $Param) -> Result<(), ServiceError> {
248 let result = self
249 .send_request(ClientRequest::$Req($Req {
250 method: Default::default(),
251 params,
252 extensions: Default::default(),
253 }))
254 .await?;
255 match result {
256 ServerResult::EmptyResult(_) => Ok(()),
257 _ => Err(ServiceError::UnexpectedResponse),
258 }
259 }
260 };
261
262 (peer_not $method:ident $Not:ident($Param: ident)) => {
263 pub async fn $method(&self, params: $Param) -> Result<(), ServiceError> {
264 self.send_notification(ClientNotification::$Not($Not {
265 method: Default::default(),
266 params,
267 extensions: Default::default(),
268 }))
269 .await?;
270 Ok(())
271 }
272 };
273 (peer_not $method:ident $Not:ident) => {
274 pub async fn $method(&self) -> Result<(), ServiceError> {
275 self.send_notification(ClientNotification::$Not($Not {
276 method: Default::default(),
277 extensions: Default::default(),
278 }))
279 .await?;
280 Ok(())
281 }
282 };
283}
284
285impl Peer<RoleClient> {
286 method!(peer_req complete CompleteRequest(CompleteRequestParam) => CompleteResult);
287 method!(peer_req set_level SetLevelRequest(SetLevelRequestParam));
288 method!(peer_req get_prompt GetPromptRequest(GetPromptRequestParam) => GetPromptResult);
289 method!(peer_req list_prompts ListPromptsRequest(PaginatedRequestParam)? => ListPromptsResult);
290 method!(peer_req list_resources ListResourcesRequest(PaginatedRequestParam)? => ListResourcesResult);
291 method!(peer_req list_resource_templates ListResourceTemplatesRequest(PaginatedRequestParam)? => ListResourceTemplatesResult);
292 method!(peer_req read_resource ReadResourceRequest(ReadResourceRequestParam) => ReadResourceResult);
293 method!(peer_req subscribe SubscribeRequest(SubscribeRequestParam) );
294 method!(peer_req unsubscribe UnsubscribeRequest(UnsubscribeRequestParam));
295 method!(peer_req call_tool CallToolRequest(CallToolRequestParam) => CallToolResult);
296 method!(peer_req list_tools ListToolsRequest(PaginatedRequestParam)? => ListToolsResult);
297
298 method!(peer_not notify_cancelled CancelledNotification(CancelledNotificationParam));
299 method!(peer_not notify_progress ProgressNotification(ProgressNotificationParam));
300 method!(peer_not notify_initialized InitializedNotification);
301 method!(peer_not notify_roots_list_changed RootsListChangedNotification);
302}
303
304impl Peer<RoleClient> {
305 pub async fn list_all_tools(&self) -> Result<Vec<crate::model::Tool>, ServiceError> {
309 let mut tools = Vec::new();
310 let mut cursor = None;
311 loop {
312 let result = self
313 .list_tools(Some(PaginatedRequestParam { cursor }))
314 .await?;
315 tools.extend(result.tools);
316 cursor = result.next_cursor;
317 if cursor.is_none() {
318 break;
319 }
320 }
321 Ok(tools)
322 }
323
324 pub async fn list_all_prompts(&self) -> Result<Vec<crate::model::Prompt>, ServiceError> {
328 let mut prompts = Vec::new();
329 let mut cursor = None;
330 loop {
331 let result = self
332 .list_prompts(Some(PaginatedRequestParam { cursor }))
333 .await?;
334 prompts.extend(result.prompts);
335 cursor = result.next_cursor;
336 if cursor.is_none() {
337 break;
338 }
339 }
340 Ok(prompts)
341 }
342
343 pub async fn list_all_resources(&self) -> Result<Vec<crate::model::Resource>, ServiceError> {
347 let mut resources = Vec::new();
348 let mut cursor = None;
349 loop {
350 let result = self
351 .list_resources(Some(PaginatedRequestParam { cursor }))
352 .await?;
353 resources.extend(result.resources);
354 cursor = result.next_cursor;
355 if cursor.is_none() {
356 break;
357 }
358 }
359 Ok(resources)
360 }
361
362 pub async fn list_all_resource_templates(
366 &self,
367 ) -> Result<Vec<crate::model::ResourceTemplate>, ServiceError> {
368 let mut resource_templates = Vec::new();
369 let mut cursor = None;
370 loop {
371 let result = self
372 .list_resource_templates(Some(PaginatedRequestParam { cursor }))
373 .await?;
374 resource_templates.extend(result.resource_templates);
375 cursor = result.next_cursor;
376 if cursor.is_none() {
377 break;
378 }
379 }
380 Ok(resource_templates)
381 }
382}