Skip to main content

mistralrs_mcp/
client.rs

1use crate::tools::{Function, Tool, ToolCallback, ToolCallbackWithTool, ToolType};
2use crate::transport::{HttpTransport, McpTransport, ProcessTransport, WebSocketTransport};
3use crate::types::McpToolResult;
4use crate::{McpClientConfig, McpServerConfig, McpServerSource, McpToolInfo};
5use anyhow::Result;
6use rust_mcp_schema::Resource;
7use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::Semaphore;
12use tracing::warn;
13
14/// Trait for MCP server connections
15#[async_trait::async_trait]
16pub trait McpServerConnection: Send + Sync {
17    /// Get the server ID
18    fn server_id(&self) -> &str;
19
20    /// Get the server name
21    fn server_name(&self) -> &str;
22
23    /// List available tools from this server
24    async fn list_tools(&self) -> Result<Vec<McpToolInfo>>;
25
26    /// Call a tool on this server
27    async fn call_tool(&self, name: &str, arguments: serde_json::Value) -> Result<String>;
28
29    /// List available resources from this server
30    async fn list_resources(&self) -> Result<Vec<Resource>>;
31
32    /// Read a resource from this server
33    async fn read_resource(&self, uri: &str) -> Result<String>;
34
35    /// Check if the connection is healthy
36    async fn ping(&self) -> Result<()>;
37
38    /// Close the server connection
39    async fn close(&self) -> Result<()>;
40}
41
42/// MCP client that manages connections to multiple MCP servers
43///
44/// The main interface for interacting with Model Context Protocol servers.
45/// Handles connection lifecycle, tool discovery, and provides integration
46/// with tool calling systems.
47///
48/// # Features
49///
50/// - **Multi-server Management**: Connects to and manages multiple MCP servers simultaneously
51/// - **Automatic Tool Discovery**: Discovers available tools from connected servers
52/// - **Tool Registration**: Converts MCP tools to internal Tool format for seamless integration
53/// - **Connection Pooling**: Maintains persistent connections for efficient tool execution
54/// - **Error Handling**: Robust error handling with proper cleanup and reconnection logic
55///
56/// # Example
57///
58/// ```rust,no_run
59/// use mistralrs_mcp::{McpClient, McpClientConfig};
60///
61/// #[tokio::main]
62/// async fn main() -> anyhow::Result<()> {
63///     let config = McpClientConfig::default();
64///     let mut client = McpClient::new(config);
65///     
66///     // Initialize all configured server connections
67///     client.initialize().await?;
68///     
69///     // Get tool callbacks for model integration
70///     let callbacks = client.get_tool_callbacks_with_tools();
71///     
72///     Ok(())
73/// }
74/// ```
75pub struct McpClient {
76    /// Configuration for the client including server list and policies
77    config: McpClientConfig,
78    /// Active connections to MCP servers, indexed by server ID
79    servers: HashMap<String, Arc<dyn McpServerConnection>>,
80    /// Registry of discovered tools from all connected servers
81    tools: HashMap<String, McpToolInfo>,
82    /// Legacy tool callbacks for backward compatibility
83    tool_callbacks: HashMap<String, Arc<ToolCallback>>,
84    /// Tool callbacks with associated Tool definitions for automatic tool calling
85    tool_callbacks_with_tools: HashMap<String, ToolCallbackWithTool>,
86    /// Semaphore to control maximum concurrent tool calls
87    concurrency_semaphore: Arc<Semaphore>,
88}
89
90impl McpClient {
91    /// Create a new MCP client with the given configuration
92    pub fn new(config: McpClientConfig) -> Self {
93        let max_concurrent = config.max_concurrent_calls.unwrap_or(10);
94        Self {
95            config,
96            servers: HashMap::new(),
97            tools: HashMap::new(),
98            tool_callbacks: HashMap::new(),
99            tool_callbacks_with_tools: HashMap::new(),
100            concurrency_semaphore: Arc::new(Semaphore::new(max_concurrent)),
101        }
102    }
103
104    /// Initialize connections to all configured servers
105    pub async fn initialize(&mut self) -> Result<()> {
106        for server_config in &self.config.servers {
107            if server_config.enabled {
108                let connection = self.create_connection(server_config).await?;
109                self.servers.insert(server_config.id.clone(), connection);
110            }
111        }
112
113        if self.config.auto_register_tools {
114            self.discover_and_register_tools().await?;
115        }
116
117        Ok(())
118    }
119
120    /// Get tool callbacks for use with legacy tool calling systems.
121    ///
122    /// Returns a map of tool names to their callback functions. These callbacks
123    /// handle argument parsing, concurrency control, and timeout enforcement
124    /// automatically.
125    ///
126    /// For new integrations, prefer [`Self::get_tool_callbacks_with_tools`] which
127    /// includes tool definitions alongside callbacks.
128    pub fn get_tool_callbacks(&self) -> &HashMap<String, Arc<ToolCallback>> {
129        &self.tool_callbacks
130    }
131
132    /// Get tool callbacks paired with their tool definitions.
133    ///
134    /// This is the primary method for integrating MCP tools with the model's
135    /// automatic tool calling system. Each entry contains:
136    /// - A callback function that executes the tool with timeout and concurrency controls
137    /// - A [`Tool`] definition with name, description, and parameter schema
138    ///
139    /// # Example
140    ///
141    /// ```rust,no_run
142    /// # use mistralrs_mcp::{McpClient, McpClientConfig};
143    /// # async fn example() -> anyhow::Result<()> {
144    /// let config = McpClientConfig::default();
145    /// let mut client = McpClient::new(config);
146    /// client.initialize().await?;
147    ///
148    /// let tools = client.get_tool_callbacks_with_tools();
149    /// for (name, tool_with_callback) in tools {
150    ///     println!("Tool: {} - {:?}", name, tool_with_callback.tool.function.description);
151    /// }
152    /// # Ok(())
153    /// # }
154    /// ```
155    pub fn get_tool_callbacks_with_tools(&self) -> &HashMap<String, ToolCallbackWithTool> {
156        &self.tool_callbacks_with_tools
157    }
158
159    /// Get information about all discovered tools.
160    ///
161    /// Returns metadata about tools discovered from connected MCP servers,
162    /// including their names, descriptions, input schemas, and which server
163    /// they came from.
164    pub fn get_tools(&self) -> &HashMap<String, McpToolInfo> {
165        &self.tools
166    }
167
168    /// Get a reference to all connected MCP server connections.
169    ///
170    /// This provides direct access to server connections, allowing you to:
171    /// - List available resources with [`McpServerConnection::list_resources`]
172    /// - Read resources with [`McpServerConnection::read_resource`]
173    /// - Check server health with [`McpServerConnection::ping`]
174    /// - Call tools directly with [`McpServerConnection::call_tool`]
175    ///
176    /// # Example
177    ///
178    /// ```rust,no_run
179    /// # use mistralrs_mcp::{McpClient, McpClientConfig};
180    /// # async fn example() -> anyhow::Result<()> {
181    /// let config = McpClientConfig::default();
182    /// let mut client = McpClient::new(config);
183    /// client.initialize().await?;
184    ///
185    /// for (server_id, connection) in client.servers() {
186    ///     println!("Server: {} ({})", connection.server_name(), server_id);
187    ///     let resources = connection.list_resources().await?;
188    ///     println!("  Resources: {:?}", resources);
189    /// }
190    /// # Ok(())
191    /// # }
192    /// ```
193    pub fn servers(&self) -> &HashMap<String, Arc<dyn McpServerConnection>> {
194        &self.servers
195    }
196
197    /// Get a specific server connection by its ID.
198    ///
199    /// Returns `None` if no server with the given ID is connected.
200    ///
201    /// # Example
202    ///
203    /// ```rust,no_run
204    /// # use mistralrs_mcp::{McpClient, McpClientConfig};
205    /// # async fn example() -> anyhow::Result<()> {
206    /// let config = McpClientConfig::default();
207    /// let mut client = McpClient::new(config);
208    /// client.initialize().await?;
209    ///
210    /// if let Some(server) = client.server("my_server_id") {
211    ///     server.ping().await?;
212    ///     let resources = server.list_resources().await?;
213    /// }
214    /// # Ok(())
215    /// # }
216    /// ```
217    pub fn server(&self, id: &str) -> Option<&Arc<dyn McpServerConnection>> {
218        self.servers.get(id)
219    }
220
221    /// Get the client configuration.
222    pub fn config(&self) -> &McpClientConfig {
223        &self.config
224    }
225
226    /// Create connection based on server source type
227    async fn create_connection(
228        &self,
229        config: &McpServerConfig,
230    ) -> Result<Arc<dyn McpServerConnection>> {
231        match &config.source {
232            McpServerSource::Http {
233                url,
234                timeout_secs,
235                headers,
236            } => {
237                // Merge Bearer token with existing headers if provided
238                let mut merged_headers = headers.clone().unwrap_or_default();
239                if let Some(token) = &config.bearer_token {
240                    merged_headers.insert("Authorization".to_string(), format!("Bearer {token}"));
241                }
242
243                let connection = HttpMcpConnection::new(
244                    config.id.clone(),
245                    config.name.clone(),
246                    url.clone(),
247                    *timeout_secs,
248                    Some(merged_headers),
249                )
250                .await?;
251                Ok(Arc::new(connection))
252            }
253            McpServerSource::Process {
254                command,
255                args,
256                work_dir,
257                env,
258            } => {
259                let connection = ProcessMcpConnection::new(
260                    config.id.clone(),
261                    config.name.clone(),
262                    command.clone(),
263                    args.clone(),
264                    work_dir.clone(),
265                    env.clone(),
266                )
267                .await?;
268                Ok(Arc::new(connection))
269            }
270            McpServerSource::WebSocket {
271                url,
272                timeout_secs,
273                headers,
274            } => {
275                // Merge Bearer token with existing headers if provided
276                let mut merged_headers = headers.clone().unwrap_or_default();
277                if let Some(token) = &config.bearer_token {
278                    merged_headers.insert("Authorization".to_string(), format!("Bearer {token}"));
279                }
280
281                let connection = WebSocketMcpConnection::new(
282                    config.id.clone(),
283                    config.name.clone(),
284                    url.clone(),
285                    *timeout_secs,
286                    Some(merged_headers),
287                )
288                .await?;
289                Ok(Arc::new(connection))
290            }
291        }
292    }
293
294    /// Discover tools from all connected servers and register them
295    async fn discover_and_register_tools(&mut self) -> Result<()> {
296        for (server_id, connection) in &self.servers {
297            let tools = connection.list_tools().await?;
298            let server_config = self
299                .config
300                .servers
301                .iter()
302                .find(|s| &s.id == server_id)
303                .ok_or_else(|| anyhow::anyhow!("Server config not found for {}", server_id))?;
304
305            for tool in tools {
306                let tool_name = if let Some(prefix) = &server_config.tool_prefix {
307                    format!("{}_{}", prefix, tool.name)
308                } else {
309                    tool.name.clone()
310                };
311
312                // Create tool callback that calls the MCP server with timeout and concurrency controls
313                let connection_clone = Arc::clone(connection);
314                let original_tool_name = tool.name.clone();
315                let semaphore_clone = Arc::clone(&self.concurrency_semaphore);
316                let timeout_duration =
317                    Duration::from_secs(self.config.tool_timeout_secs.unwrap_or(30));
318
319                let callback: Arc<ToolCallback> = Arc::new(move |called_function| {
320                    let connection = Arc::clone(&connection_clone);
321                    let tool_name = original_tool_name.clone();
322                    let semaphore = Arc::clone(&semaphore_clone);
323                    let arguments: serde_json::Value =
324                        serde_json::from_str(&called_function.arguments)?;
325
326                    // Use tokio::task::spawn_blocking to handle the async-to-sync bridge
327                    let rt = tokio::runtime::Handle::current();
328                    std::thread::spawn(move || {
329                        rt.block_on(async move {
330                            // Acquire semaphore permit for concurrency control
331                            let _permit = semaphore.acquire().await.map_err(|_| {
332                                anyhow::anyhow!("Failed to acquire concurrency permit")
333                            })?;
334
335                            // Execute tool call with timeout
336                            match tokio::time::timeout(
337                                timeout_duration,
338                                connection.call_tool(&tool_name, arguments),
339                            )
340                            .await
341                            {
342                                Ok(result) => result,
343                                Err(_) => Err(anyhow::anyhow!(
344                                    "Tool call timed out after {} seconds",
345                                    timeout_duration.as_secs()
346                                )),
347                            }
348                        })
349                    })
350                    .join()
351                    .map_err(|_| anyhow::anyhow!("Tool call thread panicked"))?
352                });
353
354                // Convert MCP tool schema to Tool definition
355                let function_def = Function {
356                    name: tool_name.clone(),
357                    description: tool.description.clone(),
358                    parameters: Self::convert_mcp_schema_to_parameters(&tool.input_schema),
359                };
360
361                let tool_def = Tool {
362                    tp: ToolType::Function,
363                    function: function_def,
364                };
365
366                // Store in both collections for backward compatibility
367                self.tool_callbacks
368                    .insert(tool_name.clone(), callback.clone());
369                self.tool_callbacks_with_tools.insert(
370                    tool_name.clone(),
371                    ToolCallbackWithTool {
372                        callback,
373                        tool: tool_def,
374                    },
375                );
376                self.tools.insert(tool_name, tool);
377            }
378        }
379
380        Ok(())
381    }
382
383    /// Convert MCP tool input schema to Tool parameters format
384    fn convert_mcp_schema_to_parameters(
385        schema: &serde_json::Value,
386    ) -> Option<HashMap<String, serde_json::Value>> {
387        // MCP tools can have various schema formats, we'll try to convert common ones
388        match schema {
389            serde_json::Value::Object(obj) => {
390                let mut params = HashMap::new();
391
392                // If it's a JSON schema object, extract properties
393                if let Some(properties) = obj.get("properties") {
394                    if let serde_json::Value::Object(props) = properties {
395                        for (key, value) in props {
396                            params.insert(key.clone(), value.clone());
397                        }
398                    }
399                } else {
400                    // If it's just a direct object, use it as-is
401                    for (key, value) in obj {
402                        params.insert(key.clone(), value.clone());
403                    }
404                }
405
406                if params.is_empty() {
407                    None
408                } else {
409                    Some(params)
410                }
411            }
412            _ => {
413                // For non-object schemas, we can't easily convert to parameters
414                None
415            }
416        }
417    }
418
419    /// Remove tools associated with a specific server
420    fn remove_tools_for_server(&mut self, server_id: &str) {
421        let tools_to_remove: Vec<String> = self
422            .tools
423            .iter()
424            .filter(|(_, info)| info.server_id == server_id)
425            .map(|(name, _)| name.clone())
426            .collect();
427
428        for name in tools_to_remove {
429            self.tools.remove(&name);
430            self.tool_callbacks.remove(&name);
431            self.tool_callbacks_with_tools.remove(&name);
432        }
433    }
434
435    /// Register tools for a single server
436    async fn register_tools_for_server(&mut self, server_id: &str) -> Result<()> {
437        let connection = self
438            .servers
439            .get(server_id)
440            .ok_or_else(|| anyhow::anyhow!("Server not connected: {}", server_id))?
441            .clone();
442
443        let server_config = self
444            .config
445            .servers
446            .iter()
447            .find(|s| s.id == server_id)
448            .ok_or_else(|| anyhow::anyhow!("Server config not found for {}", server_id))?
449            .clone();
450
451        let tools = connection.list_tools().await?;
452
453        for tool in tools {
454            let tool_name = if let Some(prefix) = &server_config.tool_prefix {
455                format!("{}_{}", prefix, tool.name)
456            } else {
457                tool.name.clone()
458            };
459
460            // Create tool callback that calls the MCP server with timeout and concurrency controls
461            let connection_clone = Arc::clone(&connection);
462            let original_tool_name = tool.name.clone();
463            let semaphore_clone = Arc::clone(&self.concurrency_semaphore);
464            let timeout_duration = Duration::from_secs(self.config.tool_timeout_secs.unwrap_or(30));
465
466            let callback: Arc<ToolCallback> = Arc::new(move |called_function| {
467                let connection = Arc::clone(&connection_clone);
468                let tool_name = original_tool_name.clone();
469                let semaphore = Arc::clone(&semaphore_clone);
470                let arguments: serde_json::Value =
471                    serde_json::from_str(&called_function.arguments)?;
472
473                // Use tokio::task::spawn_blocking to handle the async-to-sync bridge
474                let rt = tokio::runtime::Handle::current();
475                std::thread::spawn(move || {
476                    rt.block_on(async move {
477                        // Acquire semaphore permit for concurrency control
478                        let _permit = semaphore
479                            .acquire()
480                            .await
481                            .map_err(|_| anyhow::anyhow!("Failed to acquire concurrency permit"))?;
482
483                        // Execute tool call with timeout
484                        match tokio::time::timeout(
485                            timeout_duration,
486                            connection.call_tool(&tool_name, arguments),
487                        )
488                        .await
489                        {
490                            Ok(result) => result,
491                            Err(_) => Err(anyhow::anyhow!(
492                                "Tool call timed out after {} seconds",
493                                timeout_duration.as_secs()
494                            )),
495                        }
496                    })
497                })
498                .join()
499                .map_err(|_| anyhow::anyhow!("Tool call thread panicked"))?
500            });
501
502            // Convert MCP tool schema to Tool definition
503            let function_def = Function {
504                name: tool_name.clone(),
505                description: tool.description.clone(),
506                parameters: Self::convert_mcp_schema_to_parameters(&tool.input_schema),
507            };
508
509            let tool_def = Tool {
510                tp: ToolType::Function,
511                function: function_def,
512            };
513
514            // Store in both collections for backward compatibility
515            self.tool_callbacks
516                .insert(tool_name.clone(), callback.clone());
517            self.tool_callbacks_with_tools.insert(
518                tool_name.clone(),
519                ToolCallbackWithTool {
520                    callback,
521                    tool: tool_def,
522                },
523            );
524            self.tools.insert(tool_name, tool);
525        }
526
527        Ok(())
528    }
529
530    // ==================== Connection Management Methods ====================
531
532    /// Gracefully shutdown all server connections.
533    ///
534    /// Closes all active connections and clears the tools and callbacks.
535    /// The client cannot be used after calling this method without re-initializing.
536    ///
537    /// # Example
538    ///
539    /// ```rust,no_run
540    /// # use mistralrs_mcp::{McpClient, McpClientConfig};
541    /// # async fn example() -> anyhow::Result<()> {
542    /// let config = McpClientConfig::default();
543    /// let mut client = McpClient::new(config);
544    /// client.initialize().await?;
545    ///
546    /// // ... use the client ...
547    ///
548    /// // Gracefully shutdown when done
549    /// client.shutdown().await?;
550    /// # Ok(())
551    /// # }
552    /// ```
553    pub async fn shutdown(&mut self) -> Result<()> {
554        // Close all connections
555        for connection in self.servers.values() {
556            let _ = connection.close().await;
557        }
558
559        // Clear all state
560        self.servers.clear();
561        self.tools.clear();
562        self.tool_callbacks.clear();
563        self.tool_callbacks_with_tools.clear();
564
565        Ok(())
566    }
567
568    /// Disconnect a specific server by its ID.
569    ///
570    /// Removes the server from active connections and clears its associated tools.
571    /// Returns an error if the server ID is not found.
572    ///
573    /// # Example
574    ///
575    /// ```rust,no_run
576    /// # use mistralrs_mcp::{McpClient, McpClientConfig};
577    /// # async fn example() -> anyhow::Result<()> {
578    /// let config = McpClientConfig::default();
579    /// let mut client = McpClient::new(config);
580    /// client.initialize().await?;
581    ///
582    /// // Disconnect a specific server
583    /// client.disconnect("my_server_id").await?;
584    /// # Ok(())
585    /// # }
586    /// ```
587    pub async fn disconnect(&mut self, id: &str) -> Result<()> {
588        let connection = self
589            .servers
590            .remove(id)
591            .ok_or_else(|| anyhow::anyhow!("Server not connected: {}", id))?;
592
593        connection.close().await?;
594        self.remove_tools_for_server(id);
595
596        Ok(())
597    }
598
599    /// Reconnect to a specific server by its ID.
600    ///
601    /// Re-establishes the connection using the stored configuration.
602    /// Returns an error if the server ID is not in the configuration.
603    ///
604    /// # Example
605    ///
606    /// ```rust,no_run
607    /// # use mistralrs_mcp::{McpClient, McpClientConfig};
608    /// # async fn example() -> anyhow::Result<()> {
609    /// let config = McpClientConfig::default();
610    /// let mut client = McpClient::new(config);
611    /// client.initialize().await?;
612    ///
613    /// // Reconnect to a server after it was disconnected or lost connection
614    /// client.reconnect("my_server_id").await?;
615    /// # Ok(())
616    /// # }
617    /// ```
618    pub async fn reconnect(&mut self, id: &str) -> Result<()> {
619        // Find the server config
620        let server_config = self
621            .config
622            .servers
623            .iter()
624            .find(|s| s.id == id)
625            .ok_or_else(|| anyhow::anyhow!("Server config not found: {}", id))?
626            .clone();
627
628        // Close existing connection if any
629        if let Some(connection) = self.servers.remove(id) {
630            let _ = connection.close().await;
631        }
632
633        // Remove old tools for this server
634        self.remove_tools_for_server(id);
635
636        // Create new connection
637        let connection = self.create_connection(&server_config).await?;
638        self.servers.insert(id.to_string(), connection);
639
640        // Re-register tools if auto_register_tools is enabled
641        if self.config.auto_register_tools {
642            self.register_tools_for_server(id).await?;
643        }
644
645        Ok(())
646    }
647
648    /// Check if a specific server is currently connected.
649    ///
650    /// # Example
651    ///
652    /// ```rust,no_run
653    /// # use mistralrs_mcp::{McpClient, McpClientConfig};
654    /// # async fn example() -> anyhow::Result<()> {
655    /// let config = McpClientConfig::default();
656    /// let mut client = McpClient::new(config);
657    /// client.initialize().await?;
658    ///
659    /// if client.is_connected("my_server_id") {
660    ///     println!("Server is connected");
661    /// }
662    /// # Ok(())
663    /// # }
664    /// ```
665    pub fn is_connected(&self, id: &str) -> bool {
666        self.servers.contains_key(id)
667    }
668
669    /// Dynamically add and connect a new server at runtime.
670    ///
671    /// Adds the server configuration and establishes the connection.
672    /// If auto_register_tools is enabled, discovers and registers the server's tools.
673    ///
674    /// # Example
675    ///
676    /// ```rust,no_run
677    /// # use mistralrs_mcp::{McpClient, McpClientConfig, McpServerConfig, McpServerSource};
678    /// # async fn example() -> anyhow::Result<()> {
679    /// let config = McpClientConfig::default();
680    /// let mut client = McpClient::new(config);
681    /// client.initialize().await?;
682    ///
683    /// // Add a new server dynamically
684    /// let new_server = McpServerConfig {
685    ///     id: "new_server".to_string(),
686    ///     name: "New MCP Server".to_string(),
687    ///     source: McpServerSource::Http {
688    ///         url: "https://api.example.com/mcp".to_string(),
689    ///         timeout_secs: Some(30),
690    ///         headers: None,
691    ///     },
692    ///     ..Default::default()
693    /// };
694    /// client.add_server(new_server).await?;
695    /// # Ok(())
696    /// # }
697    /// ```
698    pub async fn add_server(&mut self, config: McpServerConfig) -> Result<()> {
699        let id = config.id.clone();
700
701        // Check if server already exists
702        if self.servers.contains_key(&id) {
703            return Err(anyhow::anyhow!("Server already exists: {}", id));
704        }
705
706        // Create connection
707        let connection = self.create_connection(&config).await?;
708        self.servers.insert(id.clone(), connection);
709
710        // Store config
711        self.config.servers.push(config);
712
713        // Register tools if enabled
714        if self.config.auto_register_tools {
715            self.register_tools_for_server(&id).await?;
716        }
717
718        Ok(())
719    }
720
721    /// Disconnect and remove a server from the client.
722    ///
723    /// Closes the connection and removes the server from the configuration.
724    ///
725    /// # Example
726    ///
727    /// ```rust,no_run
728    /// # use mistralrs_mcp::{McpClient, McpClientConfig};
729    /// # async fn example() -> anyhow::Result<()> {
730    /// let config = McpClientConfig::default();
731    /// let mut client = McpClient::new(config);
732    /// client.initialize().await?;
733    ///
734    /// // Remove a server completely
735    /// client.remove_server("my_server_id").await?;
736    /// # Ok(())
737    /// # }
738    /// ```
739    pub async fn remove_server(&mut self, id: &str) -> Result<()> {
740        // Disconnect first
741        if let Some(connection) = self.servers.remove(id) {
742            let _ = connection.close().await;
743        }
744
745        // Remove tools
746        self.remove_tools_for_server(id);
747
748        // Remove from config
749        self.config.servers.retain(|s| s.id != id);
750
751        Ok(())
752    }
753
754    // ==================== Tool Management Methods ====================
755
756    /// Re-discover tools from all connected servers.
757    ///
758    /// Clears existing tool registrations and re-queries all servers.
759    /// Useful for long-running clients when servers update their tools.
760    ///
761    /// # Example
762    ///
763    /// ```rust,no_run
764    /// # use mistralrs_mcp::{McpClient, McpClientConfig};
765    /// # async fn example() -> anyhow::Result<()> {
766    /// let config = McpClientConfig::default();
767    /// let mut client = McpClient::new(config);
768    /// client.initialize().await?;
769    ///
770    /// // Refresh tools after servers have been updated
771    /// client.refresh_tools().await?;
772    /// # Ok(())
773    /// # }
774    /// ```
775    pub async fn refresh_tools(&mut self) -> Result<()> {
776        // Clear all existing tools
777        self.tools.clear();
778        self.tool_callbacks.clear();
779        self.tool_callbacks_with_tools.clear();
780
781        // Re-discover tools from all servers
782        self.discover_and_register_tools().await
783    }
784
785    /// Get a specific tool by name.
786    ///
787    /// Returns `None` if no tool with the given name is registered.
788    ///
789    /// # Example
790    ///
791    /// ```rust,no_run
792    /// # use mistralrs_mcp::{McpClient, McpClientConfig};
793    /// # async fn example() -> anyhow::Result<()> {
794    /// let config = McpClientConfig::default();
795    /// let mut client = McpClient::new(config);
796    /// client.initialize().await?;
797    ///
798    /// if let Some(tool) = client.get_tool("web_search") {
799    ///     println!("Found tool: {:?}", tool.description);
800    /// }
801    /// # Ok(())
802    /// # }
803    /// ```
804    pub fn get_tool(&self, name: &str) -> Option<&McpToolInfo> {
805        self.tools.get(name)
806    }
807
808    /// Check if a tool with the given name exists.
809    ///
810    /// # Example
811    ///
812    /// ```rust,no_run
813    /// # use mistralrs_mcp::{McpClient, McpClientConfig};
814    /// # async fn example() -> anyhow::Result<()> {
815    /// let config = McpClientConfig::default();
816    /// let mut client = McpClient::new(config);
817    /// client.initialize().await?;
818    ///
819    /// if client.has_tool("web_search") {
820    ///     println!("Tool is available");
821    /// }
822    /// # Ok(())
823    /// # }
824    /// ```
825    pub fn has_tool(&self, name: &str) -> bool {
826        self.tools.contains_key(name)
827    }
828
829    /// Directly call a tool by name with the given arguments.
830    ///
831    /// This bypasses the callback system and calls the tool directly
832    /// on the appropriate server with timeout and concurrency controls.
833    ///
834    /// # Example
835    ///
836    /// ```rust,no_run
837    /// # use mistralrs_mcp::{McpClient, McpClientConfig};
838    /// # use serde_json::json;
839    /// # async fn example() -> anyhow::Result<()> {
840    /// let config = McpClientConfig::default();
841    /// let mut client = McpClient::new(config);
842    /// client.initialize().await?;
843    ///
844    /// let result = client.call_tool("web_search", json!({"query": "rust programming"})).await?;
845    /// println!("Result: {}", result);
846    /// # Ok(())
847    /// # }
848    /// ```
849    pub async fn call_tool(&self, name: &str, arguments: serde_json::Value) -> Result<String> {
850        let tool_info = self
851            .tools
852            .get(name)
853            .ok_or_else(|| anyhow::anyhow!("Tool not found: {}", name))?;
854
855        let connection = self
856            .servers
857            .get(&tool_info.server_id)
858            .ok_or_else(|| anyhow::anyhow!("Server not connected: {}", tool_info.server_id))?;
859
860        // Acquire semaphore permit for concurrency control
861        let _permit = self
862            .concurrency_semaphore
863            .acquire()
864            .await
865            .map_err(|_| anyhow::anyhow!("Failed to acquire concurrency permit"))?;
866
867        let timeout_duration = Duration::from_secs(self.config.tool_timeout_secs.unwrap_or(30));
868
869        // Execute tool call with timeout
870        match tokio::time::timeout(
871            timeout_duration,
872            connection.call_tool(&tool_info.name, arguments),
873        )
874        .await
875        {
876            Ok(result) => result,
877            Err(_) => Err(anyhow::anyhow!(
878                "Tool call timed out after {} seconds",
879                timeout_duration.as_secs()
880            )),
881        }
882    }
883
884    /// Get the total number of registered tools.
885    ///
886    /// # Example
887    ///
888    /// ```rust,no_run
889    /// # use mistralrs_mcp::{McpClient, McpClientConfig};
890    /// # async fn example() -> anyhow::Result<()> {
891    /// let config = McpClientConfig::default();
892    /// let mut client = McpClient::new(config);
893    /// client.initialize().await?;
894    ///
895    /// println!("Total tools: {}", client.tool_count());
896    /// # Ok(())
897    /// # }
898    /// ```
899    pub fn tool_count(&self) -> usize {
900        self.tools.len()
901    }
902
903    // ==================== Status / Convenience Methods ====================
904
905    /// Get the number of connected servers.
906    ///
907    /// # Example
908    ///
909    /// ```rust,no_run
910    /// # use mistralrs_mcp::{McpClient, McpClientConfig};
911    /// # async fn example() -> anyhow::Result<()> {
912    /// let config = McpClientConfig::default();
913    /// let mut client = McpClient::new(config);
914    /// client.initialize().await?;
915    ///
916    /// println!("Connected servers: {}", client.server_count());
917    /// # Ok(())
918    /// # }
919    /// ```
920    pub fn server_count(&self) -> usize {
921        self.servers.len()
922    }
923
924    /// Get a list of all connected server IDs.
925    ///
926    /// # Example
927    ///
928    /// ```rust,no_run
929    /// # use mistralrs_mcp::{McpClient, McpClientConfig};
930    /// # async fn example() -> anyhow::Result<()> {
931    /// let config = McpClientConfig::default();
932    /// let mut client = McpClient::new(config);
933    /// client.initialize().await?;
934    ///
935    /// for id in client.server_ids() {
936    ///     println!("Server: {}", id);
937    /// }
938    /// # Ok(())
939    /// # }
940    /// ```
941    pub fn server_ids(&self) -> Vec<&str> {
942        self.servers.keys().map(|s| s.as_str()).collect()
943    }
944
945    /// Ping all connected servers and return results per server.
946    ///
947    /// Returns a map of server ID to ping result. Useful for health monitoring.
948    ///
949    /// # Example
950    ///
951    /// ```rust,no_run
952    /// # use mistralrs_mcp::{McpClient, McpClientConfig};
953    /// # async fn example() -> anyhow::Result<()> {
954    /// let config = McpClientConfig::default();
955    /// let mut client = McpClient::new(config);
956    /// client.initialize().await?;
957    ///
958    /// let results = client.ping_all().await;
959    /// for (server_id, result) in results {
960    ///     match result {
961    ///         Ok(()) => println!("{}: healthy", server_id),
962    ///         Err(e) => println!("{}: unhealthy - {}", server_id, e),
963    ///     }
964    /// }
965    /// # Ok(())
966    /// # }
967    /// ```
968    pub async fn ping_all(&self) -> HashMap<String, Result<()>> {
969        let mut results = HashMap::new();
970
971        for (server_id, connection) in &self.servers {
972            let result = connection.ping().await;
973            results.insert(server_id.clone(), result);
974        }
975
976        results
977    }
978
979    // ==================== Resource Access Methods ====================
980
981    /// List resources from all connected servers.
982    ///
983    /// Returns a vector of (server_id, resource) tuples.
984    ///
985    /// # Example
986    ///
987    /// ```rust,no_run
988    /// # use mistralrs_mcp::{McpClient, McpClientConfig};
989    /// # async fn example() -> anyhow::Result<()> {
990    /// let config = McpClientConfig::default();
991    /// let mut client = McpClient::new(config);
992    /// client.initialize().await?;
993    ///
994    /// let resources = client.list_all_resources().await?;
995    /// for (server_id, resource) in resources {
996    ///     println!("Server {}: {:?}", server_id, resource);
997    /// }
998    /// # Ok(())
999    /// # }
1000    /// ```
1001    pub async fn list_all_resources(&self) -> Result<Vec<(String, Resource)>> {
1002        let mut all_resources = Vec::new();
1003
1004        for (server_id, connection) in &self.servers {
1005            match connection.list_resources().await {
1006                Ok(resources) => {
1007                    for resource in resources {
1008                        all_resources.push((server_id.clone(), resource));
1009                    }
1010                }
1011                Err(e) => {
1012                    // Log error but continue with other servers
1013                    warn!("Failed to list resources from server {}: {}", server_id, e);
1014                }
1015            }
1016        }
1017
1018        Ok(all_resources)
1019    }
1020}
1021
1022impl Drop for McpClient {
1023    fn drop(&mut self) {
1024        // Try to get the tokio runtime handle
1025        if let Ok(handle) = tokio::runtime::Handle::try_current() {
1026            let servers = std::mem::take(&mut self.servers);
1027            handle.spawn(async move {
1028                for (_, connection) in servers {
1029                    let _ = connection.close().await;
1030                }
1031            });
1032        }
1033    }
1034}
1035
1036/// HTTP-based MCP server connection
1037pub struct HttpMcpConnection {
1038    server_id: String,
1039    server_name: String,
1040    transport: Arc<dyn McpTransport>,
1041}
1042
1043impl HttpMcpConnection {
1044    pub async fn new(
1045        server_id: String,
1046        server_name: String,
1047        url: String,
1048        timeout_secs: Option<u64>,
1049        headers: Option<HashMap<String, String>>,
1050    ) -> Result<Self> {
1051        let transport = HttpTransport::new(url, timeout_secs, headers)?;
1052
1053        let connection = Self {
1054            server_id,
1055            server_name,
1056            transport: Arc::new(transport),
1057        };
1058
1059        // Initialize the connection
1060        connection.initialize().await?;
1061
1062        Ok(connection)
1063    }
1064
1065    async fn initialize(&self) -> Result<()> {
1066        let init_params = serde_json::json!({
1067            "protocolVersion": rust_mcp_schema::ProtocolVersion::latest().to_string(),
1068            "capabilities": {
1069                "tools": {}
1070            },
1071            "clientInfo": {
1072                "name": "mistral.rs",
1073                "version": env!("CARGO_PKG_VERSION"),
1074            }
1075        });
1076
1077        self.transport
1078            .send_request("initialize", init_params)
1079            .await?;
1080        self.transport.send_initialization_notification().await?;
1081        Ok(())
1082    }
1083}
1084
1085#[async_trait::async_trait]
1086impl McpServerConnection for HttpMcpConnection {
1087    fn server_id(&self) -> &str {
1088        &self.server_id
1089    }
1090
1091    fn server_name(&self) -> &str {
1092        &self.server_name
1093    }
1094
1095    async fn list_tools(&self) -> Result<Vec<McpToolInfo>> {
1096        let result = self
1097            .transport
1098            .send_request("tools/list", Value::Null)
1099            .await?;
1100
1101        let tools = result
1102            .get("tools")
1103            .and_then(|t| t.as_array())
1104            .ok_or_else(|| anyhow::anyhow!("Invalid tools response format"))?;
1105
1106        let mut tool_infos = Vec::new();
1107        for tool in tools {
1108            let name = tool
1109                .get("name")
1110                .and_then(|n| n.as_str())
1111                .ok_or_else(|| anyhow::anyhow!("Tool missing name"))?
1112                .to_string();
1113
1114            let description = tool
1115                .get("description")
1116                .and_then(|d| d.as_str())
1117                .map(|s| s.to_string());
1118
1119            let input_schema = tool
1120                .get("inputSchema")
1121                .cloned()
1122                .unwrap_or(Value::Object(serde_json::Map::new()));
1123
1124            tool_infos.push(McpToolInfo {
1125                name,
1126                description,
1127                input_schema,
1128                server_id: self.server_id.clone(),
1129                server_name: self.server_name.clone(),
1130            });
1131        }
1132
1133        Ok(tool_infos)
1134    }
1135
1136    async fn call_tool(&self, name: &str, arguments: Value) -> Result<String> {
1137        let params = serde_json::json!({
1138            "name": name,
1139            "arguments": arguments
1140        });
1141
1142        let result = self.transport.send_request("tools/call", params).await?;
1143
1144        // Parse the MCP tool result
1145        let tool_result: McpToolResult = serde_json::from_value(result)?;
1146
1147        // Check if the result indicates an error
1148        if tool_result.is_error.unwrap_or(false) {
1149            return Err(anyhow::anyhow!("Tool execution failed: {tool_result}"));
1150        }
1151
1152        Ok(tool_result.to_string())
1153    }
1154
1155    async fn list_resources(&self) -> Result<Vec<Resource>> {
1156        let result = self
1157            .transport
1158            .send_request("resources/list", Value::Null)
1159            .await?;
1160
1161        let resources = result
1162            .get("resources")
1163            .and_then(|r| r.as_array())
1164            .ok_or_else(|| anyhow::anyhow!("Invalid resources response format"))?;
1165
1166        let mut resource_list = Vec::new();
1167        for resource in resources {
1168            let mcp_resource: Resource = serde_json::from_value(resource.clone())?;
1169            resource_list.push(mcp_resource);
1170        }
1171
1172        Ok(resource_list)
1173    }
1174
1175    async fn read_resource(&self, uri: &str) -> Result<String> {
1176        let params = serde_json::json!({ "uri": uri });
1177        let result = self
1178            .transport
1179            .send_request("resources/read", params)
1180            .await?;
1181
1182        // Extract content from the response
1183        if let Some(contents) = result.get("contents").and_then(|c| c.as_array()) {
1184            if let Some(first_content) = contents.first() {
1185                if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
1186                    return Ok(text.to_string());
1187                }
1188            }
1189        }
1190
1191        Err(anyhow::anyhow!("No readable content found in resource"))
1192    }
1193
1194    async fn ping(&self) -> Result<()> {
1195        // Send a simple ping to check if the server is responsive
1196        self.transport.send_request("ping", Value::Null).await?;
1197        Ok(())
1198    }
1199
1200    async fn close(&self) -> Result<()> {
1201        self.transport.close().await
1202    }
1203}
1204
1205/// Process-based MCP server connection
1206pub struct ProcessMcpConnection {
1207    server_id: String,
1208    server_name: String,
1209    transport: Arc<dyn McpTransport>,
1210}
1211
1212impl ProcessMcpConnection {
1213    pub async fn new(
1214        server_id: String,
1215        server_name: String,
1216        command: String,
1217        args: Vec<String>,
1218        work_dir: Option<String>,
1219        env: Option<HashMap<String, String>>,
1220    ) -> Result<Self> {
1221        let transport = ProcessTransport::new(command, args, work_dir, env).await?;
1222
1223        let connection = Self {
1224            server_id,
1225            server_name,
1226            transport: Arc::new(transport),
1227        };
1228
1229        // Initialize the connection
1230        connection.initialize().await?;
1231
1232        Ok(connection)
1233    }
1234
1235    async fn initialize(&self) -> Result<()> {
1236        let init_params = serde_json::json!({
1237            "protocolVersion": rust_mcp_schema::ProtocolVersion::latest().to_string(),
1238            "capabilities": {
1239                "tools": {}
1240            },
1241            "clientInfo": {
1242                "name": "mistral.rs",
1243                "version": env!("CARGO_PKG_VERSION"),
1244            }
1245        });
1246
1247        self.transport
1248            .send_request("initialize", init_params)
1249            .await?;
1250        self.transport.send_initialization_notification().await?;
1251        Ok(())
1252    }
1253}
1254
1255#[async_trait::async_trait]
1256impl McpServerConnection for ProcessMcpConnection {
1257    fn server_id(&self) -> &str {
1258        &self.server_id
1259    }
1260
1261    fn server_name(&self) -> &str {
1262        &self.server_name
1263    }
1264
1265    async fn list_tools(&self) -> Result<Vec<McpToolInfo>> {
1266        let result = self
1267            .transport
1268            .send_request("tools/list", Value::Null)
1269            .await?;
1270
1271        let tools = result
1272            .get("tools")
1273            .and_then(|t| t.as_array())
1274            .ok_or_else(|| anyhow::anyhow!("Invalid tools response format"))?;
1275
1276        let mut tool_infos = Vec::new();
1277        for tool in tools {
1278            let name = tool
1279                .get("name")
1280                .and_then(|n| n.as_str())
1281                .ok_or_else(|| anyhow::anyhow!("Tool missing name"))?
1282                .to_string();
1283
1284            let description = tool
1285                .get("description")
1286                .and_then(|d| d.as_str())
1287                .map(|s| s.to_string());
1288
1289            let input_schema = tool
1290                .get("inputSchema")
1291                .cloned()
1292                .unwrap_or(Value::Object(serde_json::Map::new()));
1293
1294            tool_infos.push(McpToolInfo {
1295                name,
1296                description,
1297                input_schema,
1298                server_id: self.server_id.clone(),
1299                server_name: self.server_name.clone(),
1300            });
1301        }
1302
1303        Ok(tool_infos)
1304    }
1305
1306    async fn call_tool(&self, name: &str, arguments: Value) -> Result<String> {
1307        let params = serde_json::json!({
1308            "name": name,
1309            "arguments": arguments
1310        });
1311
1312        let result = self.transport.send_request("tools/call", params).await?;
1313
1314        // Parse the MCP tool result
1315        let tool_result: McpToolResult = serde_json::from_value(result)?;
1316
1317        // Check if the result indicates an error
1318        if tool_result.is_error.unwrap_or(false) {
1319            return Err(anyhow::anyhow!("Tool execution failed: {tool_result}"));
1320        }
1321
1322        Ok(tool_result.to_string())
1323    }
1324
1325    async fn list_resources(&self) -> Result<Vec<Resource>> {
1326        let result = self
1327            .transport
1328            .send_request("resources/list", Value::Null)
1329            .await?;
1330
1331        let resources = result
1332            .get("resources")
1333            .and_then(|r| r.as_array())
1334            .ok_or_else(|| anyhow::anyhow!("Invalid resources response format"))?;
1335
1336        let mut resource_list = Vec::new();
1337        for resource in resources {
1338            let mcp_resource: Resource = serde_json::from_value(resource.clone())?;
1339            resource_list.push(mcp_resource);
1340        }
1341
1342        Ok(resource_list)
1343    }
1344
1345    async fn read_resource(&self, uri: &str) -> Result<String> {
1346        let params = serde_json::json!({ "uri": uri });
1347        let result = self
1348            .transport
1349            .send_request("resources/read", params)
1350            .await?;
1351
1352        // Extract content from the response
1353        if let Some(contents) = result.get("contents").and_then(|c| c.as_array()) {
1354            if let Some(first_content) = contents.first() {
1355                if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
1356                    return Ok(text.to_string());
1357                }
1358            }
1359        }
1360
1361        Err(anyhow::anyhow!("No readable content found in resource"))
1362    }
1363
1364    async fn ping(&self) -> Result<()> {
1365        // Send a simple ping to check if the server is responsive
1366        self.transport.send_request("ping", Value::Null).await?;
1367        Ok(())
1368    }
1369
1370    async fn close(&self) -> Result<()> {
1371        self.transport.close().await
1372    }
1373}
1374
1375/// WebSocket-based MCP server connection
1376pub struct WebSocketMcpConnection {
1377    server_id: String,
1378    server_name: String,
1379    transport: Arc<dyn McpTransport>,
1380}
1381
1382impl WebSocketMcpConnection {
1383    pub async fn new(
1384        server_id: String,
1385        server_name: String,
1386        url: String,
1387        timeout_secs: Option<u64>,
1388        headers: Option<HashMap<String, String>>,
1389    ) -> Result<Self> {
1390        let transport = WebSocketTransport::new(url, timeout_secs, headers).await?;
1391
1392        let connection = Self {
1393            server_id,
1394            server_name,
1395            transport: Arc::new(transport),
1396        };
1397
1398        // Initialize the connection
1399        connection.initialize().await?;
1400
1401        Ok(connection)
1402    }
1403
1404    async fn initialize(&self) -> Result<()> {
1405        let init_params = serde_json::json!({
1406            "protocolVersion": rust_mcp_schema::ProtocolVersion::latest().to_string(),
1407            "capabilities": {
1408                "tools": {}
1409            },
1410            "clientInfo": {
1411                "name": "mistral.rs",
1412                "version": env!("CARGO_PKG_VERSION"),
1413            }
1414        });
1415
1416        self.transport
1417            .send_request("initialize", init_params)
1418            .await?;
1419        self.transport.send_initialization_notification().await?;
1420        Ok(())
1421    }
1422}
1423
1424#[async_trait::async_trait]
1425impl McpServerConnection for WebSocketMcpConnection {
1426    fn server_id(&self) -> &str {
1427        &self.server_id
1428    }
1429
1430    fn server_name(&self) -> &str {
1431        &self.server_name
1432    }
1433
1434    async fn list_tools(&self) -> Result<Vec<McpToolInfo>> {
1435        let result = self
1436            .transport
1437            .send_request("tools/list", Value::Null)
1438            .await?;
1439
1440        let tools = result
1441            .get("tools")
1442            .and_then(|t| t.as_array())
1443            .ok_or_else(|| anyhow::anyhow!("Invalid tools response format"))?;
1444
1445        let mut tool_infos = Vec::new();
1446        for tool in tools {
1447            let name = tool
1448                .get("name")
1449                .and_then(|n| n.as_str())
1450                .ok_or_else(|| anyhow::anyhow!("Tool missing name"))?
1451                .to_string();
1452
1453            let description = tool
1454                .get("description")
1455                .and_then(|d| d.as_str())
1456                .map(|s| s.to_string());
1457
1458            let input_schema = tool
1459                .get("inputSchema")
1460                .cloned()
1461                .unwrap_or(Value::Object(serde_json::Map::new()));
1462
1463            tool_infos.push(McpToolInfo {
1464                name,
1465                description,
1466                input_schema,
1467                server_id: self.server_id.clone(),
1468                server_name: self.server_name.clone(),
1469            });
1470        }
1471
1472        Ok(tool_infos)
1473    }
1474
1475    async fn call_tool(&self, name: &str, arguments: Value) -> Result<String> {
1476        let params = serde_json::json!({
1477            "name": name,
1478            "arguments": arguments
1479        });
1480
1481        let result = self.transport.send_request("tools/call", params).await?;
1482
1483        // Parse the MCP tool result
1484        let tool_result: McpToolResult = serde_json::from_value(result)?;
1485
1486        // Check if the result indicates an error
1487        if tool_result.is_error.unwrap_or(false) {
1488            return Err(anyhow::anyhow!("Tool execution failed: {tool_result}"));
1489        }
1490
1491        Ok(tool_result.to_string())
1492    }
1493
1494    async fn list_resources(&self) -> Result<Vec<Resource>> {
1495        let result = self
1496            .transport
1497            .send_request("resources/list", Value::Null)
1498            .await?;
1499
1500        let resources = result
1501            .get("resources")
1502            .and_then(|r| r.as_array())
1503            .ok_or_else(|| anyhow::anyhow!("Invalid resources response format"))?;
1504
1505        let mut resource_list = Vec::new();
1506        for resource in resources {
1507            let mcp_resource: Resource = serde_json::from_value(resource.clone())?;
1508            resource_list.push(mcp_resource);
1509        }
1510
1511        Ok(resource_list)
1512    }
1513
1514    async fn read_resource(&self, uri: &str) -> Result<String> {
1515        let params = serde_json::json!({ "uri": uri });
1516        let result = self
1517            .transport
1518            .send_request("resources/read", params)
1519            .await?;
1520
1521        // Extract content from the response
1522        if let Some(contents) = result.get("contents").and_then(|c| c.as_array()) {
1523            if let Some(first_content) = contents.first() {
1524                if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
1525                    return Ok(text.to_string());
1526                }
1527            }
1528        }
1529
1530        Err(anyhow::anyhow!("No readable content found in resource"))
1531    }
1532
1533    async fn ping(&self) -> Result<()> {
1534        // Send a simple ping to check if the server is responsive
1535        self.transport.send_request("ping", Value::Null).await?;
1536        Ok(())
1537    }
1538
1539    async fn close(&self) -> Result<()> {
1540        self.transport.close().await
1541    }
1542}