1use std::borrow::Cow;
2
3use thiserror::Error;
4
5use super::*;
6use crate::{
7 model::{
8 ArgumentInfo, CallToolRequest, CallToolRequestParams, CallToolResult,
9 CancelledNotification, CancelledNotificationParam, ClientInfo, ClientJsonRpcMessage,
10 ClientNotification, ClientRequest, ClientResult, CompleteRequest, CompleteRequestParams,
11 CompleteResult, CompletionContext, CompletionInfo, ErrorData, GetPromptRequest,
12 GetPromptRequestParams, GetPromptResult, InitializeRequest, InitializedNotification,
13 JsonRpcResponse, ListPromptsRequest, ListPromptsResult, ListResourceTemplatesRequest,
14 ListResourceTemplatesResult, ListResourcesRequest, ListResourcesResult, ListToolsRequest,
15 ListToolsResult, PaginatedRequestParams, ProgressNotification, ProgressNotificationParam,
16 ReadResourceRequest, ReadResourceRequestParams, ReadResourceResult, Reference, RequestId,
17 RootsListChangedNotification, ServerInfo, ServerJsonRpcMessage, ServerNotification,
18 ServerRequest, ServerResult, SetLevelRequest, SetLevelRequestParams, SubscribeRequest,
19 SubscribeRequestParams, UnsubscribeRequest, UnsubscribeRequestParams,
20 },
21 transport::DynamicTransportError,
22};
23
24#[derive(Error, Debug)]
28#[non_exhaustive]
29pub enum ClientInitializeError {
30 #[error("expect initialized response, but received: {0:?}")]
31 ExpectedInitResponse(Option<ServerJsonRpcMessage>),
32
33 #[error("expect initialized result, but received: {0:?}")]
34 ExpectedInitResult(Option<ServerResult>),
35
36 #[error("conflict initialized response id: expected {0}, got {1}")]
37 ConflictInitResponseId(RequestId, RequestId),
38
39 #[error("connection closed: {0}")]
40 ConnectionClosed(String),
41
42 #[error("Send message error {error}, when {context}")]
43 TransportError {
44 error: DynamicTransportError,
45 context: Cow<'static, str>,
46 },
47
48 #[error("JSON-RPC error: {0}")]
49 JsonRpcError(ErrorData),
50
51 #[error("Cancelled")]
52 Cancelled,
53}
54
55impl ClientInitializeError {
56 pub fn transport<T: Transport<RoleClient> + 'static>(
57 error: T::Error,
58 context: impl Into<Cow<'static, str>>,
59 ) -> Self {
60 Self::TransportError {
61 error: DynamicTransportError::new::<T, _>(error),
62 context: context.into(),
63 }
64 }
65}
66
67async fn expect_next_message<T>(
69 transport: &mut T,
70 context: &str,
71) -> Result<ServerJsonRpcMessage, ClientInitializeError>
72where
73 T: Transport<RoleClient>,
74{
75 transport
76 .receive()
77 .await
78 .ok_or_else(|| ClientInitializeError::ConnectionClosed(context.to_string()))
79}
80
81async fn expect_response<T, S>(
83 transport: &mut T,
84 context: &str,
85 service: &S,
86 peer: Peer<RoleClient>,
87) -> Result<(ServerResult, RequestId), ClientInitializeError>
88where
89 T: Transport<RoleClient>,
90 S: Service<RoleClient>,
91{
92 loop {
93 let message = expect_next_message(transport, context).await?;
94 match message {
95 ServerJsonRpcMessage::Response(JsonRpcResponse { id, result, .. }) => {
97 break Ok((result, id));
98 }
99 ServerJsonRpcMessage::Error(error) => {
101 break Err(ClientInitializeError::JsonRpcError(error.error));
102 }
103 ServerJsonRpcMessage::Notification(mut notification) => {
105 let ServerNotification::LoggingMessageNotification(logging) =
106 &mut notification.notification
107 else {
108 tracing::warn!(?notification, "Received unexpected message");
109 continue;
110 };
111
112 let mut context = NotificationContext {
113 peer: peer.clone(),
114 meta: Meta::default(),
115 extensions: Extensions::default(),
116 };
117
118 if let Some(meta) = logging.extensions.get_mut::<Meta>() {
119 std::mem::swap(&mut context.meta, meta);
120 }
121 std::mem::swap(&mut context.extensions, &mut logging.extensions);
122
123 if let Err(error) = service
124 .handle_notification(notification.notification, context)
125 .await
126 {
127 tracing::warn!(?error, "Handle logging before handshake failed.");
128 }
129 }
130 ServerJsonRpcMessage::Request(ref request)
132 if matches!(request.request, ServerRequest::PingRequest(_)) =>
133 {
134 tracing::trace!("Received ping request. Ignored.")
135 }
136 _ => tracing::warn!(?message, "Received unexpected message"),
138 }
139 }
140}
141
142#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
143#[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")]
144pub struct RoleClient;
145
146impl ServiceRole for RoleClient {
147 type Req = ClientRequest;
148 type Resp = ClientResult;
149 type Not = ClientNotification;
150 type PeerReq = ServerRequest;
151 type PeerResp = ServerResult;
152 type PeerNot = ServerNotification;
153 type Info = ClientInfo;
154 type PeerInfo = ServerInfo;
155 type InitializeError = ClientInitializeError;
156 const IS_CLIENT: bool = true;
157}
158
159pub type ServerSink = Peer<RoleClient>;
160
161impl<S: Service<RoleClient>> ServiceExt<RoleClient> for S {
162 fn serve_with_ct<T, E, A>(
163 self,
164 transport: T,
165 ct: CancellationToken,
166 ) -> impl Future<Output = Result<RunningService<RoleClient, Self>, ClientInitializeError>>
167 + MaybeSendFuture
168 where
169 T: IntoTransport<RoleClient, E, A>,
170 E: std::error::Error + Send + Sync + 'static,
171 Self: Sized,
172 {
173 serve_client_with_ct(self, transport, ct)
174 }
175}
176
177pub async fn serve_client<S, T, E, A>(
178 service: S,
179 transport: T,
180) -> Result<RunningService<RoleClient, S>, ClientInitializeError>
181where
182 S: Service<RoleClient>,
183 T: IntoTransport<RoleClient, E, A>,
184 E: std::error::Error + Send + Sync + 'static,
185{
186 serve_client_with_ct(service, transport, Default::default()).await
187}
188
189pub async fn serve_client_with_ct<S, T, E, A>(
190 service: S,
191 transport: T,
192 ct: CancellationToken,
193) -> Result<RunningService<RoleClient, S>, ClientInitializeError>
194where
195 S: Service<RoleClient>,
196 T: IntoTransport<RoleClient, E, A>,
197 E: std::error::Error + Send + Sync + 'static,
198{
199 tokio::select! {
200 result = serve_client_with_ct_inner(service, transport.into_transport(), ct.clone()) => { result }
201 _ = ct.cancelled() => {
202 Err(ClientInitializeError::Cancelled)
203 }
204 }
205}
206
207async fn serve_client_with_ct_inner<S, T>(
208 service: S,
209 transport: T,
210 ct: CancellationToken,
211) -> Result<RunningService<RoleClient, S>, ClientInitializeError>
212where
213 S: Service<RoleClient>,
214 T: Transport<RoleClient> + 'static,
215{
216 let mut transport = transport.into_transport();
217 let id_provider = <Arc<AtomicU32RequestIdProvider>>::default();
218
219 let id = id_provider.next_request_id();
221 let init_request = InitializeRequest {
222 method: Default::default(),
223 params: service.get_info(),
224 extensions: Default::default(),
225 };
226 transport
227 .send(ClientJsonRpcMessage::request(
228 ClientRequest::InitializeRequest(init_request),
229 id.clone(),
230 ))
231 .await
232 .map_err(|error| ClientInitializeError::TransportError {
233 error: DynamicTransportError::new::<T, _>(error),
234 context: "send initialize request".into(),
235 })?;
236
237 let (peer, peer_rx) = Peer::new(id_provider, None);
238
239 let (response, response_id) = expect_response(
240 &mut transport,
241 "initialize response",
242 &service,
243 peer.clone(),
244 )
245 .await?;
246
247 if id != response_id {
248 return Err(ClientInitializeError::ConflictInitResponseId(
249 id,
250 response_id,
251 ));
252 }
253
254 let ServerResult::InitializeResult(initialize_result) = response else {
255 return Err(ClientInitializeError::ExpectedInitResult(Some(response)));
256 };
257 peer.set_peer_info(initialize_result);
258
259 let notification = ClientJsonRpcMessage::notification(
261 ClientNotification::InitializedNotification(InitializedNotification {
262 method: Default::default(),
263 extensions: Default::default(),
264 }),
265 );
266 transport.send(notification).await.map_err(|error| {
267 ClientInitializeError::transport::<T>(error, "send initialized notification")
268 })?;
269 Ok(serve_inner(service, transport, peer, peer_rx, ct))
270}
271
272macro_rules! method {
273 (peer_req $method:ident $Req:ident() => $Resp: ident ) => {
274 pub async fn $method(&self) -> Result<$Resp, ServiceError> {
275 let result = self
276 .send_request(ClientRequest::$Req($Req {
277 method: Default::default(),
278 }))
279 .await?;
280 match result {
281 ServerResult::$Resp(result) => Ok(result),
282 _ => Err(ServiceError::UnexpectedResponse),
283 }
284 }
285 };
286 (peer_req $method:ident $Req:ident($Param: ident) => $Resp: ident ) => {
287 pub async fn $method(&self, params: $Param) -> Result<$Resp, ServiceError> {
288 let result = self
289 .send_request(ClientRequest::$Req($Req {
290 method: Default::default(),
291 params,
292 extensions: Default::default(),
293 }))
294 .await?;
295 match result {
296 ServerResult::$Resp(result) => Ok(result),
297 _ => Err(ServiceError::UnexpectedResponse),
298 }
299 }
300 };
301 (peer_req $method:ident $Req:ident($Param: ident)? => $Resp: ident ) => {
302 pub async fn $method(&self, params: Option<$Param>) -> Result<$Resp, ServiceError> {
303 let result = self
304 .send_request(ClientRequest::$Req($Req {
305 method: Default::default(),
306 params,
307 extensions: Default::default(),
308 }))
309 .await?;
310 match result {
311 ServerResult::$Resp(result) => Ok(result),
312 _ => Err(ServiceError::UnexpectedResponse),
313 }
314 }
315 };
316 (peer_req $method:ident $Req:ident($Param: ident)) => {
317 pub async fn $method(&self, params: $Param) -> Result<(), ServiceError> {
318 let result = self
319 .send_request(ClientRequest::$Req($Req {
320 method: Default::default(),
321 params,
322 extensions: Default::default(),
323 }))
324 .await?;
325 match result {
326 ServerResult::EmptyResult(_) => Ok(()),
327 _ => Err(ServiceError::UnexpectedResponse),
328 }
329 }
330 };
331
332 (peer_not $method:ident $Not:ident($Param: ident)) => {
333 pub async fn $method(&self, params: $Param) -> Result<(), ServiceError> {
334 self.send_notification(ClientNotification::$Not($Not {
335 method: Default::default(),
336 params,
337 extensions: Default::default(),
338 }))
339 .await?;
340 Ok(())
341 }
342 };
343 (peer_not $method:ident $Not:ident) => {
344 pub async fn $method(&self) -> Result<(), ServiceError> {
345 self.send_notification(ClientNotification::$Not($Not {
346 method: Default::default(),
347 extensions: Default::default(),
348 }))
349 .await?;
350 Ok(())
351 }
352 };
353}
354
355impl Peer<RoleClient> {
356 method!(peer_req complete CompleteRequest(CompleteRequestParams) => CompleteResult);
357 method!(peer_req set_level SetLevelRequest(SetLevelRequestParams));
358 method!(peer_req get_prompt GetPromptRequest(GetPromptRequestParams) => GetPromptResult);
359 method!(peer_req list_prompts ListPromptsRequest(PaginatedRequestParams)? => ListPromptsResult);
360 method!(peer_req list_resources ListResourcesRequest(PaginatedRequestParams)? => ListResourcesResult);
361 method!(peer_req list_resource_templates ListResourceTemplatesRequest(PaginatedRequestParams)? => ListResourceTemplatesResult);
362 method!(peer_req read_resource ReadResourceRequest(ReadResourceRequestParams) => ReadResourceResult);
363 method!(peer_req subscribe SubscribeRequest(SubscribeRequestParams) );
364 method!(peer_req unsubscribe UnsubscribeRequest(UnsubscribeRequestParams));
365 method!(peer_req call_tool CallToolRequest(CallToolRequestParams) => CallToolResult);
366 method!(peer_req list_tools ListToolsRequest(PaginatedRequestParams)? => ListToolsResult);
367
368 method!(peer_not notify_cancelled CancelledNotification(CancelledNotificationParam));
369 method!(peer_not notify_progress ProgressNotification(ProgressNotificationParam));
370 method!(peer_not notify_initialized InitializedNotification);
371 method!(peer_not notify_roots_list_changed RootsListChangedNotification);
372}
373
374impl Peer<RoleClient> {
375 pub async fn list_all_tools(&self) -> Result<Vec<crate::model::Tool>, ServiceError> {
379 let mut tools = Vec::new();
380 let mut cursor = None;
381 loop {
382 let result = self
383 .list_tools(Some(PaginatedRequestParams { meta: None, cursor }))
384 .await?;
385 tools.extend(result.tools);
386 cursor = result.next_cursor;
387 if cursor.is_none() {
388 break;
389 }
390 }
391 Ok(tools)
392 }
393
394 pub async fn list_all_prompts(&self) -> Result<Vec<crate::model::Prompt>, ServiceError> {
398 let mut prompts = Vec::new();
399 let mut cursor = None;
400 loop {
401 let result = self
402 .list_prompts(Some(PaginatedRequestParams { meta: None, cursor }))
403 .await?;
404 prompts.extend(result.prompts);
405 cursor = result.next_cursor;
406 if cursor.is_none() {
407 break;
408 }
409 }
410 Ok(prompts)
411 }
412
413 pub async fn list_all_resources(&self) -> Result<Vec<crate::model::Resource>, ServiceError> {
417 let mut resources = Vec::new();
418 let mut cursor = None;
419 loop {
420 let result = self
421 .list_resources(Some(PaginatedRequestParams { meta: None, cursor }))
422 .await?;
423 resources.extend(result.resources);
424 cursor = result.next_cursor;
425 if cursor.is_none() {
426 break;
427 }
428 }
429 Ok(resources)
430 }
431
432 pub async fn list_all_resource_templates(
436 &self,
437 ) -> Result<Vec<crate::model::ResourceTemplate>, ServiceError> {
438 let mut resource_templates = Vec::new();
439 let mut cursor = None;
440 loop {
441 let result = self
442 .list_resource_templates(Some(PaginatedRequestParams { meta: None, cursor }))
443 .await?;
444 resource_templates.extend(result.resource_templates);
445 cursor = result.next_cursor;
446 if cursor.is_none() {
447 break;
448 }
449 }
450 Ok(resource_templates)
451 }
452
453 pub async fn complete_prompt_argument(
464 &self,
465 prompt_name: impl Into<String>,
466 argument_name: impl Into<String>,
467 current_value: impl Into<String>,
468 context: Option<CompletionContext>,
469 ) -> Result<CompletionInfo, ServiceError> {
470 let request = CompleteRequestParams {
471 meta: None,
472 r#ref: Reference::for_prompt(prompt_name),
473 argument: ArgumentInfo {
474 name: argument_name.into(),
475 value: current_value.into(),
476 },
477 context,
478 };
479
480 let result = self.complete(request).await?;
481 Ok(result.completion)
482 }
483
484 pub async fn complete_resource_argument(
495 &self,
496 uri_template: impl Into<String>,
497 argument_name: impl Into<String>,
498 current_value: impl Into<String>,
499 context: Option<CompletionContext>,
500 ) -> Result<CompletionInfo, ServiceError> {
501 let request = CompleteRequestParams {
502 meta: None,
503 r#ref: Reference::for_resource(uri_template),
504 argument: ArgumentInfo {
505 name: argument_name.into(),
506 value: current_value.into(),
507 },
508 context,
509 };
510
511 let result = self.complete(request).await?;
512 Ok(result.completion)
513 }
514
515 pub async fn complete_prompt_simple(
520 &self,
521 prompt_name: impl Into<String>,
522 argument_name: impl Into<String>,
523 current_value: impl Into<String>,
524 ) -> Result<Vec<String>, ServiceError> {
525 let completion = self
526 .complete_prompt_argument(prompt_name, argument_name, current_value, None)
527 .await?;
528 Ok(completion.values)
529 }
530
531 pub async fn complete_resource_simple(
536 &self,
537 uri_template: impl Into<String>,
538 argument_name: impl Into<String>,
539 current_value: impl Into<String>,
540 ) -> Result<Vec<String>, ServiceError> {
541 let completion = self
542 .complete_resource_argument(uri_template, argument_name, current_value, None)
543 .await?;
544 Ok(completion.values)
545 }
546}