1use std::collections::BTreeMap;
4use std::sync::Arc;
5use std::time::Duration;
6
7use anyhow::Result;
8use tokio::sync::{mpsc, RwLock};
9
10use atomcode_telemetry::{Event as TelemetryEvent, McpErrorKind, McpTransport};
11
12use super::client::{McpClient, McpToolInfo};
13use super::config::{load_mcp_config, McpServerConfig};
14use super::transport_http::HttpClient;
15use super::transport_stdio::StdioClient;
16use super::types::ServerStatus;
17
18#[derive(Debug, Clone)]
20pub enum McpConnectEvent {
21 Connected { name: String },
23 Failed { name: String, error: String },
25 Warning { name: String, message: String },
27}
28
29pub struct McpRegistry {
31 servers: Arc<RwLock<BTreeMap<String, Arc<dyn McpClient>>>>,
32 server_timeouts_ms: Arc<RwLock<BTreeMap<String, u64>>>,
33 connect_events: Option<mpsc::UnboundedSender<McpConnectEvent>>,
35 initial_ready: Arc<tokio::sync::Notify>,
37 telemetry: Option<Arc<atomcode_telemetry::Telemetry>>,
39}
40
41impl McpRegistry {
42 pub fn new() -> Self {
44 Self {
45 servers: Arc::new(RwLock::new(BTreeMap::new())),
46 server_timeouts_ms: Arc::new(RwLock::new(BTreeMap::new())),
47 connect_events: None,
48 initial_ready: Arc::new(tokio::sync::Notify::new()),
49 telemetry: None,
50 }
51 }
52
53 pub fn with_telemetry(mut self, tel: Arc<atomcode_telemetry::Telemetry>) -> Self {
55 self.telemetry = Some(tel);
56 self
57 }
58
59 pub fn with_event_channel() -> (Self, mpsc::UnboundedReceiver<McpConnectEvent>) {
61 let (tx, rx) = mpsc::unbounded_channel();
62 (
63 Self {
64 servers: Arc::new(RwLock::new(BTreeMap::new())),
65 server_timeouts_ms: Arc::new(RwLock::new(BTreeMap::new())),
66 connect_events: Some(tx),
67 initial_ready: Arc::new(tokio::sync::Notify::new()),
68 telemetry: None,
69 },
70 rx,
71 )
72 }
73
74 pub fn event_sender(&self) -> Option<mpsc::UnboundedSender<McpConnectEvent>> {
76 self.connect_events.clone()
77 }
78
79 pub fn from_config_background(project_dir: &std::path::Path) -> Self {
83 Self::from_config_background_with_events(project_dir, None)
84 }
85
86 pub fn from_config_background_with_events(
89 project_dir: &std::path::Path,
90 event_tx: Option<mpsc::UnboundedSender<McpConnectEvent>>,
91 ) -> Self {
92 let mut registry = Self::new();
93 let combined_tx = event_tx.or(registry.connect_events.clone());
95 registry.connect_events = combined_tx.clone();
96
97 let configs = match load_mcp_config(project_dir) {
98 Ok(c) => c,
99 Err(e) => {
100 if let Some(tx) = &combined_tx {
101 let _ = tx.send(McpConnectEvent::Failed {
102 name: "config".to_string(),
103 error: format!("Failed to load config: {}", e),
104 });
105 }
106 return registry;
107 }
108 };
109
110 if !configs.is_empty() {
111 let servers = registry.servers.clone();
112 let server_timeouts_ms = registry.server_timeouts_ms.clone();
113 let initial_ready = registry.initial_ready.clone();
114 let telemetry = registry.telemetry.clone();
115 tokio::spawn(async move {
116 let tasks: Vec<_> = configs
118 .into_iter()
119 .map(|config| {
120 let servers = servers.clone();
121 let server_timeouts_ms = server_timeouts_ms.clone();
122 let tx = combined_tx.clone();
123 let telemetry = telemetry.clone();
124 async move {
125 let name = config.name.clone();
126 let timeout_ms = config.timeout_ms();
127 let config_source = config.source;
128 let transport = match &config.config {
129 super::config::McpTransportConfig::Stdio { .. } => McpTransport::Stdio,
130 super::config::McpTransportConfig::Http { .. } => McpTransport::StreamableHttp,
131 };
132 let start = std::time::Instant::now();
133 let mut client: Box<dyn McpClient> = match &config.config {
134 super::config::McpTransportConfig::Stdio {
135 command,
136 args,
137 env,
138 timeout_ms,
139 } => Box::new(StdioClient::new(
140 name.clone(),
141 command.clone(),
142 args.clone(),
143 env.clone(),
144 *timeout_ms,
145 )),
146 super::config::McpTransportConfig::Http {
147 url,
148 headers,
149 auth,
150 timeout_ms,
151 } => Box::new(HttpClient::new(
152 name.clone(),
153 url.clone(),
154 headers.clone(),
155 auth.clone(),
156 *timeout_ms,
157 )),
158 };
159
160 match client.initialize().await {
161 Ok(_result) => {
162 let duration_ms = start.elapsed().as_millis() as u32;
163 let mut servers = servers.write().await;
164 servers.insert(name.clone(), Arc::from(client));
165 drop(servers);
166 let mut timeouts = server_timeouts_ms.write().await;
167 timeouts.insert(name.clone(), timeout_ms);
168 if let Some(tx) = tx {
169 let _ = tx.send(McpConnectEvent::Connected {
170 name: name.clone(),
171 });
172 }
173 if let Some(tel) = &telemetry {
174 tel.track(TelemetryEvent::McpConnect {
175 server_name: name.clone(),
176 transport,
177 success: true,
178 duration_ms: Some(duration_ms),
179 error_kind: None,
180 error_data: Some(serde_json::json!({
181 "server_name": name,
182 "transport": match transport { McpTransport::Stdio => "stdio", McpTransport::Sse => "sse", McpTransport::StreamableHttp => "streamable_http" },
183 "duration_ms": duration_ms,
184 "tool_count": 0, "config_source": config_source.as_str(),
186 }).to_string()),
187 });
188 }
189 }
190 Err(e) => {
191 let duration_ms = start.elapsed().as_millis() as u32;
192 let error_str = format!("{}", e);
193 if let Some(tx) = tx {
194 let _ = tx.send(McpConnectEvent::Failed {
195 name: name.clone(),
196 error: error_str.clone(),
197 });
198 }
199 if let Some(tel) = &telemetry {
200 let error_kind = classify_mcp_error(&error_str);
201 tel.track(TelemetryEvent::McpConnect {
202 server_name: name.clone(),
203 transport,
204 success: false,
205 duration_ms: Some(duration_ms),
206 error_kind: Some(error_kind),
207 error_data: Some(serde_json::json!({
208 "server_name": name,
209 "transport": match transport { McpTransport::Stdio => "stdio", McpTransport::Sse => "sse", McpTransport::StreamableHttp => "streamable_http" },
210 "duration_ms": duration_ms,
211 "message": atomcode_telemetry::scrub::truncate_head(&error_str, 200),
212 "config_source": config_source.as_str(),
213 }).to_string()),
214 });
215 }
216 }
217 }
218 }
219 })
220 .collect();
221
222 futures::future::join_all(tasks).await;
224 initial_ready.notify_waiters();
226 });
227 } else {
228 registry.initial_ready.notify_waiters();
230 }
231
232 registry
233 }
234
235 pub async fn from_config(project_dir: &std::path::Path) -> Self {
238 let registry = Self::new();
239
240 let configs = match load_mcp_config(project_dir) {
241 Ok(c) => c,
242 Err(e) => {
243 eprintln!("[mcp] Failed to load config: {}", e);
244 return registry;
245 }
246 };
247
248 for config in configs {
249 if let Err(e) = registry.add_server(config).await {
250 eprintln!("[mcp] Failed to connect server: {}", e);
251 }
252 }
253
254 registry
255 }
256
257 pub async fn add_server(&self, config: McpServerConfig) -> Result<()> {
259 let mut client: Box<dyn McpClient> = match &config.config {
260 super::config::McpTransportConfig::Stdio {
261 command,
262 args,
263 env,
264 timeout_ms,
265 } => Box::new(StdioClient::new(
266 config.name.clone(),
267 command.clone(),
268 args.clone(),
269 env.clone(),
270 *timeout_ms,
271 )),
272 super::config::McpTransportConfig::Http {
273 url,
274 headers,
275 auth,
276 timeout_ms,
277 } => Box::new(HttpClient::new(
278 config.name.clone(),
279 url.clone(),
280 headers.clone(),
281 auth.clone(),
282 *timeout_ms,
283 )),
284 };
285
286 client.initialize().await?;
287
288 let mut servers = self.servers.write().await;
289 servers.insert(config.name.clone(), Arc::from(client));
290 drop(servers);
291 let mut timeouts = self.server_timeouts_ms.write().await;
292 timeouts.insert(config.name.clone(), config.timeout_ms());
293
294 Ok(())
295 }
296
297 pub async fn list_tools_timeout(&self, server_name: &str) -> Duration {
303 let configured_ms = {
304 let timeouts = self.server_timeouts_ms.read().await;
305 timeouts.get(server_name).copied().unwrap_or(30_000)
306 };
307 Duration::from_millis(configured_ms.saturating_add(5_000))
308 }
309
310 pub async fn list_all_tools(&self) -> Vec<McpToolInfo> {
312 let server_snapshot: Vec<(String, Arc<dyn McpClient>)> = {
315 let servers = self.servers.read().await;
316 servers
317 .iter()
318 .map(|(name, client)| (name.clone(), Arc::clone(client)))
319 .collect()
320 };
321 let mut all_tools = Vec::new();
322
323 for (server_name, client) in server_snapshot {
324 match client.list_tools().await {
325 Ok(result) => {
326 for tool in result.tools {
327 all_tools.push(McpToolInfo {
328 server_name: server_name.clone(),
329 tool_name: tool.name,
330 description: tool.description,
331 input_schema: tool.input_schema,
332 });
333 }
334 }
335 Err(e) => {
336 if let Some(tx) = &self.connect_events {
337 let _ = tx.send(McpConnectEvent::Warning {
338 name: server_name.clone(),
339 message: format!("tools/list failed: {}", e),
340 });
341 } else {
342 eprintln!("[mcp] Failed to list tools from {}: {}", server_name, e);
343 }
344 }
345 }
346 }
347
348 all_tools
349 }
350
351 pub async fn list_tools_for_server(&self, server_name: &str) -> Vec<McpToolInfo> {
353 let client = {
354 let servers = self.servers.read().await;
355 servers.get(server_name).map(Arc::clone)
356 };
357 let Some(client) = client else {
358 if let Some(tx) = &self.connect_events {
359 let _ = tx.send(McpConnectEvent::Warning {
360 name: server_name.to_string(),
361 message: "tools/list skipped: server not found".to_string(),
362 });
363 }
364 return Vec::new();
365 };
366
367 match client.list_tools().await {
368 Ok(result) => result
369 .tools
370 .into_iter()
371 .map(|tool| McpToolInfo {
372 server_name: server_name.to_string(),
373 tool_name: tool.name,
374 description: tool.description,
375 input_schema: tool.input_schema,
376 })
377 .collect(),
378 Err(e) => {
379 if let Some(tx) = &self.connect_events {
380 let _ = tx.send(McpConnectEvent::Warning {
381 name: server_name.to_string(),
382 message: format!("tools/list failed: {}", e),
383 });
384 } else {
385 eprintln!("[mcp] Failed to list tools from {}: {}", server_name, e);
386 }
387 Vec::new()
388 }
389 }
390 }
391
392 pub async fn call_tool(
394 &self,
395 server_name: &str,
396 tool_name: &str,
397 arguments: serde_json::Value,
398 ) -> Result<String> {
399 let servers = self.servers.read().await;
400 let client = servers
401 .get(server_name)
402 .ok_or_else(|| anyhow::anyhow!("MCP server '{}' not found", server_name))?;
403
404 let result = client.call_tool(tool_name, arguments).await?;
405
406 let output = result
408 .content
409 .into_iter()
410 .filter_map(|c| match c {
411 super::types::ContentBlock::Text { text } => Some(text),
412 _ => None,
413 })
414 .collect::<Vec<_>>()
415 .join("\n");
416
417 if result.is_error {
418 anyhow::bail!("MCP tool error: {}", output);
419 }
420
421 Ok(output)
422 }
423
424 pub async fn server_statuses(&self) -> Vec<(String, ServerStatus)> {
426 let servers = self.servers.read().await;
427 servers
428 .iter()
429 .map(|(name, client)| (name.clone(), client.status()))
430 .collect()
431 }
432
433 pub async fn wait_for_initial_connections(&self, timeout: Duration) {
436 let _ = tokio::time::timeout(timeout, self.initial_ready.notified()).await;
437 }
438
439 pub fn share(&self) -> Arc<Self> {
441 Arc::new(Self {
442 servers: self.servers.clone(),
443 server_timeouts_ms: self.server_timeouts_ms.clone(),
444 connect_events: self.connect_events.clone(),
445 initial_ready: self.initial_ready.clone(),
446 telemetry: self.telemetry.clone(),
447 })
448 }
449}
450
451fn classify_mcp_error(error: &str) -> McpErrorKind {
453 let e = error.to_lowercase();
454 if e.contains("connection refused") || e.contains("dns") || e.contains("network") {
455 McpErrorKind::NetworkError
456 } else if e.contains("401") || e.contains("403") || e.contains("unauthorized") || e.contains("oauth") {
457 McpErrorKind::AuthError
458 } else if e.contains("not found") || e.contains("no such") || e.contains("path") || e.contains("spawn") {
459 McpErrorKind::ExecutionFailed
460 } else if e.contains("timeout") || e.contains("timed out") {
461 McpErrorKind::Timeout
462 } else if e.contains("server") || e.contains("-326") || e.contains("mcp error") {
463 McpErrorKind::ServerError
464 } else {
465 McpErrorKind::Other
466 }
467}
468
469impl McpServerConfig {
470 fn timeout_ms(&self) -> u64 {
471 match &self.config {
472 super::config::McpTransportConfig::Stdio { timeout_ms, .. }
473 | super::config::McpTransportConfig::Http { timeout_ms, .. } => {
474 timeout_ms.unwrap_or(30_000)
475 }
476 }
477 }
478}
479
480impl Default for McpRegistry {
481 fn default() -> Self {
482 Self::new()
483 }
484}