1use crate::error::{Error, Result};
4use crate::metadata::ToolMetadata;
5use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::RwLock;
8use tracing::{debug, info};
9
10#[derive(Debug, Clone)]
12pub struct ServerConnection {
13 pub id: String,
14 pub name: String,
15 pub is_connected: bool,
16 pub tools: Vec<ToolMetadata>,
17}
18
19#[derive(Debug, Clone)]
21pub struct MCPClient {
22 connections: Arc<RwLock<HashMap<String, ServerConnection>>>,
23}
24
25impl MCPClient {
26 pub fn new() -> Self {
28 Self {
29 connections: Arc::new(RwLock::new(HashMap::new())),
30 }
31 }
32
33 pub fn with_timeout(_timeout_ms: u64) -> Self {
35 Self {
36 connections: Arc::new(RwLock::new(HashMap::new())),
37 }
38 }
39
40 pub async fn connect(&self, server_id: &str, server_name: &str) -> Result<()> {
52 debug!("Connecting to MCP server: {} ({})", server_id, server_name);
53
54 let connection = ServerConnection {
56 id: server_id.to_string(),
57 name: server_name.to_string(),
58 is_connected: true,
59 tools: Vec::new(),
60 };
61
62 let mut connections = self.connections.write().await;
63 connections.insert(server_id.to_string(), connection);
64
65 info!("Connected to MCP server: {}", server_id);
66 Ok(())
67 }
68
69 pub async fn disconnect(&self, server_id: &str) -> Result<()> {
77 debug!("Disconnecting from MCP server: {}", server_id);
78
79 let mut connections = self.connections.write().await;
80 if let Some(conn) = connections.get_mut(server_id) {
81 conn.is_connected = false;
82 }
83
84 info!("Disconnected from MCP server: {}", server_id);
85 Ok(())
86 }
87
88 pub async fn discover_servers(&self) -> Result<Vec<String>> {
93 debug!("Discovering MCP servers");
94
95 let connections = self.connections.read().await;
96 let servers: Vec<String> = connections.keys().cloned().collect();
97
98 info!("Discovered {} MCP servers", servers.len());
99 Ok(servers)
100 }
101
102 pub async fn discover_tools(&self, server_id: &str) -> Result<Vec<ToolMetadata>> {
113 debug!("Discovering tools from server: {}", server_id);
114
115 let connections = self.connections.read().await;
116 let connection = connections
117 .get(server_id)
118 .ok_or_else(|| Error::ConnectionError(format!("Server not connected: {}", server_id)))?;
119
120 if !connection.is_connected {
121 return Err(Error::ConnectionError(format!(
122 "Server not connected: {}",
123 server_id
124 )));
125 }
126
127 let tools = connection.tools.clone();
128 info!(
129 "Discovered {} tools from server: {}",
130 tools.len(),
131 server_id
132 );
133 Ok(tools)
134 }
135
136 pub async fn register_tools(&self, server_id: &str, tools: Vec<ToolMetadata>) -> Result<()> {
148 debug!(
149 "Registering {} tools from server: {}",
150 tools.len(),
151 server_id
152 );
153
154 let mut connections = self.connections.write().await;
155 let connection = connections
156 .get_mut(server_id)
157 .ok_or_else(|| Error::ConnectionError(format!("Server not connected: {}", server_id)))?;
158
159 connection.tools = tools;
160 info!("Registered tools for server: {}", server_id);
161 Ok(())
162 }
163
164 pub async fn get_connected_servers(&self) -> Result<Vec<ServerConnection>> {
169 let connections = self.connections.read().await;
170 let servers: Vec<ServerConnection> = connections
171 .values()
172 .filter(|c| c.is_connected)
173 .cloned()
174 .collect();
175
176 Ok(servers)
177 }
178
179 pub async fn get_server(&self, server_id: &str) -> Result<Option<ServerConnection>> {
187 let connections = self.connections.read().await;
188 Ok(connections.get(server_id).cloned())
189 }
190
191 pub async fn is_connected(&self, server_id: &str) -> bool {
199 let connections = self.connections.read().await;
200 connections
201 .get(server_id)
202 .map(|c| c.is_connected)
203 .unwrap_or(false)
204 }
205
206 pub async fn connected_server_count(&self) -> usize {
208 let connections = self.connections.read().await;
209 connections.values().filter(|c| c.is_connected).count()
210 }
211
212 pub async fn get_all_tools(&self) -> Result<Vec<ToolMetadata>> {
217 let connections = self.connections.read().await;
218 let mut all_tools = Vec::new();
219
220 for connection in connections.values() {
221 if connection.is_connected {
222 all_tools.extend(connection.tools.clone());
223 }
224 }
225
226 Ok(all_tools)
227 }
228}
229
230impl Default for MCPClient {
231 fn default() -> Self {
232 Self::new()
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239
240 #[tokio::test]
241 async fn test_create_client() {
242 let client = MCPClient::new();
243 assert_eq!(client.connected_server_count().await, 0);
244 }
245
246 #[tokio::test]
247 async fn test_connect_server() {
248 let client = MCPClient::new();
249 let result = client.connect("server1", "Test Server").await;
250 assert!(result.is_ok());
251 assert!(client.is_connected("server1").await);
252 }
253
254 #[tokio::test]
255 async fn test_disconnect_server() {
256 let client = MCPClient::new();
257 client.connect("server1", "Test Server").await.unwrap();
258 assert!(client.is_connected("server1").await);
259
260 let result = client.disconnect("server1").await;
261 assert!(result.is_ok());
262 assert!(!client.is_connected("server1").await);
263 }
264
265 #[tokio::test]
266 async fn test_discover_servers() {
267 let client = MCPClient::new();
268 client.connect("server1", "Server 1").await.unwrap();
269 client.connect("server2", "Server 2").await.unwrap();
270
271 let servers = client.discover_servers().await.unwrap();
272 assert_eq!(servers.len(), 2);
273 }
274
275 #[tokio::test]
276 async fn test_register_and_discover_tools() {
277 use crate::metadata::ToolSource;
278
279 let client = MCPClient::new();
280 client.connect("server1", "Test Server").await.unwrap();
281
282 let tool = ToolMetadata {
283 id: "test-tool".to_string(),
284 name: "Test Tool".to_string(),
285 description: "A test tool".to_string(),
286 category: "test".to_string(),
287 parameters: vec![],
288 return_type: "string".to_string(),
289 source: ToolSource::Mcp("server1".to_string()),
290 server_id: Some("server1".to_string()),
291 };
292
293 client
294 .register_tools("server1", vec![tool.clone()])
295 .await
296 .unwrap();
297
298 let tools = client.discover_tools("server1").await.unwrap();
299 assert_eq!(tools.len(), 1);
300 assert_eq!(tools[0].id, "test-tool");
301 }
302
303 #[tokio::test]
304 async fn test_get_connected_servers() {
305 let client = MCPClient::new();
306 client.connect("server1", "Server 1").await.unwrap();
307 client.connect("server2", "Server 2").await.unwrap();
308
309 let servers = client.get_connected_servers().await.unwrap();
310 assert_eq!(servers.len(), 2);
311 }
312
313 #[tokio::test]
314 async fn test_get_all_tools() {
315 use crate::metadata::ToolSource;
316
317 let client = MCPClient::new();
318 client.connect("server1", "Server 1").await.unwrap();
319
320 let tool1 = ToolMetadata {
321 id: "tool1".to_string(),
322 name: "Tool 1".to_string(),
323 description: "Tool 1".to_string(),
324 category: "test".to_string(),
325 parameters: vec![],
326 return_type: "string".to_string(),
327 source: ToolSource::Mcp("server1".to_string()),
328 server_id: Some("server1".to_string()),
329 };
330
331 let tool2 = ToolMetadata {
332 id: "tool2".to_string(),
333 name: "Tool 2".to_string(),
334 description: "Tool 2".to_string(),
335 category: "test".to_string(),
336 parameters: vec![],
337 return_type: "string".to_string(),
338 source: ToolSource::Mcp("server1".to_string()),
339 server_id: Some("server1".to_string()),
340 };
341
342 client
343 .register_tools("server1", vec![tool1, tool2])
344 .await
345 .unwrap();
346
347 let all_tools = client.get_all_tools().await.unwrap();
348 assert_eq!(all_tools.len(), 2);
349 }
350
351 #[tokio::test]
352 async fn test_discover_tools_not_connected() {
353 let client = MCPClient::new();
354 let result = client.discover_tools("nonexistent").await;
355 assert!(result.is_err());
356 }
357}