Skip to main content

atomcode_core/mcp/
registry.rs

1//! MCP server registry - manages connections to multiple MCP servers.
2
3use std::collections::BTreeMap;
4use std::sync::Arc;
5use std::time::Duration;
6
7use anyhow::Result;
8use tokio::sync::{mpsc, RwLock};
9
10use atomcode_telemetry::{Event as TelemetryEvent, McpErrorKind, McpTransport};
11
12use super::client::{McpClient, McpToolInfo};
13use super::config::{load_mcp_config, McpServerConfig};
14use super::transport_http::HttpClient;
15use super::transport_stdio::StdioClient;
16use super::types::ServerStatus;
17
18/// Connection status event sent to listeners when servers connect or fail.
19#[derive(Debug, Clone)]
20pub enum McpConnectEvent {
21    /// Server connected successfully.
22    Connected { name: String },
23    /// Server connection failed.
24    Failed { name: String, error: String },
25    /// Non-fatal warning (e.g. tools/list failed after connect).
26    Warning { name: String, message: String },
27}
28
29/// Registry of connected MCP servers.
30pub struct McpRegistry {
31    servers: Arc<RwLock<BTreeMap<String, Arc<dyn McpClient>>>>,
32    server_timeouts_ms: Arc<RwLock<BTreeMap<String, u64>>>,
33    /// Channel for connection status events (used by TUI to display in scrollback).
34    connect_events: Option<mpsc::UnboundedSender<McpConnectEvent>>,
35    /// Signals when all initial background connections have completed (or failed).
36    initial_ready: Arc<tokio::sync::Notify>,
37    /// Telemetry handle for emitting McpConnect events.
38    telemetry: Option<Arc<atomcode_telemetry::Telemetry>>,
39}
40
41impl McpRegistry {
42    /// Create a new empty registry.
43    pub fn new() -> Self {
44        Self {
45            servers: Arc::new(RwLock::new(BTreeMap::new())),
46            server_timeouts_ms: Arc::new(RwLock::new(BTreeMap::new())),
47            connect_events: None,
48            initial_ready: Arc::new(tokio::sync::Notify::new()),
49            telemetry: None,
50        }
51    }
52
53    /// Set the telemetry handle for emitting McpConnect events.
54    pub fn with_telemetry(mut self, tel: Arc<atomcode_telemetry::Telemetry>) -> Self {
55        self.telemetry = Some(tel);
56        self
57    }
58
59    /// Create a registry with a channel for connection events.
60    pub fn with_event_channel() -> (Self, mpsc::UnboundedReceiver<McpConnectEvent>) {
61        let (tx, rx) = mpsc::unbounded_channel();
62        (
63            Self {
64                servers: Arc::new(RwLock::new(BTreeMap::new())),
65                server_timeouts_ms: Arc::new(RwLock::new(BTreeMap::new())),
66                connect_events: Some(tx),
67                initial_ready: Arc::new(tokio::sync::Notify::new()),
68                telemetry: None,
69            },
70            rx,
71        )
72    }
73
74    /// Get a clone of the event sender, if configured.
75    pub fn event_sender(&self) -> Option<mpsc::UnboundedSender<McpConnectEvent>> {
76        self.connect_events.clone()
77    }
78
79    /// Load MCP configuration and start connecting to servers in the background.
80    /// Returns immediately with an empty registry; servers are added as they connect.
81    /// Connection status events are sent through the internal channel if configured.
82    pub fn from_config_background(project_dir: &std::path::Path) -> Self {
83        Self::from_config_background_with_events(project_dir, None)
84    }
85
86    /// Load MCP configuration and start connecting to servers in the background,
87    /// with an external event channel for TUI status display.
88    pub fn from_config_background_with_events(
89        project_dir: &std::path::Path,
90        event_tx: Option<mpsc::UnboundedSender<McpConnectEvent>>,
91    ) -> Self {
92        let mut registry = Self::new();
93        // Merge external channel with internal one
94        let combined_tx = event_tx.or(registry.connect_events.clone());
95        registry.connect_events = combined_tx.clone();
96
97        let configs = match load_mcp_config(project_dir) {
98            Ok(c) => c,
99            Err(e) => {
100                if let Some(tx) = &combined_tx {
101                    let _ = tx.send(McpConnectEvent::Failed {
102                        name: "config".to_string(),
103                        error: format!("Failed to load config: {}", e),
104                    });
105                }
106                return registry;
107            }
108        };
109
110        if !configs.is_empty() {
111            let servers = registry.servers.clone();
112            let server_timeouts_ms = registry.server_timeouts_ms.clone();
113            let initial_ready = registry.initial_ready.clone();
114            let telemetry = registry.telemetry.clone();
115            tokio::spawn(async move {
116                // Connect servers in parallel
117                let tasks: Vec<_> = configs
118                    .into_iter()
119                    .map(|config| {
120                        let servers = servers.clone();
121                        let server_timeouts_ms = server_timeouts_ms.clone();
122                        let tx = combined_tx.clone();
123                        let telemetry = telemetry.clone();
124                        async move {
125                            let name = config.name.clone();
126                            let timeout_ms = config.timeout_ms();
127                            let config_source = config.source;
128                            let transport = match &config.config {
129                                super::config::McpTransportConfig::Stdio { .. } => McpTransport::Stdio,
130                                super::config::McpTransportConfig::Http { .. } => McpTransport::StreamableHttp,
131                            };
132                            let start = std::time::Instant::now();
133                            let mut client: Box<dyn McpClient> = match &config.config {
134                                super::config::McpTransportConfig::Stdio {
135                                    command,
136                                    args,
137                                    env,
138                                    timeout_ms,
139                                } => Box::new(StdioClient::new(
140                                    name.clone(),
141                                    command.clone(),
142                                    args.clone(),
143                                    env.clone(),
144                                    *timeout_ms,
145                                )),
146                                super::config::McpTransportConfig::Http {
147                                    url,
148                                    headers,
149                                    auth,
150                                    timeout_ms,
151                                } => Box::new(HttpClient::new(
152                                    name.clone(),
153                                    url.clone(),
154                                    headers.clone(),
155                                    auth.clone(),
156                                    *timeout_ms,
157                                )),
158                            };
159
160                            match client.initialize().await {
161                                Ok(_result) => {
162                                    let duration_ms = start.elapsed().as_millis() as u32;
163                                    let mut servers = servers.write().await;
164                                    servers.insert(name.clone(), Arc::from(client));
165                                    drop(servers);
166                                    let mut timeouts = server_timeouts_ms.write().await;
167                                    timeouts.insert(name.clone(), timeout_ms);
168                                    if let Some(tx) = tx {
169                                        let _ = tx.send(McpConnectEvent::Connected {
170                                            name: name.clone(),
171                                        });
172                                    }
173                                    if let Some(tel) = &telemetry {
174                                        tel.track(TelemetryEvent::McpConnect {
175                                            server_name: name.clone(),
176                                            transport,
177                                            success: true,
178                                            duration_ms: Some(duration_ms),
179                                            error_kind: None,
180                                            error_data: Some(serde_json::json!({
181                                                "server_name": name,
182                                                "transport": match transport { McpTransport::Stdio => "stdio", McpTransport::Sse => "sse", McpTransport::StreamableHttp => "streamable_http" },
183                                                "duration_ms": duration_ms,
184                                                "tool_count": 0, // will be populated when tools are listed
185                                                "config_source": config_source.as_str(),
186                                            }).to_string()),
187                                        });
188                                    }
189                                }
190                                Err(e) => {
191                                    let duration_ms = start.elapsed().as_millis() as u32;
192                                    let error_str = format!("{}", e);
193                                    if let Some(tx) = tx {
194                                        let _ = tx.send(McpConnectEvent::Failed {
195                                            name: name.clone(),
196                                            error: error_str.clone(),
197                                        });
198                                    }
199                                    if let Some(tel) = &telemetry {
200                                        let error_kind = classify_mcp_error(&error_str);
201                                        tel.track(TelemetryEvent::McpConnect {
202                                            server_name: name.clone(),
203                                            transport,
204                                            success: false,
205                                            duration_ms: Some(duration_ms),
206                                            error_kind: Some(error_kind),
207                                            error_data: Some(serde_json::json!({
208                                                "server_name": name,
209                                                "transport": match transport { McpTransport::Stdio => "stdio", McpTransport::Sse => "sse", McpTransport::StreamableHttp => "streamable_http" },
210                                                "duration_ms": duration_ms,
211                                                "message": atomcode_telemetry::scrub::truncate_head(&error_str, 200),
212                                                "config_source": config_source.as_str(),
213                                            }).to_string()),
214                                        });
215                                    }
216                                }
217                            }
218                        }
219                    })
220                    .collect();
221
222                // Wait for all connections to complete (each has its own timeout)
223                futures::future::join_all(tasks).await;
224                // Signal that initial connections are done
225                initial_ready.notify_waiters();
226            });
227        } else {
228            // No servers configured — signal immediately
229            registry.initial_ready.notify_waiters();
230        }
231
232        registry
233    }
234
235    /// Load MCP configuration and connect to all servers (blocking).
236    /// Prefer `from_config_background` for non-blocking startup.
237    pub async fn from_config(project_dir: &std::path::Path) -> Self {
238        let registry = Self::new();
239
240        let configs = match load_mcp_config(project_dir) {
241            Ok(c) => c,
242            Err(e) => {
243                eprintln!("[mcp] Failed to load config: {}", e);
244                return registry;
245            }
246        };
247
248        for config in configs {
249            if let Err(e) = registry.add_server(config).await {
250                eprintln!("[mcp] Failed to connect server: {}", e);
251            }
252        }
253
254        registry
255    }
256
257    /// Add a server to the registry.
258    pub async fn add_server(&self, config: McpServerConfig) -> Result<()> {
259        let mut client: Box<dyn McpClient> = match &config.config {
260            super::config::McpTransportConfig::Stdio {
261                command,
262                args,
263                env,
264                timeout_ms,
265            } => Box::new(StdioClient::new(
266                config.name.clone(),
267                command.clone(),
268                args.clone(),
269                env.clone(),
270                *timeout_ms,
271            )),
272            super::config::McpTransportConfig::Http {
273                url,
274                headers,
275                auth,
276                timeout_ms,
277            } => Box::new(HttpClient::new(
278                config.name.clone(),
279                url.clone(),
280                headers.clone(),
281                auth.clone(),
282                *timeout_ms,
283            )),
284        };
285
286        client.initialize().await?;
287
288        let mut servers = self.servers.write().await;
289        servers.insert(config.name.clone(), Arc::from(client));
290        drop(servers);
291        let mut timeouts = self.server_timeouts_ms.write().await;
292        timeouts.insert(config.name.clone(), config.timeout_ms());
293
294        Ok(())
295    }
296
297    /// Timeout budget for a slow tools/list operation on a connected server.
298    ///
299    /// The transport already has its own request timeout. This outer budget adds
300    /// a small grace period so TUI background tasks do not cancel a request right
301    /// before the transport timeout/error can surface.
302    pub async fn list_tools_timeout(&self, server_name: &str) -> Duration {
303        let configured_ms = {
304            let timeouts = self.server_timeouts_ms.read().await;
305            timeouts.get(server_name).copied().unwrap_or(30_000)
306        };
307        Duration::from_millis(configured_ms.saturating_add(5_000))
308    }
309
310    /// Get all available tools from all connected servers.
311    pub async fn list_all_tools(&self) -> Vec<McpToolInfo> {
312        // Never hold the registry lock across an .await: list_tools can be slow and
313        // status/reload should remain responsive.
314        let server_snapshot: Vec<(String, Arc<dyn McpClient>)> = {
315            let servers = self.servers.read().await;
316            servers
317                .iter()
318                .map(|(name, client)| (name.clone(), Arc::clone(client)))
319                .collect()
320        };
321        let mut all_tools = Vec::new();
322
323        for (server_name, client) in server_snapshot {
324            match client.list_tools().await {
325                Ok(result) => {
326                    for tool in result.tools {
327                        all_tools.push(McpToolInfo {
328                            server_name: server_name.clone(),
329                            tool_name: tool.name,
330                            description: tool.description,
331                            input_schema: tool.input_schema,
332                        });
333                    }
334                }
335                Err(e) => {
336                    if let Some(tx) = &self.connect_events {
337                        let _ = tx.send(McpConnectEvent::Warning {
338                            name: server_name.clone(),
339                            message: format!("tools/list failed: {}", e),
340                        });
341                    } else {
342                        eprintln!("[mcp] Failed to list tools from {}: {}", server_name, e);
343                    }
344                }
345            }
346        }
347
348        all_tools
349    }
350
351    /// Get tools from a single connected server.
352    pub async fn list_tools_for_server(&self, server_name: &str) -> Vec<McpToolInfo> {
353        let client = {
354            let servers = self.servers.read().await;
355            servers.get(server_name).map(Arc::clone)
356        };
357        let Some(client) = client else {
358            if let Some(tx) = &self.connect_events {
359                let _ = tx.send(McpConnectEvent::Warning {
360                    name: server_name.to_string(),
361                    message: "tools/list skipped: server not found".to_string(),
362                });
363            }
364            return Vec::new();
365        };
366
367        match client.list_tools().await {
368            Ok(result) => result
369                .tools
370                .into_iter()
371                .map(|tool| McpToolInfo {
372                    server_name: server_name.to_string(),
373                    tool_name: tool.name,
374                    description: tool.description,
375                    input_schema: tool.input_schema,
376                })
377                .collect(),
378            Err(e) => {
379                if let Some(tx) = &self.connect_events {
380                    let _ = tx.send(McpConnectEvent::Warning {
381                        name: server_name.to_string(),
382                        message: format!("tools/list failed: {}", e),
383                    });
384                } else {
385                    eprintln!("[mcp] Failed to list tools from {}: {}", server_name, e);
386                }
387                Vec::new()
388            }
389        }
390    }
391
392    /// Call a tool on a specific server.
393    pub async fn call_tool(
394        &self,
395        server_name: &str,
396        tool_name: &str,
397        arguments: serde_json::Value,
398    ) -> Result<String> {
399        let servers = self.servers.read().await;
400        let client = servers
401            .get(server_name)
402            .ok_or_else(|| anyhow::anyhow!("MCP server '{}' not found", server_name))?;
403
404        let result = client.call_tool(tool_name, arguments).await?;
405
406        // Extract text from content blocks
407        let output = result
408            .content
409            .into_iter()
410            .filter_map(|c| match c {
411                super::types::ContentBlock::Text { text } => Some(text),
412                _ => None,
413            })
414            .collect::<Vec<_>>()
415            .join("\n");
416
417        if result.is_error {
418            anyhow::bail!("MCP tool error: {}", output);
419        }
420
421        Ok(output)
422    }
423
424    /// Get the status of all servers.
425    pub async fn server_statuses(&self) -> Vec<(String, ServerStatus)> {
426        let servers = self.servers.read().await;
427        servers
428            .iter()
429            .map(|(name, client)| (name.clone(), client.status()))
430            .collect()
431    }
432
433    /// Wait for initial background connections to complete (or timeout).
434    /// Returns immediately if no background connections are pending.
435    pub async fn wait_for_initial_connections(&self, timeout: Duration) {
436        let _ = tokio::time::timeout(timeout, self.initial_ready.notified()).await;
437    }
438
439    /// Get an Arc clone for sharing across threads.
440    pub fn share(&self) -> Arc<Self> {
441        Arc::new(Self {
442            servers: self.servers.clone(),
443            server_timeouts_ms: self.server_timeouts_ms.clone(),
444            connect_events: self.connect_events.clone(),
445            initial_ready: self.initial_ready.clone(),
446            telemetry: self.telemetry.clone(),
447        })
448    }
449}
450
451/// Classify an MCP connection error string into a telemetry `McpErrorKind`.
452fn classify_mcp_error(error: &str) -> McpErrorKind {
453    let e = error.to_lowercase();
454    if e.contains("connection refused") || e.contains("dns") || e.contains("network") {
455        McpErrorKind::NetworkError
456    } else if e.contains("401") || e.contains("403") || e.contains("unauthorized") || e.contains("oauth") {
457        McpErrorKind::AuthError
458    } else if e.contains("not found") || e.contains("no such") || e.contains("path") || e.contains("spawn") {
459        McpErrorKind::ExecutionFailed
460    } else if e.contains("timeout") || e.contains("timed out") {
461        McpErrorKind::Timeout
462    } else if e.contains("server") || e.contains("-326") || e.contains("mcp error") {
463        McpErrorKind::ServerError
464    } else {
465        McpErrorKind::Other
466    }
467}
468
469impl McpServerConfig {
470    fn timeout_ms(&self) -> u64 {
471        match &self.config {
472            super::config::McpTransportConfig::Stdio { timeout_ms, .. }
473            | super::config::McpTransportConfig::Http { timeout_ms, .. } => {
474                timeout_ms.unwrap_or(30_000)
475            }
476        }
477    }
478}
479
480impl Default for McpRegistry {
481    fn default() -> Self {
482        Self::new()
483    }
484}