mistralrs_mcp/client.rs
1use crate::tools::{Function, Tool, ToolCallback, ToolCallbackWithTool, ToolType};
2use crate::transport::{HttpTransport, McpTransport, ProcessTransport, WebSocketTransport};
3use crate::types::McpToolResult;
4use crate::{McpClientConfig, McpServerConfig, McpServerSource, McpToolInfo};
5use anyhow::Result;
6use rust_mcp_schema::Resource;
7use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::Semaphore;
12use tracing::warn;
13
14/// Trait for MCP server connections
15#[async_trait::async_trait]
16pub trait McpServerConnection: Send + Sync {
17 /// Get the server ID
18 fn server_id(&self) -> &str;
19
20 /// Get the server name
21 fn server_name(&self) -> &str;
22
23 /// List available tools from this server
24 async fn list_tools(&self) -> Result<Vec<McpToolInfo>>;
25
26 /// Call a tool on this server
27 async fn call_tool(&self, name: &str, arguments: serde_json::Value) -> Result<String>;
28
29 /// List available resources from this server
30 async fn list_resources(&self) -> Result<Vec<Resource>>;
31
32 /// Read a resource from this server
33 async fn read_resource(&self, uri: &str) -> Result<String>;
34
35 /// Check if the connection is healthy
36 async fn ping(&self) -> Result<()>;
37
38 /// Close the server connection
39 async fn close(&self) -> Result<()>;
40}
41
42/// MCP client that manages connections to multiple MCP servers
43///
44/// The main interface for interacting with Model Context Protocol servers.
45/// Handles connection lifecycle, tool discovery, and provides integration
46/// with tool calling systems.
47///
48/// # Features
49///
50/// - **Multi-server Management**: Connects to and manages multiple MCP servers simultaneously
51/// - **Automatic Tool Discovery**: Discovers available tools from connected servers
52/// - **Tool Registration**: Converts MCP tools to internal Tool format for seamless integration
53/// - **Connection Pooling**: Maintains persistent connections for efficient tool execution
54/// - **Error Handling**: Robust error handling with proper cleanup and reconnection logic
55///
56/// # Example
57///
58/// ```rust,no_run
59/// use mistralrs_mcp::{McpClient, McpClientConfig};
60///
61/// #[tokio::main]
62/// async fn main() -> anyhow::Result<()> {
63/// let config = McpClientConfig::default();
64/// let mut client = McpClient::new(config);
65///
66/// // Initialize all configured server connections
67/// client.initialize().await?;
68///
69/// // Get tool callbacks for model integration
70/// let callbacks = client.get_tool_callbacks_with_tools();
71///
72/// Ok(())
73/// }
74/// ```
75pub struct McpClient {
76 /// Configuration for the client including server list and policies
77 config: McpClientConfig,
78 /// Active connections to MCP servers, indexed by server ID
79 servers: HashMap<String, Arc<dyn McpServerConnection>>,
80 /// Registry of discovered tools from all connected servers
81 tools: HashMap<String, McpToolInfo>,
82 /// Legacy tool callbacks for backward compatibility
83 tool_callbacks: HashMap<String, Arc<ToolCallback>>,
84 /// Tool callbacks with associated Tool definitions for automatic tool calling
85 tool_callbacks_with_tools: HashMap<String, ToolCallbackWithTool>,
86 /// Semaphore to control maximum concurrent tool calls
87 concurrency_semaphore: Arc<Semaphore>,
88}
89
90impl McpClient {
91 /// Create a new MCP client with the given configuration
92 pub fn new(config: McpClientConfig) -> Self {
93 let max_concurrent = config.max_concurrent_calls.unwrap_or(10);
94 Self {
95 config,
96 servers: HashMap::new(),
97 tools: HashMap::new(),
98 tool_callbacks: HashMap::new(),
99 tool_callbacks_with_tools: HashMap::new(),
100 concurrency_semaphore: Arc::new(Semaphore::new(max_concurrent)),
101 }
102 }
103
104 /// Initialize connections to all configured servers
105 pub async fn initialize(&mut self) -> Result<()> {
106 for server_config in &self.config.servers {
107 if server_config.enabled {
108 let connection = self.create_connection(server_config).await?;
109 self.servers.insert(server_config.id.clone(), connection);
110 }
111 }
112
113 if self.config.auto_register_tools {
114 self.discover_and_register_tools().await?;
115 }
116
117 Ok(())
118 }
119
120 /// Get tool callbacks for use with legacy tool calling systems.
121 ///
122 /// Returns a map of tool names to their callback functions. These callbacks
123 /// handle argument parsing, concurrency control, and timeout enforcement
124 /// automatically.
125 ///
126 /// For new integrations, prefer [`Self::get_tool_callbacks_with_tools`] which
127 /// includes tool definitions alongside callbacks.
128 pub fn get_tool_callbacks(&self) -> &HashMap<String, Arc<ToolCallback>> {
129 &self.tool_callbacks
130 }
131
132 /// Get tool callbacks paired with their tool definitions.
133 ///
134 /// This is the primary method for integrating MCP tools with the model's
135 /// automatic tool calling system. Each entry contains:
136 /// - A callback function that executes the tool with timeout and concurrency controls
137 /// - A [`Tool`] definition with name, description, and parameter schema
138 ///
139 /// # Example
140 ///
141 /// ```rust,no_run
142 /// # use mistralrs_mcp::{McpClient, McpClientConfig};
143 /// # async fn example() -> anyhow::Result<()> {
144 /// let config = McpClientConfig::default();
145 /// let mut client = McpClient::new(config);
146 /// client.initialize().await?;
147 ///
148 /// let tools = client.get_tool_callbacks_with_tools();
149 /// for (name, tool_with_callback) in tools {
150 /// println!("Tool: {} - {:?}", name, tool_with_callback.tool.function.description);
151 /// }
152 /// # Ok(())
153 /// # }
154 /// ```
155 pub fn get_tool_callbacks_with_tools(&self) -> &HashMap<String, ToolCallbackWithTool> {
156 &self.tool_callbacks_with_tools
157 }
158
159 /// Get information about all discovered tools.
160 ///
161 /// Returns metadata about tools discovered from connected MCP servers,
162 /// including their names, descriptions, input schemas, and which server
163 /// they came from.
164 pub fn get_tools(&self) -> &HashMap<String, McpToolInfo> {
165 &self.tools
166 }
167
168 /// Get a reference to all connected MCP server connections.
169 ///
170 /// This provides direct access to server connections, allowing you to:
171 /// - List available resources with [`McpServerConnection::list_resources`]
172 /// - Read resources with [`McpServerConnection::read_resource`]
173 /// - Check server health with [`McpServerConnection::ping`]
174 /// - Call tools directly with [`McpServerConnection::call_tool`]
175 ///
176 /// # Example
177 ///
178 /// ```rust,no_run
179 /// # use mistralrs_mcp::{McpClient, McpClientConfig};
180 /// # async fn example() -> anyhow::Result<()> {
181 /// let config = McpClientConfig::default();
182 /// let mut client = McpClient::new(config);
183 /// client.initialize().await?;
184 ///
185 /// for (server_id, connection) in client.servers() {
186 /// println!("Server: {} ({})", connection.server_name(), server_id);
187 /// let resources = connection.list_resources().await?;
188 /// println!(" Resources: {:?}", resources);
189 /// }
190 /// # Ok(())
191 /// # }
192 /// ```
193 pub fn servers(&self) -> &HashMap<String, Arc<dyn McpServerConnection>> {
194 &self.servers
195 }
196
197 /// Get a specific server connection by its ID.
198 ///
199 /// Returns `None` if no server with the given ID is connected.
200 ///
201 /// # Example
202 ///
203 /// ```rust,no_run
204 /// # use mistralrs_mcp::{McpClient, McpClientConfig};
205 /// # async fn example() -> anyhow::Result<()> {
206 /// let config = McpClientConfig::default();
207 /// let mut client = McpClient::new(config);
208 /// client.initialize().await?;
209 ///
210 /// if let Some(server) = client.server("my_server_id") {
211 /// server.ping().await?;
212 /// let resources = server.list_resources().await?;
213 /// }
214 /// # Ok(())
215 /// # }
216 /// ```
217 pub fn server(&self, id: &str) -> Option<&Arc<dyn McpServerConnection>> {
218 self.servers.get(id)
219 }
220
221 /// Get the client configuration.
222 pub fn config(&self) -> &McpClientConfig {
223 &self.config
224 }
225
226 /// Create connection based on server source type
227 async fn create_connection(
228 &self,
229 config: &McpServerConfig,
230 ) -> Result<Arc<dyn McpServerConnection>> {
231 match &config.source {
232 McpServerSource::Http {
233 url,
234 timeout_secs,
235 headers,
236 } => {
237 // Merge Bearer token with existing headers if provided
238 let mut merged_headers = headers.clone().unwrap_or_default();
239 if let Some(token) = &config.bearer_token {
240 merged_headers.insert("Authorization".to_string(), format!("Bearer {token}"));
241 }
242
243 let connection = HttpMcpConnection::new(
244 config.id.clone(),
245 config.name.clone(),
246 url.clone(),
247 *timeout_secs,
248 Some(merged_headers),
249 )
250 .await?;
251 Ok(Arc::new(connection))
252 }
253 McpServerSource::Process {
254 command,
255 args,
256 work_dir,
257 env,
258 } => {
259 let connection = ProcessMcpConnection::new(
260 config.id.clone(),
261 config.name.clone(),
262 command.clone(),
263 args.clone(),
264 work_dir.clone(),
265 env.clone(),
266 )
267 .await?;
268 Ok(Arc::new(connection))
269 }
270 McpServerSource::WebSocket {
271 url,
272 timeout_secs,
273 headers,
274 } => {
275 // Merge Bearer token with existing headers if provided
276 let mut merged_headers = headers.clone().unwrap_or_default();
277 if let Some(token) = &config.bearer_token {
278 merged_headers.insert("Authorization".to_string(), format!("Bearer {token}"));
279 }
280
281 let connection = WebSocketMcpConnection::new(
282 config.id.clone(),
283 config.name.clone(),
284 url.clone(),
285 *timeout_secs,
286 Some(merged_headers),
287 )
288 .await?;
289 Ok(Arc::new(connection))
290 }
291 }
292 }
293
294 /// Discover tools from all connected servers and register them
295 async fn discover_and_register_tools(&mut self) -> Result<()> {
296 for (server_id, connection) in &self.servers {
297 let tools = connection.list_tools().await?;
298 let server_config = self
299 .config
300 .servers
301 .iter()
302 .find(|s| &s.id == server_id)
303 .ok_or_else(|| anyhow::anyhow!("Server config not found for {}", server_id))?;
304
305 for tool in tools {
306 let tool_name = if let Some(prefix) = &server_config.tool_prefix {
307 format!("{}_{}", prefix, tool.name)
308 } else {
309 tool.name.clone()
310 };
311
312 // Create tool callback that calls the MCP server with timeout and concurrency controls
313 let connection_clone = Arc::clone(connection);
314 let original_tool_name = tool.name.clone();
315 let semaphore_clone = Arc::clone(&self.concurrency_semaphore);
316 let timeout_duration =
317 Duration::from_secs(self.config.tool_timeout_secs.unwrap_or(30));
318
319 let callback: Arc<ToolCallback> = Arc::new(move |called_function| {
320 let connection = Arc::clone(&connection_clone);
321 let tool_name = original_tool_name.clone();
322 let semaphore = Arc::clone(&semaphore_clone);
323 let arguments: serde_json::Value =
324 serde_json::from_str(&called_function.arguments)?;
325
326 // Use tokio::task::spawn_blocking to handle the async-to-sync bridge
327 let rt = tokio::runtime::Handle::current();
328 std::thread::spawn(move || {
329 rt.block_on(async move {
330 // Acquire semaphore permit for concurrency control
331 let _permit = semaphore.acquire().await.map_err(|_| {
332 anyhow::anyhow!("Failed to acquire concurrency permit")
333 })?;
334
335 // Execute tool call with timeout
336 match tokio::time::timeout(
337 timeout_duration,
338 connection.call_tool(&tool_name, arguments),
339 )
340 .await
341 {
342 Ok(result) => result,
343 Err(_) => Err(anyhow::anyhow!(
344 "Tool call timed out after {} seconds",
345 timeout_duration.as_secs()
346 )),
347 }
348 })
349 })
350 .join()
351 .map_err(|_| anyhow::anyhow!("Tool call thread panicked"))?
352 });
353
354 // Convert MCP tool schema to Tool definition
355 let function_def = Function {
356 name: tool_name.clone(),
357 description: tool.description.clone(),
358 parameters: Self::convert_mcp_schema_to_parameters(&tool.input_schema),
359 };
360
361 let tool_def = Tool {
362 tp: ToolType::Function,
363 function: function_def,
364 };
365
366 // Store in both collections for backward compatibility
367 self.tool_callbacks
368 .insert(tool_name.clone(), callback.clone());
369 self.tool_callbacks_with_tools.insert(
370 tool_name.clone(),
371 ToolCallbackWithTool {
372 callback,
373 tool: tool_def,
374 },
375 );
376 self.tools.insert(tool_name, tool);
377 }
378 }
379
380 Ok(())
381 }
382
383 /// Convert MCP tool input schema to Tool parameters format
384 fn convert_mcp_schema_to_parameters(
385 schema: &serde_json::Value,
386 ) -> Option<HashMap<String, serde_json::Value>> {
387 // MCP tools can have various schema formats, we'll try to convert common ones
388 match schema {
389 serde_json::Value::Object(obj) => {
390 let mut params = HashMap::new();
391
392 // If it's a JSON schema object, extract properties
393 if let Some(properties) = obj.get("properties") {
394 if let serde_json::Value::Object(props) = properties {
395 for (key, value) in props {
396 params.insert(key.clone(), value.clone());
397 }
398 }
399 } else {
400 // If it's just a direct object, use it as-is
401 for (key, value) in obj {
402 params.insert(key.clone(), value.clone());
403 }
404 }
405
406 if params.is_empty() {
407 None
408 } else {
409 Some(params)
410 }
411 }
412 _ => {
413 // For non-object schemas, we can't easily convert to parameters
414 None
415 }
416 }
417 }
418
419 /// Remove tools associated with a specific server
420 fn remove_tools_for_server(&mut self, server_id: &str) {
421 let tools_to_remove: Vec<String> = self
422 .tools
423 .iter()
424 .filter(|(_, info)| info.server_id == server_id)
425 .map(|(name, _)| name.clone())
426 .collect();
427
428 for name in tools_to_remove {
429 self.tools.remove(&name);
430 self.tool_callbacks.remove(&name);
431 self.tool_callbacks_with_tools.remove(&name);
432 }
433 }
434
435 /// Register tools for a single server
436 async fn register_tools_for_server(&mut self, server_id: &str) -> Result<()> {
437 let connection = self
438 .servers
439 .get(server_id)
440 .ok_or_else(|| anyhow::anyhow!("Server not connected: {}", server_id))?
441 .clone();
442
443 let server_config = self
444 .config
445 .servers
446 .iter()
447 .find(|s| s.id == server_id)
448 .ok_or_else(|| anyhow::anyhow!("Server config not found for {}", server_id))?
449 .clone();
450
451 let tools = connection.list_tools().await?;
452
453 for tool in tools {
454 let tool_name = if let Some(prefix) = &server_config.tool_prefix {
455 format!("{}_{}", prefix, tool.name)
456 } else {
457 tool.name.clone()
458 };
459
460 // Create tool callback that calls the MCP server with timeout and concurrency controls
461 let connection_clone = Arc::clone(&connection);
462 let original_tool_name = tool.name.clone();
463 let semaphore_clone = Arc::clone(&self.concurrency_semaphore);
464 let timeout_duration = Duration::from_secs(self.config.tool_timeout_secs.unwrap_or(30));
465
466 let callback: Arc<ToolCallback> = Arc::new(move |called_function| {
467 let connection = Arc::clone(&connection_clone);
468 let tool_name = original_tool_name.clone();
469 let semaphore = Arc::clone(&semaphore_clone);
470 let arguments: serde_json::Value =
471 serde_json::from_str(&called_function.arguments)?;
472
473 // Use tokio::task::spawn_blocking to handle the async-to-sync bridge
474 let rt = tokio::runtime::Handle::current();
475 std::thread::spawn(move || {
476 rt.block_on(async move {
477 // Acquire semaphore permit for concurrency control
478 let _permit = semaphore
479 .acquire()
480 .await
481 .map_err(|_| anyhow::anyhow!("Failed to acquire concurrency permit"))?;
482
483 // Execute tool call with timeout
484 match tokio::time::timeout(
485 timeout_duration,
486 connection.call_tool(&tool_name, arguments),
487 )
488 .await
489 {
490 Ok(result) => result,
491 Err(_) => Err(anyhow::anyhow!(
492 "Tool call timed out after {} seconds",
493 timeout_duration.as_secs()
494 )),
495 }
496 })
497 })
498 .join()
499 .map_err(|_| anyhow::anyhow!("Tool call thread panicked"))?
500 });
501
502 // Convert MCP tool schema to Tool definition
503 let function_def = Function {
504 name: tool_name.clone(),
505 description: tool.description.clone(),
506 parameters: Self::convert_mcp_schema_to_parameters(&tool.input_schema),
507 };
508
509 let tool_def = Tool {
510 tp: ToolType::Function,
511 function: function_def,
512 };
513
514 // Store in both collections for backward compatibility
515 self.tool_callbacks
516 .insert(tool_name.clone(), callback.clone());
517 self.tool_callbacks_with_tools.insert(
518 tool_name.clone(),
519 ToolCallbackWithTool {
520 callback,
521 tool: tool_def,
522 },
523 );
524 self.tools.insert(tool_name, tool);
525 }
526
527 Ok(())
528 }
529
530 // ==================== Connection Management Methods ====================
531
532 /// Gracefully shutdown all server connections.
533 ///
534 /// Closes all active connections and clears the tools and callbacks.
535 /// The client cannot be used after calling this method without re-initializing.
536 ///
537 /// # Example
538 ///
539 /// ```rust,no_run
540 /// # use mistralrs_mcp::{McpClient, McpClientConfig};
541 /// # async fn example() -> anyhow::Result<()> {
542 /// let config = McpClientConfig::default();
543 /// let mut client = McpClient::new(config);
544 /// client.initialize().await?;
545 ///
546 /// // ... use the client ...
547 ///
548 /// // Gracefully shutdown when done
549 /// client.shutdown().await?;
550 /// # Ok(())
551 /// # }
552 /// ```
553 pub async fn shutdown(&mut self) -> Result<()> {
554 // Close all connections
555 for connection in self.servers.values() {
556 let _ = connection.close().await;
557 }
558
559 // Clear all state
560 self.servers.clear();
561 self.tools.clear();
562 self.tool_callbacks.clear();
563 self.tool_callbacks_with_tools.clear();
564
565 Ok(())
566 }
567
568 /// Disconnect a specific server by its ID.
569 ///
570 /// Removes the server from active connections and clears its associated tools.
571 /// Returns an error if the server ID is not found.
572 ///
573 /// # Example
574 ///
575 /// ```rust,no_run
576 /// # use mistralrs_mcp::{McpClient, McpClientConfig};
577 /// # async fn example() -> anyhow::Result<()> {
578 /// let config = McpClientConfig::default();
579 /// let mut client = McpClient::new(config);
580 /// client.initialize().await?;
581 ///
582 /// // Disconnect a specific server
583 /// client.disconnect("my_server_id").await?;
584 /// # Ok(())
585 /// # }
586 /// ```
587 pub async fn disconnect(&mut self, id: &str) -> Result<()> {
588 let connection = self
589 .servers
590 .remove(id)
591 .ok_or_else(|| anyhow::anyhow!("Server not connected: {}", id))?;
592
593 connection.close().await?;
594 self.remove_tools_for_server(id);
595
596 Ok(())
597 }
598
599 /// Reconnect to a specific server by its ID.
600 ///
601 /// Re-establishes the connection using the stored configuration.
602 /// Returns an error if the server ID is not in the configuration.
603 ///
604 /// # Example
605 ///
606 /// ```rust,no_run
607 /// # use mistralrs_mcp::{McpClient, McpClientConfig};
608 /// # async fn example() -> anyhow::Result<()> {
609 /// let config = McpClientConfig::default();
610 /// let mut client = McpClient::new(config);
611 /// client.initialize().await?;
612 ///
613 /// // Reconnect to a server after it was disconnected or lost connection
614 /// client.reconnect("my_server_id").await?;
615 /// # Ok(())
616 /// # }
617 /// ```
618 pub async fn reconnect(&mut self, id: &str) -> Result<()> {
619 // Find the server config
620 let server_config = self
621 .config
622 .servers
623 .iter()
624 .find(|s| s.id == id)
625 .ok_or_else(|| anyhow::anyhow!("Server config not found: {}", id))?
626 .clone();
627
628 // Close existing connection if any
629 if let Some(connection) = self.servers.remove(id) {
630 let _ = connection.close().await;
631 }
632
633 // Remove old tools for this server
634 self.remove_tools_for_server(id);
635
636 // Create new connection
637 let connection = self.create_connection(&server_config).await?;
638 self.servers.insert(id.to_string(), connection);
639
640 // Re-register tools if auto_register_tools is enabled
641 if self.config.auto_register_tools {
642 self.register_tools_for_server(id).await?;
643 }
644
645 Ok(())
646 }
647
648 /// Check if a specific server is currently connected.
649 ///
650 /// # Example
651 ///
652 /// ```rust,no_run
653 /// # use mistralrs_mcp::{McpClient, McpClientConfig};
654 /// # async fn example() -> anyhow::Result<()> {
655 /// let config = McpClientConfig::default();
656 /// let mut client = McpClient::new(config);
657 /// client.initialize().await?;
658 ///
659 /// if client.is_connected("my_server_id") {
660 /// println!("Server is connected");
661 /// }
662 /// # Ok(())
663 /// # }
664 /// ```
665 pub fn is_connected(&self, id: &str) -> bool {
666 self.servers.contains_key(id)
667 }
668
669 /// Dynamically add and connect a new server at runtime.
670 ///
671 /// Adds the server configuration and establishes the connection.
672 /// If auto_register_tools is enabled, discovers and registers the server's tools.
673 ///
674 /// # Example
675 ///
676 /// ```rust,no_run
677 /// # use mistralrs_mcp::{McpClient, McpClientConfig, McpServerConfig, McpServerSource};
678 /// # async fn example() -> anyhow::Result<()> {
679 /// let config = McpClientConfig::default();
680 /// let mut client = McpClient::new(config);
681 /// client.initialize().await?;
682 ///
683 /// // Add a new server dynamically
684 /// let new_server = McpServerConfig {
685 /// id: "new_server".to_string(),
686 /// name: "New MCP Server".to_string(),
687 /// source: McpServerSource::Http {
688 /// url: "https://api.example.com/mcp".to_string(),
689 /// timeout_secs: Some(30),
690 /// headers: None,
691 /// },
692 /// ..Default::default()
693 /// };
694 /// client.add_server(new_server).await?;
695 /// # Ok(())
696 /// # }
697 /// ```
698 pub async fn add_server(&mut self, config: McpServerConfig) -> Result<()> {
699 let id = config.id.clone();
700
701 // Check if server already exists
702 if self.servers.contains_key(&id) {
703 return Err(anyhow::anyhow!("Server already exists: {}", id));
704 }
705
706 // Create connection
707 let connection = self.create_connection(&config).await?;
708 self.servers.insert(id.clone(), connection);
709
710 // Store config
711 self.config.servers.push(config);
712
713 // Register tools if enabled
714 if self.config.auto_register_tools {
715 self.register_tools_for_server(&id).await?;
716 }
717
718 Ok(())
719 }
720
721 /// Disconnect and remove a server from the client.
722 ///
723 /// Closes the connection and removes the server from the configuration.
724 ///
725 /// # Example
726 ///
727 /// ```rust,no_run
728 /// # use mistralrs_mcp::{McpClient, McpClientConfig};
729 /// # async fn example() -> anyhow::Result<()> {
730 /// let config = McpClientConfig::default();
731 /// let mut client = McpClient::new(config);
732 /// client.initialize().await?;
733 ///
734 /// // Remove a server completely
735 /// client.remove_server("my_server_id").await?;
736 /// # Ok(())
737 /// # }
738 /// ```
739 pub async fn remove_server(&mut self, id: &str) -> Result<()> {
740 // Disconnect first
741 if let Some(connection) = self.servers.remove(id) {
742 let _ = connection.close().await;
743 }
744
745 // Remove tools
746 self.remove_tools_for_server(id);
747
748 // Remove from config
749 self.config.servers.retain(|s| s.id != id);
750
751 Ok(())
752 }
753
754 // ==================== Tool Management Methods ====================
755
756 /// Re-discover tools from all connected servers.
757 ///
758 /// Clears existing tool registrations and re-queries all servers.
759 /// Useful for long-running clients when servers update their tools.
760 ///
761 /// # Example
762 ///
763 /// ```rust,no_run
764 /// # use mistralrs_mcp::{McpClient, McpClientConfig};
765 /// # async fn example() -> anyhow::Result<()> {
766 /// let config = McpClientConfig::default();
767 /// let mut client = McpClient::new(config);
768 /// client.initialize().await?;
769 ///
770 /// // Refresh tools after servers have been updated
771 /// client.refresh_tools().await?;
772 /// # Ok(())
773 /// # }
774 /// ```
775 pub async fn refresh_tools(&mut self) -> Result<()> {
776 // Clear all existing tools
777 self.tools.clear();
778 self.tool_callbacks.clear();
779 self.tool_callbacks_with_tools.clear();
780
781 // Re-discover tools from all servers
782 self.discover_and_register_tools().await
783 }
784
785 /// Get a specific tool by name.
786 ///
787 /// Returns `None` if no tool with the given name is registered.
788 ///
789 /// # Example
790 ///
791 /// ```rust,no_run
792 /// # use mistralrs_mcp::{McpClient, McpClientConfig};
793 /// # async fn example() -> anyhow::Result<()> {
794 /// let config = McpClientConfig::default();
795 /// let mut client = McpClient::new(config);
796 /// client.initialize().await?;
797 ///
798 /// if let Some(tool) = client.get_tool("web_search") {
799 /// println!("Found tool: {:?}", tool.description);
800 /// }
801 /// # Ok(())
802 /// # }
803 /// ```
804 pub fn get_tool(&self, name: &str) -> Option<&McpToolInfo> {
805 self.tools.get(name)
806 }
807
808 /// Check if a tool with the given name exists.
809 ///
810 /// # Example
811 ///
812 /// ```rust,no_run
813 /// # use mistralrs_mcp::{McpClient, McpClientConfig};
814 /// # async fn example() -> anyhow::Result<()> {
815 /// let config = McpClientConfig::default();
816 /// let mut client = McpClient::new(config);
817 /// client.initialize().await?;
818 ///
819 /// if client.has_tool("web_search") {
820 /// println!("Tool is available");
821 /// }
822 /// # Ok(())
823 /// # }
824 /// ```
825 pub fn has_tool(&self, name: &str) -> bool {
826 self.tools.contains_key(name)
827 }
828
829 /// Directly call a tool by name with the given arguments.
830 ///
831 /// This bypasses the callback system and calls the tool directly
832 /// on the appropriate server with timeout and concurrency controls.
833 ///
834 /// # Example
835 ///
836 /// ```rust,no_run
837 /// # use mistralrs_mcp::{McpClient, McpClientConfig};
838 /// # use serde_json::json;
839 /// # async fn example() -> anyhow::Result<()> {
840 /// let config = McpClientConfig::default();
841 /// let mut client = McpClient::new(config);
842 /// client.initialize().await?;
843 ///
844 /// let result = client.call_tool("web_search", json!({"query": "rust programming"})).await?;
845 /// println!("Result: {}", result);
846 /// # Ok(())
847 /// # }
848 /// ```
849 pub async fn call_tool(&self, name: &str, arguments: serde_json::Value) -> Result<String> {
850 let tool_info = self
851 .tools
852 .get(name)
853 .ok_or_else(|| anyhow::anyhow!("Tool not found: {}", name))?;
854
855 let connection = self
856 .servers
857 .get(&tool_info.server_id)
858 .ok_or_else(|| anyhow::anyhow!("Server not connected: {}", tool_info.server_id))?;
859
860 // Acquire semaphore permit for concurrency control
861 let _permit = self
862 .concurrency_semaphore
863 .acquire()
864 .await
865 .map_err(|_| anyhow::anyhow!("Failed to acquire concurrency permit"))?;
866
867 let timeout_duration = Duration::from_secs(self.config.tool_timeout_secs.unwrap_or(30));
868
869 // Execute tool call with timeout
870 match tokio::time::timeout(
871 timeout_duration,
872 connection.call_tool(&tool_info.name, arguments),
873 )
874 .await
875 {
876 Ok(result) => result,
877 Err(_) => Err(anyhow::anyhow!(
878 "Tool call timed out after {} seconds",
879 timeout_duration.as_secs()
880 )),
881 }
882 }
883
884 /// Get the total number of registered tools.
885 ///
886 /// # Example
887 ///
888 /// ```rust,no_run
889 /// # use mistralrs_mcp::{McpClient, McpClientConfig};
890 /// # async fn example() -> anyhow::Result<()> {
891 /// let config = McpClientConfig::default();
892 /// let mut client = McpClient::new(config);
893 /// client.initialize().await?;
894 ///
895 /// println!("Total tools: {}", client.tool_count());
896 /// # Ok(())
897 /// # }
898 /// ```
899 pub fn tool_count(&self) -> usize {
900 self.tools.len()
901 }
902
903 // ==================== Status / Convenience Methods ====================
904
905 /// Get the number of connected servers.
906 ///
907 /// # Example
908 ///
909 /// ```rust,no_run
910 /// # use mistralrs_mcp::{McpClient, McpClientConfig};
911 /// # async fn example() -> anyhow::Result<()> {
912 /// let config = McpClientConfig::default();
913 /// let mut client = McpClient::new(config);
914 /// client.initialize().await?;
915 ///
916 /// println!("Connected servers: {}", client.server_count());
917 /// # Ok(())
918 /// # }
919 /// ```
920 pub fn server_count(&self) -> usize {
921 self.servers.len()
922 }
923
924 /// Get a list of all connected server IDs.
925 ///
926 /// # Example
927 ///
928 /// ```rust,no_run
929 /// # use mistralrs_mcp::{McpClient, McpClientConfig};
930 /// # async fn example() -> anyhow::Result<()> {
931 /// let config = McpClientConfig::default();
932 /// let mut client = McpClient::new(config);
933 /// client.initialize().await?;
934 ///
935 /// for id in client.server_ids() {
936 /// println!("Server: {}", id);
937 /// }
938 /// # Ok(())
939 /// # }
940 /// ```
941 pub fn server_ids(&self) -> Vec<&str> {
942 self.servers.keys().map(|s| s.as_str()).collect()
943 }
944
945 /// Ping all connected servers and return results per server.
946 ///
947 /// Returns a map of server ID to ping result. Useful for health monitoring.
948 ///
949 /// # Example
950 ///
951 /// ```rust,no_run
952 /// # use mistralrs_mcp::{McpClient, McpClientConfig};
953 /// # async fn example() -> anyhow::Result<()> {
954 /// let config = McpClientConfig::default();
955 /// let mut client = McpClient::new(config);
956 /// client.initialize().await?;
957 ///
958 /// let results = client.ping_all().await;
959 /// for (server_id, result) in results {
960 /// match result {
961 /// Ok(()) => println!("{}: healthy", server_id),
962 /// Err(e) => println!("{}: unhealthy - {}", server_id, e),
963 /// }
964 /// }
965 /// # Ok(())
966 /// # }
967 /// ```
968 pub async fn ping_all(&self) -> HashMap<String, Result<()>> {
969 let mut results = HashMap::new();
970
971 for (server_id, connection) in &self.servers {
972 let result = connection.ping().await;
973 results.insert(server_id.clone(), result);
974 }
975
976 results
977 }
978
979 // ==================== Resource Access Methods ====================
980
981 /// List resources from all connected servers.
982 ///
983 /// Returns a vector of (server_id, resource) tuples.
984 ///
985 /// # Example
986 ///
987 /// ```rust,no_run
988 /// # use mistralrs_mcp::{McpClient, McpClientConfig};
989 /// # async fn example() -> anyhow::Result<()> {
990 /// let config = McpClientConfig::default();
991 /// let mut client = McpClient::new(config);
992 /// client.initialize().await?;
993 ///
994 /// let resources = client.list_all_resources().await?;
995 /// for (server_id, resource) in resources {
996 /// println!("Server {}: {:?}", server_id, resource);
997 /// }
998 /// # Ok(())
999 /// # }
1000 /// ```
1001 pub async fn list_all_resources(&self) -> Result<Vec<(String, Resource)>> {
1002 let mut all_resources = Vec::new();
1003
1004 for (server_id, connection) in &self.servers {
1005 match connection.list_resources().await {
1006 Ok(resources) => {
1007 for resource in resources {
1008 all_resources.push((server_id.clone(), resource));
1009 }
1010 }
1011 Err(e) => {
1012 // Log error but continue with other servers
1013 warn!("Failed to list resources from server {}: {}", server_id, e);
1014 }
1015 }
1016 }
1017
1018 Ok(all_resources)
1019 }
1020}
1021
1022impl Drop for McpClient {
1023 fn drop(&mut self) {
1024 // Try to get the tokio runtime handle
1025 if let Ok(handle) = tokio::runtime::Handle::try_current() {
1026 let servers = std::mem::take(&mut self.servers);
1027 handle.spawn(async move {
1028 for (_, connection) in servers {
1029 let _ = connection.close().await;
1030 }
1031 });
1032 }
1033 }
1034}
1035
1036/// HTTP-based MCP server connection
1037pub struct HttpMcpConnection {
1038 server_id: String,
1039 server_name: String,
1040 transport: Arc<dyn McpTransport>,
1041}
1042
1043impl HttpMcpConnection {
1044 pub async fn new(
1045 server_id: String,
1046 server_name: String,
1047 url: String,
1048 timeout_secs: Option<u64>,
1049 headers: Option<HashMap<String, String>>,
1050 ) -> Result<Self> {
1051 let transport = HttpTransport::new(url, timeout_secs, headers)?;
1052
1053 let connection = Self {
1054 server_id,
1055 server_name,
1056 transport: Arc::new(transport),
1057 };
1058
1059 // Initialize the connection
1060 connection.initialize().await?;
1061
1062 Ok(connection)
1063 }
1064
1065 async fn initialize(&self) -> Result<()> {
1066 let init_params = serde_json::json!({
1067 "protocolVersion": rust_mcp_schema::ProtocolVersion::latest().to_string(),
1068 "capabilities": {
1069 "tools": {}
1070 },
1071 "clientInfo": {
1072 "name": "mistral.rs",
1073 "version": env!("CARGO_PKG_VERSION"),
1074 }
1075 });
1076
1077 self.transport
1078 .send_request("initialize", init_params)
1079 .await?;
1080 self.transport.send_initialization_notification().await?;
1081 Ok(())
1082 }
1083}
1084
1085#[async_trait::async_trait]
1086impl McpServerConnection for HttpMcpConnection {
1087 fn server_id(&self) -> &str {
1088 &self.server_id
1089 }
1090
1091 fn server_name(&self) -> &str {
1092 &self.server_name
1093 }
1094
1095 async fn list_tools(&self) -> Result<Vec<McpToolInfo>> {
1096 let result = self
1097 .transport
1098 .send_request("tools/list", Value::Null)
1099 .await?;
1100
1101 let tools = result
1102 .get("tools")
1103 .and_then(|t| t.as_array())
1104 .ok_or_else(|| anyhow::anyhow!("Invalid tools response format"))?;
1105
1106 let mut tool_infos = Vec::new();
1107 for tool in tools {
1108 let name = tool
1109 .get("name")
1110 .and_then(|n| n.as_str())
1111 .ok_or_else(|| anyhow::anyhow!("Tool missing name"))?
1112 .to_string();
1113
1114 let description = tool
1115 .get("description")
1116 .and_then(|d| d.as_str())
1117 .map(|s| s.to_string());
1118
1119 let input_schema = tool
1120 .get("inputSchema")
1121 .cloned()
1122 .unwrap_or(Value::Object(serde_json::Map::new()));
1123
1124 tool_infos.push(McpToolInfo {
1125 name,
1126 description,
1127 input_schema,
1128 server_id: self.server_id.clone(),
1129 server_name: self.server_name.clone(),
1130 });
1131 }
1132
1133 Ok(tool_infos)
1134 }
1135
1136 async fn call_tool(&self, name: &str, arguments: Value) -> Result<String> {
1137 let params = serde_json::json!({
1138 "name": name,
1139 "arguments": arguments
1140 });
1141
1142 let result = self.transport.send_request("tools/call", params).await?;
1143
1144 // Parse the MCP tool result
1145 let tool_result: McpToolResult = serde_json::from_value(result)?;
1146
1147 // Check if the result indicates an error
1148 if tool_result.is_error.unwrap_or(false) {
1149 return Err(anyhow::anyhow!("Tool execution failed: {tool_result}"));
1150 }
1151
1152 Ok(tool_result.to_string())
1153 }
1154
1155 async fn list_resources(&self) -> Result<Vec<Resource>> {
1156 let result = self
1157 .transport
1158 .send_request("resources/list", Value::Null)
1159 .await?;
1160
1161 let resources = result
1162 .get("resources")
1163 .and_then(|r| r.as_array())
1164 .ok_or_else(|| anyhow::anyhow!("Invalid resources response format"))?;
1165
1166 let mut resource_list = Vec::new();
1167 for resource in resources {
1168 let mcp_resource: Resource = serde_json::from_value(resource.clone())?;
1169 resource_list.push(mcp_resource);
1170 }
1171
1172 Ok(resource_list)
1173 }
1174
1175 async fn read_resource(&self, uri: &str) -> Result<String> {
1176 let params = serde_json::json!({ "uri": uri });
1177 let result = self
1178 .transport
1179 .send_request("resources/read", params)
1180 .await?;
1181
1182 // Extract content from the response
1183 if let Some(contents) = result.get("contents").and_then(|c| c.as_array()) {
1184 if let Some(first_content) = contents.first() {
1185 if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
1186 return Ok(text.to_string());
1187 }
1188 }
1189 }
1190
1191 Err(anyhow::anyhow!("No readable content found in resource"))
1192 }
1193
1194 async fn ping(&self) -> Result<()> {
1195 // Send a simple ping to check if the server is responsive
1196 self.transport.send_request("ping", Value::Null).await?;
1197 Ok(())
1198 }
1199
1200 async fn close(&self) -> Result<()> {
1201 self.transport.close().await
1202 }
1203}
1204
1205/// Process-based MCP server connection
1206pub struct ProcessMcpConnection {
1207 server_id: String,
1208 server_name: String,
1209 transport: Arc<dyn McpTransport>,
1210}
1211
1212impl ProcessMcpConnection {
1213 pub async fn new(
1214 server_id: String,
1215 server_name: String,
1216 command: String,
1217 args: Vec<String>,
1218 work_dir: Option<String>,
1219 env: Option<HashMap<String, String>>,
1220 ) -> Result<Self> {
1221 let transport = ProcessTransport::new(command, args, work_dir, env).await?;
1222
1223 let connection = Self {
1224 server_id,
1225 server_name,
1226 transport: Arc::new(transport),
1227 };
1228
1229 // Initialize the connection
1230 connection.initialize().await?;
1231
1232 Ok(connection)
1233 }
1234
1235 async fn initialize(&self) -> Result<()> {
1236 let init_params = serde_json::json!({
1237 "protocolVersion": rust_mcp_schema::ProtocolVersion::latest().to_string(),
1238 "capabilities": {
1239 "tools": {}
1240 },
1241 "clientInfo": {
1242 "name": "mistral.rs",
1243 "version": env!("CARGO_PKG_VERSION"),
1244 }
1245 });
1246
1247 self.transport
1248 .send_request("initialize", init_params)
1249 .await?;
1250 self.transport.send_initialization_notification().await?;
1251 Ok(())
1252 }
1253}
1254
1255#[async_trait::async_trait]
1256impl McpServerConnection for ProcessMcpConnection {
1257 fn server_id(&self) -> &str {
1258 &self.server_id
1259 }
1260
1261 fn server_name(&self) -> &str {
1262 &self.server_name
1263 }
1264
1265 async fn list_tools(&self) -> Result<Vec<McpToolInfo>> {
1266 let result = self
1267 .transport
1268 .send_request("tools/list", Value::Null)
1269 .await?;
1270
1271 let tools = result
1272 .get("tools")
1273 .and_then(|t| t.as_array())
1274 .ok_or_else(|| anyhow::anyhow!("Invalid tools response format"))?;
1275
1276 let mut tool_infos = Vec::new();
1277 for tool in tools {
1278 let name = tool
1279 .get("name")
1280 .and_then(|n| n.as_str())
1281 .ok_or_else(|| anyhow::anyhow!("Tool missing name"))?
1282 .to_string();
1283
1284 let description = tool
1285 .get("description")
1286 .and_then(|d| d.as_str())
1287 .map(|s| s.to_string());
1288
1289 let input_schema = tool
1290 .get("inputSchema")
1291 .cloned()
1292 .unwrap_or(Value::Object(serde_json::Map::new()));
1293
1294 tool_infos.push(McpToolInfo {
1295 name,
1296 description,
1297 input_schema,
1298 server_id: self.server_id.clone(),
1299 server_name: self.server_name.clone(),
1300 });
1301 }
1302
1303 Ok(tool_infos)
1304 }
1305
1306 async fn call_tool(&self, name: &str, arguments: Value) -> Result<String> {
1307 let params = serde_json::json!({
1308 "name": name,
1309 "arguments": arguments
1310 });
1311
1312 let result = self.transport.send_request("tools/call", params).await?;
1313
1314 // Parse the MCP tool result
1315 let tool_result: McpToolResult = serde_json::from_value(result)?;
1316
1317 // Check if the result indicates an error
1318 if tool_result.is_error.unwrap_or(false) {
1319 return Err(anyhow::anyhow!("Tool execution failed: {tool_result}"));
1320 }
1321
1322 Ok(tool_result.to_string())
1323 }
1324
1325 async fn list_resources(&self) -> Result<Vec<Resource>> {
1326 let result = self
1327 .transport
1328 .send_request("resources/list", Value::Null)
1329 .await?;
1330
1331 let resources = result
1332 .get("resources")
1333 .and_then(|r| r.as_array())
1334 .ok_or_else(|| anyhow::anyhow!("Invalid resources response format"))?;
1335
1336 let mut resource_list = Vec::new();
1337 for resource in resources {
1338 let mcp_resource: Resource = serde_json::from_value(resource.clone())?;
1339 resource_list.push(mcp_resource);
1340 }
1341
1342 Ok(resource_list)
1343 }
1344
1345 async fn read_resource(&self, uri: &str) -> Result<String> {
1346 let params = serde_json::json!({ "uri": uri });
1347 let result = self
1348 .transport
1349 .send_request("resources/read", params)
1350 .await?;
1351
1352 // Extract content from the response
1353 if let Some(contents) = result.get("contents").and_then(|c| c.as_array()) {
1354 if let Some(first_content) = contents.first() {
1355 if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
1356 return Ok(text.to_string());
1357 }
1358 }
1359 }
1360
1361 Err(anyhow::anyhow!("No readable content found in resource"))
1362 }
1363
1364 async fn ping(&self) -> Result<()> {
1365 // Send a simple ping to check if the server is responsive
1366 self.transport.send_request("ping", Value::Null).await?;
1367 Ok(())
1368 }
1369
1370 async fn close(&self) -> Result<()> {
1371 self.transport.close().await
1372 }
1373}
1374
1375/// WebSocket-based MCP server connection
1376pub struct WebSocketMcpConnection {
1377 server_id: String,
1378 server_name: String,
1379 transport: Arc<dyn McpTransport>,
1380}
1381
1382impl WebSocketMcpConnection {
1383 pub async fn new(
1384 server_id: String,
1385 server_name: String,
1386 url: String,
1387 timeout_secs: Option<u64>,
1388 headers: Option<HashMap<String, String>>,
1389 ) -> Result<Self> {
1390 let transport = WebSocketTransport::new(url, timeout_secs, headers).await?;
1391
1392 let connection = Self {
1393 server_id,
1394 server_name,
1395 transport: Arc::new(transport),
1396 };
1397
1398 // Initialize the connection
1399 connection.initialize().await?;
1400
1401 Ok(connection)
1402 }
1403
1404 async fn initialize(&self) -> Result<()> {
1405 let init_params = serde_json::json!({
1406 "protocolVersion": rust_mcp_schema::ProtocolVersion::latest().to_string(),
1407 "capabilities": {
1408 "tools": {}
1409 },
1410 "clientInfo": {
1411 "name": "mistral.rs",
1412 "version": env!("CARGO_PKG_VERSION"),
1413 }
1414 });
1415
1416 self.transport
1417 .send_request("initialize", init_params)
1418 .await?;
1419 self.transport.send_initialization_notification().await?;
1420 Ok(())
1421 }
1422}
1423
1424#[async_trait::async_trait]
1425impl McpServerConnection for WebSocketMcpConnection {
1426 fn server_id(&self) -> &str {
1427 &self.server_id
1428 }
1429
1430 fn server_name(&self) -> &str {
1431 &self.server_name
1432 }
1433
1434 async fn list_tools(&self) -> Result<Vec<McpToolInfo>> {
1435 let result = self
1436 .transport
1437 .send_request("tools/list", Value::Null)
1438 .await?;
1439
1440 let tools = result
1441 .get("tools")
1442 .and_then(|t| t.as_array())
1443 .ok_or_else(|| anyhow::anyhow!("Invalid tools response format"))?;
1444
1445 let mut tool_infos = Vec::new();
1446 for tool in tools {
1447 let name = tool
1448 .get("name")
1449 .and_then(|n| n.as_str())
1450 .ok_or_else(|| anyhow::anyhow!("Tool missing name"))?
1451 .to_string();
1452
1453 let description = tool
1454 .get("description")
1455 .and_then(|d| d.as_str())
1456 .map(|s| s.to_string());
1457
1458 let input_schema = tool
1459 .get("inputSchema")
1460 .cloned()
1461 .unwrap_or(Value::Object(serde_json::Map::new()));
1462
1463 tool_infos.push(McpToolInfo {
1464 name,
1465 description,
1466 input_schema,
1467 server_id: self.server_id.clone(),
1468 server_name: self.server_name.clone(),
1469 });
1470 }
1471
1472 Ok(tool_infos)
1473 }
1474
1475 async fn call_tool(&self, name: &str, arguments: Value) -> Result<String> {
1476 let params = serde_json::json!({
1477 "name": name,
1478 "arguments": arguments
1479 });
1480
1481 let result = self.transport.send_request("tools/call", params).await?;
1482
1483 // Parse the MCP tool result
1484 let tool_result: McpToolResult = serde_json::from_value(result)?;
1485
1486 // Check if the result indicates an error
1487 if tool_result.is_error.unwrap_or(false) {
1488 return Err(anyhow::anyhow!("Tool execution failed: {tool_result}"));
1489 }
1490
1491 Ok(tool_result.to_string())
1492 }
1493
1494 async fn list_resources(&self) -> Result<Vec<Resource>> {
1495 let result = self
1496 .transport
1497 .send_request("resources/list", Value::Null)
1498 .await?;
1499
1500 let resources = result
1501 .get("resources")
1502 .and_then(|r| r.as_array())
1503 .ok_or_else(|| anyhow::anyhow!("Invalid resources response format"))?;
1504
1505 let mut resource_list = Vec::new();
1506 for resource in resources {
1507 let mcp_resource: Resource = serde_json::from_value(resource.clone())?;
1508 resource_list.push(mcp_resource);
1509 }
1510
1511 Ok(resource_list)
1512 }
1513
1514 async fn read_resource(&self, uri: &str) -> Result<String> {
1515 let params = serde_json::json!({ "uri": uri });
1516 let result = self
1517 .transport
1518 .send_request("resources/read", params)
1519 .await?;
1520
1521 // Extract content from the response
1522 if let Some(contents) = result.get("contents").and_then(|c| c.as_array()) {
1523 if let Some(first_content) = contents.first() {
1524 if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
1525 return Ok(text.to_string());
1526 }
1527 }
1528 }
1529
1530 Err(anyhow::anyhow!("No readable content found in resource"))
1531 }
1532
1533 async fn ping(&self) -> Result<()> {
1534 // Send a simple ping to check if the server is responsive
1535 self.transport.send_request("ping", Value::Null).await?;
1536 Ok(())
1537 }
1538
1539 async fn close(&self) -> Result<()> {
1540 self.transport.close().await
1541 }
1542}