1use super::server::{McpServerTrait, ServerStatus};
6#[cfg(feature = "daemon")]
7use super::tools::ToolResult;
8use super::tools::{GetPromptResult, Prompt, Tool};
9use super::types::*;
10use crate::error::{Error, Result};
11use chrono::{DateTime, Utc};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::sync::Arc;
15use tokio::sync::RwLock;
16use uuid::Uuid;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
20#[serde(rename_all = "snake_case")]
21pub enum HealthStatus {
22 Healthy,
24 Degraded,
26 Unhealthy,
28 Checking,
30 Unknown,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct HealthCheck {
37 pub server_id: Uuid,
39 pub server_name: String,
41 pub status: HealthStatus,
43 pub checked_at: DateTime<Utc>,
45 pub response_time_ms: Option<f64>,
47 pub error: Option<String>,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct ServerRegistration {
54 pub id: Uuid,
56 pub name: String,
58 pub info: ServerInfo,
60 pub capabilities: ServerCapabilities,
62 pub registered_at: DateTime<Utc>,
64 pub last_health_check: Option<HealthCheck>,
66 pub tags: Vec<String>,
68}
69
70pub struct McpRegistry {
72 servers: Arc<RwLock<HashMap<Uuid, Arc<dyn McpServerTrait>>>>,
74 registrations: Arc<RwLock<HashMap<Uuid, ServerRegistration>>>,
76 health_check_interval_secs: u64,
78 health_check_handle: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
80}
81
82impl McpRegistry {
83 pub fn new() -> Self {
85 Self {
86 servers: Arc::new(RwLock::new(HashMap::new())),
87 registrations: Arc::new(RwLock::new(HashMap::new())),
88 health_check_interval_secs: crate::mcp::DEFAULT_HEALTH_CHECK_INTERVAL_SECS,
89 health_check_handle: Arc::new(RwLock::new(None)),
90 }
91 }
92
93 pub fn with_health_check_interval(interval_secs: u64) -> Self {
95 Self {
96 servers: Arc::new(RwLock::new(HashMap::new())),
97 registrations: Arc::new(RwLock::new(HashMap::new())),
98 health_check_interval_secs: interval_secs,
99 health_check_handle: Arc::new(RwLock::new(None)),
100 }
101 }
102
103 pub async fn register_server(
105 &self,
106 server: Arc<dyn McpServerTrait>,
107 tags: Vec<String>,
108 ) -> Result<Uuid> {
109 let info = server.server_info().await;
110 let capabilities = server.capabilities().await;
111
112 let registration = ServerRegistration {
113 id: Uuid::new_v4(),
114 name: info.name.clone(),
115 info,
116 capabilities,
117 registered_at: Utc::now(),
118 last_health_check: None,
119 tags,
120 };
121
122 let id = registration.id;
123
124 let mut servers = self.servers.write().await;
125 let mut regs = self.registrations.write().await;
126
127 servers.insert(id, server);
128 regs.insert(id, registration);
129
130 Ok(id)
131 }
132
133 pub async fn unregister_server(&self, id: Uuid) -> Result<()> {
135 let mut servers = self.servers.write().await;
136 let mut regs = self.registrations.write().await;
137
138 if let Some(server) = servers.remove(&id) {
139 drop(server);
143 }
144
145 regs.remove(&id);
146
147 Ok(())
148 }
149
150 pub async fn get_server(&self, id: Uuid) -> Option<Arc<dyn McpServerTrait>> {
152 let servers = self.servers.read().await;
153 servers.get(&id).cloned()
154 }
155
156 pub async fn list_servers(&self) -> Vec<ServerRegistration> {
158 let regs = self.registrations.read().await;
159 regs.values().cloned().collect()
160 }
161
162 pub async fn find_servers_by_tag(&self, tag: &str) -> Vec<ServerRegistration> {
164 let regs = self.registrations.read().await;
165 regs.values()
166 .filter(|r| r.tags.iter().any(|t| t == tag))
167 .cloned()
168 .collect()
169 }
170
171 pub async fn list_all_tools(&self) -> Result<Vec<Tool>> {
173 let servers = self.servers.read().await;
174 let mut all_tools = Vec::new();
175
176 for (id, server) in servers.iter() {
177 let regs = self.registrations.read().await;
178 let server_name = regs.get(id).map(|r| r.name.clone()).unwrap_or_default();
179
180 let request = McpRequest::new(
182 RequestId::String(Uuid::new_v4().to_string()),
183 "tools/list",
184 None,
185 );
186
187 match server.send_request(request).await {
188 Ok(response) => {
189 if let Some(result) = response.result {
190 if let Ok(tools_response) =
191 serde_json::from_value::<ToolsListResponse>(result)
192 {
193 for mut tool in tools_response.tools {
194 tool.server_id = Some(*id);
195 tool.server_name = Some(server_name.clone());
196 all_tools.push(tool);
197 }
198 }
199 }
200 }
201 Err(_) => {
202 continue;
204 }
205 }
206 }
207
208 Ok(all_tools)
209 }
210
211 pub async fn list_all_prompts(&self) -> Result<Vec<Prompt>> {
213 let servers = self.servers.read().await;
214 let mut all_prompts = Vec::new();
215
216 for (_, server) in servers.iter() {
217 let request = McpRequest::new(
219 RequestId::String(Uuid::new_v4().to_string()),
220 "prompts/list",
221 None,
222 );
223
224 match server.send_request(request).await {
225 Ok(response) => {
226 if let Some(result) = response.result {
227 if let Ok(prompts_response) =
228 serde_json::from_value::<PromptsListResponse>(result)
229 {
230 all_prompts.extend(prompts_response.prompts);
231 }
232 }
233 }
234 Err(_) => {
235 continue;
236 }
237 }
238 }
239
240 Ok(all_prompts)
241 }
242
243 pub async fn get_prompt(
245 &self,
246 prompt_name: &str,
247 arguments: HashMap<String, String>,
248 server_id: Option<Uuid>,
249 ) -> Result<GetPromptResult> {
250 let servers = self.servers.read().await;
251
252 if let Some(id) = server_id {
254 if let Some(server) = servers.get(&id) {
255 return self
256 .get_prompt_from_server(server.clone(), prompt_name, arguments)
257 .await;
258 } else {
259 return Err(Error::NotFound {
260 resource: format!("Server {}", id),
261 });
262 }
263 }
264
265 for (_, server) in servers.iter() {
268 if let Ok(result) = self
269 .get_prompt_from_server(server.clone(), prompt_name, arguments.clone())
270 .await
271 {
272 return Ok(result);
273 }
274 }
275
276 Err(Error::NotFound {
277 resource: format!("Prompt {}", prompt_name),
278 })
279 }
280
281 async fn get_prompt_from_server(
282 &self,
283 server: Arc<dyn McpServerTrait>,
284 prompt_name: &str,
285 arguments: HashMap<String, String>,
286 ) -> Result<GetPromptResult> {
287 let params = serde_json::json!({
288 "name": prompt_name,
289 "arguments": arguments
290 });
291
292 let request = McpRequest::new(
293 RequestId::String(Uuid::new_v4().to_string()),
294 "prompts/get",
295 Some(params),
296 );
297
298 let response = server.send_request(request).await?;
299
300 if let Some(error) = response.error {
301 return Err(Error::Mcp(error.message));
302 }
303
304 if let Some(result) = response.result {
305 let prompt_result: GetPromptResult =
306 serde_json::from_value(result).map_err(Error::Json)?;
307 Ok(prompt_result)
308 } else {
309 Err(Error::Mcp("Empty response from server".to_string()))
310 }
311 }
312
313 pub async fn check_server_health(&self, id: Uuid) -> Result<HealthCheck> {
315 let server = self.get_server(id).await.ok_or_else(|| Error::NotFound {
316 resource: format!("Server {}", id),
317 })?;
318
319 let regs = self.registrations.read().await;
320 let server_name = regs.get(&id).map(|r| r.name.clone()).unwrap_or_default();
321 drop(regs);
322
323 let start = std::time::Instant::now();
324 let is_healthy = server.health_check().await?;
325 let response_time_ms = start.elapsed().as_millis() as f64;
326
327 let status = match server.status().await {
328 ServerStatus::Running => HealthStatus::Healthy,
329 ServerStatus::Degraded => HealthStatus::Degraded,
330 ServerStatus::Unhealthy | ServerStatus::Failed => HealthStatus::Unhealthy,
331 _ => HealthStatus::Unknown,
332 };
333
334 let health_check = HealthCheck {
335 server_id: id,
336 server_name,
337 status,
338 checked_at: Utc::now(),
339 response_time_ms: Some(response_time_ms),
340 error: if !is_healthy {
341 Some("Health check failed".to_string())
342 } else {
343 None
344 },
345 };
346
347 let mut regs = self.registrations.write().await;
349 if let Some(reg) = regs.get_mut(&id) {
350 reg.last_health_check = Some(health_check.clone());
351 }
352
353 Ok(health_check)
354 }
355
356 pub async fn check_all_health(&self) -> Vec<HealthCheck> {
358 let servers = self.servers.read().await;
359 let server_ids: Vec<Uuid> = servers.keys().copied().collect();
360 drop(servers);
361
362 let mut checks = Vec::new();
363 for id in server_ids {
364 if let Ok(check) = self.check_server_health(id).await {
365 checks.push(check);
366 }
367 }
368
369 checks
370 }
371
372 pub async fn start_health_monitoring(&self) {
374 let servers = self.servers.clone();
375 let registrations = self.registrations.clone();
376 let interval_secs = self.health_check_interval_secs;
377
378 let handle = tokio::spawn(async move {
379 let mut interval = tokio::time::interval(std::time::Duration::from_secs(interval_secs));
380
381 loop {
382 interval.tick().await;
383
384 let servers_guard = servers.read().await;
385 let server_ids: Vec<Uuid> = servers_guard.keys().copied().collect();
386 drop(servers_guard);
387
388 for id in server_ids {
389 let servers_guard = servers.read().await;
390 if let Some(server) = servers_guard.get(&id).cloned() {
391 drop(servers_guard);
392
393 let start = std::time::Instant::now();
394 let is_healthy = server.health_check().await.unwrap_or(false);
395 let response_time_ms = start.elapsed().as_millis() as f64;
396
397 let status = match server.status().await {
398 ServerStatus::Running => HealthStatus::Healthy,
399 ServerStatus::Degraded => HealthStatus::Degraded,
400 ServerStatus::Unhealthy | ServerStatus::Failed => {
401 HealthStatus::Unhealthy
402 }
403 _ => HealthStatus::Unknown,
404 };
405
406 let mut regs = registrations.write().await;
407 if let Some(reg) = regs.get_mut(&id) {
408 let health_check = HealthCheck {
409 server_id: id,
410 server_name: reg.name.clone(),
411 status,
412 checked_at: Utc::now(),
413 response_time_ms: Some(response_time_ms),
414 error: if !is_healthy {
415 Some("Health check failed".to_string())
416 } else {
417 None
418 },
419 };
420 reg.last_health_check = Some(health_check);
421 }
422 }
423 }
424 }
425 });
426
427 let mut handle_lock = self.health_check_handle.write().await;
428 *handle_lock = Some(handle);
429 }
430
431 pub async fn stop_health_monitoring(&self) {
433 let mut handle_lock = self.health_check_handle.write().await;
434 if let Some(handle) = handle_lock.take() {
435 handle.abort();
436 }
437 }
438
439 pub async fn statistics(&self) -> RegistryStatistics {
441 let regs = self.registrations.read().await;
442
443 let mut healthy = 0;
444 let mut degraded = 0;
445 let mut unhealthy = 0;
446 let mut unknown = 0;
447
448 for reg in regs.values() {
449 if let Some(check) = ®.last_health_check {
450 match check.status {
451 HealthStatus::Healthy => healthy += 1,
452 HealthStatus::Degraded => degraded += 1,
453 HealthStatus::Unhealthy => unhealthy += 1,
454 _ => unknown += 1,
455 }
456 } else {
457 unknown += 1;
458 }
459 }
460
461 RegistryStatistics {
462 total_servers: regs.len(),
463 healthy_servers: healthy,
464 degraded_servers: degraded,
465 unhealthy_servers: unhealthy,
466 unknown_servers: unknown,
467 }
468 }
469
470 #[cfg(feature = "daemon")]
477 pub async fn ping_server(&self, id: &Uuid) -> Result<bool> {
478 let server = self.get_server(*id).await.ok_or_else(|| Error::NotFound {
479 resource: format!("Server {}", id),
480 })?;
481
482 server.health_check().await
484 }
485
486 #[cfg(feature = "daemon")]
491 pub async fn reconnect_server(&self, id: &Uuid) -> Result<()> {
492 let server = self.get_server(*id).await.ok_or_else(|| Error::NotFound {
493 resource: format!("Server {}", id),
494 })?;
495
496 let healthy = server.health_check().await?;
498 if healthy {
499 Ok(())
500 } else {
501 Err(Error::network(
502 "Server reconnection failed - health check returned false",
503 ))
504 }
505 }
506
507 #[cfg(feature = "daemon")]
511 pub async fn call_tool_by_name(
512 &self,
513 tool_name: &str,
514 args: serde_json::Value,
515 ) -> Result<ToolResult> {
516 use std::collections::HashMap;
517
518 let servers = self.servers.read().await;
519
520 let args_map: HashMap<String, serde_json::Value> = match args {
522 serde_json::Value::Object(obj) => obj.into_iter().collect(),
523 _ => HashMap::new(),
524 };
525
526 for (_id, server) in servers.iter() {
527 let tools = server.list_tools().await;
529 if tools.iter().any(|t| t.name == tool_name) {
530 return server.call_tool(tool_name, args_map).await;
532 }
533 }
534
535 Err(Error::NotFound {
536 resource: format!("Tool {}", tool_name),
537 })
538 }
539
540 #[cfg(feature = "daemon")]
542 pub async fn disconnect_server(&self, id: &Uuid) -> Result<()> {
543 self.unregister_server(*id).await
545 }
546}
547
548impl Default for McpRegistry {
549 fn default() -> Self {
550 Self::new()
551 }
552}
553
554#[derive(Debug, Clone, Serialize, Deserialize)]
556pub struct RegistryStatistics {
557 pub total_servers: usize,
559 pub healthy_servers: usize,
561 pub degraded_servers: usize,
563 pub unhealthy_servers: usize,
565 pub unknown_servers: usize,
567}
568
569#[derive(Debug, Deserialize)]
571struct ToolsListResponse {
572 tools: Vec<Tool>,
573 #[allow(dead_code)]
574 next_cursor: Option<String>,
575}
576
577#[derive(Debug, Deserialize)]
579struct PromptsListResponse {
580 prompts: Vec<Prompt>,
581 #[allow(dead_code)]
582 next_cursor: Option<String>,
583}
584
585#[cfg(test)]
586mod tests {
587 use super::*;
588
589 #[test]
590 fn test_health_status() {
591 let status = HealthStatus::Healthy;
592 let json = serde_json::to_string(&status).unwrap();
593 assert_eq!(json, "\"healthy\"");
594 }
595
596 #[tokio::test]
597 async fn test_registry_creation() {
598 let registry = McpRegistry::new();
599 let stats = registry.statistics().await;
600 assert_eq!(stats.total_servers, 0);
601 }
602}