1use std::process::Stdio;
36use std::sync::atomic::{AtomicI64, Ordering};
37
38use async_trait::async_trait;
39use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
40use tokio::process::{Child, Command};
41
42use crate::error::{Error, Result};
43use crate::protocol::{
44 CallToolParams, CallToolResult, ClientCapabilities, CompleteParams, CompleteResult,
45 CompletionArgument, CompletionReference, GetPromptParams, GetPromptResult, Implementation,
46 InitializeParams, InitializeResult, JsonRpcRequest, JsonRpcResponse, ListPromptsParams,
47 ListPromptsResult, ListResourcesParams, ListResourcesResult, ListRootsResult, ListToolsParams,
48 ListToolsResult, ReadResourceParams, ReadResourceResult, Root, RootsCapability, notifications,
49};
50
51#[async_trait]
53pub trait ClientTransport: Send {
54 async fn request(
56 &mut self,
57 method: &str,
58 params: serde_json::Value,
59 ) -> Result<serde_json::Value>;
60
61 async fn notify(&mut self, method: &str, params: serde_json::Value) -> Result<()>;
63
64 fn is_connected(&self) -> bool;
66
67 async fn close(self: Box<Self>) -> Result<()>;
69}
70
71pub struct McpClient<T: ClientTransport> {
73 transport: T,
74 initialized: bool,
75 server_info: Option<InitializeResult>,
76 capabilities: ClientCapabilities,
78 roots: Vec<Root>,
80}
81
82impl<T: ClientTransport> McpClient<T> {
83 pub fn new(transport: T) -> Self {
85 Self {
86 transport,
87 initialized: false,
88 server_info: None,
89 capabilities: ClientCapabilities::default(),
90 roots: Vec::new(),
91 }
92 }
93
94 pub fn with_roots(mut self, roots: Vec<Root>) -> Self {
113 self.roots = roots;
114 self.capabilities.roots = Some(RootsCapability { list_changed: true });
115 self
116 }
117
118 pub fn with_capabilities(mut self, capabilities: ClientCapabilities) -> Self {
137 self.capabilities = capabilities;
138 self
139 }
140
141 pub fn server_info(&self) -> Option<&InitializeResult> {
143 self.server_info.as_ref()
144 }
145
146 pub fn is_initialized(&self) -> bool {
148 self.initialized
149 }
150
151 pub fn roots(&self) -> &[Root] {
153 &self.roots
154 }
155
156 pub async fn set_roots(&mut self, roots: Vec<Root>) -> Result<()> {
160 self.roots = roots;
161 if self.initialized {
162 self.notify_roots_changed().await?;
163 }
164 Ok(())
165 }
166
167 pub async fn add_root(&mut self, root: Root) -> Result<()> {
169 self.roots.push(root);
170 if self.initialized {
171 self.notify_roots_changed().await?;
172 }
173 Ok(())
174 }
175
176 pub async fn remove_root(&mut self, uri: &str) -> Result<bool> {
178 let initial_len = self.roots.len();
179 self.roots.retain(|r| r.uri != uri);
180 let removed = self.roots.len() < initial_len;
181 if removed && self.initialized {
182 self.notify_roots_changed().await?;
183 }
184 Ok(removed)
185 }
186
187 async fn notify_roots_changed(&mut self) -> Result<()> {
189 self.transport
190 .notify(notifications::ROOTS_LIST_CHANGED, serde_json::json!({}))
191 .await
192 }
193
194 pub fn list_roots(&self) -> ListRootsResult {
198 ListRootsResult {
199 roots: self.roots.clone(),
200 meta: None,
201 }
202 }
203
204 pub async fn initialize(
206 &mut self,
207 client_name: &str,
208 client_version: &str,
209 ) -> Result<&InitializeResult> {
210 let params = InitializeParams {
211 protocol_version: crate::protocol::LATEST_PROTOCOL_VERSION.to_string(),
212 capabilities: self.capabilities.clone(),
213 client_info: Implementation {
214 name: client_name.to_string(),
215 version: client_version.to_string(),
216 ..Default::default()
217 },
218 meta: None,
219 };
220
221 let result: InitializeResult = self.request("initialize", ¶ms).await?;
222 self.server_info = Some(result);
223
224 self.transport
226 .notify("notifications/initialized", serde_json::json!({}))
227 .await?;
228
229 self.initialized = true;
230
231 Ok(self.server_info.as_ref().unwrap())
232 }
233
234 pub async fn list_tools(&mut self) -> Result<ListToolsResult> {
236 self.ensure_initialized()?;
237 self.request(
238 "tools/list",
239 &ListToolsParams {
240 cursor: None,
241 meta: None,
242 },
243 )
244 .await
245 }
246
247 pub async fn call_tool(
249 &mut self,
250 name: &str,
251 arguments: serde_json::Value,
252 ) -> Result<CallToolResult> {
253 self.ensure_initialized()?;
254 let params = CallToolParams {
255 name: name.to_string(),
256 arguments,
257 meta: None,
258 task: None,
259 };
260 self.request("tools/call", ¶ms).await
261 }
262
263 pub async fn list_resources(&mut self) -> Result<ListResourcesResult> {
265 self.ensure_initialized()?;
266 self.request(
267 "resources/list",
268 &ListResourcesParams {
269 cursor: None,
270 meta: None,
271 },
272 )
273 .await
274 }
275
276 pub async fn read_resource(&mut self, uri: &str) -> Result<ReadResourceResult> {
278 self.ensure_initialized()?;
279 let params = ReadResourceParams {
280 uri: uri.to_string(),
281 meta: None,
282 };
283 self.request("resources/read", ¶ms).await
284 }
285
286 pub async fn list_prompts(&mut self) -> Result<ListPromptsResult> {
288 self.ensure_initialized()?;
289 self.request(
290 "prompts/list",
291 &ListPromptsParams {
292 cursor: None,
293 meta: None,
294 },
295 )
296 .await
297 }
298
299 pub async fn get_prompt(
301 &mut self,
302 name: &str,
303 arguments: Option<std::collections::HashMap<String, String>>,
304 ) -> Result<GetPromptResult> {
305 self.ensure_initialized()?;
306 let params = GetPromptParams {
307 name: name.to_string(),
308 arguments: arguments.unwrap_or_default(),
309 meta: None,
310 };
311 self.request("prompts/get", ¶ms).await
312 }
313
314 pub async fn ping(&mut self) -> Result<()> {
316 let _: serde_json::Value = self.request("ping", &serde_json::json!({})).await?;
317 Ok(())
318 }
319
320 pub async fn complete(
324 &mut self,
325 reference: CompletionReference,
326 argument_name: &str,
327 argument_value: &str,
328 ) -> Result<CompleteResult> {
329 self.ensure_initialized()?;
330 let params = CompleteParams {
331 reference,
332 argument: CompletionArgument::new(argument_name, argument_value),
333 context: None,
334 meta: None,
335 };
336 self.request("completion/complete", ¶ms).await
337 }
338
339 pub async fn complete_prompt_arg(
341 &mut self,
342 prompt_name: &str,
343 argument_name: &str,
344 argument_value: &str,
345 ) -> Result<CompleteResult> {
346 self.complete(
347 CompletionReference::prompt(prompt_name),
348 argument_name,
349 argument_value,
350 )
351 .await
352 }
353
354 pub async fn complete_resource_uri(
356 &mut self,
357 resource_uri: &str,
358 argument_name: &str,
359 argument_value: &str,
360 ) -> Result<CompleteResult> {
361 self.complete(
362 CompletionReference::resource(resource_uri),
363 argument_name,
364 argument_value,
365 )
366 .await
367 }
368
369 pub async fn request<P: serde::Serialize, R: serde::de::DeserializeOwned>(
371 &mut self,
372 method: &str,
373 params: &P,
374 ) -> Result<R> {
375 let params_value = serde_json::to_value(params)
376 .map_err(|e| Error::Transport(format!("Failed to serialize params: {}", e)))?;
377
378 let result = self.transport.request(method, params_value).await?;
379
380 serde_json::from_value(result)
381 .map_err(|e| Error::Transport(format!("Failed to deserialize response: {}", e)))
382 }
383
384 pub async fn notify<P: serde::Serialize>(&mut self, method: &str, params: &P) -> Result<()> {
386 let params_value = serde_json::to_value(params)
387 .map_err(|e| Error::Transport(format!("Failed to serialize params: {}", e)))?;
388
389 self.transport.notify(method, params_value).await
390 }
391
392 fn ensure_initialized(&self) -> Result<()> {
393 if !self.initialized {
394 return Err(Error::Transport("Client not initialized".to_string()));
395 }
396 Ok(())
397 }
398}
399
400pub struct StdioClientTransport {
406 child: Option<Child>,
407 stdin: tokio::process::ChildStdin,
408 stdout: BufReader<tokio::process::ChildStdout>,
409 request_id: AtomicI64,
410}
411
412impl StdioClientTransport {
413 pub async fn spawn(program: &str, args: &[&str]) -> Result<Self> {
415 let mut cmd = Command::new(program);
416 cmd.args(args)
417 .stdin(Stdio::piped())
418 .stdout(Stdio::piped())
419 .stderr(Stdio::inherit());
420
421 let mut child = cmd
422 .spawn()
423 .map_err(|e| Error::Transport(format!("Failed to spawn {}: {}", program, e)))?;
424
425 let stdin = child
426 .stdin
427 .take()
428 .ok_or_else(|| Error::Transport("Failed to get child stdin".to_string()))?;
429 let stdout = child
430 .stdout
431 .take()
432 .ok_or_else(|| Error::Transport("Failed to get child stdout".to_string()))?;
433
434 tracing::info!(program = %program, "Spawned MCP server process");
435
436 Ok(Self {
437 child: Some(child),
438 stdin,
439 stdout: BufReader::new(stdout),
440 request_id: AtomicI64::new(1),
441 })
442 }
443
444 pub fn from_child(mut child: Child) -> Result<Self> {
446 let stdin = child
447 .stdin
448 .take()
449 .ok_or_else(|| Error::Transport("Failed to get child stdin".to_string()))?;
450 let stdout = child
451 .stdout
452 .take()
453 .ok_or_else(|| Error::Transport("Failed to get child stdout".to_string()))?;
454
455 Ok(Self {
456 child: Some(child),
457 stdin,
458 stdout: BufReader::new(stdout),
459 request_id: AtomicI64::new(1),
460 })
461 }
462
463 async fn send_line(&mut self, line: &str) -> Result<()> {
464 self.stdin
465 .write_all(line.as_bytes())
466 .await
467 .map_err(|e| Error::Transport(format!("Failed to write: {}", e)))?;
468 self.stdin
469 .write_all(b"\n")
470 .await
471 .map_err(|e| Error::Transport(format!("Failed to write newline: {}", e)))?;
472 self.stdin
473 .flush()
474 .await
475 .map_err(|e| Error::Transport(format!("Failed to flush: {}", e)))?;
476 Ok(())
477 }
478
479 async fn read_line(&mut self) -> Result<String> {
480 let mut line = String::new();
481 self.stdout
482 .read_line(&mut line)
483 .await
484 .map_err(|e| Error::Transport(format!("Failed to read: {}", e)))?;
485
486 if line.is_empty() {
487 return Err(Error::Transport("Connection closed".to_string()));
488 }
489
490 Ok(line)
491 }
492}
493
494#[async_trait]
495impl ClientTransport for StdioClientTransport {
496 async fn request(
497 &mut self,
498 method: &str,
499 params: serde_json::Value,
500 ) -> Result<serde_json::Value> {
501 let id = self.request_id.fetch_add(1, Ordering::Relaxed);
502 let request = JsonRpcRequest::new(id, method).with_params(params);
503
504 let request_json = serde_json::to_string(&request)
505 .map_err(|e| Error::Transport(format!("Failed to serialize: {}", e)))?;
506
507 tracing::debug!(method = %method, id = %id, "Sending request");
508 self.send_line(&request_json).await?;
509
510 let response_line = self.read_line().await?;
511 tracing::debug!(response = %response_line.trim(), "Received response");
512
513 let response: JsonRpcResponse = serde_json::from_str(response_line.trim())
514 .map_err(|e| Error::Transport(format!("Failed to parse response: {}", e)))?;
515
516 match response {
517 JsonRpcResponse::Result(r) => Ok(r.result),
518 JsonRpcResponse::Error(e) => Err(Error::JsonRpc(e.error)),
519 }
520 }
521
522 async fn notify(&mut self, method: &str, params: serde_json::Value) -> Result<()> {
523 let notification = serde_json::json!({
524 "jsonrpc": "2.0",
525 "method": method,
526 "params": params
527 });
528
529 let json = serde_json::to_string(¬ification)
530 .map_err(|e| Error::Transport(format!("Failed to serialize: {}", e)))?;
531
532 tracing::debug!(method = %method, "Sending notification");
533 self.send_line(&json).await
534 }
535
536 fn is_connected(&self) -> bool {
537 self.child.is_some()
539 }
540
541 async fn close(mut self: Box<Self>) -> Result<()> {
542 drop(self.stdin);
544
545 if let Some(mut child) = self.child.take() {
546 let result =
548 tokio::time::timeout(std::time::Duration::from_secs(5), child.wait()).await;
549
550 match result {
551 Ok(Ok(status)) => {
552 tracing::info!(status = ?status, "Child process exited");
553 }
554 Ok(Err(e)) => {
555 tracing::error!(error = %e, "Error waiting for child");
556 }
557 Err(_) => {
558 tracing::warn!("Timeout waiting for child, killing");
559 let _ = child.kill().await;
560 }
561 }
562 }
563
564 Ok(())
565 }
566}
567
568#[cfg(test)]
569mod tests {
570 use super::*;
571 use std::collections::VecDeque;
572 use std::sync::{Arc, Mutex};
573
574 struct MockTransport {
576 responses: Arc<Mutex<VecDeque<serde_json::Value>>>,
577 requests: Arc<Mutex<Vec<(String, serde_json::Value)>>>,
578 notifications: Arc<Mutex<Vec<(String, serde_json::Value)>>>,
579 connected: bool,
580 }
581
582 impl MockTransport {
583 fn new() -> Self {
584 Self {
585 responses: Arc::new(Mutex::new(VecDeque::new())),
586 requests: Arc::new(Mutex::new(Vec::new())),
587 notifications: Arc::new(Mutex::new(Vec::new())),
588 connected: true,
589 }
590 }
591
592 fn with_responses(responses: Vec<serde_json::Value>) -> Self {
593 Self {
594 responses: Arc::new(Mutex::new(responses.into())),
595 requests: Arc::new(Mutex::new(Vec::new())),
596 notifications: Arc::new(Mutex::new(Vec::new())),
597 connected: true,
598 }
599 }
600
601 #[allow(dead_code)]
602 fn get_requests(&self) -> Vec<(String, serde_json::Value)> {
603 self.requests.lock().unwrap().clone()
604 }
605
606 #[allow(dead_code)]
607 fn get_notifications(&self) -> Vec<(String, serde_json::Value)> {
608 self.notifications.lock().unwrap().clone()
609 }
610 }
611
612 #[async_trait]
613 impl ClientTransport for MockTransport {
614 async fn request(
615 &mut self,
616 method: &str,
617 params: serde_json::Value,
618 ) -> Result<serde_json::Value> {
619 self.requests
620 .lock()
621 .unwrap()
622 .push((method.to_string(), params));
623 self.responses
624 .lock()
625 .unwrap()
626 .pop_front()
627 .ok_or_else(|| Error::Transport("No more mock responses".to_string()))
628 }
629
630 async fn notify(&mut self, method: &str, params: serde_json::Value) -> Result<()> {
631 self.notifications
632 .lock()
633 .unwrap()
634 .push((method.to_string(), params));
635 Ok(())
636 }
637
638 fn is_connected(&self) -> bool {
639 self.connected
640 }
641
642 async fn close(self: Box<Self>) -> Result<()> {
643 Ok(())
644 }
645 }
646
647 fn mock_initialize_response() -> serde_json::Value {
648 serde_json::json!({
649 "protocolVersion": "2025-11-25",
650 "serverInfo": {
651 "name": "test-server",
652 "version": "1.0.0"
653 },
654 "capabilities": {
655 "tools": {}
656 }
657 })
658 }
659
660 #[tokio::test]
661 async fn test_client_not_initialized() {
662 let mut client = McpClient::new(MockTransport::new());
663
664 let result = client.list_tools().await;
666 assert!(result.is_err());
667 assert!(result.unwrap_err().to_string().contains("not initialized"));
668 }
669
670 #[tokio::test]
671 async fn test_client_initialize() {
672 let transport = MockTransport::with_responses(vec![mock_initialize_response()]);
673 let mut client = McpClient::new(transport);
674
675 assert!(!client.is_initialized());
676
677 let result = client.initialize("test-client", "1.0.0").await;
678 assert!(result.is_ok());
679 assert!(client.is_initialized());
680
681 let server_info = client.server_info().unwrap();
682 assert_eq!(server_info.server_info.name, "test-server");
683 }
684
685 #[tokio::test]
686 async fn test_list_tools() {
687 let transport = MockTransport::with_responses(vec![
688 mock_initialize_response(),
689 serde_json::json!({
690 "tools": [
691 {
692 "name": "test_tool",
693 "description": "A test tool",
694 "inputSchema": {
695 "type": "object",
696 "properties": {}
697 }
698 }
699 ]
700 }),
701 ]);
702 let mut client = McpClient::new(transport);
703
704 client.initialize("test-client", "1.0.0").await.unwrap();
705 let tools = client.list_tools().await.unwrap();
706
707 assert_eq!(tools.tools.len(), 1);
708 assert_eq!(tools.tools[0].name, "test_tool");
709 }
710
711 #[tokio::test]
712 async fn test_call_tool() {
713 let transport = MockTransport::with_responses(vec![
714 mock_initialize_response(),
715 serde_json::json!({
716 "content": [
717 {
718 "type": "text",
719 "text": "Tool result"
720 }
721 ]
722 }),
723 ]);
724 let mut client = McpClient::new(transport);
725
726 client.initialize("test-client", "1.0.0").await.unwrap();
727 let result = client
728 .call_tool("test_tool", serde_json::json!({"arg": "value"}))
729 .await
730 .unwrap();
731
732 assert!(!result.content.is_empty());
733 }
734
735 #[tokio::test]
736 async fn test_list_resources() {
737 let transport = MockTransport::with_responses(vec![
738 mock_initialize_response(),
739 serde_json::json!({
740 "resources": [
741 {
742 "uri": "file://test.txt",
743 "name": "Test File"
744 }
745 ]
746 }),
747 ]);
748 let mut client = McpClient::new(transport);
749
750 client.initialize("test-client", "1.0.0").await.unwrap();
751 let resources = client.list_resources().await.unwrap();
752
753 assert_eq!(resources.resources.len(), 1);
754 assert_eq!(resources.resources[0].uri, "file://test.txt");
755 }
756
757 #[tokio::test]
758 async fn test_read_resource() {
759 let transport = MockTransport::with_responses(vec![
760 mock_initialize_response(),
761 serde_json::json!({
762 "contents": [
763 {
764 "uri": "file://test.txt",
765 "text": "File contents"
766 }
767 ]
768 }),
769 ]);
770 let mut client = McpClient::new(transport);
771
772 client.initialize("test-client", "1.0.0").await.unwrap();
773 let result = client.read_resource("file://test.txt").await.unwrap();
774
775 assert_eq!(result.contents.len(), 1);
776 assert_eq!(result.contents[0].text.as_deref(), Some("File contents"));
777 }
778
779 #[tokio::test]
780 async fn test_list_prompts() {
781 let transport = MockTransport::with_responses(vec![
782 mock_initialize_response(),
783 serde_json::json!({
784 "prompts": [
785 {
786 "name": "test_prompt",
787 "description": "A test prompt"
788 }
789 ]
790 }),
791 ]);
792 let mut client = McpClient::new(transport);
793
794 client.initialize("test-client", "1.0.0").await.unwrap();
795 let prompts = client.list_prompts().await.unwrap();
796
797 assert_eq!(prompts.prompts.len(), 1);
798 assert_eq!(prompts.prompts[0].name, "test_prompt");
799 }
800
801 #[tokio::test]
802 async fn test_get_prompt() {
803 let transport = MockTransport::with_responses(vec![
804 mock_initialize_response(),
805 serde_json::json!({
806 "messages": [
807 {
808 "role": "user",
809 "content": {
810 "type": "text",
811 "text": "Prompt message"
812 }
813 }
814 ]
815 }),
816 ]);
817 let mut client = McpClient::new(transport);
818
819 client.initialize("test-client", "1.0.0").await.unwrap();
820 let result = client.get_prompt("test_prompt", None).await.unwrap();
821
822 assert_eq!(result.messages.len(), 1);
823 }
824
825 #[tokio::test]
826 async fn test_ping() {
827 let transport =
828 MockTransport::with_responses(vec![mock_initialize_response(), serde_json::json!({})]);
829 let mut client = McpClient::new(transport);
830
831 client.initialize("test-client", "1.0.0").await.unwrap();
832 let result = client.ping().await;
833
834 assert!(result.is_ok());
835 }
836
837 #[tokio::test]
838 async fn test_roots_management() {
839 let transport = MockTransport::with_responses(vec![mock_initialize_response()]);
840 let notifications = transport.notifications.clone();
841 let mut client = McpClient::new(transport);
842
843 assert!(client.roots().is_empty());
845
846 client.add_root(Root::new("file:///project")).await.unwrap();
848 assert_eq!(client.roots().len(), 1);
849 assert!(notifications.lock().unwrap().is_empty());
850
851 client.initialize("test-client", "1.0.0").await.unwrap();
853
854 client.add_root(Root::new("file:///other")).await.unwrap();
856 assert_eq!(client.roots().len(), 2);
857 assert_eq!(notifications.lock().unwrap().len(), 2); let removed = client.remove_root("file:///project").await.unwrap();
861 assert!(removed);
862 assert_eq!(client.roots().len(), 1);
863
864 let not_removed = client.remove_root("file:///nonexistent").await.unwrap();
866 assert!(!not_removed);
867 }
868
869 #[tokio::test]
870 async fn test_with_roots() {
871 let roots = vec![Root::new("file:///test")];
872 let transport = MockTransport::with_responses(vec![mock_initialize_response()]);
873 let client = McpClient::new(transport).with_roots(roots);
874
875 assert_eq!(client.roots().len(), 1);
876 assert!(client.capabilities.roots.is_some());
877 }
878
879 #[tokio::test]
880 async fn test_with_capabilities() {
881 let capabilities = ClientCapabilities {
882 sampling: Some(Default::default()),
883 ..Default::default()
884 };
885
886 let transport = MockTransport::with_responses(vec![mock_initialize_response()]);
887 let client = McpClient::new(transport).with_capabilities(capabilities);
888
889 assert!(client.capabilities.sampling.is_some());
890 }
891
892 #[tokio::test]
893 async fn test_list_roots() {
894 let roots = vec![
895 Root::new("file:///project1"),
896 Root::with_name("file:///project2", "Project 2"),
897 ];
898 let transport = MockTransport::new();
899 let client = McpClient::new(transport).with_roots(roots);
900
901 let result = client.list_roots();
902 assert_eq!(result.roots.len(), 2);
903 assert_eq!(result.roots[1].name, Some("Project 2".to_string()));
904 }
905}