Skip to main content

tower_mcp/
client.rs

1//! MCP Client implementation
2//!
3//! Provides client functionality for connecting to MCP servers.
4//!
5//! # Example
6//!
7//! ```rust,no_run
8//! use tower_mcp::BoxError;
9//! use tower_mcp::client::{McpClient, StdioClientTransport};
10//!
11//! #[tokio::main]
12//! async fn main() -> Result<(), BoxError> {
13//!     // Connect to an MCP server via stdio
14//!     let transport = StdioClientTransport::spawn("my-mcp-server", &["--flag"]).await?;
15//!     let mut client = McpClient::new(transport);
16//!
17//!     // Initialize the connection
18//!     let server_info = client.initialize("my-client", "1.0.0").await?;
19//!     println!("Connected to: {}", server_info.server_info.name);
20//!
21//!     // List available tools
22//!     let tools = client.list_tools().await?;
23//!     for tool in &tools.tools {
24//!         println!("Tool: {}", tool.name);
25//!     }
26//!
27//!     // Call a tool
28//!     let result = client.call_tool("my-tool", serde_json::json!({"arg": "value"})).await?;
29//!     println!("Result: {:?}", result);
30//!
31//!     Ok(())
32//! }
33//! ```
34
35use std::process::Stdio;
36use std::sync::atomic::{AtomicI64, Ordering};
37
38use async_trait::async_trait;
39use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
40use tokio::process::{Child, Command};
41
42use crate::error::{Error, Result};
43use crate::protocol::{
44    CallToolParams, CallToolResult, ClientCapabilities, CompleteParams, CompleteResult,
45    CompletionArgument, CompletionReference, GetPromptParams, GetPromptResult, Implementation,
46    InitializeParams, InitializeResult, JsonRpcRequest, JsonRpcResponse, ListPromptsParams,
47    ListPromptsResult, ListResourcesParams, ListResourcesResult, ListRootsResult, ListToolsParams,
48    ListToolsResult, ReadResourceParams, ReadResourceResult, Root, RootsCapability, notifications,
49};
50
51/// Trait for MCP client transports
52#[async_trait]
53pub trait ClientTransport: Send {
54    /// Send a request and receive a response
55    async fn request(
56        &mut self,
57        method: &str,
58        params: serde_json::Value,
59    ) -> Result<serde_json::Value>;
60
61    /// Send a notification (no response expected)
62    async fn notify(&mut self, method: &str, params: serde_json::Value) -> Result<()>;
63
64    /// Check if the transport is still connected
65    fn is_connected(&self) -> bool;
66
67    /// Close the transport
68    async fn close(self: Box<Self>) -> Result<()>;
69}
70
71/// MCP Client for connecting to MCP servers
72pub struct McpClient<T: ClientTransport> {
73    transport: T,
74    initialized: bool,
75    server_info: Option<InitializeResult>,
76    /// Client capabilities to declare during initialization
77    capabilities: ClientCapabilities,
78    /// Roots available to the server
79    roots: Vec<Root>,
80}
81
82impl<T: ClientTransport> McpClient<T> {
83    /// Create a new MCP client with the given transport
84    pub fn new(transport: T) -> Self {
85        Self {
86            transport,
87            initialized: false,
88            server_info: None,
89            capabilities: ClientCapabilities::default(),
90            roots: Vec::new(),
91        }
92    }
93
94    /// Create a new MCP client with roots capability
95    ///
96    /// The client will declare roots support during initialization.
97    pub fn with_roots(transport: T, roots: Vec<Root>) -> Self {
98        Self {
99            transport,
100            initialized: false,
101            server_info: None,
102            capabilities: ClientCapabilities {
103                roots: Some(RootsCapability { list_changed: true }),
104                ..Default::default()
105            },
106            roots,
107        }
108    }
109
110    /// Create a new MCP client with custom capabilities
111    pub fn with_capabilities(transport: T, capabilities: ClientCapabilities) -> Self {
112        Self {
113            transport,
114            initialized: false,
115            server_info: None,
116            capabilities,
117            roots: Vec::new(),
118        }
119    }
120
121    /// Get the server info (available after initialization)
122    pub fn server_info(&self) -> Option<&InitializeResult> {
123        self.server_info.as_ref()
124    }
125
126    /// Check if the client is initialized
127    pub fn is_initialized(&self) -> bool {
128        self.initialized
129    }
130
131    /// Get the current roots
132    pub fn roots(&self) -> &[Root] {
133        &self.roots
134    }
135
136    /// Set roots and notify the server if initialized
137    ///
138    /// If the client is already initialized, sends a roots list changed notification.
139    pub async fn set_roots(&mut self, roots: Vec<Root>) -> Result<()> {
140        self.roots = roots;
141        if self.initialized {
142            self.notify_roots_changed().await?;
143        }
144        Ok(())
145    }
146
147    /// Add a root and notify the server if initialized
148    pub async fn add_root(&mut self, root: Root) -> Result<()> {
149        self.roots.push(root);
150        if self.initialized {
151            self.notify_roots_changed().await?;
152        }
153        Ok(())
154    }
155
156    /// Remove a root by URI and notify the server if initialized
157    pub async fn remove_root(&mut self, uri: &str) -> Result<bool> {
158        let initial_len = self.roots.len();
159        self.roots.retain(|r| r.uri != uri);
160        let removed = self.roots.len() < initial_len;
161        if removed && self.initialized {
162            self.notify_roots_changed().await?;
163        }
164        Ok(removed)
165    }
166
167    /// Send roots list changed notification to the server
168    async fn notify_roots_changed(&mut self) -> Result<()> {
169        self.transport
170            .notify(notifications::ROOTS_LIST_CHANGED, serde_json::json!({}))
171            .await
172    }
173
174    /// Get the roots list result (for responding to server's roots/list request)
175    ///
176    /// Returns a result suitable for responding to a roots/list request from the server.
177    pub fn list_roots(&self) -> ListRootsResult {
178        ListRootsResult {
179            roots: self.roots.clone(),
180        }
181    }
182
183    /// Initialize the MCP connection
184    pub async fn initialize(
185        &mut self,
186        client_name: &str,
187        client_version: &str,
188    ) -> Result<&InitializeResult> {
189        let params = InitializeParams {
190            protocol_version: crate::protocol::LATEST_PROTOCOL_VERSION.to_string(),
191            capabilities: self.capabilities.clone(),
192            client_info: Implementation {
193                name: client_name.to_string(),
194                version: client_version.to_string(),
195                ..Default::default()
196            },
197        };
198
199        let result: InitializeResult = self.request("initialize", &params).await?;
200        self.server_info = Some(result);
201
202        // Send initialized notification
203        self.transport
204            .notify("notifications/initialized", serde_json::json!({}))
205            .await?;
206
207        self.initialized = true;
208
209        Ok(self.server_info.as_ref().unwrap())
210    }
211
212    /// List available tools
213    pub async fn list_tools(&mut self) -> Result<ListToolsResult> {
214        self.ensure_initialized()?;
215        self.request("tools/list", &ListToolsParams { cursor: None })
216            .await
217    }
218
219    /// Call a tool
220    pub async fn call_tool(
221        &mut self,
222        name: &str,
223        arguments: serde_json::Value,
224    ) -> Result<CallToolResult> {
225        self.ensure_initialized()?;
226        let params = CallToolParams {
227            name: name.to_string(),
228            arguments,
229            meta: None,
230        };
231        self.request("tools/call", &params).await
232    }
233
234    /// List available resources
235    pub async fn list_resources(&mut self) -> Result<ListResourcesResult> {
236        self.ensure_initialized()?;
237        self.request("resources/list", &ListResourcesParams { cursor: None })
238            .await
239    }
240
241    /// Read a resource
242    pub async fn read_resource(&mut self, uri: &str) -> Result<ReadResourceResult> {
243        self.ensure_initialized()?;
244        let params = ReadResourceParams {
245            uri: uri.to_string(),
246        };
247        self.request("resources/read", &params).await
248    }
249
250    /// List available prompts
251    pub async fn list_prompts(&mut self) -> Result<ListPromptsResult> {
252        self.ensure_initialized()?;
253        self.request("prompts/list", &ListPromptsParams { cursor: None })
254            .await
255    }
256
257    /// Get a prompt
258    pub async fn get_prompt(
259        &mut self,
260        name: &str,
261        arguments: Option<std::collections::HashMap<String, String>>,
262    ) -> Result<GetPromptResult> {
263        self.ensure_initialized()?;
264        let params = GetPromptParams {
265            name: name.to_string(),
266            arguments: arguments.unwrap_or_default(),
267        };
268        self.request("prompts/get", &params).await
269    }
270
271    /// Ping the server
272    pub async fn ping(&mut self) -> Result<()> {
273        let _: serde_json::Value = self.request("ping", &serde_json::json!({})).await?;
274        Ok(())
275    }
276
277    /// Request completion suggestions from the server
278    ///
279    /// This is used to get autocomplete suggestions for prompt arguments or resource URIs.
280    pub async fn complete(
281        &mut self,
282        reference: CompletionReference,
283        argument_name: &str,
284        argument_value: &str,
285    ) -> Result<CompleteResult> {
286        self.ensure_initialized()?;
287        let params = CompleteParams {
288            reference,
289            argument: CompletionArgument::new(argument_name, argument_value),
290        };
291        self.request("completion/complete", &params).await
292    }
293
294    /// Request completion for a prompt argument
295    pub async fn complete_prompt_arg(
296        &mut self,
297        prompt_name: &str,
298        argument_name: &str,
299        argument_value: &str,
300    ) -> Result<CompleteResult> {
301        self.complete(
302            CompletionReference::prompt(prompt_name),
303            argument_name,
304            argument_value,
305        )
306        .await
307    }
308
309    /// Request completion for a resource URI
310    pub async fn complete_resource_uri(
311        &mut self,
312        resource_uri: &str,
313        argument_name: &str,
314        argument_value: &str,
315    ) -> Result<CompleteResult> {
316        self.complete(
317            CompletionReference::resource(resource_uri),
318            argument_name,
319            argument_value,
320        )
321        .await
322    }
323
324    /// Send a raw request
325    pub async fn request<P: serde::Serialize, R: serde::de::DeserializeOwned>(
326        &mut self,
327        method: &str,
328        params: &P,
329    ) -> Result<R> {
330        let params_value = serde_json::to_value(params)
331            .map_err(|e| Error::Transport(format!("Failed to serialize params: {}", e)))?;
332
333        let result = self.transport.request(method, params_value).await?;
334
335        serde_json::from_value(result)
336            .map_err(|e| Error::Transport(format!("Failed to deserialize response: {}", e)))
337    }
338
339    /// Send a notification
340    pub async fn notify<P: serde::Serialize>(&mut self, method: &str, params: &P) -> Result<()> {
341        let params_value = serde_json::to_value(params)
342            .map_err(|e| Error::Transport(format!("Failed to serialize params: {}", e)))?;
343
344        self.transport.notify(method, params_value).await
345    }
346
347    fn ensure_initialized(&self) -> Result<()> {
348        if !self.initialized {
349            return Err(Error::Transport("Client not initialized".to_string()));
350        }
351        Ok(())
352    }
353}
354
355// ============================================================================
356// Stdio Client Transport
357// ============================================================================
358
359/// Client transport that communicates with a subprocess via stdio
360pub struct StdioClientTransport {
361    child: Option<Child>,
362    stdin: tokio::process::ChildStdin,
363    stdout: BufReader<tokio::process::ChildStdout>,
364    request_id: AtomicI64,
365}
366
367impl StdioClientTransport {
368    /// Spawn a new subprocess and connect to it
369    pub async fn spawn(program: &str, args: &[&str]) -> Result<Self> {
370        let mut cmd = Command::new(program);
371        cmd.args(args)
372            .stdin(Stdio::piped())
373            .stdout(Stdio::piped())
374            .stderr(Stdio::inherit());
375
376        let mut child = cmd
377            .spawn()
378            .map_err(|e| Error::Transport(format!("Failed to spawn {}: {}", program, e)))?;
379
380        let stdin = child
381            .stdin
382            .take()
383            .ok_or_else(|| Error::Transport("Failed to get child stdin".to_string()))?;
384        let stdout = child
385            .stdout
386            .take()
387            .ok_or_else(|| Error::Transport("Failed to get child stdout".to_string()))?;
388
389        tracing::info!(program = %program, "Spawned MCP server process");
390
391        Ok(Self {
392            child: Some(child),
393            stdin,
394            stdout: BufReader::new(stdout),
395            request_id: AtomicI64::new(1),
396        })
397    }
398
399    /// Create from an existing child process
400    pub fn from_child(mut child: Child) -> Result<Self> {
401        let stdin = child
402            .stdin
403            .take()
404            .ok_or_else(|| Error::Transport("Failed to get child stdin".to_string()))?;
405        let stdout = child
406            .stdout
407            .take()
408            .ok_or_else(|| Error::Transport("Failed to get child stdout".to_string()))?;
409
410        Ok(Self {
411            child: Some(child),
412            stdin,
413            stdout: BufReader::new(stdout),
414            request_id: AtomicI64::new(1),
415        })
416    }
417
418    async fn send_line(&mut self, line: &str) -> Result<()> {
419        self.stdin
420            .write_all(line.as_bytes())
421            .await
422            .map_err(|e| Error::Transport(format!("Failed to write: {}", e)))?;
423        self.stdin
424            .write_all(b"\n")
425            .await
426            .map_err(|e| Error::Transport(format!("Failed to write newline: {}", e)))?;
427        self.stdin
428            .flush()
429            .await
430            .map_err(|e| Error::Transport(format!("Failed to flush: {}", e)))?;
431        Ok(())
432    }
433
434    async fn read_line(&mut self) -> Result<String> {
435        let mut line = String::new();
436        self.stdout
437            .read_line(&mut line)
438            .await
439            .map_err(|e| Error::Transport(format!("Failed to read: {}", e)))?;
440
441        if line.is_empty() {
442            return Err(Error::Transport("Connection closed".to_string()));
443        }
444
445        Ok(line)
446    }
447}
448
449#[async_trait]
450impl ClientTransport for StdioClientTransport {
451    async fn request(
452        &mut self,
453        method: &str,
454        params: serde_json::Value,
455    ) -> Result<serde_json::Value> {
456        let id = self.request_id.fetch_add(1, Ordering::Relaxed);
457        let request = JsonRpcRequest::new(id, method).with_params(params);
458
459        let request_json = serde_json::to_string(&request)
460            .map_err(|e| Error::Transport(format!("Failed to serialize: {}", e)))?;
461
462        tracing::debug!(method = %method, id = %id, "Sending request");
463        self.send_line(&request_json).await?;
464
465        let response_line = self.read_line().await?;
466        tracing::debug!(response = %response_line.trim(), "Received response");
467
468        let response: JsonRpcResponse = serde_json::from_str(response_line.trim())
469            .map_err(|e| Error::Transport(format!("Failed to parse response: {}", e)))?;
470
471        match response {
472            JsonRpcResponse::Result(r) => Ok(r.result),
473            JsonRpcResponse::Error(e) => Err(Error::JsonRpc(e.error)),
474        }
475    }
476
477    async fn notify(&mut self, method: &str, params: serde_json::Value) -> Result<()> {
478        let notification = serde_json::json!({
479            "jsonrpc": "2.0",
480            "method": method,
481            "params": params
482        });
483
484        let json = serde_json::to_string(&notification)
485            .map_err(|e| Error::Transport(format!("Failed to serialize: {}", e)))?;
486
487        tracing::debug!(method = %method, "Sending notification");
488        self.send_line(&json).await
489    }
490
491    fn is_connected(&self) -> bool {
492        // Assume connected if we have a child process handle
493        self.child.is_some()
494    }
495
496    async fn close(mut self: Box<Self>) -> Result<()> {
497        // Close stdin to signal EOF
498        drop(self.stdin);
499
500        if let Some(mut child) = self.child.take() {
501            // Wait for process with timeout
502            let result =
503                tokio::time::timeout(std::time::Duration::from_secs(5), child.wait()).await;
504
505            match result {
506                Ok(Ok(status)) => {
507                    tracing::info!(status = ?status, "Child process exited");
508                }
509                Ok(Err(e)) => {
510                    tracing::error!(error = %e, "Error waiting for child");
511                }
512                Err(_) => {
513                    tracing::warn!("Timeout waiting for child, killing");
514                    let _ = child.kill().await;
515                }
516            }
517        }
518
519        Ok(())
520    }
521}
522
523#[cfg(test)]
524mod tests {
525    use super::*;
526    use std::collections::VecDeque;
527    use std::sync::{Arc, Mutex};
528
529    /// Mock transport that returns preconfigured responses
530    struct MockTransport {
531        responses: Arc<Mutex<VecDeque<serde_json::Value>>>,
532        requests: Arc<Mutex<Vec<(String, serde_json::Value)>>>,
533        notifications: Arc<Mutex<Vec<(String, serde_json::Value)>>>,
534        connected: bool,
535    }
536
537    impl MockTransport {
538        fn new() -> Self {
539            Self {
540                responses: Arc::new(Mutex::new(VecDeque::new())),
541                requests: Arc::new(Mutex::new(Vec::new())),
542                notifications: Arc::new(Mutex::new(Vec::new())),
543                connected: true,
544            }
545        }
546
547        fn with_responses(responses: Vec<serde_json::Value>) -> Self {
548            Self {
549                responses: Arc::new(Mutex::new(responses.into())),
550                requests: Arc::new(Mutex::new(Vec::new())),
551                notifications: Arc::new(Mutex::new(Vec::new())),
552                connected: true,
553            }
554        }
555
556        #[allow(dead_code)]
557        fn get_requests(&self) -> Vec<(String, serde_json::Value)> {
558            self.requests.lock().unwrap().clone()
559        }
560
561        #[allow(dead_code)]
562        fn get_notifications(&self) -> Vec<(String, serde_json::Value)> {
563            self.notifications.lock().unwrap().clone()
564        }
565    }
566
567    #[async_trait]
568    impl ClientTransport for MockTransport {
569        async fn request(
570            &mut self,
571            method: &str,
572            params: serde_json::Value,
573        ) -> Result<serde_json::Value> {
574            self.requests
575                .lock()
576                .unwrap()
577                .push((method.to_string(), params));
578            self.responses
579                .lock()
580                .unwrap()
581                .pop_front()
582                .ok_or_else(|| Error::Transport("No more mock responses".to_string()))
583        }
584
585        async fn notify(&mut self, method: &str, params: serde_json::Value) -> Result<()> {
586            self.notifications
587                .lock()
588                .unwrap()
589                .push((method.to_string(), params));
590            Ok(())
591        }
592
593        fn is_connected(&self) -> bool {
594            self.connected
595        }
596
597        async fn close(self: Box<Self>) -> Result<()> {
598            Ok(())
599        }
600    }
601
602    fn mock_initialize_response() -> serde_json::Value {
603        serde_json::json!({
604            "protocolVersion": "2025-11-25",
605            "serverInfo": {
606                "name": "test-server",
607                "version": "1.0.0"
608            },
609            "capabilities": {
610                "tools": {}
611            }
612        })
613    }
614
615    #[tokio::test]
616    async fn test_client_not_initialized() {
617        let mut client = McpClient::new(MockTransport::new());
618
619        // Should fail because not initialized
620        let result = client.list_tools().await;
621        assert!(result.is_err());
622        assert!(result.unwrap_err().to_string().contains("not initialized"));
623    }
624
625    #[tokio::test]
626    async fn test_client_initialize() {
627        let transport = MockTransport::with_responses(vec![mock_initialize_response()]);
628        let mut client = McpClient::new(transport);
629
630        assert!(!client.is_initialized());
631
632        let result = client.initialize("test-client", "1.0.0").await;
633        assert!(result.is_ok());
634        assert!(client.is_initialized());
635
636        let server_info = client.server_info().unwrap();
637        assert_eq!(server_info.server_info.name, "test-server");
638    }
639
640    #[tokio::test]
641    async fn test_list_tools() {
642        let transport = MockTransport::with_responses(vec![
643            mock_initialize_response(),
644            serde_json::json!({
645                "tools": [
646                    {
647                        "name": "test_tool",
648                        "description": "A test tool",
649                        "inputSchema": {
650                            "type": "object",
651                            "properties": {}
652                        }
653                    }
654                ]
655            }),
656        ]);
657        let mut client = McpClient::new(transport);
658
659        client.initialize("test-client", "1.0.0").await.unwrap();
660        let tools = client.list_tools().await.unwrap();
661
662        assert_eq!(tools.tools.len(), 1);
663        assert_eq!(tools.tools[0].name, "test_tool");
664    }
665
666    #[tokio::test]
667    async fn test_call_tool() {
668        let transport = MockTransport::with_responses(vec![
669            mock_initialize_response(),
670            serde_json::json!({
671                "content": [
672                    {
673                        "type": "text",
674                        "text": "Tool result"
675                    }
676                ]
677            }),
678        ]);
679        let mut client = McpClient::new(transport);
680
681        client.initialize("test-client", "1.0.0").await.unwrap();
682        let result = client
683            .call_tool("test_tool", serde_json::json!({"arg": "value"}))
684            .await
685            .unwrap();
686
687        assert!(!result.content.is_empty());
688    }
689
690    #[tokio::test]
691    async fn test_list_resources() {
692        let transport = MockTransport::with_responses(vec![
693            mock_initialize_response(),
694            serde_json::json!({
695                "resources": [
696                    {
697                        "uri": "file://test.txt",
698                        "name": "Test File"
699                    }
700                ]
701            }),
702        ]);
703        let mut client = McpClient::new(transport);
704
705        client.initialize("test-client", "1.0.0").await.unwrap();
706        let resources = client.list_resources().await.unwrap();
707
708        assert_eq!(resources.resources.len(), 1);
709        assert_eq!(resources.resources[0].uri, "file://test.txt");
710    }
711
712    #[tokio::test]
713    async fn test_read_resource() {
714        let transport = MockTransport::with_responses(vec![
715            mock_initialize_response(),
716            serde_json::json!({
717                "contents": [
718                    {
719                        "uri": "file://test.txt",
720                        "text": "File contents"
721                    }
722                ]
723            }),
724        ]);
725        let mut client = McpClient::new(transport);
726
727        client.initialize("test-client", "1.0.0").await.unwrap();
728        let result = client.read_resource("file://test.txt").await.unwrap();
729
730        assert_eq!(result.contents.len(), 1);
731        assert_eq!(result.contents[0].text.as_deref(), Some("File contents"));
732    }
733
734    #[tokio::test]
735    async fn test_list_prompts() {
736        let transport = MockTransport::with_responses(vec![
737            mock_initialize_response(),
738            serde_json::json!({
739                "prompts": [
740                    {
741                        "name": "test_prompt",
742                        "description": "A test prompt"
743                    }
744                ]
745            }),
746        ]);
747        let mut client = McpClient::new(transport);
748
749        client.initialize("test-client", "1.0.0").await.unwrap();
750        let prompts = client.list_prompts().await.unwrap();
751
752        assert_eq!(prompts.prompts.len(), 1);
753        assert_eq!(prompts.prompts[0].name, "test_prompt");
754    }
755
756    #[tokio::test]
757    async fn test_get_prompt() {
758        let transport = MockTransport::with_responses(vec![
759            mock_initialize_response(),
760            serde_json::json!({
761                "messages": [
762                    {
763                        "role": "user",
764                        "content": {
765                            "type": "text",
766                            "text": "Prompt message"
767                        }
768                    }
769                ]
770            }),
771        ]);
772        let mut client = McpClient::new(transport);
773
774        client.initialize("test-client", "1.0.0").await.unwrap();
775        let result = client.get_prompt("test_prompt", None).await.unwrap();
776
777        assert_eq!(result.messages.len(), 1);
778    }
779
780    #[tokio::test]
781    async fn test_ping() {
782        let transport =
783            MockTransport::with_responses(vec![mock_initialize_response(), serde_json::json!({})]);
784        let mut client = McpClient::new(transport);
785
786        client.initialize("test-client", "1.0.0").await.unwrap();
787        let result = client.ping().await;
788
789        assert!(result.is_ok());
790    }
791
792    #[tokio::test]
793    async fn test_roots_management() {
794        let transport = MockTransport::with_responses(vec![mock_initialize_response()]);
795        let notifications = transport.notifications.clone();
796        let mut client = McpClient::new(transport);
797
798        // Initially no roots
799        assert!(client.roots().is_empty());
800
801        // Add a root before initialization (no notification)
802        client.add_root(Root::new("file:///project")).await.unwrap();
803        assert_eq!(client.roots().len(), 1);
804        assert!(notifications.lock().unwrap().is_empty());
805
806        // Initialize
807        client.initialize("test-client", "1.0.0").await.unwrap();
808
809        // Add another root after initialization (should notify)
810        client.add_root(Root::new("file:///other")).await.unwrap();
811        assert_eq!(client.roots().len(), 2);
812        assert_eq!(notifications.lock().unwrap().len(), 2); // initialized + roots changed
813
814        // Remove a root
815        let removed = client.remove_root("file:///project").await.unwrap();
816        assert!(removed);
817        assert_eq!(client.roots().len(), 1);
818
819        // Try to remove non-existent root
820        let not_removed = client.remove_root("file:///nonexistent").await.unwrap();
821        assert!(!not_removed);
822    }
823
824    #[tokio::test]
825    async fn test_with_roots() {
826        let roots = vec![Root::new("file:///test")];
827        let transport = MockTransport::with_responses(vec![mock_initialize_response()]);
828        let client = McpClient::with_roots(transport, roots);
829
830        assert_eq!(client.roots().len(), 1);
831        assert!(client.capabilities.roots.is_some());
832    }
833
834    #[tokio::test]
835    async fn test_with_capabilities() {
836        let capabilities = ClientCapabilities {
837            sampling: Some(Default::default()),
838            ..Default::default()
839        };
840
841        let transport = MockTransport::with_responses(vec![mock_initialize_response()]);
842        let client = McpClient::with_capabilities(transport, capabilities);
843
844        assert!(client.capabilities.sampling.is_some());
845    }
846
847    #[tokio::test]
848    async fn test_list_roots() {
849        let roots = vec![
850            Root::new("file:///project1"),
851            Root::with_name("file:///project2", "Project 2"),
852        ];
853        let transport = MockTransport::new();
854        let client = McpClient::with_roots(transport, roots);
855
856        let result = client.list_roots();
857        assert_eq!(result.roots.len(), 2);
858        assert_eq!(result.roots[1].name, Some("Project 2".to_string()));
859    }
860}