1mod channel;
34mod handler;
35#[cfg(feature = "http-client")]
36mod http;
37#[cfg(feature = "oauth-client")]
38mod oauth;
39mod stdio;
40mod transport;
41
42pub use channel::ChannelTransport;
43pub use handler::{ClientHandler, NotificationHandler, ServerNotification};
44#[cfg(feature = "http-client")]
45pub use http::{HttpClientConfig, HttpClientTransport};
46#[cfg(feature = "oauth-client")]
47pub use oauth::{
48 OAuthClientCredentials, OAuthClientCredentialsBuilder, OAuthClientError, TokenProvider,
49};
50pub use stdio::StdioClientTransport;
51pub use transport::ClientTransport;
52
53use std::collections::HashMap;
54use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
55use std::sync::{Arc, OnceLock};
56
57use tokio::sync::{RwLock, mpsc, oneshot};
58use tokio::task::JoinHandle;
59
60use crate::error::{Error, Result};
61use crate::protocol::{
62 CallToolParams, CallToolResult, ClientCapabilities, CompleteParams, CompleteResult,
63 CompletionArgument, CompletionReference, ElicitationCapability, GetPromptParams,
64 GetPromptResult, Implementation, InitializeParams, InitializeResult, JsonRpcNotification,
65 JsonRpcRequest, ListPromptsParams, ListPromptsResult, ListResourceTemplatesParams,
66 ListResourceTemplatesResult, ListResourcesParams, ListResourcesResult, ListRootsResult,
67 ListToolsParams, ListToolsResult, PromptDefinition, ReadResourceParams, ReadResourceResult,
68 RequestId, ResourceDefinition, ResourceTemplateDefinition, Root, RootsCapability,
69 SamplingCapability, ToolDefinition, notifications,
70};
71use tower_mcp_types::JsonRpcError;
72
73enum LoopCommand {
75 Request {
77 method: String,
78 params: serde_json::Value,
79 response_tx: oneshot::Sender<Result<serde_json::Value>>,
80 },
81 Notify {
83 method: String,
84 params: serde_json::Value,
85 },
86 Shutdown,
88}
89
90pub struct McpClient {
120 command_tx: mpsc::Sender<LoopCommand>,
122 task: Option<JoinHandle<()>>,
124 initialized: AtomicBool,
126 server_info: OnceLock<InitializeResult>,
128 capabilities: ClientCapabilities,
130 roots: Arc<RwLock<Vec<Root>>>,
132 connected: Arc<AtomicBool>,
134}
135
136pub struct McpClientBuilder {
156 capabilities: ClientCapabilities,
157 roots: Vec<Root>,
158}
159
160impl McpClientBuilder {
161 pub fn new() -> Self {
163 Self {
164 capabilities: ClientCapabilities::default(),
165 roots: Vec::new(),
166 }
167 }
168
169 pub fn with_roots(mut self, roots: Vec<Root>) -> Self {
174 self.roots = roots;
175 self.capabilities.roots = Some(RootsCapability { list_changed: true });
176 self
177 }
178
179 pub fn with_capabilities(mut self, capabilities: ClientCapabilities) -> Self {
181 self.capabilities = capabilities;
182 self
183 }
184
185 pub fn with_sampling(mut self) -> Self {
192 self.capabilities.sampling = Some(SamplingCapability::default());
193 self
194 }
195
196 pub fn with_elicitation(mut self) -> Self {
203 self.capabilities.elicitation = Some(ElicitationCapability::default());
204 self
205 }
206
207 pub async fn connect<T, H>(self, transport: T, handler: H) -> Result<McpClient>
212 where
213 T: ClientTransport,
214 H: ClientHandler,
215 {
216 McpClient::connect_inner(transport, handler, self.capabilities, self.roots).await
217 }
218
219 pub async fn connect_simple<T: ClientTransport>(self, transport: T) -> Result<McpClient> {
223 self.connect(transport, ()).await
224 }
225}
226
227impl Default for McpClientBuilder {
228 fn default() -> Self {
229 Self::new()
230 }
231}
232
233impl McpClient {
234 pub async fn connect<T: ClientTransport>(transport: T) -> Result<Self> {
238 McpClientBuilder::new().connect_simple(transport).await
239 }
240
241 pub async fn connect_with_handler<T, H>(transport: T, handler: H) -> Result<Self>
243 where
244 T: ClientTransport,
245 H: ClientHandler,
246 {
247 McpClientBuilder::new().connect(transport, handler).await
248 }
249
250 pub fn builder() -> McpClientBuilder {
252 McpClientBuilder::new()
253 }
254
255 async fn connect_inner<T, H>(
257 transport: T,
258 handler: H,
259 capabilities: ClientCapabilities,
260 roots: Vec<Root>,
261 ) -> Result<Self>
262 where
263 T: ClientTransport,
264 H: ClientHandler,
265 {
266 let (command_tx, command_rx) = mpsc::channel::<LoopCommand>(64);
267 let connected = Arc::new(AtomicBool::new(true));
268 let roots = Arc::new(RwLock::new(roots));
269
270 let loop_connected = connected.clone();
271 let loop_roots = roots.clone();
272
273 let task = tokio::spawn(async move {
274 message_loop(transport, handler, command_rx, loop_connected, loop_roots).await;
275 });
276
277 Ok(Self {
278 command_tx,
279 task: Some(task),
280 initialized: AtomicBool::new(false),
281 server_info: OnceLock::new(),
282 capabilities,
283 roots,
284 connected,
285 })
286 }
287
288 pub fn is_initialized(&self) -> bool {
290 self.initialized.load(Ordering::Acquire)
291 }
292
293 pub fn is_connected(&self) -> bool {
295 self.connected.load(Ordering::Acquire)
296 }
297
298 pub fn server_info(&self) -> Option<&InitializeResult> {
300 self.server_info.get()
301 }
302
303 pub async fn initialize(
308 &self,
309 client_name: &str,
310 client_version: &str,
311 ) -> Result<&InitializeResult> {
312 let params = InitializeParams {
313 protocol_version: crate::protocol::LATEST_PROTOCOL_VERSION.to_string(),
314 capabilities: self.capabilities.clone(),
315 client_info: Implementation {
316 name: client_name.to_string(),
317 version: client_version.to_string(),
318 ..Default::default()
319 },
320 meta: None,
321 };
322
323 let result: InitializeResult = self.send_request("initialize", ¶ms).await?;
324 let _ = self.server_info.set(result);
325
326 self.send_notification("notifications/initialized", &serde_json::json!({}))
328 .await?;
329 self.initialized.store(true, Ordering::Release);
330
331 Ok(self.server_info.get().unwrap())
332 }
333
334 pub async fn list_tools(&self) -> Result<ListToolsResult> {
336 self.ensure_initialized()?;
337 self.send_request(
338 "tools/list",
339 &ListToolsParams {
340 cursor: None,
341 meta: None,
342 },
343 )
344 .await
345 }
346
347 pub async fn call_tool(
349 &self,
350 name: &str,
351 arguments: serde_json::Value,
352 ) -> Result<CallToolResult> {
353 self.ensure_initialized()?;
354 let params = CallToolParams {
355 name: name.to_string(),
356 arguments,
357 meta: None,
358 task: None,
359 };
360 self.send_request("tools/call", ¶ms).await
361 }
362
363 pub async fn list_resources(&self) -> Result<ListResourcesResult> {
365 self.ensure_initialized()?;
366 self.send_request(
367 "resources/list",
368 &ListResourcesParams {
369 cursor: None,
370 meta: None,
371 },
372 )
373 .await
374 }
375
376 pub async fn read_resource(&self, uri: &str) -> Result<ReadResourceResult> {
378 self.ensure_initialized()?;
379 let params = ReadResourceParams {
380 uri: uri.to_string(),
381 meta: None,
382 };
383 self.send_request("resources/read", ¶ms).await
384 }
385
386 pub async fn list_prompts(&self) -> Result<ListPromptsResult> {
388 self.ensure_initialized()?;
389 self.send_request(
390 "prompts/list",
391 &ListPromptsParams {
392 cursor: None,
393 meta: None,
394 },
395 )
396 .await
397 }
398
399 pub async fn list_tools_with_cursor(&self, cursor: Option<String>) -> Result<ListToolsResult> {
401 self.ensure_initialized()?;
402 self.send_request("tools/list", &ListToolsParams { cursor, meta: None })
403 .await
404 }
405
406 pub async fn list_resources_with_cursor(
408 &self,
409 cursor: Option<String>,
410 ) -> Result<ListResourcesResult> {
411 self.ensure_initialized()?;
412 self.send_request(
413 "resources/list",
414 &ListResourcesParams { cursor, meta: None },
415 )
416 .await
417 }
418
419 pub async fn list_resource_templates(&self) -> Result<ListResourceTemplatesResult> {
421 self.ensure_initialized()?;
422 self.send_request(
423 "resources/templates/list",
424 &ListResourceTemplatesParams {
425 cursor: None,
426 meta: None,
427 },
428 )
429 .await
430 }
431
432 pub async fn list_resource_templates_with_cursor(
434 &self,
435 cursor: Option<String>,
436 ) -> Result<ListResourceTemplatesResult> {
437 self.ensure_initialized()?;
438 self.send_request(
439 "resources/templates/list",
440 &ListResourceTemplatesParams { cursor, meta: None },
441 )
442 .await
443 }
444
445 pub async fn list_prompts_with_cursor(
447 &self,
448 cursor: Option<String>,
449 ) -> Result<ListPromptsResult> {
450 self.ensure_initialized()?;
451 self.send_request("prompts/list", &ListPromptsParams { cursor, meta: None })
452 .await
453 }
454
455 pub async fn list_all_tools(&self) -> Result<Vec<ToolDefinition>> {
457 let mut all = Vec::new();
458 let mut cursor = None;
459 loop {
460 let result = self.list_tools_with_cursor(cursor).await?;
461 all.extend(result.tools);
462 match result.next_cursor {
463 Some(c) => cursor = Some(c),
464 None => break,
465 }
466 }
467 Ok(all)
468 }
469
470 pub async fn list_all_resources(&self) -> Result<Vec<ResourceDefinition>> {
472 let mut all = Vec::new();
473 let mut cursor = None;
474 loop {
475 let result = self.list_resources_with_cursor(cursor).await?;
476 all.extend(result.resources);
477 match result.next_cursor {
478 Some(c) => cursor = Some(c),
479 None => break,
480 }
481 }
482 Ok(all)
483 }
484
485 pub async fn list_all_resource_templates(&self) -> Result<Vec<ResourceTemplateDefinition>> {
487 let mut all = Vec::new();
488 let mut cursor = None;
489 loop {
490 let result = self.list_resource_templates_with_cursor(cursor).await?;
491 all.extend(result.resource_templates);
492 match result.next_cursor {
493 Some(c) => cursor = Some(c),
494 None => break,
495 }
496 }
497 Ok(all)
498 }
499
500 pub async fn list_all_prompts(&self) -> Result<Vec<PromptDefinition>> {
502 let mut all = Vec::new();
503 let mut cursor = None;
504 loop {
505 let result = self.list_prompts_with_cursor(cursor).await?;
506 all.extend(result.prompts);
507 match result.next_cursor {
508 Some(c) => cursor = Some(c),
509 None => break,
510 }
511 }
512 Ok(all)
513 }
514
515 pub async fn call_tool_text(&self, name: &str, arguments: serde_json::Value) -> Result<String> {
523 let result = self.call_tool(name, arguments).await?;
524 if result.is_error {
525 return Err(Error::Internal(result.all_text()));
526 }
527 Ok(result.all_text())
528 }
529
530 pub async fn get_prompt(
532 &self,
533 name: &str,
534 arguments: Option<std::collections::HashMap<String, String>>,
535 ) -> Result<GetPromptResult> {
536 self.ensure_initialized()?;
537 let params = GetPromptParams {
538 name: name.to_string(),
539 arguments: arguments.unwrap_or_default(),
540 meta: None,
541 };
542 self.send_request("prompts/get", ¶ms).await
543 }
544
545 pub async fn ping(&self) -> Result<()> {
547 let _: serde_json::Value = self.send_request("ping", &serde_json::json!({})).await?;
548 Ok(())
549 }
550
551 pub async fn complete(
553 &self,
554 reference: CompletionReference,
555 argument_name: &str,
556 argument_value: &str,
557 ) -> Result<CompleteResult> {
558 self.ensure_initialized()?;
559 let params = CompleteParams {
560 reference,
561 argument: CompletionArgument::new(argument_name, argument_value),
562 context: None,
563 meta: None,
564 };
565 self.send_request("completion/complete", ¶ms).await
566 }
567
568 pub async fn complete_prompt_arg(
570 &self,
571 prompt_name: &str,
572 argument_name: &str,
573 argument_value: &str,
574 ) -> Result<CompleteResult> {
575 self.complete(
576 CompletionReference::prompt(prompt_name),
577 argument_name,
578 argument_value,
579 )
580 .await
581 }
582
583 pub async fn complete_resource_uri(
585 &self,
586 resource_uri: &str,
587 argument_name: &str,
588 argument_value: &str,
589 ) -> Result<CompleteResult> {
590 self.complete(
591 CompletionReference::resource(resource_uri),
592 argument_name,
593 argument_value,
594 )
595 .await
596 }
597
598 pub async fn request<P: serde::Serialize, R: serde::de::DeserializeOwned>(
600 &self,
601 method: &str,
602 params: &P,
603 ) -> Result<R> {
604 self.send_request(method, params).await
605 }
606
607 pub async fn notify<P: serde::Serialize>(&self, method: &str, params: &P) -> Result<()> {
609 self.send_notification(method, params).await
610 }
611
612 pub async fn roots(&self) -> Vec<Root> {
614 self.roots.read().await.clone()
615 }
616
617 pub async fn set_roots(&self, roots: Vec<Root>) -> Result<()> {
619 *self.roots.write().await = roots;
620 if self.is_initialized() {
621 self.send_notification(notifications::ROOTS_LIST_CHANGED, &serde_json::json!({}))
622 .await?;
623 }
624 Ok(())
625 }
626
627 pub async fn add_root(&self, root: Root) -> Result<()> {
629 self.roots.write().await.push(root);
630 if self.is_initialized() {
631 self.send_notification(notifications::ROOTS_LIST_CHANGED, &serde_json::json!({}))
632 .await?;
633 }
634 Ok(())
635 }
636
637 pub async fn remove_root(&self, uri: &str) -> Result<bool> {
639 let mut roots = self.roots.write().await;
640 let initial_len = roots.len();
641 roots.retain(|r| r.uri != uri);
642 let removed = roots.len() < initial_len;
643 drop(roots);
644
645 if removed && self.is_initialized() {
646 self.send_notification(notifications::ROOTS_LIST_CHANGED, &serde_json::json!({}))
647 .await?;
648 }
649 Ok(removed)
650 }
651
652 pub async fn list_roots(&self) -> ListRootsResult {
654 ListRootsResult {
655 roots: self.roots.read().await.clone(),
656 meta: None,
657 }
658 }
659
660 pub async fn shutdown(mut self) -> Result<()> {
662 let _ = self.command_tx.send(LoopCommand::Shutdown).await;
663 if let Some(task) = self.task.take() {
664 let _ = task.await;
665 }
666 Ok(())
667 }
668
669 async fn send_request<P: serde::Serialize, R: serde::de::DeserializeOwned>(
672 &self,
673 method: &str,
674 params: &P,
675 ) -> Result<R> {
676 self.ensure_connected()?;
677 let params_value = serde_json::to_value(params)
678 .map_err(|e| Error::Transport(format!("Failed to serialize params: {}", e)))?;
679
680 let (response_tx, response_rx) = oneshot::channel();
681 self.command_tx
682 .send(LoopCommand::Request {
683 method: method.to_string(),
684 params: params_value,
685 response_tx,
686 })
687 .await
688 .map_err(|_| Error::Transport("Connection closed".to_string()))?;
689
690 let result = response_rx
691 .await
692 .map_err(|_| Error::Transport("Connection closed".to_string()))??;
693
694 serde_json::from_value(result)
695 .map_err(|e| Error::Transport(format!("Failed to deserialize response: {}", e)))
696 }
697
698 async fn send_notification<P: serde::Serialize>(&self, method: &str, params: &P) -> Result<()> {
699 self.ensure_connected()?;
700 let params_value = serde_json::to_value(params)
701 .map_err(|e| Error::Transport(format!("Failed to serialize params: {}", e)))?;
702
703 self.command_tx
704 .send(LoopCommand::Notify {
705 method: method.to_string(),
706 params: params_value,
707 })
708 .await
709 .map_err(|_| Error::Transport("Connection closed".to_string()))?;
710
711 Ok(())
712 }
713
714 fn ensure_connected(&self) -> Result<()> {
715 if !self.connected.load(Ordering::Acquire) {
716 return Err(Error::Transport("Connection closed".to_string()));
717 }
718 Ok(())
719 }
720
721 fn ensure_initialized(&self) -> Result<()> {
722 if !self.initialized.load(Ordering::Acquire) {
723 return Err(Error::Transport("Client not initialized".to_string()));
724 }
725 Ok(())
726 }
727}
728
729impl Drop for McpClient {
730 fn drop(&mut self) {
731 if let Some(task) = self.task.take() {
732 task.abort();
733 }
734 }
735}
736
737struct PendingRequest {
743 response_tx: oneshot::Sender<Result<serde_json::Value>>,
744}
745
746async fn message_loop<T: ClientTransport, H: ClientHandler>(
748 mut transport: T,
749 handler: H,
750 mut command_rx: mpsc::Receiver<LoopCommand>,
751 connected: Arc<AtomicBool>,
752 roots: Arc<RwLock<Vec<Root>>>,
753) {
754 let handler = Arc::new(handler);
755 let mut pending_requests: HashMap<RequestId, PendingRequest> = HashMap::new();
756 let next_id = AtomicI64::new(1);
757
758 loop {
759 tokio::select! {
760 command = command_rx.recv() => {
762 match command {
763 Some(LoopCommand::Request { method, params, response_tx }) => {
764 let id = RequestId::Number(next_id.fetch_add(1, Ordering::Relaxed));
765
766 let request = JsonRpcRequest::new(id.clone(), &method)
767 .with_params(params);
768 let json = match serde_json::to_string(&request) {
769 Ok(j) => j,
770 Err(e) => {
771 let _ = response_tx.send(Err(Error::Transport(
772 format!("Serialization failed: {}", e)
773 )));
774 continue;
775 }
776 };
777
778 tracing::debug!(method = %method, id = ?id, "Sending request");
779 pending_requests.insert(id, PendingRequest { response_tx });
780
781 if let Err(e) = transport.send(&json).await {
782 tracing::error!(error = %e, "Transport send error");
783 fail_all_pending(&mut pending_requests, &format!("Transport error: {}", e));
784 break;
785 }
786 }
787 Some(LoopCommand::Notify { method, params }) => {
788 let notification = JsonRpcNotification::new(&method)
789 .with_params(params);
790 if let Ok(json) = serde_json::to_string(¬ification) {
791 tracing::debug!(method = %method, "Sending notification");
792 let _ = transport.send(&json).await;
793 }
794 }
795 Some(LoopCommand::Shutdown) | None => {
796 tracing::debug!("Message loop shutting down");
797 break;
798 }
799 }
800 }
801
802 result = transport.recv() => {
804 match result {
805 Ok(Some(line)) => {
806 handle_incoming(
807 &line,
808 &mut pending_requests,
809 &handler,
810 &roots,
811 &mut transport,
812 ).await;
813 }
814 Ok(None) => {
815 tracing::info!("Transport closed (EOF)");
816 break;
817 }
818 Err(e) => {
819 tracing::error!(error = %e, "Transport receive error");
820 break;
821 }
822 }
823 }
824 }
825 }
826
827 connected.store(false, Ordering::Release);
829 fail_all_pending(&mut pending_requests, "Connection closed");
830 let _ = transport.close().await;
831}
832
833async fn handle_incoming<T: ClientTransport, H: ClientHandler>(
835 line: &str,
836 pending_requests: &mut HashMap<RequestId, PendingRequest>,
837 handler: &Arc<H>,
838 roots: &Arc<RwLock<Vec<Root>>>,
839 transport: &mut T,
840) {
841 let parsed: serde_json::Value = match serde_json::from_str(line) {
842 Ok(v) => v,
843 Err(e) => {
844 tracing::warn!(error = %e, "Failed to parse incoming message");
845 return;
846 }
847 };
848
849 if parsed.get("method").is_none()
851 && (parsed.get("result").is_some() || parsed.get("error").is_some())
852 {
853 handle_response(&parsed, pending_requests);
854 return;
855 }
856
857 if parsed.get("id").is_some() && parsed.get("method").is_some() {
859 let id = parse_request_id(&parsed);
860 let method = parsed["method"].as_str().unwrap_or("");
861 let params = parsed.get("params").cloned();
862
863 let result = dispatch_server_request(handler, roots, method, params).await;
864
865 let response = match result {
867 Ok(value) => {
868 if let Some(id) = id {
869 serde_json::json!({
870 "jsonrpc": "2.0",
871 "id": id,
872 "result": value
873 })
874 } else {
875 return;
876 }
877 }
878 Err(error) => {
879 serde_json::json!({
880 "jsonrpc": "2.0",
881 "id": id,
882 "error": {
883 "code": error.code,
884 "message": error.message
885 }
886 })
887 }
888 };
889
890 if let Ok(json) = serde_json::to_string(&response) {
891 let _ = transport.send(&json).await;
892 }
893 return;
894 }
895
896 if parsed.get("method").is_some() && parsed.get("id").is_none() {
898 let method = parsed["method"].as_str().unwrap_or("");
899 let params = parsed.get("params").cloned();
900 let notification = parse_server_notification(method, params);
901 handler.on_notification(notification).await;
902 }
903}
904
905fn handle_response(
907 parsed: &serde_json::Value,
908 pending_requests: &mut HashMap<RequestId, PendingRequest>,
909) {
910 let id = match parse_request_id(parsed) {
911 Some(id) => id,
912 None => {
913 tracing::warn!("Response without id");
914 return;
915 }
916 };
917
918 let pending = match pending_requests.remove(&id) {
919 Some(p) => p,
920 None => {
921 tracing::warn!(id = ?id, "Response for unknown request");
922 return;
923 }
924 };
925
926 tracing::debug!(id = ?id, "Received response");
927
928 if let Some(error) = parsed.get("error") {
929 let code = error.get("code").and_then(|c| c.as_i64()).unwrap_or(-1) as i32;
930 let message = error
931 .get("message")
932 .and_then(|m| m.as_str())
933 .unwrap_or("Unknown error")
934 .to_string();
935 let data = error.get("data").cloned();
936 let json_rpc_error = JsonRpcError {
937 code,
938 message,
939 data,
940 };
941 let _ = pending
942 .response_tx
943 .send(Err(Error::JsonRpc(json_rpc_error)));
944 } else if let Some(result) = parsed.get("result") {
945 let _ = pending.response_tx.send(Ok(result.clone()));
946 } else {
947 let _ = pending
948 .response_tx
949 .send(Err(Error::Transport("Invalid response".to_string())));
950 }
951}
952
953async fn dispatch_server_request<H: ClientHandler>(
955 handler: &Arc<H>,
956 roots: &Arc<RwLock<Vec<Root>>>,
957 method: &str,
958 params: Option<serde_json::Value>,
959) -> std::result::Result<serde_json::Value, JsonRpcError> {
960 match method {
961 "sampling/createMessage" => {
962 let p = serde_json::from_value(params.unwrap_or_default())
963 .map_err(|e| JsonRpcError::invalid_params(e.to_string()))?;
964 let result = handler.handle_create_message(p).await?;
965 serde_json::to_value(result).map_err(|e| JsonRpcError::internal_error(e.to_string()))
966 }
967 "elicitation/create" => {
968 let p = serde_json::from_value(params.unwrap_or_default())
969 .map_err(|e| JsonRpcError::invalid_params(e.to_string()))?;
970 let result = handler.handle_elicit(p).await?;
971 serde_json::to_value(result).map_err(|e| JsonRpcError::internal_error(e.to_string()))
972 }
973 "roots/list" => {
974 let roots_list = roots.read().await;
976 if !roots_list.is_empty() {
977 let result = ListRootsResult {
978 roots: roots_list.clone(),
979 meta: None,
980 };
981 return serde_json::to_value(result)
982 .map_err(|e| JsonRpcError::internal_error(e.to_string()));
983 }
984 drop(roots_list);
985
986 let result = handler.handle_list_roots().await?;
987 serde_json::to_value(result).map_err(|e| JsonRpcError::internal_error(e.to_string()))
988 }
989 "ping" => Ok(serde_json::json!({})),
990 _ => Err(JsonRpcError::method_not_found(method)),
991 }
992}
993
994fn parse_request_id(parsed: &serde_json::Value) -> Option<RequestId> {
996 parsed.get("id").and_then(|id| {
997 if let Some(n) = id.as_i64() {
998 Some(RequestId::Number(n))
999 } else {
1000 id.as_str().map(|s| RequestId::String(s.to_string()))
1001 }
1002 })
1003}
1004
1005fn parse_server_notification(
1007 method: &str,
1008 params: Option<serde_json::Value>,
1009) -> ServerNotification {
1010 match method {
1011 notifications::PROGRESS => {
1012 if let Some(params) = params
1013 && let Ok(p) = serde_json::from_value(params)
1014 {
1015 return ServerNotification::Progress(p);
1016 }
1017 ServerNotification::Unknown {
1018 method: method.to_string(),
1019 params: None,
1020 }
1021 }
1022 notifications::MESSAGE => {
1023 if let Some(params) = params
1024 && let Ok(p) = serde_json::from_value(params)
1025 {
1026 return ServerNotification::LogMessage(p);
1027 }
1028 ServerNotification::Unknown {
1029 method: method.to_string(),
1030 params: None,
1031 }
1032 }
1033 notifications::RESOURCE_UPDATED => {
1034 if let Some(params) = ¶ms
1035 && let Some(uri) = params.get("uri").and_then(|u| u.as_str())
1036 {
1037 return ServerNotification::ResourceUpdated {
1038 uri: uri.to_string(),
1039 };
1040 }
1041 ServerNotification::Unknown {
1042 method: method.to_string(),
1043 params,
1044 }
1045 }
1046 notifications::RESOURCES_LIST_CHANGED => ServerNotification::ResourcesListChanged,
1047 notifications::TOOLS_LIST_CHANGED => ServerNotification::ToolsListChanged,
1048 notifications::PROMPTS_LIST_CHANGED => ServerNotification::PromptsListChanged,
1049 _ => ServerNotification::Unknown {
1050 method: method.to_string(),
1051 params,
1052 },
1053 }
1054}
1055
1056fn fail_all_pending(pending: &mut HashMap<RequestId, PendingRequest>, reason: &str) {
1058 for (_, req) in pending.drain() {
1059 let _ = req
1060 .response_tx
1061 .send(Err(Error::Transport(reason.to_string())));
1062 }
1063}
1064
1065#[cfg(test)]
1066mod tests {
1067 use super::*;
1068 use async_trait::async_trait;
1069 use std::sync::Mutex;
1070
1071 struct MockTransport {
1079 responses: Arc<Mutex<Vec<serde_json::Value>>>,
1081 response_idx: Arc<std::sync::atomic::AtomicUsize>,
1083 incoming_tx: mpsc::Sender<String>,
1085 incoming_rx: mpsc::Receiver<String>,
1087 outgoing: Arc<Mutex<Vec<String>>>,
1089 connected: Arc<AtomicBool>,
1090 }
1091
1092 #[allow(dead_code)]
1093 impl MockTransport {
1094 fn new() -> Self {
1095 let (tx, rx) = mpsc::channel(32);
1096 Self {
1097 responses: Arc::new(Mutex::new(Vec::new())),
1098 response_idx: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
1099 incoming_tx: tx,
1100 incoming_rx: rx,
1101 outgoing: Arc::new(Mutex::new(Vec::new())),
1102 connected: Arc::new(AtomicBool::new(true)),
1103 }
1104 }
1105
1106 fn with_responses(responses: Vec<serde_json::Value>) -> Self {
1112 let (tx, rx) = mpsc::channel(32);
1113 Self {
1114 responses: Arc::new(Mutex::new(responses)),
1115 response_idx: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
1116 incoming_tx: tx,
1117 incoming_rx: rx,
1118 outgoing: Arc::new(Mutex::new(Vec::new())),
1119 connected: Arc::new(AtomicBool::new(true)),
1120 }
1121 }
1122 }
1123
1124 #[async_trait]
1125 impl ClientTransport for MockTransport {
1126 async fn send(&mut self, message: &str) -> Result<()> {
1127 self.outgoing.lock().unwrap().push(message.to_string());
1128
1129 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(message) {
1131 if let Some(id) = parsed.get("id") {
1133 let idx = self
1134 .response_idx
1135 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1136 let responses = self.responses.lock().unwrap();
1137 if let Some(result) = responses.get(idx) {
1138 let response = serde_json::json!({
1139 "jsonrpc": "2.0",
1140 "id": id,
1141 "result": result
1142 });
1143 let _ = self.incoming_tx.try_send(response.to_string());
1144 }
1145 }
1146 }
1147
1148 Ok(())
1149 }
1150
1151 async fn recv(&mut self) -> Result<Option<String>> {
1152 match self.incoming_rx.recv().await {
1155 Some(msg) => Ok(Some(msg)),
1156 None => Ok(None),
1157 }
1158 }
1159
1160 fn is_connected(&self) -> bool {
1161 self.connected.load(Ordering::Relaxed)
1162 }
1163
1164 async fn close(&mut self) -> Result<()> {
1165 self.connected.store(false, Ordering::Relaxed);
1166 Ok(())
1167 }
1168 }
1169
1170 fn mock_initialize_response() -> serde_json::Value {
1171 serde_json::json!({
1172 "protocolVersion": "2025-11-25",
1173 "serverInfo": {
1174 "name": "test-server",
1175 "version": "1.0.0"
1176 },
1177 "capabilities": {
1178 "tools": {}
1179 }
1180 })
1181 }
1182
1183 #[tokio::test]
1184 async fn test_client_not_initialized() {
1185 let client = McpClient::connect(MockTransport::with_responses(vec![]))
1186 .await
1187 .unwrap();
1188
1189 let result = client.list_tools().await;
1190 assert!(result.is_err());
1191 assert!(result.unwrap_err().to_string().contains("not initialized"));
1192 }
1193
1194 #[tokio::test]
1195 async fn test_client_initialize() {
1196 let client = McpClient::connect(MockTransport::with_responses(vec![
1197 mock_initialize_response(),
1198 ]))
1199 .await
1200 .unwrap();
1201
1202 assert!(!client.is_initialized());
1203
1204 let result = client.initialize("test-client", "1.0.0").await;
1205 assert!(result.is_ok());
1206 assert!(client.is_initialized());
1207
1208 let server_info = client.server_info().unwrap();
1209 assert_eq!(server_info.server_info.name, "test-server");
1210 }
1211
1212 #[tokio::test]
1213 async fn test_list_tools() {
1214 let client = McpClient::connect(MockTransport::with_responses(vec![
1215 mock_initialize_response(),
1216 serde_json::json!({
1217 "tools": [
1218 {
1219 "name": "test_tool",
1220 "description": "A test tool",
1221 "inputSchema": {
1222 "type": "object",
1223 "properties": {}
1224 }
1225 }
1226 ]
1227 }),
1228 ]))
1229 .await
1230 .unwrap();
1231
1232 client.initialize("test-client", "1.0.0").await.unwrap();
1233 let tools = client.list_tools().await.unwrap();
1234
1235 assert_eq!(tools.tools.len(), 1);
1236 assert_eq!(tools.tools[0].name, "test_tool");
1237 }
1238
1239 #[tokio::test]
1240 async fn test_call_tool() {
1241 let client = McpClient::connect(MockTransport::with_responses(vec![
1242 mock_initialize_response(),
1243 serde_json::json!({
1244 "content": [
1245 {
1246 "type": "text",
1247 "text": "Tool result"
1248 }
1249 ]
1250 }),
1251 ]))
1252 .await
1253 .unwrap();
1254
1255 client.initialize("test-client", "1.0.0").await.unwrap();
1256 let result = client
1257 .call_tool("test_tool", serde_json::json!({"arg": "value"}))
1258 .await
1259 .unwrap();
1260
1261 assert!(!result.content.is_empty());
1262 }
1263
1264 #[tokio::test]
1265 async fn test_list_resources() {
1266 let client = McpClient::connect(MockTransport::with_responses(vec![
1267 mock_initialize_response(),
1268 serde_json::json!({
1269 "resources": [
1270 {
1271 "uri": "file://test.txt",
1272 "name": "Test File"
1273 }
1274 ]
1275 }),
1276 ]))
1277 .await
1278 .unwrap();
1279
1280 client.initialize("test-client", "1.0.0").await.unwrap();
1281 let resources = client.list_resources().await.unwrap();
1282
1283 assert_eq!(resources.resources.len(), 1);
1284 assert_eq!(resources.resources[0].uri, "file://test.txt");
1285 }
1286
1287 #[tokio::test]
1288 async fn test_read_resource() {
1289 let client = McpClient::connect(MockTransport::with_responses(vec![
1290 mock_initialize_response(),
1291 serde_json::json!({
1292 "contents": [
1293 {
1294 "uri": "file://test.txt",
1295 "text": "File contents"
1296 }
1297 ]
1298 }),
1299 ]))
1300 .await
1301 .unwrap();
1302
1303 client.initialize("test-client", "1.0.0").await.unwrap();
1304 let result = client.read_resource("file://test.txt").await.unwrap();
1305
1306 assert_eq!(result.contents.len(), 1);
1307 assert_eq!(result.contents[0].text.as_deref(), Some("File contents"));
1308 }
1309
1310 #[tokio::test]
1311 async fn test_list_prompts() {
1312 let client = McpClient::connect(MockTransport::with_responses(vec![
1313 mock_initialize_response(),
1314 serde_json::json!({
1315 "prompts": [
1316 {
1317 "name": "test_prompt",
1318 "description": "A test prompt"
1319 }
1320 ]
1321 }),
1322 ]))
1323 .await
1324 .unwrap();
1325
1326 client.initialize("test-client", "1.0.0").await.unwrap();
1327 let prompts = client.list_prompts().await.unwrap();
1328
1329 assert_eq!(prompts.prompts.len(), 1);
1330 assert_eq!(prompts.prompts[0].name, "test_prompt");
1331 }
1332
1333 #[tokio::test]
1334 async fn test_get_prompt() {
1335 let client = McpClient::connect(MockTransport::with_responses(vec![
1336 mock_initialize_response(),
1337 serde_json::json!({
1338 "messages": [
1339 {
1340 "role": "user",
1341 "content": {
1342 "type": "text",
1343 "text": "Prompt message"
1344 }
1345 }
1346 ]
1347 }),
1348 ]))
1349 .await
1350 .unwrap();
1351
1352 client.initialize("test-client", "1.0.0").await.unwrap();
1353 let result = client.get_prompt("test_prompt", None).await.unwrap();
1354
1355 assert_eq!(result.messages.len(), 1);
1356 }
1357
1358 #[tokio::test]
1359 async fn test_ping() {
1360 let client = McpClient::connect(MockTransport::with_responses(vec![
1361 mock_initialize_response(),
1362 serde_json::json!({}),
1363 ]))
1364 .await
1365 .unwrap();
1366
1367 client.initialize("test-client", "1.0.0").await.unwrap();
1368 let result = client.ping().await;
1369 assert!(result.is_ok());
1370 }
1371
1372 #[tokio::test]
1373 async fn test_with_roots() {
1374 let roots = vec![Root::new("file:///test")];
1375 let client = McpClient::builder()
1376 .with_roots(roots)
1377 .connect_simple(MockTransport::with_responses(vec![]))
1378 .await
1379 .unwrap();
1380
1381 let current_roots = client.roots().await;
1382 assert_eq!(current_roots.len(), 1);
1383 }
1384
1385 #[tokio::test]
1386 async fn test_roots_management() {
1387 let client = McpClient::connect(MockTransport::with_responses(vec![
1388 mock_initialize_response(),
1389 ]))
1390 .await
1391 .unwrap();
1392
1393 assert!(client.roots().await.is_empty());
1395
1396 client.add_root(Root::new("file:///project")).await.unwrap();
1398 assert_eq!(client.roots().await.len(), 1);
1399
1400 client.initialize("test-client", "1.0.0").await.unwrap();
1402
1403 let removed = client.remove_root("file:///project").await.unwrap();
1405 assert!(removed);
1406 assert!(client.roots().await.is_empty());
1407
1408 let not_removed = client.remove_root("file:///nonexistent").await.unwrap();
1410 assert!(!not_removed);
1411 }
1412
1413 #[tokio::test]
1414 async fn test_list_roots() {
1415 let roots = vec![
1416 Root::new("file:///project1"),
1417 Root::with_name("file:///project2", "Project 2"),
1418 ];
1419 let client = McpClient::builder()
1420 .with_roots(roots)
1421 .connect_simple(MockTransport::with_responses(vec![]))
1422 .await
1423 .unwrap();
1424
1425 let result = client.list_roots().await;
1426 assert_eq!(result.roots.len(), 2);
1427 assert_eq!(result.roots[1].name, Some("Project 2".to_string()));
1428 }
1429
1430 #[test]
1431 fn test_builder_with_sampling() {
1432 let builder = McpClientBuilder::new().with_sampling();
1433 assert!(builder.capabilities.sampling.is_some());
1434 }
1435
1436 #[test]
1437 fn test_builder_with_elicitation() {
1438 let builder = McpClientBuilder::new().with_elicitation();
1439 assert!(builder.capabilities.elicitation.is_some());
1440 }
1441
1442 #[test]
1443 fn test_builder_chaining() {
1444 let builder = McpClientBuilder::new()
1445 .with_sampling()
1446 .with_elicitation()
1447 .with_roots(vec![Root::new("file:///project")]);
1448 assert!(builder.capabilities.sampling.is_some());
1449 assert!(builder.capabilities.elicitation.is_some());
1450 assert!(builder.capabilities.roots.is_some());
1451 }
1452
1453 #[tokio::test]
1454 async fn test_bidirectional_sampling_round_trip() {
1455 use crate::protocol::{
1456 ContentRole, CreateMessageParams, CreateMessageResult, SamplingContent,
1457 SamplingContentOrArray,
1458 };
1459
1460 struct RecordingHandler {
1462 called: Arc<AtomicBool>,
1463 }
1464
1465 #[async_trait]
1466 impl ClientHandler for RecordingHandler {
1467 async fn handle_create_message(
1468 &self,
1469 _params: CreateMessageParams,
1470 ) -> std::result::Result<CreateMessageResult, tower_mcp_types::JsonRpcError>
1471 {
1472 self.called.store(true, Ordering::SeqCst);
1473 Ok(CreateMessageResult {
1474 content: SamplingContentOrArray::Single(SamplingContent::Text {
1475 text: "test response".to_string(),
1476 annotations: None,
1477 meta: None,
1478 }),
1479 model: "test-model".to_string(),
1480 role: ContentRole::Assistant,
1481 stop_reason: Some("end_turn".to_string()),
1482 meta: None,
1483 })
1484 }
1485 }
1486
1487 let called = Arc::new(AtomicBool::new(false));
1488 let handler = RecordingHandler {
1489 called: called.clone(),
1490 };
1491
1492 let (inject_tx, rx) = mpsc::channel::<String>(32);
1495 let responses = vec![mock_initialize_response()];
1496 let inject_tx_clone = inject_tx.clone();
1497
1498 let transport = MockTransport {
1499 responses: Arc::new(Mutex::new(responses)),
1500 response_idx: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
1501 incoming_tx: inject_tx,
1502 incoming_rx: rx,
1503 outgoing: Arc::new(Mutex::new(Vec::new())),
1504 connected: Arc::new(AtomicBool::new(true)),
1505 };
1506
1507 let client = McpClient::builder()
1508 .with_sampling()
1509 .connect(transport, handler)
1510 .await
1511 .unwrap();
1512
1513 client.initialize("test-client", "1.0.0").await.unwrap();
1515
1516 let sampling_request = serde_json::json!({
1518 "jsonrpc": "2.0",
1519 "id": 100,
1520 "method": "sampling/createMessage",
1521 "params": {
1522 "messages": [
1523 {
1524 "role": "user",
1525 "content": {
1526 "type": "text",
1527 "text": "Hello"
1528 }
1529 }
1530 ],
1531 "maxTokens": 100
1532 }
1533 });
1534 inject_tx_clone
1535 .send(sampling_request.to_string())
1536 .await
1537 .unwrap();
1538
1539 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
1541
1542 assert!(
1544 called.load(Ordering::SeqCst),
1545 "handle_create_message should have been called"
1546 );
1547 }
1548
1549 #[tokio::test]
1550 async fn test_list_resource_templates() {
1551 let client = McpClient::connect(MockTransport::with_responses(vec![
1552 mock_initialize_response(),
1553 serde_json::json!({
1554 "resourceTemplates": [
1555 {
1556 "uriTemplate": "file:///{path}",
1557 "name": "File Template",
1558 "description": "A file template"
1559 }
1560 ]
1561 }),
1562 ]))
1563 .await
1564 .unwrap();
1565
1566 client.initialize("test-client", "1.0.0").await.unwrap();
1567 let result = client.list_resource_templates().await.unwrap();
1568
1569 assert_eq!(result.resource_templates.len(), 1);
1570 assert_eq!(result.resource_templates[0].name, "File Template");
1571 }
1572
1573 #[tokio::test]
1574 async fn test_list_all_tools_single_page() {
1575 let client = McpClient::connect(MockTransport::with_responses(vec![
1576 mock_initialize_response(),
1577 serde_json::json!({
1578 "tools": [
1579 {
1580 "name": "tool_a",
1581 "description": "Tool A",
1582 "inputSchema": { "type": "object", "properties": {} }
1583 },
1584 {
1585 "name": "tool_b",
1586 "description": "Tool B",
1587 "inputSchema": { "type": "object", "properties": {} }
1588 }
1589 ]
1590 }),
1591 ]))
1592 .await
1593 .unwrap();
1594
1595 client.initialize("test-client", "1.0.0").await.unwrap();
1596 let tools = client.list_all_tools().await.unwrap();
1597
1598 assert_eq!(tools.len(), 2);
1599 assert_eq!(tools[0].name, "tool_a");
1600 assert_eq!(tools[1].name, "tool_b");
1601 }
1602
1603 #[tokio::test]
1604 async fn test_list_all_tools_paginated() {
1605 let client = McpClient::connect(MockTransport::with_responses(vec![
1606 mock_initialize_response(),
1607 serde_json::json!({
1609 "tools": [
1610 {
1611 "name": "tool_a",
1612 "description": "Tool A",
1613 "inputSchema": { "type": "object", "properties": {} }
1614 }
1615 ],
1616 "nextCursor": "page2"
1617 }),
1618 serde_json::json!({
1620 "tools": [
1621 {
1622 "name": "tool_b",
1623 "description": "Tool B",
1624 "inputSchema": { "type": "object", "properties": {} }
1625 }
1626 ]
1627 }),
1628 ]))
1629 .await
1630 .unwrap();
1631
1632 client.initialize("test-client", "1.0.0").await.unwrap();
1633 let tools = client.list_all_tools().await.unwrap();
1634
1635 assert_eq!(tools.len(), 2);
1636 assert_eq!(tools[0].name, "tool_a");
1637 assert_eq!(tools[1].name, "tool_b");
1638 }
1639
1640 #[tokio::test]
1641 async fn test_call_tool_text_success() {
1642 let client = McpClient::connect(MockTransport::with_responses(vec![
1643 mock_initialize_response(),
1644 serde_json::json!({
1645 "content": [
1646 { "type": "text", "text": "Hello " },
1647 { "type": "text", "text": "World" }
1648 ]
1649 }),
1650 ]))
1651 .await
1652 .unwrap();
1653
1654 client.initialize("test-client", "1.0.0").await.unwrap();
1655 let text = client
1656 .call_tool_text("test_tool", serde_json::json!({}))
1657 .await
1658 .unwrap();
1659
1660 assert_eq!(text, "Hello World");
1661 }
1662
1663 #[tokio::test]
1664 async fn test_call_tool_text_error() {
1665 let client = McpClient::connect(MockTransport::with_responses(vec![
1666 mock_initialize_response(),
1667 serde_json::json!({
1668 "content": [
1669 { "type": "text", "text": "something went wrong" }
1670 ],
1671 "isError": true
1672 }),
1673 ]))
1674 .await
1675 .unwrap();
1676
1677 client.initialize("test-client", "1.0.0").await.unwrap();
1678 let result = client
1679 .call_tool_text("test_tool", serde_json::json!({}))
1680 .await;
1681
1682 assert!(result.is_err());
1683 let err = result.unwrap_err();
1684 assert!(
1685 err.to_string().contains("something went wrong"),
1686 "Error message should contain tool error text, got: {}",
1687 err
1688 );
1689 }
1690
1691 #[tokio::test]
1692 async fn test_server_notification_parsing() {
1693 let notification = parse_server_notification("notifications/tools/list_changed", None);
1694 assert!(matches!(notification, ServerNotification::ToolsListChanged));
1695
1696 let notification = parse_server_notification("notifications/resources/list_changed", None);
1697 assert!(matches!(
1698 notification,
1699 ServerNotification::ResourcesListChanged
1700 ));
1701
1702 let notification = parse_server_notification(
1703 "notifications/resources/updated",
1704 Some(serde_json::json!({"uri": "file:///test"})),
1705 );
1706 match notification {
1707 ServerNotification::ResourceUpdated { uri } => {
1708 assert_eq!(uri, "file:///test");
1709 }
1710 _ => panic!("Expected ResourceUpdated"),
1711 }
1712
1713 let notification =
1714 parse_server_notification("custom/notification", Some(serde_json::json!({"data": 42})));
1715 match notification {
1716 ServerNotification::Unknown { method, params } => {
1717 assert_eq!(method, "custom/notification");
1718 assert!(params.is_some());
1719 }
1720 _ => panic!("Expected Unknown"),
1721 }
1722 }
1723}