Skip to main content

tower_mcp/
client.rs

1//! MCP Client implementation
2//!
3//! Provides client functionality for connecting to MCP servers.
4//!
5//! # Example
6//!
7//! ```rust,no_run
8//! use tower_mcp::BoxError;
9//! use tower_mcp::client::{McpClient, StdioClientTransport};
10//!
11//! #[tokio::main]
12//! async fn main() -> Result<(), BoxError> {
13//!     // Connect to an MCP server via stdio
14//!     let transport = StdioClientTransport::spawn("my-mcp-server", &["--flag"]).await?;
15//!     let mut client = McpClient::new(transport);
16//!
17//!     // Initialize the connection
18//!     let server_info = client.initialize("my-client", "1.0.0").await?;
19//!     println!("Connected to: {}", server_info.server_info.name);
20//!
21//!     // List available tools
22//!     let tools = client.list_tools().await?;
23//!     for tool in &tools.tools {
24//!         println!("Tool: {}", tool.name);
25//!     }
26//!
27//!     // Call a tool
28//!     let result = client.call_tool("my-tool", serde_json::json!({"arg": "value"})).await?;
29//!     println!("Result: {:?}", result);
30//!
31//!     Ok(())
32//! }
33//! ```
34
35use std::process::Stdio;
36use std::sync::atomic::{AtomicI64, Ordering};
37
38use async_trait::async_trait;
39use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
40use tokio::process::{Child, Command};
41
42use crate::error::{Error, Result};
43use crate::protocol::{
44    CallToolParams, CallToolResult, ClientCapabilities, CompleteParams, CompleteResult,
45    CompletionArgument, CompletionReference, GetPromptParams, GetPromptResult, Implementation,
46    InitializeParams, InitializeResult, JsonRpcRequest, JsonRpcResponse, ListPromptsParams,
47    ListPromptsResult, ListResourcesParams, ListResourcesResult, ListRootsResult, ListToolsParams,
48    ListToolsResult, ReadResourceParams, ReadResourceResult, Root, RootsCapability, notifications,
49};
50
51/// Trait for MCP client transports
52#[async_trait]
53pub trait ClientTransport: Send {
54    /// Send a request and receive a response
55    async fn request(
56        &mut self,
57        method: &str,
58        params: serde_json::Value,
59    ) -> Result<serde_json::Value>;
60
61    /// Send a notification (no response expected)
62    async fn notify(&mut self, method: &str, params: serde_json::Value) -> Result<()>;
63
64    /// Check if the transport is still connected
65    fn is_connected(&self) -> bool;
66
67    /// Close the transport
68    async fn close(self: Box<Self>) -> Result<()>;
69}
70
71/// MCP Client for connecting to MCP servers
72pub struct McpClient<T: ClientTransport> {
73    transport: T,
74    initialized: bool,
75    server_info: Option<InitializeResult>,
76    /// Client capabilities to declare during initialization
77    capabilities: ClientCapabilities,
78    /// Roots available to the server
79    roots: Vec<Root>,
80}
81
82impl<T: ClientTransport> McpClient<T> {
83    /// Create a new MCP client with the given transport
84    pub fn new(transport: T) -> Self {
85        Self {
86            transport,
87            initialized: false,
88            server_info: None,
89            capabilities: ClientCapabilities::default(),
90            roots: Vec::new(),
91        }
92    }
93
94    /// Configure roots for this client.
95    ///
96    /// The client will declare roots support during initialization and
97    /// provide these roots when requested by the server.
98    ///
99    /// # Example
100    ///
101    /// ```rust,no_run
102    /// use tower_mcp::client::{McpClient, StdioClientTransport};
103    /// use tower_mcp::protocol::Root;
104    ///
105    /// # async fn example() -> Result<(), tower_mcp::BoxError> {
106    /// let transport = StdioClientTransport::spawn("server", &[]).await?;
107    /// let client = McpClient::new(transport)
108    ///     .with_roots(vec![Root { uri: "file:///project".into(), name: Some("Project".into()) }]);
109    /// # Ok(())
110    /// # }
111    /// ```
112    pub fn with_roots(mut self, roots: Vec<Root>) -> Self {
113        self.roots = roots;
114        self.capabilities.roots = Some(RootsCapability { list_changed: true });
115        self
116    }
117
118    /// Configure custom capabilities for this client.
119    ///
120    /// # Example
121    ///
122    /// ```rust,no_run
123    /// use tower_mcp::client::{McpClient, StdioClientTransport};
124    /// use tower_mcp::protocol::ClientCapabilities;
125    ///
126    /// # async fn example() -> Result<(), tower_mcp::BoxError> {
127    /// let transport = StdioClientTransport::spawn("server", &[]).await?;
128    /// let client = McpClient::new(transport)
129    ///     .with_capabilities(ClientCapabilities {
130    ///         sampling: Some(Default::default()),
131    ///         ..Default::default()
132    ///     });
133    /// # Ok(())
134    /// # }
135    /// ```
136    pub fn with_capabilities(mut self, capabilities: ClientCapabilities) -> Self {
137        self.capabilities = capabilities;
138        self
139    }
140
141    /// Get the server info (available after initialization)
142    pub fn server_info(&self) -> Option<&InitializeResult> {
143        self.server_info.as_ref()
144    }
145
146    /// Check if the client is initialized
147    pub fn is_initialized(&self) -> bool {
148        self.initialized
149    }
150
151    /// Get the current roots
152    pub fn roots(&self) -> &[Root] {
153        &self.roots
154    }
155
156    /// Set roots and notify the server if initialized
157    ///
158    /// If the client is already initialized, sends a roots list changed notification.
159    pub async fn set_roots(&mut self, roots: Vec<Root>) -> Result<()> {
160        self.roots = roots;
161        if self.initialized {
162            self.notify_roots_changed().await?;
163        }
164        Ok(())
165    }
166
167    /// Add a root and notify the server if initialized
168    pub async fn add_root(&mut self, root: Root) -> Result<()> {
169        self.roots.push(root);
170        if self.initialized {
171            self.notify_roots_changed().await?;
172        }
173        Ok(())
174    }
175
176    /// Remove a root by URI and notify the server if initialized
177    pub async fn remove_root(&mut self, uri: &str) -> Result<bool> {
178        let initial_len = self.roots.len();
179        self.roots.retain(|r| r.uri != uri);
180        let removed = self.roots.len() < initial_len;
181        if removed && self.initialized {
182            self.notify_roots_changed().await?;
183        }
184        Ok(removed)
185    }
186
187    /// Send roots list changed notification to the server
188    async fn notify_roots_changed(&mut self) -> Result<()> {
189        self.transport
190            .notify(notifications::ROOTS_LIST_CHANGED, serde_json::json!({}))
191            .await
192    }
193
194    /// Get the roots list result (for responding to server's roots/list request)
195    ///
196    /// Returns a result suitable for responding to a roots/list request from the server.
197    pub fn list_roots(&self) -> ListRootsResult {
198        ListRootsResult {
199            roots: self.roots.clone(),
200        }
201    }
202
203    /// Initialize the MCP connection
204    pub async fn initialize(
205        &mut self,
206        client_name: &str,
207        client_version: &str,
208    ) -> Result<&InitializeResult> {
209        let params = InitializeParams {
210            protocol_version: crate::protocol::LATEST_PROTOCOL_VERSION.to_string(),
211            capabilities: self.capabilities.clone(),
212            client_info: Implementation {
213                name: client_name.to_string(),
214                version: client_version.to_string(),
215                ..Default::default()
216            },
217        };
218
219        let result: InitializeResult = self.request("initialize", &params).await?;
220        self.server_info = Some(result);
221
222        // Send initialized notification
223        self.transport
224            .notify("notifications/initialized", serde_json::json!({}))
225            .await?;
226
227        self.initialized = true;
228
229        Ok(self.server_info.as_ref().unwrap())
230    }
231
232    /// List available tools
233    pub async fn list_tools(&mut self) -> Result<ListToolsResult> {
234        self.ensure_initialized()?;
235        self.request("tools/list", &ListToolsParams { cursor: None })
236            .await
237    }
238
239    /// Call a tool
240    pub async fn call_tool(
241        &mut self,
242        name: &str,
243        arguments: serde_json::Value,
244    ) -> Result<CallToolResult> {
245        self.ensure_initialized()?;
246        let params = CallToolParams {
247            name: name.to_string(),
248            arguments,
249            meta: None,
250        };
251        self.request("tools/call", &params).await
252    }
253
254    /// List available resources
255    pub async fn list_resources(&mut self) -> Result<ListResourcesResult> {
256        self.ensure_initialized()?;
257        self.request("resources/list", &ListResourcesParams { cursor: None })
258            .await
259    }
260
261    /// Read a resource
262    pub async fn read_resource(&mut self, uri: &str) -> Result<ReadResourceResult> {
263        self.ensure_initialized()?;
264        let params = ReadResourceParams {
265            uri: uri.to_string(),
266        };
267        self.request("resources/read", &params).await
268    }
269
270    /// List available prompts
271    pub async fn list_prompts(&mut self) -> Result<ListPromptsResult> {
272        self.ensure_initialized()?;
273        self.request("prompts/list", &ListPromptsParams { cursor: None })
274            .await
275    }
276
277    /// Get a prompt
278    pub async fn get_prompt(
279        &mut self,
280        name: &str,
281        arguments: Option<std::collections::HashMap<String, String>>,
282    ) -> Result<GetPromptResult> {
283        self.ensure_initialized()?;
284        let params = GetPromptParams {
285            name: name.to_string(),
286            arguments: arguments.unwrap_or_default(),
287        };
288        self.request("prompts/get", &params).await
289    }
290
291    /// Ping the server
292    pub async fn ping(&mut self) -> Result<()> {
293        let _: serde_json::Value = self.request("ping", &serde_json::json!({})).await?;
294        Ok(())
295    }
296
297    /// Request completion suggestions from the server
298    ///
299    /// This is used to get autocomplete suggestions for prompt arguments or resource URIs.
300    pub async fn complete(
301        &mut self,
302        reference: CompletionReference,
303        argument_name: &str,
304        argument_value: &str,
305    ) -> Result<CompleteResult> {
306        self.ensure_initialized()?;
307        let params = CompleteParams {
308            reference,
309            argument: CompletionArgument::new(argument_name, argument_value),
310        };
311        self.request("completion/complete", &params).await
312    }
313
314    /// Request completion for a prompt argument
315    pub async fn complete_prompt_arg(
316        &mut self,
317        prompt_name: &str,
318        argument_name: &str,
319        argument_value: &str,
320    ) -> Result<CompleteResult> {
321        self.complete(
322            CompletionReference::prompt(prompt_name),
323            argument_name,
324            argument_value,
325        )
326        .await
327    }
328
329    /// Request completion for a resource URI
330    pub async fn complete_resource_uri(
331        &mut self,
332        resource_uri: &str,
333        argument_name: &str,
334        argument_value: &str,
335    ) -> Result<CompleteResult> {
336        self.complete(
337            CompletionReference::resource(resource_uri),
338            argument_name,
339            argument_value,
340        )
341        .await
342    }
343
344    /// Send a raw request
345    pub async fn request<P: serde::Serialize, R: serde::de::DeserializeOwned>(
346        &mut self,
347        method: &str,
348        params: &P,
349    ) -> Result<R> {
350        let params_value = serde_json::to_value(params)
351            .map_err(|e| Error::Transport(format!("Failed to serialize params: {}", e)))?;
352
353        let result = self.transport.request(method, params_value).await?;
354
355        serde_json::from_value(result)
356            .map_err(|e| Error::Transport(format!("Failed to deserialize response: {}", e)))
357    }
358
359    /// Send a notification
360    pub async fn notify<P: serde::Serialize>(&mut self, method: &str, params: &P) -> Result<()> {
361        let params_value = serde_json::to_value(params)
362            .map_err(|e| Error::Transport(format!("Failed to serialize params: {}", e)))?;
363
364        self.transport.notify(method, params_value).await
365    }
366
367    fn ensure_initialized(&self) -> Result<()> {
368        if !self.initialized {
369            return Err(Error::Transport("Client not initialized".to_string()));
370        }
371        Ok(())
372    }
373}
374
375// ============================================================================
376// Stdio Client Transport
377// ============================================================================
378
379/// Client transport that communicates with a subprocess via stdio
380pub struct StdioClientTransport {
381    child: Option<Child>,
382    stdin: tokio::process::ChildStdin,
383    stdout: BufReader<tokio::process::ChildStdout>,
384    request_id: AtomicI64,
385}
386
387impl StdioClientTransport {
388    /// Spawn a new subprocess and connect to it
389    pub async fn spawn(program: &str, args: &[&str]) -> Result<Self> {
390        let mut cmd = Command::new(program);
391        cmd.args(args)
392            .stdin(Stdio::piped())
393            .stdout(Stdio::piped())
394            .stderr(Stdio::inherit());
395
396        let mut child = cmd
397            .spawn()
398            .map_err(|e| Error::Transport(format!("Failed to spawn {}: {}", program, e)))?;
399
400        let stdin = child
401            .stdin
402            .take()
403            .ok_or_else(|| Error::Transport("Failed to get child stdin".to_string()))?;
404        let stdout = child
405            .stdout
406            .take()
407            .ok_or_else(|| Error::Transport("Failed to get child stdout".to_string()))?;
408
409        tracing::info!(program = %program, "Spawned MCP server process");
410
411        Ok(Self {
412            child: Some(child),
413            stdin,
414            stdout: BufReader::new(stdout),
415            request_id: AtomicI64::new(1),
416        })
417    }
418
419    /// Create from an existing child process
420    pub fn from_child(mut child: Child) -> Result<Self> {
421        let stdin = child
422            .stdin
423            .take()
424            .ok_or_else(|| Error::Transport("Failed to get child stdin".to_string()))?;
425        let stdout = child
426            .stdout
427            .take()
428            .ok_or_else(|| Error::Transport("Failed to get child stdout".to_string()))?;
429
430        Ok(Self {
431            child: Some(child),
432            stdin,
433            stdout: BufReader::new(stdout),
434            request_id: AtomicI64::new(1),
435        })
436    }
437
438    async fn send_line(&mut self, line: &str) -> Result<()> {
439        self.stdin
440            .write_all(line.as_bytes())
441            .await
442            .map_err(|e| Error::Transport(format!("Failed to write: {}", e)))?;
443        self.stdin
444            .write_all(b"\n")
445            .await
446            .map_err(|e| Error::Transport(format!("Failed to write newline: {}", e)))?;
447        self.stdin
448            .flush()
449            .await
450            .map_err(|e| Error::Transport(format!("Failed to flush: {}", e)))?;
451        Ok(())
452    }
453
454    async fn read_line(&mut self) -> Result<String> {
455        let mut line = String::new();
456        self.stdout
457            .read_line(&mut line)
458            .await
459            .map_err(|e| Error::Transport(format!("Failed to read: {}", e)))?;
460
461        if line.is_empty() {
462            return Err(Error::Transport("Connection closed".to_string()));
463        }
464
465        Ok(line)
466    }
467}
468
469#[async_trait]
470impl ClientTransport for StdioClientTransport {
471    async fn request(
472        &mut self,
473        method: &str,
474        params: serde_json::Value,
475    ) -> Result<serde_json::Value> {
476        let id = self.request_id.fetch_add(1, Ordering::Relaxed);
477        let request = JsonRpcRequest::new(id, method).with_params(params);
478
479        let request_json = serde_json::to_string(&request)
480            .map_err(|e| Error::Transport(format!("Failed to serialize: {}", e)))?;
481
482        tracing::debug!(method = %method, id = %id, "Sending request");
483        self.send_line(&request_json).await?;
484
485        let response_line = self.read_line().await?;
486        tracing::debug!(response = %response_line.trim(), "Received response");
487
488        let response: JsonRpcResponse = serde_json::from_str(response_line.trim())
489            .map_err(|e| Error::Transport(format!("Failed to parse response: {}", e)))?;
490
491        match response {
492            JsonRpcResponse::Result(r) => Ok(r.result),
493            JsonRpcResponse::Error(e) => Err(Error::JsonRpc(e.error)),
494        }
495    }
496
497    async fn notify(&mut self, method: &str, params: serde_json::Value) -> Result<()> {
498        let notification = serde_json::json!({
499            "jsonrpc": "2.0",
500            "method": method,
501            "params": params
502        });
503
504        let json = serde_json::to_string(&notification)
505            .map_err(|e| Error::Transport(format!("Failed to serialize: {}", e)))?;
506
507        tracing::debug!(method = %method, "Sending notification");
508        self.send_line(&json).await
509    }
510
511    fn is_connected(&self) -> bool {
512        // Assume connected if we have a child process handle
513        self.child.is_some()
514    }
515
516    async fn close(mut self: Box<Self>) -> Result<()> {
517        // Close stdin to signal EOF
518        drop(self.stdin);
519
520        if let Some(mut child) = self.child.take() {
521            // Wait for process with timeout
522            let result =
523                tokio::time::timeout(std::time::Duration::from_secs(5), child.wait()).await;
524
525            match result {
526                Ok(Ok(status)) => {
527                    tracing::info!(status = ?status, "Child process exited");
528                }
529                Ok(Err(e)) => {
530                    tracing::error!(error = %e, "Error waiting for child");
531                }
532                Err(_) => {
533                    tracing::warn!("Timeout waiting for child, killing");
534                    let _ = child.kill().await;
535                }
536            }
537        }
538
539        Ok(())
540    }
541}
542
543#[cfg(test)]
544mod tests {
545    use super::*;
546    use std::collections::VecDeque;
547    use std::sync::{Arc, Mutex};
548
549    /// Mock transport that returns preconfigured responses
550    struct MockTransport {
551        responses: Arc<Mutex<VecDeque<serde_json::Value>>>,
552        requests: Arc<Mutex<Vec<(String, serde_json::Value)>>>,
553        notifications: Arc<Mutex<Vec<(String, serde_json::Value)>>>,
554        connected: bool,
555    }
556
557    impl MockTransport {
558        fn new() -> Self {
559            Self {
560                responses: Arc::new(Mutex::new(VecDeque::new())),
561                requests: Arc::new(Mutex::new(Vec::new())),
562                notifications: Arc::new(Mutex::new(Vec::new())),
563                connected: true,
564            }
565        }
566
567        fn with_responses(responses: Vec<serde_json::Value>) -> Self {
568            Self {
569                responses: Arc::new(Mutex::new(responses.into())),
570                requests: Arc::new(Mutex::new(Vec::new())),
571                notifications: Arc::new(Mutex::new(Vec::new())),
572                connected: true,
573            }
574        }
575
576        #[allow(dead_code)]
577        fn get_requests(&self) -> Vec<(String, serde_json::Value)> {
578            self.requests.lock().unwrap().clone()
579        }
580
581        #[allow(dead_code)]
582        fn get_notifications(&self) -> Vec<(String, serde_json::Value)> {
583            self.notifications.lock().unwrap().clone()
584        }
585    }
586
587    #[async_trait]
588    impl ClientTransport for MockTransport {
589        async fn request(
590            &mut self,
591            method: &str,
592            params: serde_json::Value,
593        ) -> Result<serde_json::Value> {
594            self.requests
595                .lock()
596                .unwrap()
597                .push((method.to_string(), params));
598            self.responses
599                .lock()
600                .unwrap()
601                .pop_front()
602                .ok_or_else(|| Error::Transport("No more mock responses".to_string()))
603        }
604
605        async fn notify(&mut self, method: &str, params: serde_json::Value) -> Result<()> {
606            self.notifications
607                .lock()
608                .unwrap()
609                .push((method.to_string(), params));
610            Ok(())
611        }
612
613        fn is_connected(&self) -> bool {
614            self.connected
615        }
616
617        async fn close(self: Box<Self>) -> Result<()> {
618            Ok(())
619        }
620    }
621
622    fn mock_initialize_response() -> serde_json::Value {
623        serde_json::json!({
624            "protocolVersion": "2025-11-25",
625            "serverInfo": {
626                "name": "test-server",
627                "version": "1.0.0"
628            },
629            "capabilities": {
630                "tools": {}
631            }
632        })
633    }
634
635    #[tokio::test]
636    async fn test_client_not_initialized() {
637        let mut client = McpClient::new(MockTransport::new());
638
639        // Should fail because not initialized
640        let result = client.list_tools().await;
641        assert!(result.is_err());
642        assert!(result.unwrap_err().to_string().contains("not initialized"));
643    }
644
645    #[tokio::test]
646    async fn test_client_initialize() {
647        let transport = MockTransport::with_responses(vec![mock_initialize_response()]);
648        let mut client = McpClient::new(transport);
649
650        assert!(!client.is_initialized());
651
652        let result = client.initialize("test-client", "1.0.0").await;
653        assert!(result.is_ok());
654        assert!(client.is_initialized());
655
656        let server_info = client.server_info().unwrap();
657        assert_eq!(server_info.server_info.name, "test-server");
658    }
659
660    #[tokio::test]
661    async fn test_list_tools() {
662        let transport = MockTransport::with_responses(vec![
663            mock_initialize_response(),
664            serde_json::json!({
665                "tools": [
666                    {
667                        "name": "test_tool",
668                        "description": "A test tool",
669                        "inputSchema": {
670                            "type": "object",
671                            "properties": {}
672                        }
673                    }
674                ]
675            }),
676        ]);
677        let mut client = McpClient::new(transport);
678
679        client.initialize("test-client", "1.0.0").await.unwrap();
680        let tools = client.list_tools().await.unwrap();
681
682        assert_eq!(tools.tools.len(), 1);
683        assert_eq!(tools.tools[0].name, "test_tool");
684    }
685
686    #[tokio::test]
687    async fn test_call_tool() {
688        let transport = MockTransport::with_responses(vec![
689            mock_initialize_response(),
690            serde_json::json!({
691                "content": [
692                    {
693                        "type": "text",
694                        "text": "Tool result"
695                    }
696                ]
697            }),
698        ]);
699        let mut client = McpClient::new(transport);
700
701        client.initialize("test-client", "1.0.0").await.unwrap();
702        let result = client
703            .call_tool("test_tool", serde_json::json!({"arg": "value"}))
704            .await
705            .unwrap();
706
707        assert!(!result.content.is_empty());
708    }
709
710    #[tokio::test]
711    async fn test_list_resources() {
712        let transport = MockTransport::with_responses(vec![
713            mock_initialize_response(),
714            serde_json::json!({
715                "resources": [
716                    {
717                        "uri": "file://test.txt",
718                        "name": "Test File"
719                    }
720                ]
721            }),
722        ]);
723        let mut client = McpClient::new(transport);
724
725        client.initialize("test-client", "1.0.0").await.unwrap();
726        let resources = client.list_resources().await.unwrap();
727
728        assert_eq!(resources.resources.len(), 1);
729        assert_eq!(resources.resources[0].uri, "file://test.txt");
730    }
731
732    #[tokio::test]
733    async fn test_read_resource() {
734        let transport = MockTransport::with_responses(vec![
735            mock_initialize_response(),
736            serde_json::json!({
737                "contents": [
738                    {
739                        "uri": "file://test.txt",
740                        "text": "File contents"
741                    }
742                ]
743            }),
744        ]);
745        let mut client = McpClient::new(transport);
746
747        client.initialize("test-client", "1.0.0").await.unwrap();
748        let result = client.read_resource("file://test.txt").await.unwrap();
749
750        assert_eq!(result.contents.len(), 1);
751        assert_eq!(result.contents[0].text.as_deref(), Some("File contents"));
752    }
753
754    #[tokio::test]
755    async fn test_list_prompts() {
756        let transport = MockTransport::with_responses(vec![
757            mock_initialize_response(),
758            serde_json::json!({
759                "prompts": [
760                    {
761                        "name": "test_prompt",
762                        "description": "A test prompt"
763                    }
764                ]
765            }),
766        ]);
767        let mut client = McpClient::new(transport);
768
769        client.initialize("test-client", "1.0.0").await.unwrap();
770        let prompts = client.list_prompts().await.unwrap();
771
772        assert_eq!(prompts.prompts.len(), 1);
773        assert_eq!(prompts.prompts[0].name, "test_prompt");
774    }
775
776    #[tokio::test]
777    async fn test_get_prompt() {
778        let transport = MockTransport::with_responses(vec![
779            mock_initialize_response(),
780            serde_json::json!({
781                "messages": [
782                    {
783                        "role": "user",
784                        "content": {
785                            "type": "text",
786                            "text": "Prompt message"
787                        }
788                    }
789                ]
790            }),
791        ]);
792        let mut client = McpClient::new(transport);
793
794        client.initialize("test-client", "1.0.0").await.unwrap();
795        let result = client.get_prompt("test_prompt", None).await.unwrap();
796
797        assert_eq!(result.messages.len(), 1);
798    }
799
800    #[tokio::test]
801    async fn test_ping() {
802        let transport =
803            MockTransport::with_responses(vec![mock_initialize_response(), serde_json::json!({})]);
804        let mut client = McpClient::new(transport);
805
806        client.initialize("test-client", "1.0.0").await.unwrap();
807        let result = client.ping().await;
808
809        assert!(result.is_ok());
810    }
811
812    #[tokio::test]
813    async fn test_roots_management() {
814        let transport = MockTransport::with_responses(vec![mock_initialize_response()]);
815        let notifications = transport.notifications.clone();
816        let mut client = McpClient::new(transport);
817
818        // Initially no roots
819        assert!(client.roots().is_empty());
820
821        // Add a root before initialization (no notification)
822        client.add_root(Root::new("file:///project")).await.unwrap();
823        assert_eq!(client.roots().len(), 1);
824        assert!(notifications.lock().unwrap().is_empty());
825
826        // Initialize
827        client.initialize("test-client", "1.0.0").await.unwrap();
828
829        // Add another root after initialization (should notify)
830        client.add_root(Root::new("file:///other")).await.unwrap();
831        assert_eq!(client.roots().len(), 2);
832        assert_eq!(notifications.lock().unwrap().len(), 2); // initialized + roots changed
833
834        // Remove a root
835        let removed = client.remove_root("file:///project").await.unwrap();
836        assert!(removed);
837        assert_eq!(client.roots().len(), 1);
838
839        // Try to remove non-existent root
840        let not_removed = client.remove_root("file:///nonexistent").await.unwrap();
841        assert!(!not_removed);
842    }
843
844    #[tokio::test]
845    async fn test_with_roots() {
846        let roots = vec![Root::new("file:///test")];
847        let transport = MockTransport::with_responses(vec![mock_initialize_response()]);
848        let client = McpClient::new(transport).with_roots(roots);
849
850        assert_eq!(client.roots().len(), 1);
851        assert!(client.capabilities.roots.is_some());
852    }
853
854    #[tokio::test]
855    async fn test_with_capabilities() {
856        let capabilities = ClientCapabilities {
857            sampling: Some(Default::default()),
858            ..Default::default()
859        };
860
861        let transport = MockTransport::with_responses(vec![mock_initialize_response()]);
862        let client = McpClient::new(transport).with_capabilities(capabilities);
863
864        assert!(client.capabilities.sampling.is_some());
865    }
866
867    #[tokio::test]
868    async fn test_list_roots() {
869        let roots = vec![
870            Root::new("file:///project1"),
871            Root::with_name("file:///project2", "Project 2"),
872        ];
873        let transport = MockTransport::new();
874        let client = McpClient::new(transport).with_roots(roots);
875
876        let result = client.list_roots();
877        assert_eq!(result.roots.len(), 2);
878        assert_eq!(result.roots[1].name, Some("Project 2".to_string()));
879    }
880}