1use std::process::Stdio;
35use std::sync::atomic::{AtomicI64, Ordering};
36
37use async_trait::async_trait;
38use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
39use tokio::process::{Child, Command};
40
41use crate::error::{Error, Result};
42use crate::protocol::{
43 CallToolParams, CallToolResult, ClientCapabilities, CompleteParams, CompleteResult,
44 CompletionArgument, CompletionReference, GetPromptParams, GetPromptResult, Implementation,
45 InitializeParams, InitializeResult, JsonRpcRequest, JsonRpcResponse, ListPromptsParams,
46 ListPromptsResult, ListResourcesParams, ListResourcesResult, ListRootsResult, ListToolsParams,
47 ListToolsResult, ReadResourceParams, ReadResourceResult, Root, RootsCapability, notifications,
48};
49
50#[async_trait]
52pub trait ClientTransport: Send {
53 async fn request(
55 &mut self,
56 method: &str,
57 params: serde_json::Value,
58 ) -> Result<serde_json::Value>;
59
60 async fn notify(&mut self, method: &str, params: serde_json::Value) -> Result<()>;
62
63 fn is_connected(&self) -> bool;
65
66 async fn close(self: Box<Self>) -> Result<()>;
68}
69
70pub struct McpClient<T: ClientTransport> {
72 transport: T,
73 initialized: bool,
74 server_info: Option<InitializeResult>,
75 capabilities: ClientCapabilities,
77 roots: Vec<Root>,
79}
80
81impl<T: ClientTransport> McpClient<T> {
82 pub fn new(transport: T) -> Self {
84 Self {
85 transport,
86 initialized: false,
87 server_info: None,
88 capabilities: ClientCapabilities::default(),
89 roots: Vec::new(),
90 }
91 }
92
93 pub fn with_roots(transport: T, roots: Vec<Root>) -> Self {
97 Self {
98 transport,
99 initialized: false,
100 server_info: None,
101 capabilities: ClientCapabilities {
102 roots: Some(RootsCapability { list_changed: true }),
103 ..Default::default()
104 },
105 roots,
106 }
107 }
108
109 pub fn with_capabilities(transport: T, capabilities: ClientCapabilities) -> Self {
111 Self {
112 transport,
113 initialized: false,
114 server_info: None,
115 capabilities,
116 roots: Vec::new(),
117 }
118 }
119
120 pub fn server_info(&self) -> Option<&InitializeResult> {
122 self.server_info.as_ref()
123 }
124
125 pub fn is_initialized(&self) -> bool {
127 self.initialized
128 }
129
130 pub fn roots(&self) -> &[Root] {
132 &self.roots
133 }
134
135 pub async fn set_roots(&mut self, roots: Vec<Root>) -> Result<()> {
139 self.roots = roots;
140 if self.initialized {
141 self.notify_roots_changed().await?;
142 }
143 Ok(())
144 }
145
146 pub async fn add_root(&mut self, root: Root) -> Result<()> {
148 self.roots.push(root);
149 if self.initialized {
150 self.notify_roots_changed().await?;
151 }
152 Ok(())
153 }
154
155 pub async fn remove_root(&mut self, uri: &str) -> Result<bool> {
157 let initial_len = self.roots.len();
158 self.roots.retain(|r| r.uri != uri);
159 let removed = self.roots.len() < initial_len;
160 if removed && self.initialized {
161 self.notify_roots_changed().await?;
162 }
163 Ok(removed)
164 }
165
166 async fn notify_roots_changed(&mut self) -> Result<()> {
168 self.transport
169 .notify(notifications::ROOTS_LIST_CHANGED, serde_json::json!({}))
170 .await
171 }
172
173 pub fn list_roots(&self) -> ListRootsResult {
177 ListRootsResult {
178 roots: self.roots.clone(),
179 }
180 }
181
182 pub async fn initialize(
184 &mut self,
185 client_name: &str,
186 client_version: &str,
187 ) -> Result<&InitializeResult> {
188 let params = InitializeParams {
189 protocol_version: crate::protocol::LATEST_PROTOCOL_VERSION.to_string(),
190 capabilities: self.capabilities.clone(),
191 client_info: Implementation {
192 name: client_name.to_string(),
193 version: client_version.to_string(),
194 },
195 };
196
197 let result: InitializeResult = self.request("initialize", ¶ms).await?;
198 self.server_info = Some(result);
199
200 self.transport
202 .notify("notifications/initialized", serde_json::json!({}))
203 .await?;
204
205 self.initialized = true;
206
207 Ok(self.server_info.as_ref().unwrap())
208 }
209
210 pub async fn list_tools(&mut self) -> Result<ListToolsResult> {
212 self.ensure_initialized()?;
213 self.request("tools/list", &ListToolsParams { cursor: None })
214 .await
215 }
216
217 pub async fn call_tool(
219 &mut self,
220 name: &str,
221 arguments: serde_json::Value,
222 ) -> Result<CallToolResult> {
223 self.ensure_initialized()?;
224 let params = CallToolParams {
225 name: name.to_string(),
226 arguments,
227 meta: None,
228 };
229 self.request("tools/call", ¶ms).await
230 }
231
232 pub async fn list_resources(&mut self) -> Result<ListResourcesResult> {
234 self.ensure_initialized()?;
235 self.request("resources/list", &ListResourcesParams { cursor: None })
236 .await
237 }
238
239 pub async fn read_resource(&mut self, uri: &str) -> Result<ReadResourceResult> {
241 self.ensure_initialized()?;
242 let params = ReadResourceParams {
243 uri: uri.to_string(),
244 };
245 self.request("resources/read", ¶ms).await
246 }
247
248 pub async fn list_prompts(&mut self) -> Result<ListPromptsResult> {
250 self.ensure_initialized()?;
251 self.request("prompts/list", &ListPromptsParams { cursor: None })
252 .await
253 }
254
255 pub async fn get_prompt(
257 &mut self,
258 name: &str,
259 arguments: Option<std::collections::HashMap<String, String>>,
260 ) -> Result<GetPromptResult> {
261 self.ensure_initialized()?;
262 let params = GetPromptParams {
263 name: name.to_string(),
264 arguments: arguments.unwrap_or_default(),
265 };
266 self.request("prompts/get", ¶ms).await
267 }
268
269 pub async fn ping(&mut self) -> Result<()> {
271 let _: serde_json::Value = self.request("ping", &serde_json::json!({})).await?;
272 Ok(())
273 }
274
275 pub async fn complete(
279 &mut self,
280 reference: CompletionReference,
281 argument_name: &str,
282 argument_value: &str,
283 ) -> Result<CompleteResult> {
284 self.ensure_initialized()?;
285 let params = CompleteParams {
286 reference,
287 argument: CompletionArgument::new(argument_name, argument_value),
288 };
289 self.request("completion/complete", ¶ms).await
290 }
291
292 pub async fn complete_prompt_arg(
294 &mut self,
295 prompt_name: &str,
296 argument_name: &str,
297 argument_value: &str,
298 ) -> Result<CompleteResult> {
299 self.complete(
300 CompletionReference::prompt(prompt_name),
301 argument_name,
302 argument_value,
303 )
304 .await
305 }
306
307 pub async fn complete_resource_uri(
309 &mut self,
310 resource_uri: &str,
311 argument_name: &str,
312 argument_value: &str,
313 ) -> Result<CompleteResult> {
314 self.complete(
315 CompletionReference::resource(resource_uri),
316 argument_name,
317 argument_value,
318 )
319 .await
320 }
321
322 pub async fn request<P: serde::Serialize, R: serde::de::DeserializeOwned>(
324 &mut self,
325 method: &str,
326 params: &P,
327 ) -> Result<R> {
328 let params_value = serde_json::to_value(params)
329 .map_err(|e| Error::Transport(format!("Failed to serialize params: {}", e)))?;
330
331 let result = self.transport.request(method, params_value).await?;
332
333 serde_json::from_value(result)
334 .map_err(|e| Error::Transport(format!("Failed to deserialize response: {}", e)))
335 }
336
337 pub async fn notify<P: serde::Serialize>(&mut self, method: &str, params: &P) -> Result<()> {
339 let params_value = serde_json::to_value(params)
340 .map_err(|e| Error::Transport(format!("Failed to serialize params: {}", e)))?;
341
342 self.transport.notify(method, params_value).await
343 }
344
345 fn ensure_initialized(&self) -> Result<()> {
346 if !self.initialized {
347 return Err(Error::Transport("Client not initialized".to_string()));
348 }
349 Ok(())
350 }
351}
352
353pub struct StdioClientTransport {
359 child: Option<Child>,
360 stdin: tokio::process::ChildStdin,
361 stdout: BufReader<tokio::process::ChildStdout>,
362 request_id: AtomicI64,
363}
364
365impl StdioClientTransport {
366 pub async fn spawn(program: &str, args: &[&str]) -> Result<Self> {
368 let mut cmd = Command::new(program);
369 cmd.args(args)
370 .stdin(Stdio::piped())
371 .stdout(Stdio::piped())
372 .stderr(Stdio::inherit());
373
374 let mut child = cmd
375 .spawn()
376 .map_err(|e| Error::Transport(format!("Failed to spawn {}: {}", program, e)))?;
377
378 let stdin = child
379 .stdin
380 .take()
381 .ok_or_else(|| Error::Transport("Failed to get child stdin".to_string()))?;
382 let stdout = child
383 .stdout
384 .take()
385 .ok_or_else(|| Error::Transport("Failed to get child stdout".to_string()))?;
386
387 tracing::info!(program = %program, "Spawned MCP server process");
388
389 Ok(Self {
390 child: Some(child),
391 stdin,
392 stdout: BufReader::new(stdout),
393 request_id: AtomicI64::new(1),
394 })
395 }
396
397 pub fn from_child(mut child: Child) -> Result<Self> {
399 let stdin = child
400 .stdin
401 .take()
402 .ok_or_else(|| Error::Transport("Failed to get child stdin".to_string()))?;
403 let stdout = child
404 .stdout
405 .take()
406 .ok_or_else(|| Error::Transport("Failed to get child stdout".to_string()))?;
407
408 Ok(Self {
409 child: Some(child),
410 stdin,
411 stdout: BufReader::new(stdout),
412 request_id: AtomicI64::new(1),
413 })
414 }
415
416 async fn send_line(&mut self, line: &str) -> Result<()> {
417 self.stdin
418 .write_all(line.as_bytes())
419 .await
420 .map_err(|e| Error::Transport(format!("Failed to write: {}", e)))?;
421 self.stdin
422 .write_all(b"\n")
423 .await
424 .map_err(|e| Error::Transport(format!("Failed to write newline: {}", e)))?;
425 self.stdin
426 .flush()
427 .await
428 .map_err(|e| Error::Transport(format!("Failed to flush: {}", e)))?;
429 Ok(())
430 }
431
432 async fn read_line(&mut self) -> Result<String> {
433 let mut line = String::new();
434 self.stdout
435 .read_line(&mut line)
436 .await
437 .map_err(|e| Error::Transport(format!("Failed to read: {}", e)))?;
438
439 if line.is_empty() {
440 return Err(Error::Transport("Connection closed".to_string()));
441 }
442
443 Ok(line)
444 }
445}
446
447#[async_trait]
448impl ClientTransport for StdioClientTransport {
449 async fn request(
450 &mut self,
451 method: &str,
452 params: serde_json::Value,
453 ) -> Result<serde_json::Value> {
454 let id = self.request_id.fetch_add(1, Ordering::Relaxed);
455 let request = JsonRpcRequest::new(id, method).with_params(params);
456
457 let request_json = serde_json::to_string(&request)
458 .map_err(|e| Error::Transport(format!("Failed to serialize: {}", e)))?;
459
460 tracing::debug!(method = %method, id = %id, "Sending request");
461 self.send_line(&request_json).await?;
462
463 let response_line = self.read_line().await?;
464 tracing::debug!(response = %response_line.trim(), "Received response");
465
466 let response: JsonRpcResponse = serde_json::from_str(response_line.trim())
467 .map_err(|e| Error::Transport(format!("Failed to parse response: {}", e)))?;
468
469 match response {
470 JsonRpcResponse::Result(r) => Ok(r.result),
471 JsonRpcResponse::Error(e) => Err(Error::JsonRpc(e.error)),
472 }
473 }
474
475 async fn notify(&mut self, method: &str, params: serde_json::Value) -> Result<()> {
476 let notification = serde_json::json!({
477 "jsonrpc": "2.0",
478 "method": method,
479 "params": params
480 });
481
482 let json = serde_json::to_string(¬ification)
483 .map_err(|e| Error::Transport(format!("Failed to serialize: {}", e)))?;
484
485 tracing::debug!(method = %method, "Sending notification");
486 self.send_line(&json).await
487 }
488
489 fn is_connected(&self) -> bool {
490 self.child.is_some()
492 }
493
494 async fn close(mut self: Box<Self>) -> Result<()> {
495 drop(self.stdin);
497
498 if let Some(mut child) = self.child.take() {
499 let result =
501 tokio::time::timeout(std::time::Duration::from_secs(5), child.wait()).await;
502
503 match result {
504 Ok(Ok(status)) => {
505 tracing::info!(status = ?status, "Child process exited");
506 }
507 Ok(Err(e)) => {
508 tracing::error!(error = %e, "Error waiting for child");
509 }
510 Err(_) => {
511 tracing::warn!("Timeout waiting for child, killing");
512 let _ = child.kill().await;
513 }
514 }
515 }
516
517 Ok(())
518 }
519}
520
521#[cfg(test)]
522mod tests {
523 use super::*;
524 use std::collections::VecDeque;
525 use std::sync::{Arc, Mutex};
526
527 struct MockTransport {
529 responses: Arc<Mutex<VecDeque<serde_json::Value>>>,
530 requests: Arc<Mutex<Vec<(String, serde_json::Value)>>>,
531 notifications: Arc<Mutex<Vec<(String, serde_json::Value)>>>,
532 connected: bool,
533 }
534
535 impl MockTransport {
536 fn new() -> Self {
537 Self {
538 responses: Arc::new(Mutex::new(VecDeque::new())),
539 requests: Arc::new(Mutex::new(Vec::new())),
540 notifications: Arc::new(Mutex::new(Vec::new())),
541 connected: true,
542 }
543 }
544
545 fn with_responses(responses: Vec<serde_json::Value>) -> Self {
546 Self {
547 responses: Arc::new(Mutex::new(responses.into())),
548 requests: Arc::new(Mutex::new(Vec::new())),
549 notifications: Arc::new(Mutex::new(Vec::new())),
550 connected: true,
551 }
552 }
553
554 #[allow(dead_code)]
555 fn get_requests(&self) -> Vec<(String, serde_json::Value)> {
556 self.requests.lock().unwrap().clone()
557 }
558
559 #[allow(dead_code)]
560 fn get_notifications(&self) -> Vec<(String, serde_json::Value)> {
561 self.notifications.lock().unwrap().clone()
562 }
563 }
564
565 #[async_trait]
566 impl ClientTransport for MockTransport {
567 async fn request(
568 &mut self,
569 method: &str,
570 params: serde_json::Value,
571 ) -> Result<serde_json::Value> {
572 self.requests
573 .lock()
574 .unwrap()
575 .push((method.to_string(), params));
576 self.responses
577 .lock()
578 .unwrap()
579 .pop_front()
580 .ok_or_else(|| Error::Transport("No more mock responses".to_string()))
581 }
582
583 async fn notify(&mut self, method: &str, params: serde_json::Value) -> Result<()> {
584 self.notifications
585 .lock()
586 .unwrap()
587 .push((method.to_string(), params));
588 Ok(())
589 }
590
591 fn is_connected(&self) -> bool {
592 self.connected
593 }
594
595 async fn close(self: Box<Self>) -> Result<()> {
596 Ok(())
597 }
598 }
599
600 fn mock_initialize_response() -> serde_json::Value {
601 serde_json::json!({
602 "protocolVersion": "2025-03-26",
603 "serverInfo": {
604 "name": "test-server",
605 "version": "1.0.0"
606 },
607 "capabilities": {
608 "tools": {}
609 }
610 })
611 }
612
613 #[tokio::test]
614 async fn test_client_not_initialized() {
615 let mut client = McpClient::new(MockTransport::new());
616
617 let result = client.list_tools().await;
619 assert!(result.is_err());
620 assert!(result.unwrap_err().to_string().contains("not initialized"));
621 }
622
623 #[tokio::test]
624 async fn test_client_initialize() {
625 let transport = MockTransport::with_responses(vec![mock_initialize_response()]);
626 let mut client = McpClient::new(transport);
627
628 assert!(!client.is_initialized());
629
630 let result = client.initialize("test-client", "1.0.0").await;
631 assert!(result.is_ok());
632 assert!(client.is_initialized());
633
634 let server_info = client.server_info().unwrap();
635 assert_eq!(server_info.server_info.name, "test-server");
636 }
637
638 #[tokio::test]
639 async fn test_list_tools() {
640 let transport = MockTransport::with_responses(vec![
641 mock_initialize_response(),
642 serde_json::json!({
643 "tools": [
644 {
645 "name": "test_tool",
646 "description": "A test tool",
647 "inputSchema": {
648 "type": "object",
649 "properties": {}
650 }
651 }
652 ]
653 }),
654 ]);
655 let mut client = McpClient::new(transport);
656
657 client.initialize("test-client", "1.0.0").await.unwrap();
658 let tools = client.list_tools().await.unwrap();
659
660 assert_eq!(tools.tools.len(), 1);
661 assert_eq!(tools.tools[0].name, "test_tool");
662 }
663
664 #[tokio::test]
665 async fn test_call_tool() {
666 let transport = MockTransport::with_responses(vec![
667 mock_initialize_response(),
668 serde_json::json!({
669 "content": [
670 {
671 "type": "text",
672 "text": "Tool result"
673 }
674 ]
675 }),
676 ]);
677 let mut client = McpClient::new(transport);
678
679 client.initialize("test-client", "1.0.0").await.unwrap();
680 let result = client
681 .call_tool("test_tool", serde_json::json!({"arg": "value"}))
682 .await
683 .unwrap();
684
685 assert!(!result.content.is_empty());
686 }
687
688 #[tokio::test]
689 async fn test_list_resources() {
690 let transport = MockTransport::with_responses(vec![
691 mock_initialize_response(),
692 serde_json::json!({
693 "resources": [
694 {
695 "uri": "file://test.txt",
696 "name": "Test File"
697 }
698 ]
699 }),
700 ]);
701 let mut client = McpClient::new(transport);
702
703 client.initialize("test-client", "1.0.0").await.unwrap();
704 let resources = client.list_resources().await.unwrap();
705
706 assert_eq!(resources.resources.len(), 1);
707 assert_eq!(resources.resources[0].uri, "file://test.txt");
708 }
709
710 #[tokio::test]
711 async fn test_read_resource() {
712 let transport = MockTransport::with_responses(vec![
713 mock_initialize_response(),
714 serde_json::json!({
715 "contents": [
716 {
717 "uri": "file://test.txt",
718 "text": "File contents"
719 }
720 ]
721 }),
722 ]);
723 let mut client = McpClient::new(transport);
724
725 client.initialize("test-client", "1.0.0").await.unwrap();
726 let result = client.read_resource("file://test.txt").await.unwrap();
727
728 assert_eq!(result.contents.len(), 1);
729 assert_eq!(result.contents[0].text.as_deref(), Some("File contents"));
730 }
731
732 #[tokio::test]
733 async fn test_list_prompts() {
734 let transport = MockTransport::with_responses(vec![
735 mock_initialize_response(),
736 serde_json::json!({
737 "prompts": [
738 {
739 "name": "test_prompt",
740 "description": "A test prompt"
741 }
742 ]
743 }),
744 ]);
745 let mut client = McpClient::new(transport);
746
747 client.initialize("test-client", "1.0.0").await.unwrap();
748 let prompts = client.list_prompts().await.unwrap();
749
750 assert_eq!(prompts.prompts.len(), 1);
751 assert_eq!(prompts.prompts[0].name, "test_prompt");
752 }
753
754 #[tokio::test]
755 async fn test_get_prompt() {
756 let transport = MockTransport::with_responses(vec![
757 mock_initialize_response(),
758 serde_json::json!({
759 "messages": [
760 {
761 "role": "user",
762 "content": {
763 "type": "text",
764 "text": "Prompt message"
765 }
766 }
767 ]
768 }),
769 ]);
770 let mut client = McpClient::new(transport);
771
772 client.initialize("test-client", "1.0.0").await.unwrap();
773 let result = client.get_prompt("test_prompt", None).await.unwrap();
774
775 assert_eq!(result.messages.len(), 1);
776 }
777
778 #[tokio::test]
779 async fn test_ping() {
780 let transport =
781 MockTransport::with_responses(vec![mock_initialize_response(), serde_json::json!({})]);
782 let mut client = McpClient::new(transport);
783
784 client.initialize("test-client", "1.0.0").await.unwrap();
785 let result = client.ping().await;
786
787 assert!(result.is_ok());
788 }
789
790 #[tokio::test]
791 async fn test_roots_management() {
792 let transport = MockTransport::with_responses(vec![mock_initialize_response()]);
793 let notifications = transport.notifications.clone();
794 let mut client = McpClient::new(transport);
795
796 assert!(client.roots().is_empty());
798
799 client.add_root(Root::new("file:///project")).await.unwrap();
801 assert_eq!(client.roots().len(), 1);
802 assert!(notifications.lock().unwrap().is_empty());
803
804 client.initialize("test-client", "1.0.0").await.unwrap();
806
807 client.add_root(Root::new("file:///other")).await.unwrap();
809 assert_eq!(client.roots().len(), 2);
810 assert_eq!(notifications.lock().unwrap().len(), 2); let removed = client.remove_root("file:///project").await.unwrap();
814 assert!(removed);
815 assert_eq!(client.roots().len(), 1);
816
817 let not_removed = client.remove_root("file:///nonexistent").await.unwrap();
819 assert!(!not_removed);
820 }
821
822 #[tokio::test]
823 async fn test_with_roots() {
824 let roots = vec![Root::new("file:///test")];
825 let transport = MockTransport::with_responses(vec![mock_initialize_response()]);
826 let client = McpClient::with_roots(transport, roots);
827
828 assert_eq!(client.roots().len(), 1);
829 assert!(client.capabilities.roots.is_some());
830 }
831
832 #[tokio::test]
833 async fn test_with_capabilities() {
834 let capabilities = ClientCapabilities {
835 sampling: Some(Default::default()),
836 ..Default::default()
837 };
838
839 let transport = MockTransport::with_responses(vec![mock_initialize_response()]);
840 let client = McpClient::with_capabilities(transport, capabilities);
841
842 assert!(client.capabilities.sampling.is_some());
843 }
844
845 #[tokio::test]
846 async fn test_list_roots() {
847 let roots = vec![
848 Root::new("file:///project1"),
849 Root::with_name("file:///project2", "Project 2"),
850 ];
851 let transport = MockTransport::new();
852 let client = McpClient::with_roots(transport, roots);
853
854 let result = client.list_roots();
855 assert_eq!(result.roots.len(), 2);
856 assert_eq!(result.roots[1].name, Some("Project 2".to_string()));
857 }
858}