1use std::borrow::Cow;
2
3use thiserror::Error;
4
5use super::*;
6use crate::{
7 model::{
8 ArgumentInfo, CallToolRequest, CallToolRequestParams, CallToolResult,
9 CancelledNotification, CancelledNotificationParam, ClientInfo, ClientJsonRpcMessage,
10 ClientNotification, ClientRequest, ClientResult, CompleteRequest, CompleteRequestParams,
11 CompleteResult, CompletionContext, CompletionInfo, ErrorData, GetPromptRequest,
12 GetPromptRequestParams, GetPromptResult, InitializeRequest, InitializedNotification,
13 JsonRpcResponse, ListPromptsRequest, ListPromptsResult, ListResourceTemplatesRequest,
14 ListResourceTemplatesResult, ListResourcesRequest, ListResourcesResult, ListToolsRequest,
15 ListToolsResult, PaginatedRequestParams, ProgressNotification, ProgressNotificationParam,
16 ReadResourceRequest, ReadResourceRequestParams, ReadResourceResult, Reference, RequestId,
17 RootsListChangedNotification, ServerInfo, ServerJsonRpcMessage, ServerNotification,
18 ServerRequest, ServerResult, SetLevelRequest, SetLevelRequestParams, SubscribeRequest,
19 SubscribeRequestParams, UnsubscribeRequest, UnsubscribeRequestParams,
20 },
21 transport::DynamicTransportError,
22};
23
24#[derive(Error, Debug)]
28pub enum ClientInitializeError {
29 #[error("expect initialized response, but received: {0:?}")]
30 ExpectedInitResponse(Option<ServerJsonRpcMessage>),
31
32 #[error("expect initialized result, but received: {0:?}")]
33 ExpectedInitResult(Option<ServerResult>),
34
35 #[error("conflict initialized response id: expected {0}, got {1}")]
36 ConflictInitResponseId(RequestId, RequestId),
37
38 #[error("connection closed: {0}")]
39 ConnectionClosed(String),
40
41 #[error("Send message error {error}, when {context}")]
42 TransportError {
43 error: DynamicTransportError,
44 context: Cow<'static, str>,
45 },
46
47 #[error("JSON-RPC error: {0}")]
48 JsonRpcError(ErrorData),
49
50 #[error("Cancelled")]
51 Cancelled,
52}
53
54impl ClientInitializeError {
55 pub fn transport<T: Transport<RoleClient> + 'static>(
56 error: T::Error,
57 context: impl Into<Cow<'static, str>>,
58 ) -> Self {
59 Self::TransportError {
60 error: DynamicTransportError::new::<T, _>(error),
61 context: context.into(),
62 }
63 }
64}
65
66async fn expect_next_message<T>(
68 transport: &mut T,
69 context: &str,
70) -> Result<ServerJsonRpcMessage, ClientInitializeError>
71where
72 T: Transport<RoleClient>,
73{
74 transport
75 .receive()
76 .await
77 .ok_or_else(|| ClientInitializeError::ConnectionClosed(context.to_string()))
78}
79
80async fn expect_response<T, S>(
82 transport: &mut T,
83 context: &str,
84 service: &S,
85 peer: Peer<RoleClient>,
86) -> Result<(ServerResult, RequestId), ClientInitializeError>
87where
88 T: Transport<RoleClient>,
89 S: Service<RoleClient>,
90{
91 loop {
92 let message = expect_next_message(transport, context).await?;
93 match message {
94 ServerJsonRpcMessage::Response(JsonRpcResponse { id, result, .. }) => {
96 break Ok((result, id));
97 }
98 ServerJsonRpcMessage::Error(error) => {
100 break Err(ClientInitializeError::JsonRpcError(error.error));
101 }
102 ServerJsonRpcMessage::Notification(mut notification) => {
104 let ServerNotification::LoggingMessageNotification(logging) =
105 &mut notification.notification
106 else {
107 tracing::warn!(?notification, "Received unexpected message");
108 continue;
109 };
110
111 let mut context = NotificationContext {
112 peer: peer.clone(),
113 meta: Meta::default(),
114 extensions: Extensions::default(),
115 };
116
117 if let Some(meta) = logging.extensions.get_mut::<Meta>() {
118 std::mem::swap(&mut context.meta, meta);
119 }
120 std::mem::swap(&mut context.extensions, &mut logging.extensions);
121
122 if let Err(error) = service
123 .handle_notification(notification.notification, context)
124 .await
125 {
126 tracing::warn!(?error, "Handle logging before handshake failed.");
127 }
128 }
129 ServerJsonRpcMessage::Request(ref request)
131 if matches!(request.request, ServerRequest::PingRequest(_)) =>
132 {
133 tracing::trace!("Received ping request. Ignored.")
134 }
135 _ => tracing::warn!(?message, "Received unexpected message"),
137 }
138 }
139}
140
141#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
142pub struct RoleClient;
143
144impl ServiceRole for RoleClient {
145 type Req = ClientRequest;
146 type Resp = ClientResult;
147 type Not = ClientNotification;
148 type PeerReq = ServerRequest;
149 type PeerResp = ServerResult;
150 type PeerNot = ServerNotification;
151 type Info = ClientInfo;
152 type PeerInfo = ServerInfo;
153 type InitializeError = ClientInitializeError;
154 const IS_CLIENT: bool = true;
155}
156
157pub type ServerSink = Peer<RoleClient>;
158
159impl<S: Service<RoleClient>> ServiceExt<RoleClient> for S {
160 fn serve_with_ct<T, E, A>(
161 self,
162 transport: T,
163 ct: CancellationToken,
164 ) -> impl Future<Output = Result<RunningService<RoleClient, Self>, ClientInitializeError>> + Send
165 where
166 T: IntoTransport<RoleClient, E, A>,
167 E: std::error::Error + Send + Sync + 'static,
168 Self: Sized,
169 {
170 serve_client_with_ct(self, transport, ct)
171 }
172}
173
174pub async fn serve_client<S, T, E, A>(
175 service: S,
176 transport: T,
177) -> Result<RunningService<RoleClient, S>, ClientInitializeError>
178where
179 S: Service<RoleClient>,
180 T: IntoTransport<RoleClient, E, A>,
181 E: std::error::Error + Send + Sync + 'static,
182{
183 serve_client_with_ct(service, transport, Default::default()).await
184}
185
186pub async fn serve_client_with_ct<S, T, E, A>(
187 service: S,
188 transport: T,
189 ct: CancellationToken,
190) -> Result<RunningService<RoleClient, S>, ClientInitializeError>
191where
192 S: Service<RoleClient>,
193 T: IntoTransport<RoleClient, E, A>,
194 E: std::error::Error + Send + Sync + 'static,
195{
196 tokio::select! {
197 result = serve_client_with_ct_inner(service, transport.into_transport(), ct.clone()) => { result }
198 _ = ct.cancelled() => {
199 Err(ClientInitializeError::Cancelled)
200 }
201 }
202}
203
204async fn serve_client_with_ct_inner<S, T>(
205 service: S,
206 transport: T,
207 ct: CancellationToken,
208) -> Result<RunningService<RoleClient, S>, ClientInitializeError>
209where
210 S: Service<RoleClient>,
211 T: Transport<RoleClient> + 'static,
212{
213 let mut transport = transport.into_transport();
214 let id_provider = <Arc<AtomicU32RequestIdProvider>>::default();
215
216 let id = id_provider.next_request_id();
218 let init_request = InitializeRequest {
219 method: Default::default(),
220 params: service.get_info(),
221 extensions: Default::default(),
222 };
223 transport
224 .send(ClientJsonRpcMessage::request(
225 ClientRequest::InitializeRequest(init_request),
226 id.clone(),
227 ))
228 .await
229 .map_err(|error| ClientInitializeError::TransportError {
230 error: DynamicTransportError::new::<T, _>(error),
231 context: "send initialize request".into(),
232 })?;
233
234 let (peer, peer_rx) = Peer::new(id_provider, None);
235
236 let (response, response_id) = expect_response(
237 &mut transport,
238 "initialize response",
239 &service,
240 peer.clone(),
241 )
242 .await?;
243
244 if id != response_id {
245 return Err(ClientInitializeError::ConflictInitResponseId(
246 id,
247 response_id,
248 ));
249 }
250
251 let ServerResult::InitializeResult(initialize_result) = response else {
252 return Err(ClientInitializeError::ExpectedInitResult(Some(response)));
253 };
254 peer.set_peer_info(initialize_result);
255
256 let notification = ClientJsonRpcMessage::notification(
258 ClientNotification::InitializedNotification(InitializedNotification {
259 method: Default::default(),
260 extensions: Default::default(),
261 }),
262 );
263 transport.send(notification).await.map_err(|error| {
264 ClientInitializeError::transport::<T>(error, "send initialized notification")
265 })?;
266 Ok(serve_inner(service, transport, peer, peer_rx, ct))
267}
268
269macro_rules! method {
270 (peer_req $method:ident $Req:ident() => $Resp: ident ) => {
271 pub async fn $method(&self) -> Result<$Resp, ServiceError> {
272 let result = self
273 .send_request(ClientRequest::$Req($Req {
274 method: Default::default(),
275 }))
276 .await?;
277 match result {
278 ServerResult::$Resp(result) => Ok(result),
279 _ => Err(ServiceError::UnexpectedResponse),
280 }
281 }
282 };
283 (peer_req $method:ident $Req:ident($Param: ident) => $Resp: ident ) => {
284 pub async fn $method(&self, params: $Param) -> Result<$Resp, ServiceError> {
285 let result = self
286 .send_request(ClientRequest::$Req($Req {
287 method: Default::default(),
288 params,
289 extensions: Default::default(),
290 }))
291 .await?;
292 match result {
293 ServerResult::$Resp(result) => Ok(result),
294 _ => Err(ServiceError::UnexpectedResponse),
295 }
296 }
297 };
298 (peer_req $method:ident $Req:ident($Param: ident)? => $Resp: ident ) => {
299 pub async fn $method(&self, params: Option<$Param>) -> Result<$Resp, ServiceError> {
300 let result = self
301 .send_request(ClientRequest::$Req($Req {
302 method: Default::default(),
303 params,
304 extensions: Default::default(),
305 }))
306 .await?;
307 match result {
308 ServerResult::$Resp(result) => Ok(result),
309 _ => Err(ServiceError::UnexpectedResponse),
310 }
311 }
312 };
313 (peer_req $method:ident $Req:ident($Param: ident)) => {
314 pub async fn $method(&self, params: $Param) -> Result<(), ServiceError> {
315 let result = self
316 .send_request(ClientRequest::$Req($Req {
317 method: Default::default(),
318 params,
319 extensions: Default::default(),
320 }))
321 .await?;
322 match result {
323 ServerResult::EmptyResult(_) => Ok(()),
324 _ => Err(ServiceError::UnexpectedResponse),
325 }
326 }
327 };
328
329 (peer_not $method:ident $Not:ident($Param: ident)) => {
330 pub async fn $method(&self, params: $Param) -> Result<(), ServiceError> {
331 self.send_notification(ClientNotification::$Not($Not {
332 method: Default::default(),
333 params,
334 extensions: Default::default(),
335 }))
336 .await?;
337 Ok(())
338 }
339 };
340 (peer_not $method:ident $Not:ident) => {
341 pub async fn $method(&self) -> Result<(), ServiceError> {
342 self.send_notification(ClientNotification::$Not($Not {
343 method: Default::default(),
344 extensions: Default::default(),
345 }))
346 .await?;
347 Ok(())
348 }
349 };
350}
351
352impl Peer<RoleClient> {
353 method!(peer_req complete CompleteRequest(CompleteRequestParams) => CompleteResult);
354 method!(peer_req set_level SetLevelRequest(SetLevelRequestParams));
355 method!(peer_req get_prompt GetPromptRequest(GetPromptRequestParams) => GetPromptResult);
356 method!(peer_req list_prompts ListPromptsRequest(PaginatedRequestParams)? => ListPromptsResult);
357 method!(peer_req list_resources ListResourcesRequest(PaginatedRequestParams)? => ListResourcesResult);
358 method!(peer_req list_resource_templates ListResourceTemplatesRequest(PaginatedRequestParams)? => ListResourceTemplatesResult);
359 method!(peer_req read_resource ReadResourceRequest(ReadResourceRequestParams) => ReadResourceResult);
360 method!(peer_req subscribe SubscribeRequest(SubscribeRequestParams) );
361 method!(peer_req unsubscribe UnsubscribeRequest(UnsubscribeRequestParams));
362 method!(peer_req call_tool CallToolRequest(CallToolRequestParams) => CallToolResult);
363 method!(peer_req list_tools ListToolsRequest(PaginatedRequestParams)? => ListToolsResult);
364
365 method!(peer_not notify_cancelled CancelledNotification(CancelledNotificationParam));
366 method!(peer_not notify_progress ProgressNotification(ProgressNotificationParam));
367 method!(peer_not notify_initialized InitializedNotification);
368 method!(peer_not notify_roots_list_changed RootsListChangedNotification);
369}
370
371impl Peer<RoleClient> {
372 pub async fn list_all_tools(&self) -> Result<Vec<crate::model::Tool>, ServiceError> {
376 let mut tools = Vec::new();
377 let mut cursor = None;
378 loop {
379 let result = self
380 .list_tools(Some(PaginatedRequestParams { meta: None, cursor }))
381 .await?;
382 tools.extend(result.tools);
383 cursor = result.next_cursor;
384 if cursor.is_none() {
385 break;
386 }
387 }
388 Ok(tools)
389 }
390
391 pub async fn list_all_prompts(&self) -> Result<Vec<crate::model::Prompt>, ServiceError> {
395 let mut prompts = Vec::new();
396 let mut cursor = None;
397 loop {
398 let result = self
399 .list_prompts(Some(PaginatedRequestParams { meta: None, cursor }))
400 .await?;
401 prompts.extend(result.prompts);
402 cursor = result.next_cursor;
403 if cursor.is_none() {
404 break;
405 }
406 }
407 Ok(prompts)
408 }
409
410 pub async fn list_all_resources(&self) -> Result<Vec<crate::model::Resource>, ServiceError> {
414 let mut resources = Vec::new();
415 let mut cursor = None;
416 loop {
417 let result = self
418 .list_resources(Some(PaginatedRequestParams { meta: None, cursor }))
419 .await?;
420 resources.extend(result.resources);
421 cursor = result.next_cursor;
422 if cursor.is_none() {
423 break;
424 }
425 }
426 Ok(resources)
427 }
428
429 pub async fn list_all_resource_templates(
433 &self,
434 ) -> Result<Vec<crate::model::ResourceTemplate>, ServiceError> {
435 let mut resource_templates = Vec::new();
436 let mut cursor = None;
437 loop {
438 let result = self
439 .list_resource_templates(Some(PaginatedRequestParams { meta: None, cursor }))
440 .await?;
441 resource_templates.extend(result.resource_templates);
442 cursor = result.next_cursor;
443 if cursor.is_none() {
444 break;
445 }
446 }
447 Ok(resource_templates)
448 }
449
450 pub async fn complete_prompt_argument(
461 &self,
462 prompt_name: impl Into<String>,
463 argument_name: impl Into<String>,
464 current_value: impl Into<String>,
465 context: Option<CompletionContext>,
466 ) -> Result<CompletionInfo, ServiceError> {
467 let request = CompleteRequestParams {
468 meta: None,
469 r#ref: Reference::for_prompt(prompt_name),
470 argument: ArgumentInfo {
471 name: argument_name.into(),
472 value: current_value.into(),
473 },
474 context,
475 };
476
477 let result = self.complete(request).await?;
478 Ok(result.completion)
479 }
480
481 pub async fn complete_resource_argument(
492 &self,
493 uri_template: impl Into<String>,
494 argument_name: impl Into<String>,
495 current_value: impl Into<String>,
496 context: Option<CompletionContext>,
497 ) -> Result<CompletionInfo, ServiceError> {
498 let request = CompleteRequestParams {
499 meta: None,
500 r#ref: Reference::for_resource(uri_template),
501 argument: ArgumentInfo {
502 name: argument_name.into(),
503 value: current_value.into(),
504 },
505 context,
506 };
507
508 let result = self.complete(request).await?;
509 Ok(result.completion)
510 }
511
512 pub async fn complete_prompt_simple(
517 &self,
518 prompt_name: impl Into<String>,
519 argument_name: impl Into<String>,
520 current_value: impl Into<String>,
521 ) -> Result<Vec<String>, ServiceError> {
522 let completion = self
523 .complete_prompt_argument(prompt_name, argument_name, current_value, None)
524 .await?;
525 Ok(completion.values)
526 }
527
528 pub async fn complete_resource_simple(
533 &self,
534 uri_template: impl Into<String>,
535 argument_name: impl Into<String>,
536 current_value: impl Into<String>,
537 ) -> Result<Vec<String>, ServiceError> {
538 let completion = self
539 .complete_resource_argument(uri_template, argument_name, current_value, None)
540 .await?;
541 Ok(completion.values)
542 }
543}