model_context_protocol/
hub.rs1use serde_json::Value;
7use std::sync::Arc;
8use std::time::Duration;
9
10use crate::circuit_breaker::CircuitBreakerStats;
11use crate::hub_common::HubConnections;
12use crate::protocol::McpToolDefinition;
13use crate::transport::{McpServerConnectionConfig, McpTransport, McpTransportError};
14
15pub struct McpHub {
42 connections: HubConnections,
44 discovery_timeout: Duration,
46}
47
48impl Default for McpHub {
49 fn default() -> Self {
50 Self::new()
51 }
52}
53
54impl McpHub {
55 pub fn new() -> Self {
57 Self {
58 connections: HubConnections::new(),
59 discovery_timeout: Duration::from_secs(30),
60 }
61 }
62
63 pub fn with_discovery_timeout(timeout: Duration) -> Self {
65 Self {
66 connections: HubConnections::new(),
67 discovery_timeout: timeout,
68 }
69 }
70
71 pub async fn connect(
79 &self,
80 config: McpServerConnectionConfig,
81 ) -> Result<Arc<dyn McpTransport>, McpTransportError> {
82 let conn = self.connections.connect(config).await?;
83 conn.get_transport()
84 .await
85 .ok_or(McpTransportError::ConnectionClosed)
86 }
87
88 pub async fn call_tool(&self, name: &str, args: Value) -> Result<Value, McpTransportError> {
93 self.connections.call_tool(name, args).await
94 }
95
96 pub async fn list_tools(&self) -> Result<Vec<(String, McpToolDefinition)>, McpTransportError> {
98 Ok(self.connections.list_tools())
99 }
100
101 pub async fn list_all_tools(&self) -> Result<Vec<McpToolDefinition>, McpTransportError> {
103 Ok(self.connections.list_tool_definitions())
104 }
105
106 pub async fn discover_tools_parallel(
110 &self,
111 ) -> Result<Vec<(String, McpToolDefinition)>, McpTransportError> {
112 self.connections
113 .discover_tools_parallel(self.discovery_timeout)
114 .await
115 }
116
117 pub async fn refresh_tool_cache(&self) -> Result<(), McpTransportError> {
119 self.connections
120 .refresh_tools_parallel(self.discovery_timeout)
121 .await
122 }
123
124 pub async fn shutdown_all(&self) -> Result<(), McpTransportError> {
126 let mut errors = Vec::new();
127
128 for (server_name, conn) in self.connections.iter() {
129 if let Some(transport) = conn.get_transport().await {
130 if let Err(e) = transport.shutdown().await {
131 errors.push(format!("{}: {}", server_name, e));
132 }
133 }
134 }
135 self.connections.clear();
136
137 if errors.is_empty() {
138 Ok(())
139 } else {
140 Err(McpTransportError::TransportError(errors.join("; ")))
141 }
142 }
143
144 pub async fn disconnect(&self, server_name: &str) -> Result<(), McpTransportError> {
146 let conn = self
147 .connections
148 .remove(server_name)
149 .ok_or_else(|| McpTransportError::ServerNotFound(server_name.to_string()))?;
150
151 self.connections.clear_tools_for_server(server_name);
152
153 if let Some(transport) = conn.get_transport().await {
154 transport.shutdown().await?;
155 }
156 Ok(())
157 }
158
159 pub fn list_servers(&self) -> Vec<String> {
161 self.connections.list_servers()
162 }
163
164 pub fn is_connected(&self, server_name: &str) -> bool {
166 self.connections.is_connected(server_name)
167 }
168
169 pub async fn health_check(&self) -> Vec<(String, bool)> {
171 self.connections.health_check().await
172 }
173
174 pub fn server_for_tool(&self, tool_name: &str) -> Option<String> {
176 self.connections.server_for_tool(tool_name)
177 }
178
179 pub fn circuit_breaker_stats(&self, server_name: &str) -> Option<CircuitBreakerStats> {
181 self.connections.circuit_breaker_stats(server_name)
182 }
183
184 pub fn reset_circuit_breaker(&self, server_name: &str) {
186 self.connections.reset_circuit_breaker(server_name);
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193
194 #[tokio::test]
195 async fn test_hub_creation() {
196 let hub = McpHub::new();
197 let servers = hub.list_servers();
198 assert!(servers.is_empty());
199 }
200
201 #[tokio::test]
202 async fn test_hub_unknown_tool() {
203 let hub = McpHub::new();
204
205 let result = hub
206 .call_tool("nonexistent_tool", serde_json::json!({}))
207 .await;
208 assert!(matches!(result, Err(McpTransportError::UnknownTool(_))));
209 }
210
211 #[test]
212 fn test_connection_config() {
213 let config =
214 McpServerConnectionConfig::stdio("test", "node", vec!["server.js".to_string()])
215 .with_timeout(60);
216
217 assert_eq!(config.name, "test");
218 assert_eq!(config.timeout_secs, 60);
219 }
220}