Skip to main content

codetether_agent/mcp/
client.rs

1//! MCP Client - Connect to external MCP servers
2//!
3//! Allows CodeTether to use tools from other MCP servers:
4//! - Filesystem servers
5//! - Database servers
6//! - API integration servers
7//! - Custom tool servers
8
9use super::transport::{McpMessage, ProcessTransport, Transport};
10use super::types::*;
11use anyhow::Result;
12use serde_json::Value;
13use std::collections::HashMap;
14use std::sync::atomic::{AtomicI64, Ordering};
15use std::sync::Arc;
16use std::time::Duration;
17use tokio::sync::{oneshot, RwLock};
18use tokio::time::timeout;
19use tracing::{debug, error, info, warn};
20
21/// MCP Client for connecting to external servers
22pub struct McpClient {
23    transport: Arc<dyn Transport>,
24    pending_requests: RwLock<HashMap<RequestId, oneshot::Sender<JsonRpcResponse>>>,
25    request_id: AtomicI64,
26    server_info: RwLock<Option<ServerInfo>>,
27    server_capabilities: RwLock<Option<ServerCapabilities>>,
28    available_tools: RwLock<Vec<McpTool>>,
29    /// Registry for managing multiple MCP server connections and capability tracking
30    registry: Arc<McpRegistry>,
31    /// Server name identifier for registry tracking
32    server_name: RwLock<Option<String>>,
33}
34
35impl McpClient {
36    /// Connect to an MCP server via subprocess
37    pub async fn connect_subprocess(command: &str, args: &[&str]) -> Result<Arc<Self>> {
38        let transport = Arc::new(ProcessTransport::spawn(command, args).await?);
39        let client = Arc::new(Self::new(transport));
40        
41        // Start message receiver
42        let client_clone = Arc::clone(&client);
43        tokio::spawn(async move {
44            client_clone.receive_loop().await;
45        });
46        
47        // Initialize the connection
48        client.initialize().await?;
49        
50        Ok(client)
51    }
52    
53    /// Create a new MCP client with custom transport
54    pub fn new(transport: Arc<dyn Transport>) -> Self {
55        Self {
56            transport,
57            pending_requests: RwLock::new(HashMap::new()),
58            request_id: AtomicI64::new(1),
59            server_info: RwLock::new(None),
60            server_capabilities: RwLock::new(None),
61            available_tools: RwLock::new(Vec::new()),
62            registry: Arc::new(McpRegistry::new()),
63            server_name: RwLock::new(None),
64        }
65    }
66
67    /// Create a new MCP client with a shared registry for multi-server management
68    pub fn with_registry(transport: Arc<dyn Transport>, registry: Arc<McpRegistry>, name: Option<String>) -> Self {
69        Self {
70            transport,
71            pending_requests: RwLock::new(HashMap::new()),
72            request_id: AtomicI64::new(1),
73            server_info: RwLock::new(None),
74            server_capabilities: RwLock::new(None),
75            available_tools: RwLock::new(Vec::new()),
76            registry,
77            server_name: RwLock::new(name),
78        }
79    }
80    
81    /// Initialize the connection with the server
82    pub async fn initialize(&self) -> Result<InitializeResult> {
83        let params = InitializeParams {
84            protocol_version: PROTOCOL_VERSION.to_string(),
85            capabilities: ClientCapabilities {
86                roots: Some(RootsCapability { list_changed: true }),
87                sampling: Some(SamplingCapability {}),
88                experimental: None,
89            },
90            client_info: ClientInfo {
91                name: "codetether".to_string(),
92                version: env!("CARGO_PKG_VERSION").to_string(),
93            },
94        };
95        
96        let response = self.request("initialize", Some(serde_json::to_value(&params)?)).await?;
97        let result: InitializeResult = serde_json::from_value(response)?;
98        
99        // Store server info
100        *self.server_info.write().await = Some(result.server_info.clone());
101        *self.server_capabilities.write().await = Some(result.capabilities.clone());
102        
103        // Register this client with the registry if a server name is set
104        if let Some(name) = self.server_name.read().await.clone() {
105            // Create a self-reference for registration
106            // Note: This is called after construction, so we need to handle registration externally
107            // or use a post-initialization hook
108            debug!("Client initialized with server name: {}", name);
109        }
110        
111        // Send initialized notification
112        self.notify("notifications/initialized", None).await?;
113        
114        info!(
115            "Connected to MCP server: {} v{}",
116            result.server_info.name, result.server_info.version
117        );
118        
119        // Fetch available tools
120        if result.capabilities.tools.is_some() {
121            self.refresh_tools().await?;
122        }
123        
124        Ok(result)
125    }
126    
127    /// Get the registry associated with this client
128    pub fn registry(&self) -> Arc<McpRegistry> {
129        Arc::clone(&self.registry)
130    }
131    
132    /// Get the server name if set
133    pub async fn server_name(&self) -> Option<String> {
134        self.server_name.read().await.clone()
135    }
136    
137    /// Set the server name for registry tracking
138    pub async fn set_server_name(&self, name: String) {
139        *self.server_name.write().await = Some(name);
140    }
141    
142    /// Check if the connected server has a specific capability
143    pub async fn has_capability(&self, capability: &str) -> bool {
144        let caps = self.server_capabilities.read().await;
145        match capability {
146            "tools" => caps.as_ref().map(|c| c.tools.is_some()).unwrap_or(false),
147            "resources" => caps.as_ref().map(|c| c.resources.is_some()).unwrap_or(false),
148            "prompts" => caps.as_ref().map(|c| c.prompts.is_some()).unwrap_or(false),
149            "logging" => caps.as_ref().map(|c| c.logging.is_some()).unwrap_or(false),
150            _ => false,
151        }
152    }
153    
154    /// Get server capabilities
155    pub async fn capabilities(&self) -> Option<ServerCapabilities> {
156        self.server_capabilities.read().await.clone()
157    }
158    
159    /// Discover tools from the registry across all connected servers
160    pub async fn discover_tools_from_registry(&self) -> Vec<(String, McpTool)> {
161        self.registry.all_tools().await
162    }
163    
164    /// Find a tool across all servers in the registry
165    pub async fn find_tool_in_registry(&self, tool_name: &str) -> Option<(String, McpTool)> {
166        self.registry.find_tool(tool_name).await
167    }
168    
169    /// Refresh the list of available tools
170    pub async fn refresh_tools(&self) -> Result<Vec<McpTool>> {
171        let response = self.request("tools/list", None).await?;
172        let result: ListToolsResult = serde_json::from_value(response)?;
173        
174        *self.available_tools.write().await = result.tools.clone();
175        
176        info!("Loaded {} tools from MCP server", result.tools.len());
177        
178        Ok(result.tools)
179    }
180    
181    /// Get available tools
182    pub async fn tools(&self) -> Vec<McpTool> {
183        self.available_tools.read().await.clone()
184    }
185    
186    /// Call a tool
187    pub async fn call_tool(&self, name: &str, arguments: Value) -> Result<CallToolResult> {
188        let params = CallToolParams {
189            name: name.to_string(),
190            arguments,
191        };
192        
193        let response = self.request("tools/call", Some(serde_json::to_value(&params)?)).await?;
194        let result: CallToolResult = serde_json::from_value(response)?;
195        
196        Ok(result)
197    }
198    
199    /// List available resources
200    pub async fn list_resources(&self) -> Result<Vec<McpResource>> {
201        let response = self.request("resources/list", None).await?;
202        let result: ListResourcesResult = serde_json::from_value(response)?;
203        Ok(result.resources)
204    }
205    
206    /// Read a resource
207    pub async fn read_resource(&self, uri: &str) -> Result<ReadResourceResult> {
208        let params = ReadResourceParams { uri: uri.to_string() };
209        let response = self.request("resources/read", Some(serde_json::to_value(&params)?)).await?;
210        let result: ReadResourceResult = serde_json::from_value(response)?;
211        Ok(result)
212    }
213    
214    /// List available prompts
215    pub async fn list_prompts(&self) -> Result<Vec<McpPrompt>> {
216        let response = self.request("prompts/list", None).await?;
217        let result: ListPromptsResult = serde_json::from_value(response)?;
218        Ok(result.prompts)
219    }
220    
221    /// Get a prompt
222    pub async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult> {
223        let params = GetPromptParams {
224            name: name.to_string(),
225            arguments,
226        };
227        let response = self.request("prompts/get", Some(serde_json::to_value(&params)?)).await?;
228        let result: GetPromptResult = serde_json::from_value(response)?;
229        Ok(result)
230    }
231    
232    /// Send a JSON-RPC request and wait for response
233    async fn request(&self, method: &str, params: Option<Value>) -> Result<Value> {
234        let id = RequestId::Number(self.request_id.fetch_add(1, Ordering::SeqCst));
235        let request = JsonRpcRequest::new(id.clone(), method, params);
236        
237        // Create response channel
238        let (tx, rx) = oneshot::channel();
239        self.pending_requests.write().await.insert(id.clone(), tx);
240        
241        // Send request
242        self.transport.send_request(request).await?;
243        
244        // Wait for response with timeout
245        let response = timeout(Duration::from_secs(30), rx)
246            .await
247            .map_err(|_| anyhow::anyhow!("Request timed out"))??;
248        
249        if let Some(error) = response.error {
250            return Err(anyhow::anyhow!("MCP error {}: {}", error.code, error.message));
251        }
252        
253        response.result.ok_or_else(|| anyhow::anyhow!("Empty response"))
254    }
255    
256    /// Send a notification (no response expected)
257    async fn notify(&self, method: &str, params: Option<Value>) -> Result<()> {
258        let notification = JsonRpcNotification::new(method, params);
259        self.transport.send_notification(notification).await
260    }
261    
262    /// Message receive loop
263    async fn receive_loop(&self) {
264        loop {
265            match self.transport.receive().await {
266                Ok(Some(message)) => {
267                    self.handle_message(message).await;
268                }
269                Ok(None) => {
270                    info!("MCP connection closed");
271                    break;
272                }
273                Err(e) => {
274                    error!("Error receiving MCP message: {}", e);
275                    break;
276                }
277            }
278        }
279    }
280    
281    /// Handle an incoming message
282    async fn handle_message(&self, message: McpMessage) {
283        match message {
284            McpMessage::Response(response) => {
285                // Find and notify the waiting request
286                if let Some(tx) = self.pending_requests.write().await.remove(&response.id) {
287                    let _ = tx.send(response);
288                } else {
289                    warn!("Received response for unknown request: {:?}", response.id);
290                }
291            }
292            McpMessage::Request(request) => {
293                // Server is making a request to us (e.g., sampling)
294                debug!("Received request from server: {}", request.method);
295                
296                let response = match request.method.as_str() {
297                    "sampling/createMessage" => {
298                        // Handle sampling request
299                        // TODO: Implement sampling using our provider
300                        JsonRpcResponse::error(
301                            request.id,
302                            JsonRpcError::method_not_found("Sampling not yet implemented"),
303                        )
304                    }
305                    _ => {
306                        JsonRpcResponse::error(
307                            request.id,
308                            JsonRpcError::method_not_found(&request.method),
309                        )
310                    }
311                };
312                
313                if let Err(e) = self.transport.send_response(response).await {
314                    error!("Failed to send response: {}", e);
315                }
316            }
317            McpMessage::Notification(notification) => {
318                debug!("Received notification: {}", notification.method);
319                
320                match notification.method.as_str() {
321                    "notifications/tools/list_changed" => {
322                        info!("Tools list changed, refreshing...");
323                        if let Err(e) = self.refresh_tools().await {
324                            error!("Failed to refresh tools: {}", e);
325                        }
326                    }
327                    "notifications/resources/list_changed" => {
328                        info!("Resources list changed");
329                    }
330                    _ => {
331                        debug!("Unknown notification: {}", notification.method);
332                    }
333                }
334            }
335        }
336    }
337    
338    /// Close the connection
339    pub async fn close(&self) -> Result<()> {
340        self.transport.close().await
341    }
342}
343
344/// MCP Server Registry - manages multiple MCP server connections
345/// 
346/// This registry allows managing connections to multiple external MCP servers,
347/// enabling CodeTether to use tools from various sources like filesystem servers,
348/// database servers, and custom tool servers.
349pub struct McpRegistry {
350    clients: RwLock<HashMap<String, Arc<McpClient>>>,
351    /// Track server capabilities for quick lookup without querying each client
352    server_capabilities: RwLock<HashMap<String, ServerCapabilities>>,
353    /// Track all discovered tools across servers for efficient discovery
354    tool_index: RwLock<HashMap<String, String>>, // tool_name -> server_name
355}
356
357impl McpRegistry {
358    /// Create a new registry
359    pub fn new() -> Self {
360        Self {
361            clients: RwLock::new(HashMap::new()),
362            server_capabilities: RwLock::new(HashMap::new()),
363            tool_index: RwLock::new(HashMap::new()),
364        }
365    }
366    
367    /// Connect to an MCP server and register it with the registry
368    pub async fn connect(&self, name: &str, command: &str, args: &[&str]) -> Result<Arc<McpClient>> {
369        let transport = Arc::new(ProcessTransport::spawn(command, args).await?);
370        let client = Arc::new(McpClient::with_registry(
371            transport, 
372            Arc::new(McpRegistry::new()), // Each client gets its own registry for now
373            Some(name.to_string())
374        ));
375        
376        // Start message receiver
377        let client_clone = Arc::clone(&client);
378        tokio::spawn(async move {
379            client_clone.receive_loop().await;
380        });
381        
382        // Initialize the connection
383        let init_result = client.initialize().await?;
384        
385        // Register the client
386        self.register(name, Arc::clone(&client), init_result.capabilities).await;
387        
388        Ok(client)
389    }
390    
391    /// Register a client with the registry
392    pub async fn register(&self, name: &str, client: Arc<McpClient>, capabilities: ServerCapabilities) {
393        // Store client
394        self.clients.write().await.insert(name.to_string(), client);
395        
396        // Store capabilities
397        self.server_capabilities.write().await.insert(name.to_string(), capabilities);
398        
399        info!("Registered MCP server '{}' with registry", name);
400    }
401    
402    /// Get a connected client
403    pub async fn get(&self, name: &str) -> Option<Arc<McpClient>> {
404        self.clients.read().await.get(name).cloned()
405    }
406    
407    /// List all connected servers
408    pub async fn list(&self) -> Vec<String> {
409        self.clients.read().await.keys().cloned().collect()
410    }
411    
412    /// Get capabilities for a specific server
413    pub async fn get_capabilities(&self, name: &str) -> Option<ServerCapabilities> {
414        self.server_capabilities.read().await.get(name).cloned()
415    }
416    
417    /// Check if a server has a specific capability
418    pub async fn has_capability(&self, name: &str, capability: &str) -> bool {
419        let caps = self.server_capabilities.read().await;
420        caps.get(name).map(|c| {
421            match capability {
422                "tools" => c.tools.is_some(),
423                "resources" => c.resources.is_some(),
424                "prompts" => c.prompts.is_some(),
425                "logging" => c.logging.is_some(),
426                _ => false,
427            }
428        }).unwrap_or(false)
429    }
430    
431    /// List servers that have a specific capability
432    pub async fn list_by_capability(&self, capability: &str) -> Vec<String> {
433        let mut result = Vec::new();
434        let caps = self.server_capabilities.read().await;
435        
436        for (name, caps) in caps.iter() {
437            let has_cap = match capability {
438                "tools" => caps.tools.is_some(),
439                "resources" => caps.resources.is_some(),
440                "prompts" => caps.prompts.is_some(),
441                "logging" => caps.logging.is_some(),
442                _ => false,
443            };
444            if has_cap {
445                result.push(name.clone());
446            }
447        }
448        
449        result
450    }
451    
452    /// Disconnect from a server
453    pub async fn disconnect(&self, name: &str) -> Result<()> {
454        if let Some(client) = self.clients.write().await.remove(name) {
455            // Remove capabilities
456            self.server_capabilities.write().await.remove(name);
457            // Remove from tool index
458            let mut tool_index = self.tool_index.write().await;
459            tool_index.retain(|_, server| server != name);
460            // Close connection
461            client.close().await?;
462        }
463        Ok(())
464    }
465    
466    /// Get all available tools from all servers
467    pub async fn all_tools(&self) -> Vec<(String, McpTool)> {
468        let mut all_tools = Vec::new();
469        
470        for (name, client) in self.clients.read().await.iter() {
471            for tool in client.tools().await {
472                all_tools.push((name.clone(), tool));
473            }
474        }
475        
476        all_tools
477    }
478    
479    /// Find a specific tool across all servers
480    pub async fn find_tool(&self, tool_name: &str) -> Option<(String, McpTool)> {
481        // First check the tool index
482        if let Some(server_name) = self.tool_index.read().await.get(tool_name) {
483            if let Some(client) = self.get(server_name).await {
484                if let Some(tool) = client.tools().await.iter().find(|t| t.name == tool_name) {
485                    return Some((server_name.clone(), tool.clone()));
486                }
487            }
488        }
489        
490        // Fallback: search all clients
491        for (name, client) in self.clients.read().await.iter() {
492            if let Some(tool) = client.tools().await.iter().find(|t| t.name == tool_name) {
493                // Update index
494                self.tool_index.write().await.insert(tool_name.to_string(), name.clone());
495                return Some((name.clone(), tool.clone()));
496            }
497        }
498        
499        None
500    }
501    
502    /// Refresh the tool index from all servers
503    pub async fn refresh_tool_index(&self) {
504        let mut tool_index = self.tool_index.write().await;
505        tool_index.clear();
506        
507        for (name, client) in self.clients.read().await.iter() {
508            for tool in client.tools().await {
509                tool_index.insert(tool.name.clone(), name.clone());
510            }
511        }
512        
513        info!("Refreshed tool index with {} tools", tool_index.len());
514    }
515    
516    /// Call a tool on a specific server
517    pub async fn call_tool(&self, server: &str, tool: &str, arguments: Value) -> Result<CallToolResult> {
518        let client = self.get(server).await
519            .ok_or_else(|| anyhow::anyhow!("Server not found: {}", server))?;
520        client.call_tool(tool, arguments).await
521    }
522    
523    /// Call a tool by name, finding the appropriate server automatically
524    pub async fn call_tool_auto(&self, tool_name: &str, arguments: Value) -> Result<CallToolResult> {
525        let (server, _) = self.find_tool(tool_name).await
526            .ok_or_else(|| anyhow::anyhow!("Tool not found: {}", tool_name))?;
527        self.call_tool(&server, tool_name, arguments).await
528    }
529}
530
531impl Default for McpRegistry {
532    fn default() -> Self {
533        Self::new()
534    }
535}