hanzo_llm_mcp/client.rs
1use crate::tools::{
2 Function, Tool, ToolCallback, ToolCallbackKind, ToolCallbackWithTool, ToolType,
3};
4use crate::transport::{HttpTransport, McpTransport, ProcessTransport, WebSocketTransport};
5use crate::types::McpToolResult;
6use crate::{McpClientConfig, McpServerConfig, McpServerSource, McpToolInfo};
7use anyhow::Result;
8use rust_mcp_schema::Resource;
9use serde_json::Value;
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::Duration;
13use tokio::sync::Semaphore;
14use tracing::warn;
15
16/// Trait for MCP server connections
17#[async_trait::async_trait]
18pub trait McpServerConnection: Send + Sync {
19 /// Get the server ID
20 fn server_id(&self) -> &str;
21
22 /// Get the server name
23 fn server_name(&self) -> &str;
24
25 /// List available tools from this server
26 async fn list_tools(&self) -> Result<Vec<McpToolInfo>>;
27
28 /// Call a tool on this server
29 async fn call_tool(&self, name: &str, arguments: serde_json::Value) -> Result<String>;
30
31 /// List available resources from this server
32 async fn list_resources(&self) -> Result<Vec<Resource>>;
33
34 /// Read a resource from this server
35 async fn read_resource(&self, uri: &str) -> Result<String>;
36
37 /// Check if the connection is healthy
38 async fn ping(&self) -> Result<()>;
39
40 /// Close the server connection
41 async fn close(&self) -> Result<()>;
42}
43
44fn initialize_params() -> Value {
45 serde_json::json!({
46 "protocolVersion": rust_mcp_schema::ProtocolVersion::latest().to_string(),
47 "capabilities": {
48 "tools": {}
49 },
50 "clientInfo": {
51 "name": "hanzo",
52 "version": env!("CARGO_PKG_VERSION"),
53 }
54 })
55}
56
57async fn initialize_transport(transport: &Arc<dyn McpTransport>) -> Result<()> {
58 transport
59 .send_request("initialize", initialize_params())
60 .await?;
61 transport.send_initialization_notification().await
62}
63
64async fn list_tools_via_transport(
65 transport: &Arc<dyn McpTransport>,
66 server_id: &str,
67 server_name: &str,
68) -> Result<Vec<McpToolInfo>> {
69 let result = transport.send_request("tools/list", Value::Null).await?;
70
71 let tools = result
72 .get("tools")
73 .and_then(|t| t.as_array())
74 .ok_or_else(|| anyhow::anyhow!("Invalid tools response format"))?;
75
76 let mut tool_infos = Vec::new();
77 for tool in tools {
78 let name = tool
79 .get("name")
80 .and_then(|n| n.as_str())
81 .ok_or_else(|| anyhow::anyhow!("Tool missing name"))?
82 .to_string();
83
84 let description = tool
85 .get("description")
86 .and_then(|d| d.as_str())
87 .map(|s| s.to_string());
88
89 let input_schema = tool
90 .get("inputSchema")
91 .cloned()
92 .unwrap_or(Value::Object(serde_json::Map::new()));
93
94 tool_infos.push(McpToolInfo {
95 name,
96 description,
97 input_schema,
98 server_id: server_id.to_string(),
99 server_name: server_name.to_string(),
100 });
101 }
102
103 Ok(tool_infos)
104}
105
106async fn call_tool_via_transport(
107 transport: &Arc<dyn McpTransport>,
108 name: &str,
109 arguments: Value,
110) -> Result<String> {
111 let params = serde_json::json!({
112 "name": name,
113 "arguments": arguments
114 });
115
116 let result = transport.send_request("tools/call", params).await?;
117 let tool_result: McpToolResult = serde_json::from_value(result)?;
118
119 if tool_result.is_error.unwrap_or(false) {
120 return Err(anyhow::anyhow!("Tool execution failed: {tool_result}"));
121 }
122
123 Ok(tool_result.to_string())
124}
125
126async fn list_resources_via_transport(transport: &Arc<dyn McpTransport>) -> Result<Vec<Resource>> {
127 let result = transport
128 .send_request("resources/list", Value::Null)
129 .await?;
130
131 let resources = result
132 .get("resources")
133 .and_then(|r| r.as_array())
134 .ok_or_else(|| anyhow::anyhow!("Invalid resources response format"))?;
135
136 let mut resource_list = Vec::new();
137 for resource in resources {
138 resource_list.push(serde_json::from_value(resource.clone())?);
139 }
140
141 Ok(resource_list)
142}
143
144async fn read_resource_via_transport(
145 transport: &Arc<dyn McpTransport>,
146 uri: &str,
147) -> Result<String> {
148 let params = serde_json::json!({ "uri": uri });
149 let result = transport.send_request("resources/read", params).await?;
150
151 if let Some(contents) = result.get("contents").and_then(|c| c.as_array()) {
152 if let Some(first_content) = contents.first() {
153 if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
154 return Ok(text.to_string());
155 }
156 }
157 }
158
159 Err(anyhow::anyhow!("No readable content found in resource"))
160}
161
162async fn ping_transport(transport: &Arc<dyn McpTransport>) -> Result<()> {
163 transport.send_request("ping", Value::Null).await?;
164 Ok(())
165}
166
167/// MCP client that manages connections to multiple MCP servers
168///
169/// The main interface for interacting with Model Context Protocol servers.
170/// Handles connection lifecycle, tool discovery, and provides integration
171/// with tool calling systems.
172///
173/// # Features
174///
175/// - **Multi-server Management**: Connects to and manages multiple MCP servers simultaneously
176/// - **Automatic Tool Discovery**: Discovers available tools from connected servers
177/// - **Tool Registration**: Converts MCP tools to internal Tool format for seamless integration
178/// - **Connection Pooling**: Maintains persistent connections for efficient tool execution
179/// - **Error Handling**: Robust error handling with proper cleanup and reconnection logic
180///
181/// # Example
182///
183/// ```rust,no_run
184/// use hanzo_llm_mcp::{McpClient, McpClientConfig};
185///
186/// #[tokio::main]
187/// async fn main() -> anyhow::Result<()> {
188/// let config = McpClientConfig::default();
189/// let mut client = McpClient::new(config);
190///
191/// // Initialize all configured server connections
192/// client.initialize().await?;
193///
194/// // Get tool callbacks for model integration
195/// let callbacks = client.get_tool_callbacks_with_tools();
196///
197/// Ok(())
198/// }
199/// ```
200pub struct McpClient {
201 /// Configuration for the client including server list and policies
202 config: McpClientConfig,
203 /// Active connections to MCP servers, indexed by server ID
204 servers: HashMap<String, Arc<dyn McpServerConnection>>,
205 /// Registry of discovered tools from all connected servers
206 tools: HashMap<String, McpToolInfo>,
207 /// Legacy tool callbacks for backward compatibility
208 tool_callbacks: HashMap<String, Arc<ToolCallback>>,
209 /// Tool callbacks with associated Tool definitions for automatic tool calling
210 tool_callbacks_with_tools: HashMap<String, ToolCallbackWithTool>,
211 /// Semaphore to control maximum concurrent tool calls
212 concurrency_semaphore: Arc<Semaphore>,
213}
214
215impl McpClient {
216 /// Create a new MCP client with the given configuration
217 pub fn new(config: McpClientConfig) -> Self {
218 let max_concurrent = config.max_concurrent_calls.unwrap_or(10);
219 Self {
220 config,
221 servers: HashMap::new(),
222 tools: HashMap::new(),
223 tool_callbacks: HashMap::new(),
224 tool_callbacks_with_tools: HashMap::new(),
225 concurrency_semaphore: Arc::new(Semaphore::new(max_concurrent)),
226 }
227 }
228
229 /// Initialize connections to all configured servers
230 pub async fn initialize(&mut self) -> Result<()> {
231 for server_config in &self.config.servers {
232 if server_config.enabled {
233 let connection = self.create_connection(server_config).await?;
234 self.servers.insert(server_config.id.clone(), connection);
235 }
236 }
237
238 if self.config.auto_register_tools {
239 self.discover_and_register_tools().await?;
240 }
241
242 Ok(())
243 }
244
245 /// Get tool callbacks for use with legacy tool calling systems.
246 ///
247 /// Returns a map of tool names to their callback functions. These callbacks
248 /// handle argument parsing, concurrency control, and timeout enforcement
249 /// automatically.
250 ///
251 /// For new integrations, prefer [`Self::get_tool_callbacks_with_tools`] which
252 /// includes tool definitions alongside callbacks.
253 pub fn get_tool_callbacks(&self) -> &HashMap<String, Arc<ToolCallback>> {
254 &self.tool_callbacks
255 }
256
257 /// Get tool callbacks paired with their tool definitions.
258 ///
259 /// This is the primary method for integrating MCP tools with the model's
260 /// automatic tool calling system. Each entry contains:
261 /// - A callback function that executes the tool with timeout and concurrency controls
262 /// - A [`Tool`] definition with name, description, and parameter schema
263 ///
264 /// # Example
265 ///
266 /// ```rust,no_run
267 /// # use hanzo_llm_mcp::{McpClient, McpClientConfig};
268 /// # async fn example() -> anyhow::Result<()> {
269 /// let config = McpClientConfig::default();
270 /// let mut client = McpClient::new(config);
271 /// client.initialize().await?;
272 ///
273 /// let tools = client.get_tool_callbacks_with_tools();
274 /// for (name, tool_with_callback) in tools {
275 /// println!("Tool: {} - {:?}", name, tool_with_callback.tool.function.description);
276 /// }
277 /// # Ok(())
278 /// # }
279 /// ```
280 pub fn get_tool_callbacks_with_tools(&self) -> &HashMap<String, ToolCallbackWithTool> {
281 &self.tool_callbacks_with_tools
282 }
283
284 /// Get information about all discovered tools.
285 ///
286 /// Returns metadata about tools discovered from connected MCP servers,
287 /// including their names, descriptions, input schemas, and which server
288 /// they came from.
289 pub fn get_tools(&self) -> &HashMap<String, McpToolInfo> {
290 &self.tools
291 }
292
293 /// Get a reference to all connected MCP server connections.
294 ///
295 /// This provides direct access to server connections, allowing you to:
296 /// - List available resources with [`McpServerConnection::list_resources`]
297 /// - Read resources with [`McpServerConnection::read_resource`]
298 /// - Check server health with [`McpServerConnection::ping`]
299 /// - Call tools directly with [`McpServerConnection::call_tool`]
300 ///
301 /// # Example
302 ///
303 /// ```rust,no_run
304 /// # use hanzo_llm_mcp::{McpClient, McpClientConfig};
305 /// # async fn example() -> anyhow::Result<()> {
306 /// let config = McpClientConfig::default();
307 /// let mut client = McpClient::new(config);
308 /// client.initialize().await?;
309 ///
310 /// for (server_id, connection) in client.servers() {
311 /// println!("Server: {} ({})", connection.server_name(), server_id);
312 /// let resources = connection.list_resources().await?;
313 /// println!(" Resources: {:?}", resources);
314 /// }
315 /// # Ok(())
316 /// # }
317 /// ```
318 pub fn servers(&self) -> &HashMap<String, Arc<dyn McpServerConnection>> {
319 &self.servers
320 }
321
322 /// Get a specific server connection by its ID.
323 ///
324 /// Returns `None` if no server with the given ID is connected.
325 ///
326 /// # Example
327 ///
328 /// ```rust,no_run
329 /// # use hanzo_llm_mcp::{McpClient, McpClientConfig};
330 /// # async fn example() -> anyhow::Result<()> {
331 /// let config = McpClientConfig::default();
332 /// let mut client = McpClient::new(config);
333 /// client.initialize().await?;
334 ///
335 /// if let Some(server) = client.server("my_server_id") {
336 /// server.ping().await?;
337 /// let resources = server.list_resources().await?;
338 /// }
339 /// # Ok(())
340 /// # }
341 /// ```
342 pub fn server(&self, id: &str) -> Option<&Arc<dyn McpServerConnection>> {
343 self.servers.get(id)
344 }
345
346 /// Get the client configuration.
347 pub fn config(&self) -> &McpClientConfig {
348 &self.config
349 }
350
351 /// Create connection based on server source type
352 async fn create_connection(
353 &self,
354 config: &McpServerConfig,
355 ) -> Result<Arc<dyn McpServerConnection>> {
356 match &config.source {
357 McpServerSource::Http {
358 url,
359 timeout_secs,
360 headers,
361 } => {
362 // Merge Bearer token with existing headers if provided
363 let mut merged_headers = headers.clone().unwrap_or_default();
364 if let Some(token) = &config.bearer_token {
365 merged_headers.insert("Authorization".to_string(), format!("Bearer {token}"));
366 }
367
368 let connection = HttpMcpConnection::new(
369 config.id.clone(),
370 config.name.clone(),
371 url.clone(),
372 *timeout_secs,
373 Some(merged_headers),
374 )
375 .await?;
376 Ok(Arc::new(connection))
377 }
378 McpServerSource::Process {
379 command,
380 args,
381 work_dir,
382 env,
383 } => {
384 let connection = ProcessMcpConnection::new(
385 config.id.clone(),
386 config.name.clone(),
387 command.clone(),
388 args.clone(),
389 work_dir.clone(),
390 env.clone(),
391 )
392 .await?;
393 Ok(Arc::new(connection))
394 }
395 McpServerSource::WebSocket {
396 url,
397 timeout_secs,
398 headers,
399 } => {
400 // Merge Bearer token with existing headers if provided
401 let mut merged_headers = headers.clone().unwrap_or_default();
402 if let Some(token) = &config.bearer_token {
403 merged_headers.insert("Authorization".to_string(), format!("Bearer {token}"));
404 }
405
406 let connection = WebSocketMcpConnection::new(
407 config.id.clone(),
408 config.name.clone(),
409 url.clone(),
410 *timeout_secs,
411 Some(merged_headers),
412 )
413 .await?;
414 Ok(Arc::new(connection))
415 }
416 }
417 }
418
419 /// Discover tools from all connected servers and register them
420 async fn discover_and_register_tools(&mut self) -> Result<()> {
421 for (server_id, connection) in &self.servers {
422 let tools = connection.list_tools().await?;
423 let server_config = self
424 .config
425 .servers
426 .iter()
427 .find(|s| &s.id == server_id)
428 .ok_or_else(|| anyhow::anyhow!("Server config not found for {}", server_id))?;
429
430 for tool in tools {
431 let tool_name = if let Some(prefix) = &server_config.tool_prefix {
432 format!("{}_{}", prefix, tool.name)
433 } else {
434 tool.name.clone()
435 };
436
437 // Create tool callback that calls the MCP server with timeout and concurrency controls
438 let connection_clone = Arc::clone(connection);
439 let original_tool_name = tool.name.clone();
440 let semaphore_clone = Arc::clone(&self.concurrency_semaphore);
441 let timeout_duration =
442 Duration::from_secs(self.config.tool_timeout_secs.unwrap_or(30));
443
444 let callback: Arc<ToolCallback> = Arc::new(move |called_function, _ctx| {
445 let connection = Arc::clone(&connection_clone);
446 let tool_name = original_tool_name.clone();
447 let semaphore = Arc::clone(&semaphore_clone);
448 let arguments: serde_json::Value =
449 serde_json::from_str(&called_function.arguments)?;
450
451 // Use tokio::task::spawn_blocking to handle the async-to-sync bridge
452 let rt = tokio::runtime::Handle::current();
453 std::thread::spawn(move || {
454 rt.block_on(async move {
455 // Acquire semaphore permit for concurrency control
456 let _permit = semaphore.acquire().await.map_err(|_| {
457 anyhow::anyhow!("Failed to acquire concurrency permit")
458 })?;
459
460 // Execute tool call with timeout
461 match tokio::time::timeout(
462 timeout_duration,
463 connection.call_tool(&tool_name, arguments),
464 )
465 .await
466 {
467 Ok(result) => result,
468 Err(_) => Err(anyhow::anyhow!(
469 "Tool call timed out after {} seconds",
470 timeout_duration.as_secs()
471 )),
472 }
473 })
474 })
475 .join()
476 .map_err(|_| anyhow::anyhow!("Tool call thread panicked"))?
477 });
478
479 // MCP tools always provide an inputSchema, so we can always enable strict mode and let constrained decoding enforce it.
480 let parameters = Self::convert_mcp_schema_to_parameters(&tool.input_schema);
481 let function_def = Function {
482 name: tool_name.clone(),
483 description: tool.description.clone(),
484 strict: if parameters.is_some() {
485 Some(true)
486 } else {
487 None
488 },
489 parameters,
490 };
491
492 let tool_def = Tool {
493 tp: ToolType::Function,
494 function: function_def,
495 };
496
497 // Store in both collections for backward compatibility
498 self.tool_callbacks
499 .insert(tool_name.clone(), callback.clone());
500 self.tool_callbacks_with_tools.insert(
501 tool_name.clone(),
502 ToolCallbackWithTool {
503 callback: ToolCallbackKind::Text(callback),
504 tool: tool_def,
505 },
506 );
507 self.tools.insert(tool_name, tool);
508 }
509 }
510
511 Ok(())
512 }
513
514 /// Convert MCP tool input schema to Tool parameters format
515 fn convert_mcp_schema_to_parameters(
516 schema: &serde_json::Value,
517 ) -> Option<HashMap<String, serde_json::Value>> {
518 // MCP tools can have various schema formats, we'll try to convert common ones
519 match schema {
520 serde_json::Value::Object(obj) => {
521 let mut params = HashMap::new();
522
523 // If it's a JSON schema object, extract properties
524 if let Some(properties) = obj.get("properties") {
525 if let serde_json::Value::Object(props) = properties {
526 for (key, value) in props {
527 params.insert(key.clone(), value.clone());
528 }
529 }
530 } else {
531 // If it's just a direct object, use it as-is
532 for (key, value) in obj {
533 params.insert(key.clone(), value.clone());
534 }
535 }
536
537 if params.is_empty() {
538 None
539 } else {
540 Some(params)
541 }
542 }
543 _ => {
544 // For non-object schemas, we can't easily convert to parameters
545 None
546 }
547 }
548 }
549
550 /// Remove tools associated with a specific server
551 fn remove_tools_for_server(&mut self, server_id: &str) {
552 let tools_to_remove: Vec<String> = self
553 .tools
554 .iter()
555 .filter(|(_, info)| info.server_id == server_id)
556 .map(|(name, _)| name.clone())
557 .collect();
558
559 for name in tools_to_remove {
560 self.tools.remove(&name);
561 self.tool_callbacks.remove(&name);
562 self.tool_callbacks_with_tools.remove(&name);
563 }
564 }
565
566 /// Register tools for a single server
567 async fn register_tools_for_server(&mut self, server_id: &str) -> Result<()> {
568 let connection = self
569 .servers
570 .get(server_id)
571 .ok_or_else(|| anyhow::anyhow!("Server not connected: {}", server_id))?
572 .clone();
573
574 let server_config = self
575 .config
576 .servers
577 .iter()
578 .find(|s| s.id == server_id)
579 .ok_or_else(|| anyhow::anyhow!("Server config not found for {}", server_id))?
580 .clone();
581
582 let tools = connection.list_tools().await?;
583
584 for tool in tools {
585 let tool_name = if let Some(prefix) = &server_config.tool_prefix {
586 format!("{}_{}", prefix, tool.name)
587 } else {
588 tool.name.clone()
589 };
590
591 // Create tool callback that calls the MCP server with timeout and concurrency controls
592 let connection_clone = Arc::clone(&connection);
593 let original_tool_name = tool.name.clone();
594 let semaphore_clone = Arc::clone(&self.concurrency_semaphore);
595 let timeout_duration = Duration::from_secs(self.config.tool_timeout_secs.unwrap_or(30));
596
597 let callback: Arc<ToolCallback> = Arc::new(move |called_function, _ctx| {
598 let connection = Arc::clone(&connection_clone);
599 let tool_name = original_tool_name.clone();
600 let semaphore = Arc::clone(&semaphore_clone);
601 let arguments: serde_json::Value =
602 serde_json::from_str(&called_function.arguments)?;
603
604 // Use tokio::task::spawn_blocking to handle the async-to-sync bridge
605 let rt = tokio::runtime::Handle::current();
606 std::thread::spawn(move || {
607 rt.block_on(async move {
608 // Acquire semaphore permit for concurrency control
609 let _permit = semaphore
610 .acquire()
611 .await
612 .map_err(|_| anyhow::anyhow!("Failed to acquire concurrency permit"))?;
613
614 // Execute tool call with timeout
615 match tokio::time::timeout(
616 timeout_duration,
617 connection.call_tool(&tool_name, arguments),
618 )
619 .await
620 {
621 Ok(result) => result,
622 Err(_) => Err(anyhow::anyhow!(
623 "Tool call timed out after {} seconds",
624 timeout_duration.as_secs()
625 )),
626 }
627 })
628 })
629 .join()
630 .map_err(|_| anyhow::anyhow!("Tool call thread panicked"))?
631 });
632
633 // Convert MCP tool schema to Tool definition.
634 // Auto-enable strict mode when a schema is available.
635 let parameters = Self::convert_mcp_schema_to_parameters(&tool.input_schema);
636 let function_def = Function {
637 name: tool_name.clone(),
638 description: tool.description.clone(),
639 strict: if parameters.is_some() {
640 Some(true)
641 } else {
642 None
643 },
644 parameters,
645 };
646
647 let tool_def = Tool {
648 tp: ToolType::Function,
649 function: function_def,
650 };
651
652 // Store in both collections for backward compatibility
653 self.tool_callbacks
654 .insert(tool_name.clone(), callback.clone());
655 self.tool_callbacks_with_tools.insert(
656 tool_name.clone(),
657 ToolCallbackWithTool {
658 callback: ToolCallbackKind::Text(callback),
659 tool: tool_def,
660 },
661 );
662 self.tools.insert(tool_name, tool);
663 }
664
665 Ok(())
666 }
667
668 // ==================== Connection Management Methods ====================
669
670 /// Gracefully shutdown all server connections.
671 ///
672 /// Closes all active connections and clears the tools and callbacks.
673 /// The client cannot be used after calling this method without re-initializing.
674 ///
675 /// # Example
676 ///
677 /// ```rust,no_run
678 /// # use hanzo_llm_mcp::{McpClient, McpClientConfig};
679 /// # async fn example() -> anyhow::Result<()> {
680 /// let config = McpClientConfig::default();
681 /// let mut client = McpClient::new(config);
682 /// client.initialize().await?;
683 ///
684 /// // ... use the client ...
685 ///
686 /// // Gracefully shutdown when done
687 /// client.shutdown().await?;
688 /// # Ok(())
689 /// # }
690 /// ```
691 pub async fn shutdown(&mut self) -> Result<()> {
692 // Close all connections
693 for connection in self.servers.values() {
694 let _ = connection.close().await;
695 }
696
697 // Clear all state
698 self.servers.clear();
699 self.tools.clear();
700 self.tool_callbacks.clear();
701 self.tool_callbacks_with_tools.clear();
702
703 Ok(())
704 }
705
706 /// Disconnect a specific server by its ID.
707 ///
708 /// Removes the server from active connections and clears its associated tools.
709 /// Returns an error if the server ID is not found.
710 ///
711 /// # Example
712 ///
713 /// ```rust,no_run
714 /// # use hanzo_llm_mcp::{McpClient, McpClientConfig};
715 /// # async fn example() -> anyhow::Result<()> {
716 /// let config = McpClientConfig::default();
717 /// let mut client = McpClient::new(config);
718 /// client.initialize().await?;
719 ///
720 /// // Disconnect a specific server
721 /// client.disconnect("my_server_id").await?;
722 /// # Ok(())
723 /// # }
724 /// ```
725 pub async fn disconnect(&mut self, id: &str) -> Result<()> {
726 let connection = self
727 .servers
728 .remove(id)
729 .ok_or_else(|| anyhow::anyhow!("Server not connected: {}", id))?;
730
731 connection.close().await?;
732 self.remove_tools_for_server(id);
733
734 Ok(())
735 }
736
737 /// Reconnect to a specific server by its ID.
738 ///
739 /// Re-establishes the connection using the stored configuration.
740 /// Returns an error if the server ID is not in the configuration.
741 ///
742 /// # Example
743 ///
744 /// ```rust,no_run
745 /// # use hanzo_llm_mcp::{McpClient, McpClientConfig};
746 /// # async fn example() -> anyhow::Result<()> {
747 /// let config = McpClientConfig::default();
748 /// let mut client = McpClient::new(config);
749 /// client.initialize().await?;
750 ///
751 /// // Reconnect to a server after it was disconnected or lost connection
752 /// client.reconnect("my_server_id").await?;
753 /// # Ok(())
754 /// # }
755 /// ```
756 pub async fn reconnect(&mut self, id: &str) -> Result<()> {
757 // Find the server config
758 let server_config = self
759 .config
760 .servers
761 .iter()
762 .find(|s| s.id == id)
763 .ok_or_else(|| anyhow::anyhow!("Server config not found: {}", id))?
764 .clone();
765
766 // Close existing connection if any
767 if let Some(connection) = self.servers.remove(id) {
768 let _ = connection.close().await;
769 }
770
771 // Remove old tools for this server
772 self.remove_tools_for_server(id);
773
774 // Create new connection
775 let connection = self.create_connection(&server_config).await?;
776 self.servers.insert(id.to_string(), connection);
777
778 // Re-register tools if auto_register_tools is enabled
779 if self.config.auto_register_tools {
780 self.register_tools_for_server(id).await?;
781 }
782
783 Ok(())
784 }
785
786 /// Check if a specific server is currently connected.
787 ///
788 /// # Example
789 ///
790 /// ```rust,no_run
791 /// # use hanzo_llm_mcp::{McpClient, McpClientConfig};
792 /// # async fn example() -> anyhow::Result<()> {
793 /// let config = McpClientConfig::default();
794 /// let mut client = McpClient::new(config);
795 /// client.initialize().await?;
796 ///
797 /// if client.is_connected("my_server_id") {
798 /// println!("Server is connected");
799 /// }
800 /// # Ok(())
801 /// # }
802 /// ```
803 pub fn is_connected(&self, id: &str) -> bool {
804 self.servers.contains_key(id)
805 }
806
807 /// Dynamically add and connect a new server at runtime.
808 ///
809 /// Adds the server configuration and establishes the connection.
810 /// If auto_register_tools is enabled, discovers and registers the server's tools.
811 ///
812 /// # Example
813 ///
814 /// ```rust,no_run
815 /// # use hanzo_llm_mcp::{McpClient, McpClientConfig, McpServerConfig, McpServerSource};
816 /// # async fn example() -> anyhow::Result<()> {
817 /// let config = McpClientConfig::default();
818 /// let mut client = McpClient::new(config);
819 /// client.initialize().await?;
820 ///
821 /// // Add a new server dynamically
822 /// let new_server = McpServerConfig {
823 /// id: "new_server".to_string(),
824 /// name: "New MCP Server".to_string(),
825 /// source: McpServerSource::Http {
826 /// url: "https://api.example.com/mcp".to_string(),
827 /// timeout_secs: Some(30),
828 /// headers: None,
829 /// },
830 /// ..Default::default()
831 /// };
832 /// client.add_server(new_server).await?;
833 /// # Ok(())
834 /// # }
835 /// ```
836 pub async fn add_server(&mut self, config: McpServerConfig) -> Result<()> {
837 let id = config.id.clone();
838
839 // Check if server already exists
840 if self.servers.contains_key(&id) {
841 return Err(anyhow::anyhow!("Server already exists: {}", id));
842 }
843
844 // Create connection
845 let connection = self.create_connection(&config).await?;
846 self.servers.insert(id.clone(), connection);
847
848 // Store config
849 self.config.servers.push(config);
850
851 // Register tools if enabled
852 if self.config.auto_register_tools {
853 self.register_tools_for_server(&id).await?;
854 }
855
856 Ok(())
857 }
858
859 /// Disconnect and remove a server from the client.
860 ///
861 /// Closes the connection and removes the server from the configuration.
862 ///
863 /// # Example
864 ///
865 /// ```rust,no_run
866 /// # use hanzo_llm_mcp::{McpClient, McpClientConfig};
867 /// # async fn example() -> anyhow::Result<()> {
868 /// let config = McpClientConfig::default();
869 /// let mut client = McpClient::new(config);
870 /// client.initialize().await?;
871 ///
872 /// // Remove a server completely
873 /// client.remove_server("my_server_id").await?;
874 /// # Ok(())
875 /// # }
876 /// ```
877 pub async fn remove_server(&mut self, id: &str) -> Result<()> {
878 // Disconnect first
879 if let Some(connection) = self.servers.remove(id) {
880 let _ = connection.close().await;
881 }
882
883 // Remove tools
884 self.remove_tools_for_server(id);
885
886 // Remove from config
887 self.config.servers.retain(|s| s.id != id);
888
889 Ok(())
890 }
891
892 // ==================== Tool Management Methods ====================
893
894 /// Re-discover tools from all connected servers.
895 ///
896 /// Clears existing tool registrations and re-queries all servers.
897 /// Useful for long-running clients when servers update their tools.
898 ///
899 /// # Example
900 ///
901 /// ```rust,no_run
902 /// # use hanzo_llm_mcp::{McpClient, McpClientConfig};
903 /// # async fn example() -> anyhow::Result<()> {
904 /// let config = McpClientConfig::default();
905 /// let mut client = McpClient::new(config);
906 /// client.initialize().await?;
907 ///
908 /// // Refresh tools after servers have been updated
909 /// client.refresh_tools().await?;
910 /// # Ok(())
911 /// # }
912 /// ```
913 pub async fn refresh_tools(&mut self) -> Result<()> {
914 // Clear all existing tools
915 self.tools.clear();
916 self.tool_callbacks.clear();
917 self.tool_callbacks_with_tools.clear();
918
919 // Re-discover tools from all servers
920 self.discover_and_register_tools().await
921 }
922
923 /// Get a specific tool by name.
924 ///
925 /// Returns `None` if no tool with the given name is registered.
926 ///
927 /// # Example
928 ///
929 /// ```rust,no_run
930 /// # use hanzo_llm_mcp::{McpClient, McpClientConfig};
931 /// # async fn example() -> anyhow::Result<()> {
932 /// let config = McpClientConfig::default();
933 /// let mut client = McpClient::new(config);
934 /// client.initialize().await?;
935 ///
936 /// if let Some(tool) = client.get_tool("web_search") {
937 /// println!("Found tool: {:?}", tool.description);
938 /// }
939 /// # Ok(())
940 /// # }
941 /// ```
942 pub fn get_tool(&self, name: &str) -> Option<&McpToolInfo> {
943 self.tools.get(name)
944 }
945
946 /// Check if a tool with the given name exists.
947 ///
948 /// # Example
949 ///
950 /// ```rust,no_run
951 /// # use hanzo_llm_mcp::{McpClient, McpClientConfig};
952 /// # async fn example() -> anyhow::Result<()> {
953 /// let config = McpClientConfig::default();
954 /// let mut client = McpClient::new(config);
955 /// client.initialize().await?;
956 ///
957 /// if client.has_tool("web_search") {
958 /// println!("Tool is available");
959 /// }
960 /// # Ok(())
961 /// # }
962 /// ```
963 pub fn has_tool(&self, name: &str) -> bool {
964 self.tools.contains_key(name)
965 }
966
967 /// Directly call a tool by name with the given arguments.
968 ///
969 /// This bypasses the callback system and calls the tool directly
970 /// on the appropriate server with timeout and concurrency controls.
971 ///
972 /// # Example
973 ///
974 /// ```rust,no_run
975 /// # use hanzo_llm_mcp::{McpClient, McpClientConfig};
976 /// # use serde_json::json;
977 /// # async fn example() -> anyhow::Result<()> {
978 /// let config = McpClientConfig::default();
979 /// let mut client = McpClient::new(config);
980 /// client.initialize().await?;
981 ///
982 /// let result = client.call_tool("web_search", json!({"query": "rust programming"})).await?;
983 /// println!("Result: {}", result);
984 /// # Ok(())
985 /// # }
986 /// ```
987 pub async fn call_tool(&self, name: &str, arguments: serde_json::Value) -> Result<String> {
988 let tool_info = self
989 .tools
990 .get(name)
991 .ok_or_else(|| anyhow::anyhow!("Tool not found: {}", name))?;
992
993 let connection = self
994 .servers
995 .get(&tool_info.server_id)
996 .ok_or_else(|| anyhow::anyhow!("Server not connected: {}", tool_info.server_id))?;
997
998 // Acquire semaphore permit for concurrency control
999 let _permit = self
1000 .concurrency_semaphore
1001 .acquire()
1002 .await
1003 .map_err(|_| anyhow::anyhow!("Failed to acquire concurrency permit"))?;
1004
1005 let timeout_duration = Duration::from_secs(self.config.tool_timeout_secs.unwrap_or(30));
1006
1007 // Execute tool call with timeout
1008 match tokio::time::timeout(
1009 timeout_duration,
1010 connection.call_tool(&tool_info.name, arguments),
1011 )
1012 .await
1013 {
1014 Ok(result) => result,
1015 Err(_) => Err(anyhow::anyhow!(
1016 "Tool call timed out after {} seconds",
1017 timeout_duration.as_secs()
1018 )),
1019 }
1020 }
1021
1022 /// Get the total number of registered tools.
1023 ///
1024 /// # Example
1025 ///
1026 /// ```rust,no_run
1027 /// # use hanzo_llm_mcp::{McpClient, McpClientConfig};
1028 /// # async fn example() -> anyhow::Result<()> {
1029 /// let config = McpClientConfig::default();
1030 /// let mut client = McpClient::new(config);
1031 /// client.initialize().await?;
1032 ///
1033 /// println!("Total tools: {}", client.tool_count());
1034 /// # Ok(())
1035 /// # }
1036 /// ```
1037 pub fn tool_count(&self) -> usize {
1038 self.tools.len()
1039 }
1040
1041 // ==================== Status / Convenience Methods ====================
1042
1043 /// Get the number of connected servers.
1044 ///
1045 /// # Example
1046 ///
1047 /// ```rust,no_run
1048 /// # use hanzo_llm_mcp::{McpClient, McpClientConfig};
1049 /// # async fn example() -> anyhow::Result<()> {
1050 /// let config = McpClientConfig::default();
1051 /// let mut client = McpClient::new(config);
1052 /// client.initialize().await?;
1053 ///
1054 /// println!("Connected servers: {}", client.server_count());
1055 /// # Ok(())
1056 /// # }
1057 /// ```
1058 pub fn server_count(&self) -> usize {
1059 self.servers.len()
1060 }
1061
1062 /// Get a list of all connected server IDs.
1063 ///
1064 /// # Example
1065 ///
1066 /// ```rust,no_run
1067 /// # use hanzo_llm_mcp::{McpClient, McpClientConfig};
1068 /// # async fn example() -> anyhow::Result<()> {
1069 /// let config = McpClientConfig::default();
1070 /// let mut client = McpClient::new(config);
1071 /// client.initialize().await?;
1072 ///
1073 /// for id in client.server_ids() {
1074 /// println!("Server: {}", id);
1075 /// }
1076 /// # Ok(())
1077 /// # }
1078 /// ```
1079 pub fn server_ids(&self) -> Vec<&str> {
1080 self.servers.keys().map(|s| s.as_str()).collect()
1081 }
1082
1083 /// Ping all connected servers and return results per server.
1084 ///
1085 /// Returns a map of server ID to ping result. Useful for health monitoring.
1086 ///
1087 /// # Example
1088 ///
1089 /// ```rust,no_run
1090 /// # use hanzo_llm_mcp::{McpClient, McpClientConfig};
1091 /// # async fn example() -> anyhow::Result<()> {
1092 /// let config = McpClientConfig::default();
1093 /// let mut client = McpClient::new(config);
1094 /// client.initialize().await?;
1095 ///
1096 /// let results = client.ping_all().await;
1097 /// for (server_id, result) in results {
1098 /// match result {
1099 /// Ok(()) => println!("{}: healthy", server_id),
1100 /// Err(e) => println!("{}: unhealthy - {}", server_id, e),
1101 /// }
1102 /// }
1103 /// # Ok(())
1104 /// # }
1105 /// ```
1106 pub async fn ping_all(&self) -> HashMap<String, Result<()>> {
1107 let mut results = HashMap::new();
1108
1109 for (server_id, connection) in &self.servers {
1110 let result = connection.ping().await;
1111 results.insert(server_id.clone(), result);
1112 }
1113
1114 results
1115 }
1116
1117 // ==================== Resource Access Methods ====================
1118
1119 /// List resources from all connected servers.
1120 ///
1121 /// Returns a vector of (server_id, resource) tuples.
1122 ///
1123 /// # Example
1124 ///
1125 /// ```rust,no_run
1126 /// # use hanzo_llm_mcp::{McpClient, McpClientConfig};
1127 /// # async fn example() -> anyhow::Result<()> {
1128 /// let config = McpClientConfig::default();
1129 /// let mut client = McpClient::new(config);
1130 /// client.initialize().await?;
1131 ///
1132 /// let resources = client.list_all_resources().await?;
1133 /// for (server_id, resource) in resources {
1134 /// println!("Server {}: {:?}", server_id, resource);
1135 /// }
1136 /// # Ok(())
1137 /// # }
1138 /// ```
1139 pub async fn list_all_resources(&self) -> Result<Vec<(String, Resource)>> {
1140 let mut all_resources = Vec::new();
1141
1142 for (server_id, connection) in &self.servers {
1143 match connection.list_resources().await {
1144 Ok(resources) => {
1145 for resource in resources {
1146 all_resources.push((server_id.clone(), resource));
1147 }
1148 }
1149 Err(e) => {
1150 // Log error but continue with other servers
1151 warn!("Failed to list resources from server {}: {}", server_id, e);
1152 }
1153 }
1154 }
1155
1156 Ok(all_resources)
1157 }
1158}
1159
1160impl Drop for McpClient {
1161 fn drop(&mut self) {
1162 // Try to get the tokio runtime handle
1163 if let Ok(handle) = tokio::runtime::Handle::try_current() {
1164 let servers = std::mem::take(&mut self.servers);
1165 handle.spawn(async move {
1166 for (_, connection) in servers {
1167 let _ = connection.close().await;
1168 }
1169 });
1170 }
1171 }
1172}
1173
1174/// HTTP-based MCP server connection
1175pub struct HttpMcpConnection {
1176 server_id: String,
1177 server_name: String,
1178 transport: Arc<dyn McpTransport>,
1179}
1180
1181impl HttpMcpConnection {
1182 pub async fn new(
1183 server_id: String,
1184 server_name: String,
1185 url: String,
1186 timeout_secs: Option<u64>,
1187 headers: Option<HashMap<String, String>>,
1188 ) -> Result<Self> {
1189 let transport = HttpTransport::new(url, timeout_secs, headers)?;
1190
1191 let connection = Self {
1192 server_id,
1193 server_name,
1194 transport: Arc::new(transport),
1195 };
1196
1197 // Initialize the connection
1198 connection.initialize().await?;
1199
1200 Ok(connection)
1201 }
1202
1203 async fn initialize(&self) -> Result<()> {
1204 initialize_transport(&self.transport).await
1205 }
1206}
1207
1208#[async_trait::async_trait]
1209impl McpServerConnection for HttpMcpConnection {
1210 fn server_id(&self) -> &str {
1211 &self.server_id
1212 }
1213
1214 fn server_name(&self) -> &str {
1215 &self.server_name
1216 }
1217
1218 async fn list_tools(&self) -> Result<Vec<McpToolInfo>> {
1219 list_tools_via_transport(&self.transport, &self.server_id, &self.server_name).await
1220 }
1221
1222 async fn call_tool(&self, name: &str, arguments: Value) -> Result<String> {
1223 call_tool_via_transport(&self.transport, name, arguments).await
1224 }
1225
1226 async fn list_resources(&self) -> Result<Vec<Resource>> {
1227 list_resources_via_transport(&self.transport).await
1228 }
1229
1230 async fn read_resource(&self, uri: &str) -> Result<String> {
1231 read_resource_via_transport(&self.transport, uri).await
1232 }
1233
1234 async fn ping(&self) -> Result<()> {
1235 ping_transport(&self.transport).await
1236 }
1237
1238 async fn close(&self) -> Result<()> {
1239 self.transport.close().await
1240 }
1241}
1242
1243/// Process-based MCP server connection
1244pub struct ProcessMcpConnection {
1245 server_id: String,
1246 server_name: String,
1247 transport: Arc<dyn McpTransport>,
1248}
1249
1250impl ProcessMcpConnection {
1251 pub async fn new(
1252 server_id: String,
1253 server_name: String,
1254 command: String,
1255 args: Vec<String>,
1256 work_dir: Option<String>,
1257 env: Option<HashMap<String, String>>,
1258 ) -> Result<Self> {
1259 let transport = ProcessTransport::new(command, args, work_dir, env).await?;
1260
1261 let connection = Self {
1262 server_id,
1263 server_name,
1264 transport: Arc::new(transport),
1265 };
1266
1267 // Initialize the connection
1268 connection.initialize().await?;
1269
1270 Ok(connection)
1271 }
1272
1273 async fn initialize(&self) -> Result<()> {
1274 initialize_transport(&self.transport).await
1275 }
1276}
1277
1278#[async_trait::async_trait]
1279impl McpServerConnection for ProcessMcpConnection {
1280 fn server_id(&self) -> &str {
1281 &self.server_id
1282 }
1283
1284 fn server_name(&self) -> &str {
1285 &self.server_name
1286 }
1287
1288 async fn list_tools(&self) -> Result<Vec<McpToolInfo>> {
1289 list_tools_via_transport(&self.transport, &self.server_id, &self.server_name).await
1290 }
1291
1292 async fn call_tool(&self, name: &str, arguments: Value) -> Result<String> {
1293 call_tool_via_transport(&self.transport, name, arguments).await
1294 }
1295
1296 async fn list_resources(&self) -> Result<Vec<Resource>> {
1297 list_resources_via_transport(&self.transport).await
1298 }
1299
1300 async fn read_resource(&self, uri: &str) -> Result<String> {
1301 read_resource_via_transport(&self.transport, uri).await
1302 }
1303
1304 async fn ping(&self) -> Result<()> {
1305 ping_transport(&self.transport).await
1306 }
1307
1308 async fn close(&self) -> Result<()> {
1309 self.transport.close().await
1310 }
1311}
1312
1313/// WebSocket-based MCP server connection
1314pub struct WebSocketMcpConnection {
1315 server_id: String,
1316 server_name: String,
1317 transport: Arc<dyn McpTransport>,
1318}
1319
1320impl WebSocketMcpConnection {
1321 pub async fn new(
1322 server_id: String,
1323 server_name: String,
1324 url: String,
1325 timeout_secs: Option<u64>,
1326 headers: Option<HashMap<String, String>>,
1327 ) -> Result<Self> {
1328 let transport = WebSocketTransport::new(url, timeout_secs, headers).await?;
1329
1330 let connection = Self {
1331 server_id,
1332 server_name,
1333 transport: Arc::new(transport),
1334 };
1335
1336 // Initialize the connection
1337 connection.initialize().await?;
1338
1339 Ok(connection)
1340 }
1341
1342 async fn initialize(&self) -> Result<()> {
1343 initialize_transport(&self.transport).await
1344 }
1345}
1346
1347#[async_trait::async_trait]
1348impl McpServerConnection for WebSocketMcpConnection {
1349 fn server_id(&self) -> &str {
1350 &self.server_id
1351 }
1352
1353 fn server_name(&self) -> &str {
1354 &self.server_name
1355 }
1356
1357 async fn list_tools(&self) -> Result<Vec<McpToolInfo>> {
1358 list_tools_via_transport(&self.transport, &self.server_id, &self.server_name).await
1359 }
1360
1361 async fn call_tool(&self, name: &str, arguments: Value) -> Result<String> {
1362 call_tool_via_transport(&self.transport, name, arguments).await
1363 }
1364
1365 async fn list_resources(&self) -> Result<Vec<Resource>> {
1366 list_resources_via_transport(&self.transport).await
1367 }
1368
1369 async fn read_resource(&self, uri: &str) -> Result<String> {
1370 read_resource_via_transport(&self.transport, uri).await
1371 }
1372
1373 async fn ping(&self) -> Result<()> {
1374 ping_transport(&self.transport).await
1375 }
1376
1377 async fn close(&self) -> Result<()> {
1378 self.transport.close().await
1379 }
1380}