Skip to main content

codetether_agent/mcp/
client.rs

1//! MCP Client - Connect to external MCP servers
2//!
3//! Allows CodeTether to use tools from other MCP servers:
4//! - Filesystem servers
5//! - Database servers
6//! - API integration servers
7//! - Custom tool servers
8
9use super::transport::{McpMessage, ProcessTransport, Transport};
10use super::types::*;
11use anyhow::Result;
12use serde_json::Value;
13use std::collections::HashMap;
14use std::sync::Arc;
15use std::sync::atomic::{AtomicI64, Ordering};
16use std::time::Duration;
17use tokio::sync::{RwLock, oneshot};
18use tokio::time::timeout;
19use tracing::{debug, error, info, warn};
20
21/// MCP Client for connecting to external servers
22pub struct McpClient {
23    transport: Arc<dyn Transport>,
24    pending_requests: RwLock<HashMap<RequestId, oneshot::Sender<JsonRpcResponse>>>,
25    request_id: AtomicI64,
26    server_info: RwLock<Option<ServerInfo>>,
27    server_capabilities: RwLock<Option<ServerCapabilities>>,
28    available_tools: RwLock<Vec<McpTool>>,
29    /// Registry for managing multiple MCP server connections and capability tracking
30    registry: Arc<McpRegistry>,
31    /// Server name identifier for registry tracking
32    server_name: RwLock<Option<String>>,
33}
34
35impl McpClient {
36    /// Connect to an MCP server via subprocess
37    pub async fn connect_subprocess(command: &str, args: &[&str]) -> Result<Arc<Self>> {
38        let transport = Arc::new(ProcessTransport::spawn(command, args).await?);
39        let client = Arc::new(Self::new(transport));
40
41        // Start message receiver
42        let client_clone = Arc::clone(&client);
43        tokio::spawn(async move {
44            client_clone.receive_loop().await;
45        });
46
47        // Initialize the connection
48        client.initialize().await?;
49
50        Ok(client)
51    }
52
53    /// Create a new MCP client with custom transport
54    pub fn new(transport: Arc<dyn Transport>) -> Self {
55        Self {
56            transport,
57            pending_requests: RwLock::new(HashMap::new()),
58            request_id: AtomicI64::new(1),
59            server_info: RwLock::new(None),
60            server_capabilities: RwLock::new(None),
61            available_tools: RwLock::new(Vec::new()),
62            registry: Arc::new(McpRegistry::new()),
63            server_name: RwLock::new(None),
64        }
65    }
66
67    /// Create a new MCP client with a shared registry for multi-server management
68    pub fn with_registry(
69        transport: Arc<dyn Transport>,
70        registry: Arc<McpRegistry>,
71        name: Option<String>,
72    ) -> Self {
73        Self {
74            transport,
75            pending_requests: RwLock::new(HashMap::new()),
76            request_id: AtomicI64::new(1),
77            server_info: RwLock::new(None),
78            server_capabilities: RwLock::new(None),
79            available_tools: RwLock::new(Vec::new()),
80            registry,
81            server_name: RwLock::new(name),
82        }
83    }
84
85    /// Initialize the connection with the server
86    pub async fn initialize(&self) -> Result<InitializeResult> {
87        let params = InitializeParams {
88            protocol_version: PROTOCOL_VERSION.to_string(),
89            capabilities: ClientCapabilities {
90                roots: Some(RootsCapability { list_changed: true }),
91                sampling: Some(SamplingCapability {}),
92                experimental: None,
93            },
94            client_info: ClientInfo {
95                name: "codetether".to_string(),
96                version: env!("CARGO_PKG_VERSION").to_string(),
97            },
98        };
99
100        let response = self
101            .request("initialize", Some(serde_json::to_value(&params)?))
102            .await?;
103        let result: InitializeResult = serde_json::from_value(response)?;
104
105        // Store server info
106        *self.server_info.write().await = Some(result.server_info.clone());
107        *self.server_capabilities.write().await = Some(result.capabilities.clone());
108
109        // Register this client with the registry if a server name is set
110        if let Some(name) = self.server_name.read().await.clone() {
111            // Create a self-reference for registration
112            // Note: This is called after construction, so we need to handle registration externally
113            // or use a post-initialization hook
114            debug!("Client initialized with server name: {}", name);
115        }
116
117        // Send initialized notification
118        self.notify("notifications/initialized", None).await?;
119
120        info!(
121            "Connected to MCP server: {} v{}",
122            result.server_info.name, result.server_info.version
123        );
124
125        // Fetch available tools
126        if result.capabilities.tools.is_some() {
127            self.refresh_tools().await?;
128        }
129
130        Ok(result)
131    }
132
133    /// Get the registry associated with this client
134    pub fn registry(&self) -> Arc<McpRegistry> {
135        Arc::clone(&self.registry)
136    }
137
138    /// Get the server name if set
139    pub async fn server_name(&self) -> Option<String> {
140        self.server_name.read().await.clone()
141    }
142
143    /// Set the server name for registry tracking
144    pub async fn set_server_name(&self, name: String) {
145        *self.server_name.write().await = Some(name);
146    }
147
148    /// Check if the connected server has a specific capability
149    pub async fn has_capability(&self, capability: &str) -> bool {
150        let caps = self.server_capabilities.read().await;
151        match capability {
152            "tools" => caps.as_ref().map(|c| c.tools.is_some()).unwrap_or(false),
153            "resources" => caps
154                .as_ref()
155                .map(|c| c.resources.is_some())
156                .unwrap_or(false),
157            "prompts" => caps.as_ref().map(|c| c.prompts.is_some()).unwrap_or(false),
158            "logging" => caps.as_ref().map(|c| c.logging.is_some()).unwrap_or(false),
159            _ => false,
160        }
161    }
162
163    /// Get server capabilities
164    pub async fn capabilities(&self) -> Option<ServerCapabilities> {
165        self.server_capabilities.read().await.clone()
166    }
167
168    /// Discover tools from the registry across all connected servers
169    pub async fn discover_tools_from_registry(&self) -> Vec<(String, McpTool)> {
170        self.registry.all_tools().await
171    }
172
173    /// Find a tool across all servers in the registry
174    pub async fn find_tool_in_registry(&self, tool_name: &str) -> Option<(String, McpTool)> {
175        self.registry.find_tool(tool_name).await
176    }
177
178    /// Refresh the list of available tools
179    pub async fn refresh_tools(&self) -> Result<Vec<McpTool>> {
180        let response = self.request("tools/list", None).await?;
181        let result: ListToolsResult = serde_json::from_value(response)?;
182
183        *self.available_tools.write().await = result.tools.clone();
184
185        info!("Loaded {} tools from MCP server", result.tools.len());
186
187        Ok(result.tools)
188    }
189
190    /// Get available tools
191    pub async fn tools(&self) -> Vec<McpTool> {
192        self.available_tools.read().await.clone()
193    }
194
195    /// Call a tool
196    pub async fn call_tool(&self, name: &str, arguments: Value) -> Result<CallToolResult> {
197        let params = CallToolParams {
198            name: name.to_string(),
199            arguments,
200        };
201
202        let response = self
203            .request("tools/call", Some(serde_json::to_value(&params)?))
204            .await?;
205        let result: CallToolResult = serde_json::from_value(response)?;
206
207        Ok(result)
208    }
209
210    /// List available resources
211    pub async fn list_resources(&self) -> Result<Vec<McpResource>> {
212        let response = self.request("resources/list", None).await?;
213        let result: ListResourcesResult = serde_json::from_value(response)?;
214        Ok(result.resources)
215    }
216
217    /// Read a resource
218    pub async fn read_resource(&self, uri: &str) -> Result<ReadResourceResult> {
219        let params = ReadResourceParams {
220            uri: uri.to_string(),
221        };
222        let response = self
223            .request("resources/read", Some(serde_json::to_value(&params)?))
224            .await?;
225        let result: ReadResourceResult = serde_json::from_value(response)?;
226        Ok(result)
227    }
228
229    /// List available prompts
230    pub async fn list_prompts(&self) -> Result<Vec<McpPrompt>> {
231        let response = self.request("prompts/list", None).await?;
232        let result: ListPromptsResult = serde_json::from_value(response)?;
233        Ok(result.prompts)
234    }
235
236    /// Get a prompt
237    pub async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult> {
238        let params = GetPromptParams {
239            name: name.to_string(),
240            arguments,
241        };
242        let response = self
243            .request("prompts/get", Some(serde_json::to_value(&params)?))
244            .await?;
245        let result: GetPromptResult = serde_json::from_value(response)?;
246        Ok(result)
247    }
248
249    /// Send a JSON-RPC request and wait for response
250    async fn request(&self, method: &str, params: Option<Value>) -> Result<Value> {
251        let id = RequestId::Number(self.request_id.fetch_add(1, Ordering::SeqCst));
252        let request = JsonRpcRequest::new(id.clone(), method, params);
253
254        // Create response channel
255        let (tx, rx) = oneshot::channel();
256        self.pending_requests.write().await.insert(id.clone(), tx);
257
258        // Send request
259        self.transport.send_request(request).await?;
260
261        // Wait for response with timeout
262        let response = timeout(Duration::from_secs(30), rx)
263            .await
264            .map_err(|_| anyhow::anyhow!("Request timed out"))??;
265
266        if let Some(error) = response.error {
267            return Err(anyhow::anyhow!(
268                "MCP error {}: {}",
269                error.code,
270                error.message
271            ));
272        }
273
274        response
275            .result
276            .ok_or_else(|| anyhow::anyhow!("Empty response"))
277    }
278
279    /// Send a notification (no response expected)
280    async fn notify(&self, method: &str, params: Option<Value>) -> Result<()> {
281        let notification = JsonRpcNotification::new(method, params);
282        self.transport.send_notification(notification).await
283    }
284
285    /// Message receive loop
286    async fn receive_loop(&self) {
287        loop {
288            match self.transport.receive().await {
289                Ok(Some(message)) => {
290                    self.handle_message(message).await;
291                }
292                Ok(None) => {
293                    info!("MCP connection closed");
294                    break;
295                }
296                Err(e) => {
297                    error!("Error receiving MCP message: {}", e);
298                    break;
299                }
300            }
301        }
302    }
303
304    /// Handle an incoming message
305    async fn handle_message(&self, message: McpMessage) {
306        match message {
307            McpMessage::Response(response) => {
308                // Find and notify the waiting request
309                if let Some(tx) = self.pending_requests.write().await.remove(&response.id) {
310                    let _ = tx.send(response);
311                } else {
312                    warn!("Received response for unknown request: {:?}", response.id);
313                }
314            }
315            McpMessage::Request(request) => {
316                // Server is making a request to us (e.g., sampling)
317                debug!("Received request from server: {}", request.method);
318
319                let response = match request.method.as_str() {
320                    "sampling/createMessage" => {
321                        // Handle sampling request
322                        // TODO: Implement sampling using our provider
323                        JsonRpcResponse::error(
324                            request.id,
325                            JsonRpcError::method_not_found("Sampling not yet implemented"),
326                        )
327                    }
328                    _ => JsonRpcResponse::error(
329                        request.id,
330                        JsonRpcError::method_not_found(&request.method),
331                    ),
332                };
333
334                if let Err(e) = self.transport.send_response(response).await {
335                    error!("Failed to send response: {}", e);
336                }
337            }
338            McpMessage::Notification(notification) => {
339                debug!("Received notification: {}", notification.method);
340
341                match notification.method.as_str() {
342                    "notifications/tools/list_changed" => {
343                        info!("Tools list changed, refreshing...");
344                        if let Err(e) = self.refresh_tools().await {
345                            error!("Failed to refresh tools: {}", e);
346                        }
347                    }
348                    "notifications/resources/list_changed" => {
349                        info!("Resources list changed");
350                    }
351                    _ => {
352                        debug!("Unknown notification: {}", notification.method);
353                    }
354                }
355            }
356        }
357    }
358
359    /// Close the connection
360    pub async fn close(&self) -> Result<()> {
361        self.transport.close().await
362    }
363}
364
365/// MCP Server Registry - manages multiple MCP server connections
366///
367/// This registry allows managing connections to multiple external MCP servers,
368/// enabling CodeTether to use tools from various sources like filesystem servers,
369/// database servers, and custom tool servers.
370pub struct McpRegistry {
371    clients: RwLock<HashMap<String, Arc<McpClient>>>,
372    /// Track server capabilities for quick lookup without querying each client
373    server_capabilities: RwLock<HashMap<String, ServerCapabilities>>,
374    /// Track all discovered tools across servers for efficient discovery
375    tool_index: RwLock<HashMap<String, String>>, // tool_name -> server_name
376}
377
378impl McpRegistry {
379    /// Create a new registry
380    pub fn new() -> Self {
381        Self {
382            clients: RwLock::new(HashMap::new()),
383            server_capabilities: RwLock::new(HashMap::new()),
384            tool_index: RwLock::new(HashMap::new()),
385        }
386    }
387
388    /// Connect to an MCP server and register it with the registry
389    pub async fn connect(
390        &self,
391        name: &str,
392        command: &str,
393        args: &[&str],
394    ) -> Result<Arc<McpClient>> {
395        let transport = Arc::new(ProcessTransport::spawn(command, args).await?);
396        let client = Arc::new(McpClient::with_registry(
397            transport,
398            Arc::new(McpRegistry::new()), // Each client gets its own registry for now
399            Some(name.to_string()),
400        ));
401
402        // Start message receiver
403        let client_clone = Arc::clone(&client);
404        tokio::spawn(async move {
405            client_clone.receive_loop().await;
406        });
407
408        // Initialize the connection
409        let init_result = client.initialize().await?;
410
411        // Register the client
412        self.register(name, Arc::clone(&client), init_result.capabilities)
413            .await;
414
415        Ok(client)
416    }
417
418    /// Register a client with the registry
419    pub async fn register(
420        &self,
421        name: &str,
422        client: Arc<McpClient>,
423        capabilities: ServerCapabilities,
424    ) {
425        // Store client
426        self.clients.write().await.insert(name.to_string(), client);
427
428        // Store capabilities
429        self.server_capabilities
430            .write()
431            .await
432            .insert(name.to_string(), capabilities);
433
434        info!("Registered MCP server '{}' with registry", name);
435    }
436
437    /// Get a connected client
438    pub async fn get(&self, name: &str) -> Option<Arc<McpClient>> {
439        self.clients.read().await.get(name).cloned()
440    }
441
442    /// List all connected servers
443    pub async fn list(&self) -> Vec<String> {
444        self.clients.read().await.keys().cloned().collect()
445    }
446
447    /// Get capabilities for a specific server
448    pub async fn get_capabilities(&self, name: &str) -> Option<ServerCapabilities> {
449        self.server_capabilities.read().await.get(name).cloned()
450    }
451
452    /// Check if a server has a specific capability
453    pub async fn has_capability(&self, name: &str, capability: &str) -> bool {
454        let caps = self.server_capabilities.read().await;
455        caps.get(name)
456            .map(|c| match capability {
457                "tools" => c.tools.is_some(),
458                "resources" => c.resources.is_some(),
459                "prompts" => c.prompts.is_some(),
460                "logging" => c.logging.is_some(),
461                _ => false,
462            })
463            .unwrap_or(false)
464    }
465
466    /// List servers that have a specific capability
467    pub async fn list_by_capability(&self, capability: &str) -> Vec<String> {
468        let mut result = Vec::new();
469        let caps = self.server_capabilities.read().await;
470
471        for (name, caps) in caps.iter() {
472            let has_cap = match capability {
473                "tools" => caps.tools.is_some(),
474                "resources" => caps.resources.is_some(),
475                "prompts" => caps.prompts.is_some(),
476                "logging" => caps.logging.is_some(),
477                _ => false,
478            };
479            if has_cap {
480                result.push(name.clone());
481            }
482        }
483
484        result
485    }
486
487    /// Disconnect from a server
488    pub async fn disconnect(&self, name: &str) -> Result<()> {
489        if let Some(client) = self.clients.write().await.remove(name) {
490            // Remove capabilities
491            self.server_capabilities.write().await.remove(name);
492            // Remove from tool index
493            let mut tool_index = self.tool_index.write().await;
494            tool_index.retain(|_, server| server != name);
495            // Close connection
496            client.close().await?;
497        }
498        Ok(())
499    }
500
501    /// Get all available tools from all servers
502    pub async fn all_tools(&self) -> Vec<(String, McpTool)> {
503        let mut all_tools = Vec::new();
504
505        for (name, client) in self.clients.read().await.iter() {
506            for tool in client.tools().await {
507                all_tools.push((name.clone(), tool));
508            }
509        }
510
511        all_tools
512    }
513
514    /// Find a specific tool across all servers
515    pub async fn find_tool(&self, tool_name: &str) -> Option<(String, McpTool)> {
516        // First check the tool index
517        if let Some(server_name) = self.tool_index.read().await.get(tool_name) {
518            if let Some(client) = self.get(server_name).await {
519                if let Some(tool) = client.tools().await.iter().find(|t| t.name == tool_name) {
520                    return Some((server_name.clone(), tool.clone()));
521                }
522            }
523        }
524
525        // Fallback: search all clients
526        for (name, client) in self.clients.read().await.iter() {
527            if let Some(tool) = client.tools().await.iter().find(|t| t.name == tool_name) {
528                // Update index
529                self.tool_index
530                    .write()
531                    .await
532                    .insert(tool_name.to_string(), name.clone());
533                return Some((name.clone(), tool.clone()));
534            }
535        }
536
537        None
538    }
539
540    /// Refresh the tool index from all servers
541    pub async fn refresh_tool_index(&self) {
542        let mut tool_index = self.tool_index.write().await;
543        tool_index.clear();
544
545        for (name, client) in self.clients.read().await.iter() {
546            for tool in client.tools().await {
547                tool_index.insert(tool.name.clone(), name.clone());
548            }
549        }
550
551        info!("Refreshed tool index with {} tools", tool_index.len());
552    }
553
554    /// Call a tool on a specific server
555    pub async fn call_tool(
556        &self,
557        server: &str,
558        tool: &str,
559        arguments: Value,
560    ) -> Result<CallToolResult> {
561        let client = self
562            .get(server)
563            .await
564            .ok_or_else(|| anyhow::anyhow!("Server not found: {}", server))?;
565        client.call_tool(tool, arguments).await
566    }
567
568    /// Call a tool by name, finding the appropriate server automatically
569    pub async fn call_tool_auto(
570        &self,
571        tool_name: &str,
572        arguments: Value,
573    ) -> Result<CallToolResult> {
574        let (server, _) = self
575            .find_tool(tool_name)
576            .await
577            .ok_or_else(|| anyhow::anyhow!("Tool not found: {}", tool_name))?;
578        self.call_tool(&server, tool_name, arguments).await
579    }
580}
581
582impl Default for McpRegistry {
583    fn default() -> Self {
584        Self::new()
585    }
586}