1use std::borrow::Cow;
2
3use thiserror::Error;
4
5use super::*;
6use crate::{
7 model::{
8 ArgumentInfo, CallToolRequest, CallToolRequestParams, CallToolResult,
9 CancelledNotification, CancelledNotificationParam, ClientInfo, ClientJsonRpcMessage,
10 ClientNotification, ClientRequest, ClientResult, CompleteRequest, CompleteRequestParams,
11 CompleteResult, CompletionContext, CompletionInfo, ErrorData, GetPromptRequest,
12 GetPromptRequestParams, GetPromptResult, InitializeRequest, InitializedNotification,
13 JsonRpcResponse, ListPromptsRequest, ListPromptsResult, ListResourceTemplatesRequest,
14 ListResourceTemplatesResult, ListResourcesRequest, ListResourcesResult, ListToolsRequest,
15 ListToolsResult, PaginatedRequestParams, ProgressNotification, ProgressNotificationParam,
16 ReadResourceRequest, ReadResourceRequestParams, ReadResourceResult, Reference, RequestId,
17 RootsListChangedNotification, ServerInfo, ServerJsonRpcMessage, ServerNotification,
18 ServerRequest, ServerResult, SetLevelRequest, SetLevelRequestParams, SubscribeRequest,
19 SubscribeRequestParams, UnsubscribeRequest, UnsubscribeRequestParams,
20 },
21 transport::DynamicTransportError,
22};
23
24#[derive(Error, Debug)]
28#[non_exhaustive]
29pub enum ClientInitializeError {
30 #[error("expect initialized response, but received: {0:?}")]
31 ExpectedInitResponse(Option<ServerJsonRpcMessage>),
32
33 #[error("expect initialized result, but received: {0:?}")]
34 ExpectedInitResult(Option<ServerResult>),
35
36 #[error("conflict initialized response id: expected {0}, got {1}")]
37 ConflictInitResponseId(RequestId, RequestId),
38
39 #[error("connection closed: {0}")]
40 ConnectionClosed(String),
41
42 #[error("Send message error {error}, when {context}")]
43 TransportError {
44 error: DynamicTransportError,
45 context: Cow<'static, str>,
46 },
47
48 #[error("JSON-RPC error: {0}")]
49 JsonRpcError(ErrorData),
50
51 #[error("Cancelled")]
52 Cancelled,
53}
54
55impl ClientInitializeError {
56 pub fn transport<T: Transport<RoleClient> + 'static>(
57 error: T::Error,
58 context: impl Into<Cow<'static, str>>,
59 ) -> Self {
60 Self::TransportError {
61 error: DynamicTransportError::new::<T, _>(error),
62 context: context.into(),
63 }
64 }
65}
66
67async fn expect_next_message<T>(
69 transport: &mut T,
70 context: &str,
71) -> Result<ServerJsonRpcMessage, ClientInitializeError>
72where
73 T: Transport<RoleClient>,
74{
75 transport
76 .receive()
77 .await
78 .ok_or_else(|| ClientInitializeError::ConnectionClosed(context.to_string()))
79}
80
81async fn expect_response<T, S>(
83 transport: &mut T,
84 context: &str,
85 service: &S,
86 peer: Peer<RoleClient>,
87) -> Result<(ServerResult, RequestId), ClientInitializeError>
88where
89 T: Transport<RoleClient>,
90 S: Service<RoleClient>,
91{
92 loop {
93 let message = expect_next_message(transport, context).await?;
94 match message {
95 ServerJsonRpcMessage::Response(JsonRpcResponse { id, result, .. }) => {
97 break Ok((result, id));
98 }
99 ServerJsonRpcMessage::Error(error) => {
101 break Err(ClientInitializeError::JsonRpcError(error.error));
102 }
103 ServerJsonRpcMessage::Notification(mut notification) => {
105 let ServerNotification::LoggingMessageNotification(logging) =
106 &mut notification.notification
107 else {
108 tracing::warn!(?notification, "Received unexpected message");
109 continue;
110 };
111
112 let mut context = NotificationContext {
113 peer: peer.clone(),
114 meta: Meta::default(),
115 extensions: Extensions::default(),
116 };
117
118 if let Some(meta) = logging.extensions.get_mut::<Meta>() {
119 std::mem::swap(&mut context.meta, meta);
120 }
121 std::mem::swap(&mut context.extensions, &mut logging.extensions);
122
123 if let Err(error) = service
124 .handle_notification(notification.notification, context)
125 .await
126 {
127 tracing::warn!(?error, "Handle logging before handshake failed.");
128 }
129 }
130 ServerJsonRpcMessage::Request(ref request)
132 if matches!(request.request, ServerRequest::PingRequest(_)) =>
133 {
134 tracing::trace!("Received ping request. Ignored.")
135 }
136 _ => tracing::warn!(?message, "Received unexpected message"),
138 }
139 }
140}
141
142#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
143pub struct RoleClient;
144
145impl ServiceRole for RoleClient {
146 type Req = ClientRequest;
147 type Resp = ClientResult;
148 type Not = ClientNotification;
149 type PeerReq = ServerRequest;
150 type PeerResp = ServerResult;
151 type PeerNot = ServerNotification;
152 type Info = ClientInfo;
153 type PeerInfo = ServerInfo;
154 type InitializeError = ClientInitializeError;
155 const IS_CLIENT: bool = true;
156}
157
158pub type ServerSink = Peer<RoleClient>;
159
160impl<S: Service<RoleClient>> ServiceExt<RoleClient> for S {
161 fn serve_with_ct<T, E, A>(
162 self,
163 transport: T,
164 ct: CancellationToken,
165 ) -> impl Future<Output = Result<RunningService<RoleClient, Self>, ClientInitializeError>> + Send
166 where
167 T: IntoTransport<RoleClient, E, A>,
168 E: std::error::Error + Send + Sync + 'static,
169 Self: Sized,
170 {
171 serve_client_with_ct(self, transport, ct)
172 }
173}
174
175pub async fn serve_client<S, T, E, A>(
176 service: S,
177 transport: T,
178) -> Result<RunningService<RoleClient, S>, ClientInitializeError>
179where
180 S: Service<RoleClient>,
181 T: IntoTransport<RoleClient, E, A>,
182 E: std::error::Error + Send + Sync + 'static,
183{
184 serve_client_with_ct(service, transport, Default::default()).await
185}
186
187pub async fn serve_client_with_ct<S, T, E, A>(
188 service: S,
189 transport: T,
190 ct: CancellationToken,
191) -> Result<RunningService<RoleClient, S>, ClientInitializeError>
192where
193 S: Service<RoleClient>,
194 T: IntoTransport<RoleClient, E, A>,
195 E: std::error::Error + Send + Sync + 'static,
196{
197 tokio::select! {
198 result = serve_client_with_ct_inner(service, transport.into_transport(), ct.clone()) => { result }
199 _ = ct.cancelled() => {
200 Err(ClientInitializeError::Cancelled)
201 }
202 }
203}
204
205async fn serve_client_with_ct_inner<S, T>(
206 service: S,
207 transport: T,
208 ct: CancellationToken,
209) -> Result<RunningService<RoleClient, S>, ClientInitializeError>
210where
211 S: Service<RoleClient>,
212 T: Transport<RoleClient> + 'static,
213{
214 let mut transport = transport.into_transport();
215 let id_provider = <Arc<AtomicU32RequestIdProvider>>::default();
216
217 let id = id_provider.next_request_id();
219 let init_request = InitializeRequest {
220 method: Default::default(),
221 params: service.get_info(),
222 extensions: Default::default(),
223 };
224 transport
225 .send(ClientJsonRpcMessage::request(
226 ClientRequest::InitializeRequest(init_request),
227 id.clone(),
228 ))
229 .await
230 .map_err(|error| ClientInitializeError::TransportError {
231 error: DynamicTransportError::new::<T, _>(error),
232 context: "send initialize request".into(),
233 })?;
234
235 let (peer, peer_rx) = Peer::new(id_provider, None);
236
237 let (response, response_id) = expect_response(
238 &mut transport,
239 "initialize response",
240 &service,
241 peer.clone(),
242 )
243 .await?;
244
245 if id != response_id {
246 return Err(ClientInitializeError::ConflictInitResponseId(
247 id,
248 response_id,
249 ));
250 }
251
252 let ServerResult::InitializeResult(initialize_result) = response else {
253 return Err(ClientInitializeError::ExpectedInitResult(Some(response)));
254 };
255 peer.set_peer_info(initialize_result);
256
257 let notification = ClientJsonRpcMessage::notification(
259 ClientNotification::InitializedNotification(InitializedNotification {
260 method: Default::default(),
261 extensions: Default::default(),
262 }),
263 );
264 transport.send(notification).await.map_err(|error| {
265 ClientInitializeError::transport::<T>(error, "send initialized notification")
266 })?;
267 Ok(serve_inner(service, transport, peer, peer_rx, ct))
268}
269
270macro_rules! method {
271 (peer_req $method:ident $Req:ident() => $Resp: ident ) => {
272 pub async fn $method(&self) -> Result<$Resp, ServiceError> {
273 let result = self
274 .send_request(ClientRequest::$Req($Req {
275 method: Default::default(),
276 }))
277 .await?;
278 match result {
279 ServerResult::$Resp(result) => Ok(result),
280 _ => Err(ServiceError::UnexpectedResponse),
281 }
282 }
283 };
284 (peer_req $method:ident $Req:ident($Param: ident) => $Resp: ident ) => {
285 pub async fn $method(&self, params: $Param) -> Result<$Resp, ServiceError> {
286 let result = self
287 .send_request(ClientRequest::$Req($Req {
288 method: Default::default(),
289 params,
290 extensions: Default::default(),
291 }))
292 .await?;
293 match result {
294 ServerResult::$Resp(result) => Ok(result),
295 _ => Err(ServiceError::UnexpectedResponse),
296 }
297 }
298 };
299 (peer_req $method:ident $Req:ident($Param: ident)? => $Resp: ident ) => {
300 pub async fn $method(&self, params: Option<$Param>) -> Result<$Resp, ServiceError> {
301 let result = self
302 .send_request(ClientRequest::$Req($Req {
303 method: Default::default(),
304 params,
305 extensions: Default::default(),
306 }))
307 .await?;
308 match result {
309 ServerResult::$Resp(result) => Ok(result),
310 _ => Err(ServiceError::UnexpectedResponse),
311 }
312 }
313 };
314 (peer_req $method:ident $Req:ident($Param: ident)) => {
315 pub async fn $method(&self, params: $Param) -> Result<(), ServiceError> {
316 let result = self
317 .send_request(ClientRequest::$Req($Req {
318 method: Default::default(),
319 params,
320 extensions: Default::default(),
321 }))
322 .await?;
323 match result {
324 ServerResult::EmptyResult(_) => Ok(()),
325 _ => Err(ServiceError::UnexpectedResponse),
326 }
327 }
328 };
329
330 (peer_not $method:ident $Not:ident($Param: ident)) => {
331 pub async fn $method(&self, params: $Param) -> Result<(), ServiceError> {
332 self.send_notification(ClientNotification::$Not($Not {
333 method: Default::default(),
334 params,
335 extensions: Default::default(),
336 }))
337 .await?;
338 Ok(())
339 }
340 };
341 (peer_not $method:ident $Not:ident) => {
342 pub async fn $method(&self) -> Result<(), ServiceError> {
343 self.send_notification(ClientNotification::$Not($Not {
344 method: Default::default(),
345 extensions: Default::default(),
346 }))
347 .await?;
348 Ok(())
349 }
350 };
351}
352
353impl Peer<RoleClient> {
354 method!(peer_req complete CompleteRequest(CompleteRequestParams) => CompleteResult);
355 method!(peer_req set_level SetLevelRequest(SetLevelRequestParams));
356 method!(peer_req get_prompt GetPromptRequest(GetPromptRequestParams) => GetPromptResult);
357 method!(peer_req list_prompts ListPromptsRequest(PaginatedRequestParams)? => ListPromptsResult);
358 method!(peer_req list_resources ListResourcesRequest(PaginatedRequestParams)? => ListResourcesResult);
359 method!(peer_req list_resource_templates ListResourceTemplatesRequest(PaginatedRequestParams)? => ListResourceTemplatesResult);
360 method!(peer_req read_resource ReadResourceRequest(ReadResourceRequestParams) => ReadResourceResult);
361 method!(peer_req subscribe SubscribeRequest(SubscribeRequestParams) );
362 method!(peer_req unsubscribe UnsubscribeRequest(UnsubscribeRequestParams));
363 method!(peer_req call_tool CallToolRequest(CallToolRequestParams) => CallToolResult);
364 method!(peer_req list_tools ListToolsRequest(PaginatedRequestParams)? => ListToolsResult);
365
366 method!(peer_not notify_cancelled CancelledNotification(CancelledNotificationParam));
367 method!(peer_not notify_progress ProgressNotification(ProgressNotificationParam));
368 method!(peer_not notify_initialized InitializedNotification);
369 method!(peer_not notify_roots_list_changed RootsListChangedNotification);
370}
371
372impl Peer<RoleClient> {
373 pub async fn list_all_tools(&self) -> Result<Vec<crate::model::Tool>, ServiceError> {
377 let mut tools = Vec::new();
378 let mut cursor = None;
379 loop {
380 let result = self
381 .list_tools(Some(PaginatedRequestParams { meta: None, cursor }))
382 .await?;
383 tools.extend(result.tools);
384 cursor = result.next_cursor;
385 if cursor.is_none() {
386 break;
387 }
388 }
389 Ok(tools)
390 }
391
392 pub async fn list_all_prompts(&self) -> Result<Vec<crate::model::Prompt>, ServiceError> {
396 let mut prompts = Vec::new();
397 let mut cursor = None;
398 loop {
399 let result = self
400 .list_prompts(Some(PaginatedRequestParams { meta: None, cursor }))
401 .await?;
402 prompts.extend(result.prompts);
403 cursor = result.next_cursor;
404 if cursor.is_none() {
405 break;
406 }
407 }
408 Ok(prompts)
409 }
410
411 pub async fn list_all_resources(&self) -> Result<Vec<crate::model::Resource>, ServiceError> {
415 let mut resources = Vec::new();
416 let mut cursor = None;
417 loop {
418 let result = self
419 .list_resources(Some(PaginatedRequestParams { meta: None, cursor }))
420 .await?;
421 resources.extend(result.resources);
422 cursor = result.next_cursor;
423 if cursor.is_none() {
424 break;
425 }
426 }
427 Ok(resources)
428 }
429
430 pub async fn list_all_resource_templates(
434 &self,
435 ) -> Result<Vec<crate::model::ResourceTemplate>, ServiceError> {
436 let mut resource_templates = Vec::new();
437 let mut cursor = None;
438 loop {
439 let result = self
440 .list_resource_templates(Some(PaginatedRequestParams { meta: None, cursor }))
441 .await?;
442 resource_templates.extend(result.resource_templates);
443 cursor = result.next_cursor;
444 if cursor.is_none() {
445 break;
446 }
447 }
448 Ok(resource_templates)
449 }
450
451 pub async fn complete_prompt_argument(
462 &self,
463 prompt_name: impl Into<String>,
464 argument_name: impl Into<String>,
465 current_value: impl Into<String>,
466 context: Option<CompletionContext>,
467 ) -> Result<CompletionInfo, ServiceError> {
468 let request = CompleteRequestParams {
469 meta: None,
470 r#ref: Reference::for_prompt(prompt_name),
471 argument: ArgumentInfo {
472 name: argument_name.into(),
473 value: current_value.into(),
474 },
475 context,
476 };
477
478 let result = self.complete(request).await?;
479 Ok(result.completion)
480 }
481
482 pub async fn complete_resource_argument(
493 &self,
494 uri_template: impl Into<String>,
495 argument_name: impl Into<String>,
496 current_value: impl Into<String>,
497 context: Option<CompletionContext>,
498 ) -> Result<CompletionInfo, ServiceError> {
499 let request = CompleteRequestParams {
500 meta: None,
501 r#ref: Reference::for_resource(uri_template),
502 argument: ArgumentInfo {
503 name: argument_name.into(),
504 value: current_value.into(),
505 },
506 context,
507 };
508
509 let result = self.complete(request).await?;
510 Ok(result.completion)
511 }
512
513 pub async fn complete_prompt_simple(
518 &self,
519 prompt_name: impl Into<String>,
520 argument_name: impl Into<String>,
521 current_value: impl Into<String>,
522 ) -> Result<Vec<String>, ServiceError> {
523 let completion = self
524 .complete_prompt_argument(prompt_name, argument_name, current_value, None)
525 .await?;
526 Ok(completion.values)
527 }
528
529 pub async fn complete_resource_simple(
534 &self,
535 uri_template: impl Into<String>,
536 argument_name: impl Into<String>,
537 current_value: impl Into<String>,
538 ) -> Result<Vec<String>, ServiceError> {
539 let completion = self
540 .complete_resource_argument(uri_template, argument_name, current_value, None)
541 .await?;
542 Ok(completion.values)
543 }
544}