ricecoder_mcp/
client.rs

1//! MCP Client implementation
2
3use crate::error::{Error, Result};
4use crate::metadata::ToolMetadata;
5use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::RwLock;
8use tracing::{debug, info};
9
10/// MCP Server connection information
11#[derive(Debug, Clone)]
12pub struct ServerConnection {
13    pub id: String,
14    pub name: String,
15    pub is_connected: bool,
16    pub tools: Vec<ToolMetadata>,
17}
18
19/// MCP Client for communicating with MCP servers
20#[derive(Debug, Clone)]
21pub struct MCPClient {
22    connections: Arc<RwLock<HashMap<String, ServerConnection>>>,
23}
24
25impl MCPClient {
26    /// Creates a new MCP client
27    pub fn new() -> Self {
28        Self {
29            connections: Arc::new(RwLock::new(HashMap::new())),
30        }
31    }
32
33    /// Creates a new MCP client with custom timeout
34    pub fn with_timeout(_timeout_ms: u64) -> Self {
35        Self {
36            connections: Arc::new(RwLock::new(HashMap::new())),
37        }
38    }
39
40    /// Connects to an MCP server
41    ///
42    /// # Arguments
43    /// * `server_id` - Unique identifier for the server
44    /// * `server_name` - Human-readable name for the server
45    ///
46    /// # Returns
47    /// Result indicating success or failure
48    ///
49    /// # Errors
50    /// Returns error if connection fails or times out
51    pub async fn connect(&self, server_id: &str, server_name: &str) -> Result<()> {
52        debug!("Connecting to MCP server: {} ({})", server_id, server_name);
53
54        // Simulate connection with timeout
55        let connection = ServerConnection {
56            id: server_id.to_string(),
57            name: server_name.to_string(),
58            is_connected: true,
59            tools: Vec::new(),
60        };
61
62        let mut connections = self.connections.write().await;
63        connections.insert(server_id.to_string(), connection);
64
65        info!("Connected to MCP server: {}", server_id);
66        Ok(())
67    }
68
69    /// Disconnects from an MCP server
70    ///
71    /// # Arguments
72    /// * `server_id` - Unique identifier for the server
73    ///
74    /// # Returns
75    /// Result indicating success or failure
76    pub async fn disconnect(&self, server_id: &str) -> Result<()> {
77        debug!("Disconnecting from MCP server: {}", server_id);
78
79        let mut connections = self.connections.write().await;
80        if let Some(conn) = connections.get_mut(server_id) {
81            conn.is_connected = false;
82        }
83
84        info!("Disconnected from MCP server: {}", server_id);
85        Ok(())
86    }
87
88    /// Discovers available MCP servers
89    ///
90    /// # Returns
91    /// List of discovered server IDs
92    pub async fn discover_servers(&self) -> Result<Vec<String>> {
93        debug!("Discovering MCP servers");
94
95        let connections = self.connections.read().await;
96        let servers: Vec<String> = connections.keys().cloned().collect();
97
98        info!("Discovered {} MCP servers", servers.len());
99        Ok(servers)
100    }
101
102    /// Discovers tools from a specific MCP server
103    ///
104    /// # Arguments
105    /// * `server_id` - Unique identifier for the server
106    ///
107    /// # Returns
108    /// List of tools available from the server
109    ///
110    /// # Errors
111    /// Returns error if server is not connected or discovery fails
112    pub async fn discover_tools(&self, server_id: &str) -> Result<Vec<ToolMetadata>> {
113        debug!("Discovering tools from server: {}", server_id);
114
115        let connections = self.connections.read().await;
116        let connection = connections
117            .get(server_id)
118            .ok_or_else(|| Error::ConnectionError(format!("Server not connected: {}", server_id)))?;
119
120        if !connection.is_connected {
121            return Err(Error::ConnectionError(format!(
122                "Server not connected: {}",
123                server_id
124            )));
125        }
126
127        let tools = connection.tools.clone();
128        info!(
129            "Discovered {} tools from server: {}",
130            tools.len(),
131            server_id
132        );
133        Ok(tools)
134    }
135
136    /// Registers tools from an MCP server
137    ///
138    /// # Arguments
139    /// * `server_id` - Unique identifier for the server
140    /// * `tools` - List of tools to register
141    ///
142    /// # Returns
143    /// Result indicating success or failure
144    ///
145    /// # Errors
146    /// Returns error if server is not connected or registration fails
147    pub async fn register_tools(&self, server_id: &str, tools: Vec<ToolMetadata>) -> Result<()> {
148        debug!(
149            "Registering {} tools from server: {}",
150            tools.len(),
151            server_id
152        );
153
154        let mut connections = self.connections.write().await;
155        let connection = connections
156            .get_mut(server_id)
157            .ok_or_else(|| Error::ConnectionError(format!("Server not connected: {}", server_id)))?;
158
159        connection.tools = tools;
160        info!("Registered tools for server: {}", server_id);
161        Ok(())
162    }
163
164    /// Gets all connected servers
165    ///
166    /// # Returns
167    /// List of connected server connections
168    pub async fn get_connected_servers(&self) -> Result<Vec<ServerConnection>> {
169        let connections = self.connections.read().await;
170        let servers: Vec<ServerConnection> = connections
171            .values()
172            .filter(|c| c.is_connected)
173            .cloned()
174            .collect();
175
176        Ok(servers)
177    }
178
179    /// Gets a specific server connection
180    ///
181    /// # Arguments
182    /// * `server_id` - Unique identifier for the server
183    ///
184    /// # Returns
185    /// Server connection if found
186    pub async fn get_server(&self, server_id: &str) -> Result<Option<ServerConnection>> {
187        let connections = self.connections.read().await;
188        Ok(connections.get(server_id).cloned())
189    }
190
191    /// Checks if a server is connected
192    ///
193    /// # Arguments
194    /// * `server_id` - Unique identifier for the server
195    ///
196    /// # Returns
197    /// True if server is connected, false otherwise
198    pub async fn is_connected(&self, server_id: &str) -> bool {
199        let connections = self.connections.read().await;
200        connections
201            .get(server_id)
202            .map(|c| c.is_connected)
203            .unwrap_or(false)
204    }
205
206    /// Gets the number of connected servers
207    pub async fn connected_server_count(&self) -> usize {
208        let connections = self.connections.read().await;
209        connections.values().filter(|c| c.is_connected).count()
210    }
211
212    /// Gets all tools from all connected servers
213    ///
214    /// # Returns
215    /// List of all tools from all connected servers
216    pub async fn get_all_tools(&self) -> Result<Vec<ToolMetadata>> {
217        let connections = self.connections.read().await;
218        let mut all_tools = Vec::new();
219
220        for connection in connections.values() {
221            if connection.is_connected {
222                all_tools.extend(connection.tools.clone());
223            }
224        }
225
226        Ok(all_tools)
227    }
228}
229
230impl Default for MCPClient {
231    fn default() -> Self {
232        Self::new()
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239
240    #[tokio::test]
241    async fn test_create_client() {
242        let client = MCPClient::new();
243        assert_eq!(client.connected_server_count().await, 0);
244    }
245
246    #[tokio::test]
247    async fn test_connect_server() {
248        let client = MCPClient::new();
249        let result = client.connect("server1", "Test Server").await;
250        assert!(result.is_ok());
251        assert!(client.is_connected("server1").await);
252    }
253
254    #[tokio::test]
255    async fn test_disconnect_server() {
256        let client = MCPClient::new();
257        client.connect("server1", "Test Server").await.unwrap();
258        assert!(client.is_connected("server1").await);
259
260        let result = client.disconnect("server1").await;
261        assert!(result.is_ok());
262        assert!(!client.is_connected("server1").await);
263    }
264
265    #[tokio::test]
266    async fn test_discover_servers() {
267        let client = MCPClient::new();
268        client.connect("server1", "Server 1").await.unwrap();
269        client.connect("server2", "Server 2").await.unwrap();
270
271        let servers = client.discover_servers().await.unwrap();
272        assert_eq!(servers.len(), 2);
273    }
274
275    #[tokio::test]
276    async fn test_register_and_discover_tools() {
277        use crate::metadata::ToolSource;
278        
279        let client = MCPClient::new();
280        client.connect("server1", "Test Server").await.unwrap();
281
282        let tool = ToolMetadata {
283            id: "test-tool".to_string(),
284            name: "Test Tool".to_string(),
285            description: "A test tool".to_string(),
286            category: "test".to_string(),
287            parameters: vec![],
288            return_type: "string".to_string(),
289            source: ToolSource::Mcp("server1".to_string()),
290            server_id: Some("server1".to_string()),
291        };
292
293        client
294            .register_tools("server1", vec![tool.clone()])
295            .await
296            .unwrap();
297
298        let tools = client.discover_tools("server1").await.unwrap();
299        assert_eq!(tools.len(), 1);
300        assert_eq!(tools[0].id, "test-tool");
301    }
302
303    #[tokio::test]
304    async fn test_get_connected_servers() {
305        let client = MCPClient::new();
306        client.connect("server1", "Server 1").await.unwrap();
307        client.connect("server2", "Server 2").await.unwrap();
308
309        let servers = client.get_connected_servers().await.unwrap();
310        assert_eq!(servers.len(), 2);
311    }
312
313    #[tokio::test]
314    async fn test_get_all_tools() {
315        use crate::metadata::ToolSource;
316        
317        let client = MCPClient::new();
318        client.connect("server1", "Server 1").await.unwrap();
319
320        let tool1 = ToolMetadata {
321            id: "tool1".to_string(),
322            name: "Tool 1".to_string(),
323            description: "Tool 1".to_string(),
324            category: "test".to_string(),
325            parameters: vec![],
326            return_type: "string".to_string(),
327            source: ToolSource::Mcp("server1".to_string()),
328            server_id: Some("server1".to_string()),
329        };
330
331        let tool2 = ToolMetadata {
332            id: "tool2".to_string(),
333            name: "Tool 2".to_string(),
334            description: "Tool 2".to_string(),
335            category: "test".to_string(),
336            parameters: vec![],
337            return_type: "string".to_string(),
338            source: ToolSource::Mcp("server1".to_string()),
339            server_id: Some("server1".to_string()),
340        };
341
342        client
343            .register_tools("server1", vec![tool1, tool2])
344            .await
345            .unwrap();
346
347        let all_tools = client.get_all_tools().await.unwrap();
348        assert_eq!(all_tools.len(), 2);
349    }
350
351    #[tokio::test]
352    async fn test_discover_tools_not_connected() {
353        let client = MCPClient::new();
354        let result = client.discover_tools("nonexistent").await;
355        assert!(result.is_err());
356    }
357}