1use crate::{McpError, McpServer, Result, ToolSchema};
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use std::collections::HashMap;
11use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
12use std::sync::Arc;
13use std::time::Instant;
14use tokio::sync::RwLock;
15
16type ServerMap = Arc<RwLock<HashMap<String, ServerEntry>>>;
18
19#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
21pub enum LoadBalanceStrategy {
22 #[default]
24 RoundRobin,
25 LeastConnections,
27 Random,
29 WeightedRoundRobin,
31}
32
33#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
35pub enum ServerHealth {
36 #[default]
37 Healthy,
38 Degraded,
39 Unhealthy,
40}
41
42#[derive(Clone)]
44pub struct ServerEntry {
45 pub server: Arc<Box<dyn McpServer>>,
47 pub weight: u32,
49 pub health: ServerHealth,
51 pub active_connections: Arc<AtomicU64>,
53 pub request_count: Arc<AtomicU64>,
55 pub error_count: Arc<AtomicU64>,
57 pub avg_response_time_ms: Arc<AtomicU64>,
59 pub last_health_check: Option<Instant>,
61 pub group: Option<String>,
63 pub tags: Vec<String>,
65}
66
67impl ServerEntry {
68 fn new(server: Arc<Box<dyn McpServer>>) -> Self {
69 Self {
70 server,
71 weight: 1,
72 health: ServerHealth::Healthy,
73 active_connections: Arc::new(AtomicU64::new(0)),
74 request_count: Arc::new(AtomicU64::new(0)),
75 error_count: Arc::new(AtomicU64::new(0)),
76 avg_response_time_ms: Arc::new(AtomicU64::new(0)),
77 last_health_check: None,
78 group: None,
79 tags: Vec::new(),
80 }
81 }
82
83 fn with_weight(mut self, weight: u32) -> Self {
84 self.weight = weight;
85 self
86 }
87
88 fn with_group(mut self, group: String) -> Self {
89 self.group = Some(group);
90 self
91 }
92
93 fn with_tags(mut self, tags: Vec<String>) -> Self {
94 self.tags = tags;
95 self
96 }
97}
98
99pub struct McpRegistry {
102 servers: ServerMap,
103 rr_counter: AtomicUsize,
105 default_strategy: LoadBalanceStrategy,
107 health_check_interval_secs: u64,
109 unhealthy_threshold: u32,
111 recovery_threshold: u32,
113}
114
115impl McpRegistry {
116 pub fn new() -> Self {
118 Self {
119 servers: Arc::new(RwLock::new(HashMap::new())),
120 rr_counter: AtomicUsize::new(0),
121 default_strategy: LoadBalanceStrategy::RoundRobin,
122 health_check_interval_secs: 30,
123 unhealthy_threshold: 3,
124 recovery_threshold: 2,
125 }
126 }
127
128 pub fn with_config(
130 strategy: LoadBalanceStrategy,
131 health_check_interval_secs: u64,
132 unhealthy_threshold: u32,
133 recovery_threshold: u32,
134 ) -> Self {
135 Self {
136 servers: Arc::new(RwLock::new(HashMap::new())),
137 rr_counter: AtomicUsize::new(0),
138 default_strategy: strategy,
139 health_check_interval_secs,
140 unhealthy_threshold,
141 recovery_threshold,
142 }
143 }
144
145 pub async fn register<S: McpServer + 'static>(
147 &self,
148 server_id: String,
149 server: S,
150 ) -> Result<()> {
151 let mut servers = self.servers.write().await;
152 let entry = ServerEntry::new(Arc::new(Box::new(server)));
153 servers.insert(server_id.clone(), entry);
154 tracing::info!("Registered MCP server: {}", server_id);
155 Ok(())
156 }
157
158 pub async fn register_with_weight<S: McpServer + 'static>(
160 &self,
161 server_id: String,
162 server: S,
163 weight: u32,
164 ) -> Result<()> {
165 let mut servers = self.servers.write().await;
166 let entry = ServerEntry::new(Arc::new(Box::new(server))).with_weight(weight);
167 servers.insert(server_id.clone(), entry);
168 tracing::info!(
169 "Registered MCP server: {} with weight {}",
170 server_id,
171 weight
172 );
173 Ok(())
174 }
175
176 pub async fn register_with_group<S: McpServer + 'static>(
178 &self,
179 server_id: String,
180 server: S,
181 group: String,
182 ) -> Result<()> {
183 let mut servers = self.servers.write().await;
184 let entry = ServerEntry::new(Arc::new(Box::new(server))).with_group(group.clone());
185 servers.insert(server_id.clone(), entry);
186 tracing::info!("Registered MCP server: {} in group {}", server_id, group);
187 Ok(())
188 }
189
190 pub async fn register_with_tags<S: McpServer + 'static>(
192 &self,
193 server_id: String,
194 server: S,
195 tags: Vec<String>,
196 ) -> Result<()> {
197 let mut servers = self.servers.write().await;
198 let entry = ServerEntry::new(Arc::new(Box::new(server))).with_tags(tags.clone());
199 servers.insert(server_id.clone(), entry);
200 tracing::info!("Registered MCP server: {} with tags {:?}", server_id, tags);
201 Ok(())
202 }
203
204 pub async fn unregister(&self, server_id: &str) -> Result<()> {
206 let mut servers = self.servers.write().await;
207 servers
208 .remove(server_id)
209 .ok_or_else(|| McpError::ServerError(format!("Server '{}' not found", server_id)))?;
210 tracing::info!("Unregistered MCP server: {}", server_id);
211 Ok(())
212 }
213
214 pub async fn get_server(&self, server_id: &str) -> Option<Arc<Box<dyn McpServer>>> {
216 let servers = self.servers.read().await;
217 servers.get(server_id).map(|entry| entry.server.clone())
218 }
219
220 pub async fn get_server_entry(&self, server_id: &str) -> Option<ServerEntry> {
222 let servers = self.servers.read().await;
223 servers.get(server_id).cloned()
224 }
225
226 pub async fn set_server_health(&self, server_id: &str, health: ServerHealth) -> Result<()> {
228 let mut servers = self.servers.write().await;
229 let entry = servers
230 .get_mut(server_id)
231 .ok_or_else(|| McpError::ServerError(format!("Server '{}' not found", server_id)))?;
232 entry.health = health;
233 entry.last_health_check = Some(Instant::now());
234 tracing::info!("Updated server {} health to {:?}", server_id, health);
235 Ok(())
236 }
237
238 pub async fn set_server_weight(&self, server_id: &str, weight: u32) -> Result<()> {
240 let mut servers = self.servers.write().await;
241 let entry = servers
242 .get_mut(server_id)
243 .ok_or_else(|| McpError::ServerError(format!("Server '{}' not found", server_id)))?;
244 entry.weight = weight;
245 tracing::info!("Updated server {} weight to {}", server_id, weight);
246 Ok(())
247 }
248
249 pub async fn select_server(&self, strategy: Option<LoadBalanceStrategy>) -> Option<String> {
251 let strategy = strategy.unwrap_or(self.default_strategy);
252 let servers = self.servers.read().await;
253
254 let healthy_servers: Vec<(&String, &ServerEntry)> = servers
256 .iter()
257 .filter(|(_, e)| e.health == ServerHealth::Healthy)
258 .collect();
259
260 if healthy_servers.is_empty() {
261 return None;
262 }
263
264 match strategy {
265 LoadBalanceStrategy::RoundRobin => {
266 let idx = self.rr_counter.fetch_add(1, Ordering::Relaxed) % healthy_servers.len();
267 healthy_servers.get(idx).map(|(id, _)| (*id).clone())
268 }
269 LoadBalanceStrategy::LeastConnections => healthy_servers
270 .iter()
271 .min_by_key(|(_, e)| e.active_connections.load(Ordering::Relaxed))
272 .map(|(id, _)| (*id).clone()),
273 LoadBalanceStrategy::Random => {
274 use std::collections::hash_map::DefaultHasher;
275 use std::hash::{Hash, Hasher};
276 let mut hasher = DefaultHasher::new();
277 std::time::Instant::now().hash(&mut hasher);
278 let hash = hasher.finish() as usize;
279 let idx = hash % healthy_servers.len();
280 healthy_servers.get(idx).map(|(id, _)| (*id).clone())
281 }
282 LoadBalanceStrategy::WeightedRoundRobin => {
283 let total_weight: u32 = healthy_servers.iter().map(|(_, e)| e.weight).sum();
284 if total_weight == 0 {
285 return healthy_servers.first().map(|(id, _)| (*id).clone());
286 }
287 let idx = self.rr_counter.fetch_add(1, Ordering::Relaxed);
288 let mut position = (idx as u32) % total_weight;
289
290 for (id, entry) in &healthy_servers {
291 if position < entry.weight {
292 return Some((*id).clone());
293 }
294 position -= entry.weight;
295 }
296 healthy_servers.first().map(|(id, _)| (*id).clone())
297 }
298 }
299 }
300
301 pub async fn select_server_from_group(&self, group: &str) -> Option<String> {
303 let servers = self.servers.read().await;
304
305 let group_servers: Vec<(&String, &ServerEntry)> = servers
306 .iter()
307 .filter(|(_, e)| e.health == ServerHealth::Healthy && e.group.as_deref() == Some(group))
308 .collect();
309
310 if group_servers.is_empty() {
311 return None;
312 }
313
314 let idx = self.rr_counter.fetch_add(1, Ordering::Relaxed) % group_servers.len();
315 group_servers.get(idx).map(|(id, _)| (*id).clone())
316 }
317
318 pub async fn select_server_by_tag(&self, tag: &str) -> Option<String> {
320 let servers = self.servers.read().await;
321
322 let tagged_servers: Vec<(&String, &ServerEntry)> = servers
323 .iter()
324 .filter(|(_, e)| e.health == ServerHealth::Healthy && e.tags.contains(&tag.to_string()))
325 .collect();
326
327 if tagged_servers.is_empty() {
328 return None;
329 }
330
331 let idx = self.rr_counter.fetch_add(1, Ordering::Relaxed) % tagged_servers.len();
332 tagged_servers.get(idx).map(|(id, _)| (*id).clone())
333 }
334
335 pub async fn invoke_tool_with_failover(
337 &self,
338 tool_name: &str,
339 arguments: Value,
340 max_retries: u32,
341 ) -> Result<Value> {
342 let server_ids = self.find_tool(tool_name).await?;
344 if server_ids.is_empty() {
345 return Err(McpError::ToolNotFound(format!(
346 "Tool '{}' not found in any server",
347 tool_name
348 )));
349 }
350
351 let mut last_error = None;
352 let mut tried_servers: Vec<String> = Vec::new();
353
354 for _ in 0..max_retries.min(server_ids.len() as u32) {
355 let servers = self.servers.read().await;
357 let available_server = server_ids
358 .iter()
359 .filter(|id| !tried_servers.contains(id))
360 .find(|id| {
361 servers
362 .get(*id)
363 .map(|e| e.health == ServerHealth::Healthy)
364 .unwrap_or(false)
365 })
366 .cloned();
367 drop(servers);
368
369 let server_id = match available_server {
370 Some(id) => id,
371 None => break, };
373
374 tried_servers.push(server_id.clone());
375
376 {
378 let servers = self.servers.read().await;
379 if let Some(entry) = servers.get(&server_id) {
380 entry.active_connections.fetch_add(1, Ordering::Relaxed);
381 entry.request_count.fetch_add(1, Ordering::Relaxed);
382 }
383 }
384
385 let start_time = Instant::now();
386 let result = self
387 .invoke_tool(&server_id, tool_name, arguments.clone())
388 .await;
389
390 {
392 let servers = self.servers.read().await;
393 if let Some(entry) = servers.get(&server_id) {
394 entry.active_connections.fetch_sub(1, Ordering::Relaxed);
395 let elapsed_ms = start_time.elapsed().as_millis() as u64;
396 let old_avg = entry.avg_response_time_ms.load(Ordering::Relaxed);
398 let new_avg = (old_avg + elapsed_ms) / 2;
399 entry.avg_response_time_ms.store(new_avg, Ordering::Relaxed);
400 }
401 }
402
403 match result {
404 Ok(value) => return Ok(value),
405 Err(e) => {
406 tracing::warn!(
407 "Tool invocation failed on server {}: {}. Trying failover...",
408 server_id,
409 e
410 );
411
412 {
414 let servers = self.servers.read().await;
415 if let Some(entry) = servers.get(&server_id) {
416 entry.error_count.fetch_add(1, Ordering::Relaxed);
417 }
418 }
419
420 last_error = Some(e);
421 }
422 }
423 }
424
425 Err(last_error.unwrap_or_else(|| {
426 McpError::ServerError("All servers failed or unavailable".to_string())
427 }))
428 }
429
430 pub async fn get_server_health_status(&self) -> HashMap<String, ServerHealth> {
432 let servers = self.servers.read().await;
433 servers
434 .iter()
435 .map(|(id, entry)| (id.clone(), entry.health))
436 .collect()
437 }
438
439 pub async fn get_server_metrics(&self, server_id: &str) -> Option<ServerMetrics> {
441 let servers = self.servers.read().await;
442 servers.get(server_id).map(|entry| ServerMetrics {
443 server_id: server_id.to_string(),
444 health: entry.health,
445 weight: entry.weight,
446 active_connections: entry.active_connections.load(Ordering::Relaxed),
447 request_count: entry.request_count.load(Ordering::Relaxed),
448 error_count: entry.error_count.load(Ordering::Relaxed),
449 avg_response_time_ms: entry.avg_response_time_ms.load(Ordering::Relaxed),
450 group: entry.group.clone(),
451 tags: entry.tags.clone(),
452 })
453 }
454
455 pub async fn get_all_server_metrics(&self) -> Vec<ServerMetrics> {
457 let servers = self.servers.read().await;
458 servers
459 .iter()
460 .map(|(id, entry)| ServerMetrics {
461 server_id: id.clone(),
462 health: entry.health,
463 weight: entry.weight,
464 active_connections: entry.active_connections.load(Ordering::Relaxed),
465 request_count: entry.request_count.load(Ordering::Relaxed),
466 error_count: entry.error_count.load(Ordering::Relaxed),
467 avg_response_time_ms: entry.avg_response_time_ms.load(Ordering::Relaxed),
468 group: entry.group.clone(),
469 tags: entry.tags.clone(),
470 })
471 .collect()
472 }
473
474 pub async fn list_servers_in_group(&self, group: &str) -> Vec<String> {
476 let servers = self.servers.read().await;
477 servers
478 .iter()
479 .filter(|(_, e)| e.group.as_deref() == Some(group))
480 .map(|(id, _)| id.clone())
481 .collect()
482 }
483
484 pub async fn list_servers_with_tag(&self, tag: &str) -> Vec<String> {
486 let servers = self.servers.read().await;
487 servers
488 .iter()
489 .filter(|(_, e)| e.tags.contains(&tag.to_string()))
490 .map(|(id, _)| id.clone())
491 .collect()
492 }
493
494 pub fn get_config(&self) -> LoadBalanceConfig {
496 LoadBalanceConfig {
497 default_strategy: self.default_strategy,
498 health_check_interval_secs: self.health_check_interval_secs,
499 unhealthy_threshold: self.unhealthy_threshold,
500 recovery_threshold: self.recovery_threshold,
501 }
502 }
503
504 pub async fn list_server_ids(&self) -> Vec<String> {
506 let servers = self.servers.read().await;
507 servers.keys().cloned().collect()
508 }
509
510 pub async fn list_all_tools(&self) -> Result<HashMap<String, Vec<ToolSchema>>> {
512 let servers = self.servers.read().await;
513 let mut all_tools = HashMap::new();
514
515 for (server_id, entry) in servers.iter() {
516 let tools_json = entry.server.list_tools().await?;
517 let tools: Vec<ToolSchema> = tools_json
518 .into_iter()
519 .filter_map(|v| serde_json::from_value(v).ok())
520 .collect();
521 all_tools.insert(server_id.clone(), tools);
522 }
523
524 Ok(all_tools)
525 }
526
527 pub async fn list_tools(&self, server_id: &str) -> Result<Vec<ToolSchema>> {
529 let server = self
530 .get_server(server_id)
531 .await
532 .ok_or_else(|| McpError::ServerError(format!("Server '{}' not found", server_id)))?;
533
534 let tools_json = server.list_tools().await?;
535 let tools: Vec<ToolSchema> = tools_json
536 .into_iter()
537 .filter_map(|v| serde_json::from_value(v).ok())
538 .collect();
539
540 Ok(tools)
541 }
542
543 pub async fn invoke_tool(
545 &self,
546 server_id: &str,
547 tool_name: &str,
548 arguments: Value,
549 ) -> Result<Value> {
550 let server = self
551 .get_server(server_id)
552 .await
553 .ok_or_else(|| McpError::ServerError(format!("Server '{}' not found", server_id)))?;
554
555 server.call_tool(tool_name, arguments).await
556 }
557
558 pub async fn find_tool(&self, tool_name: &str) -> Result<Vec<String>> {
560 let all_tools = self.list_all_tools().await?;
561 let mut server_ids = Vec::new();
562
563 for (server_id, tools) in all_tools {
564 if tools.iter().any(|t| t.name == tool_name) {
565 server_ids.push(server_id);
566 }
567 }
568
569 Ok(server_ids)
570 }
571
572 pub async fn get_stats(&self) -> RegistryStats {
574 let servers = self.servers.read().await;
575 let server_count = servers.len();
576
577 let mut total_tools = 0;
578 for entry in servers.values() {
579 if let Ok(tools) = entry.server.list_tools().await {
580 total_tools += tools.len();
581 }
582 }
583
584 RegistryStats {
585 server_count,
586 total_tools,
587 }
588 }
589
590 pub async fn clear(&self) {
592 let mut servers = self.servers.write().await;
593 servers.clear();
594 tracing::info!("Cleared all registered MCP servers");
595 }
596}
597
598impl Default for McpRegistry {
599 fn default() -> Self {
600 Self::new()
601 }
602}
603
604#[derive(Debug, Clone)]
606pub struct RegistryStats {
607 pub server_count: usize,
608 pub total_tools: usize,
609}
610
611#[derive(Debug, Clone, Serialize, Deserialize)]
613pub struct ServerMetrics {
614 pub server_id: String,
615 pub health: ServerHealth,
616 pub weight: u32,
617 pub active_connections: u64,
618 pub request_count: u64,
619 pub error_count: u64,
620 pub avg_response_time_ms: u64,
621 pub group: Option<String>,
622 pub tags: Vec<String>,
623}
624
625impl ServerMetrics {
626 pub fn error_rate(&self) -> f64 {
628 if self.request_count == 0 {
629 return 0.0;
630 }
631 self.error_count as f64 / self.request_count as f64
632 }
633
634 pub fn should_mark_unhealthy(&self, error_rate_threshold: f64) -> bool {
636 self.error_rate() > error_rate_threshold && self.request_count > 10
637 }
638}
639
640#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
642pub struct LoadBalanceConfig {
643 pub default_strategy: LoadBalanceStrategy,
644 pub health_check_interval_secs: u64,
645 pub unhealthy_threshold: u32,
646 pub recovery_threshold: u32,
647}
648
649impl Default for LoadBalanceConfig {
650 fn default() -> Self {
651 Self {
652 default_strategy: LoadBalanceStrategy::RoundRobin,
653 health_check_interval_secs: 30,
654 unhealthy_threshold: 3,
655 recovery_threshold: 2,
656 }
657 }
658}
659
660#[async_trait]
661impl McpServer for McpRegistry {
662 async fn call_tool(&self, name: &str, arguments: Value) -> Result<Value> {
665 let server_id = arguments
666 .get("server_id")
667 .and_then(|v| v.as_str())
668 .ok_or_else(|| McpError::InvalidRequest("Missing 'server_id' field".to_string()))?
669 .to_string();
670
671 self.invoke_tool(&server_id, name, arguments).await
672 }
673
674 async fn list_tools(&self) -> Result<Vec<Value>> {
676 let all_tools = self.list_all_tools().await?;
677 let mut tools = Vec::new();
678
679 for (server_id, server_tools) in all_tools {
680 for tool in server_tools {
681 tools.push(serde_json::json!({
682 "server_id": server_id,
683 "name": tool.name,
684 "description": tool.description,
685 "inputSchema": tool.input_schema,
686 }));
687 }
688 }
689
690 Ok(tools)
691 }
692}
693
694#[cfg(test)]
695mod tests {
696 use super::*;
697 use crate::servers::FilesystemServer;
698 use std::path::PathBuf;
699
700 #[tokio::test]
701 async fn test_registry_creation() {
702 let registry = McpRegistry::new();
703 let server_ids = registry.list_server_ids().await;
704 assert_eq!(server_ids.len(), 0);
705 }
706
707 #[tokio::test]
708 async fn test_register_server() {
709 let registry = McpRegistry::new();
710 let fs_server = FilesystemServer::new(PathBuf::from("/tmp"));
711
712 registry
713 .register("fs".to_string(), fs_server)
714 .await
715 .unwrap();
716
717 let server_ids = registry.list_server_ids().await;
718 assert_eq!(server_ids.len(), 1);
719 assert!(server_ids.contains(&"fs".to_string()));
720 }
721
722 #[tokio::test]
723 async fn test_unregister_server() {
724 let registry = McpRegistry::new();
725 let fs_server = FilesystemServer::new(PathBuf::from("/tmp"));
726
727 registry
728 .register("fs".to_string(), fs_server)
729 .await
730 .unwrap();
731
732 registry.unregister("fs").await.unwrap();
733
734 let server_ids = registry.list_server_ids().await;
735 assert_eq!(server_ids.len(), 0);
736 }
737
738 #[tokio::test]
739 async fn test_list_tools() {
740 let registry = McpRegistry::new();
741 let fs_server = FilesystemServer::new(PathBuf::from("/tmp"));
742
743 registry
744 .register("fs".to_string(), fs_server)
745 .await
746 .unwrap();
747
748 let tools = registry.list_tools("fs").await.unwrap();
749 assert!(!tools.is_empty());
750 assert!(tools.iter().any(|t| t.name == "fs_read"));
751 }
752
753 #[tokio::test]
754 async fn test_find_tool() {
755 let registry = McpRegistry::new();
756 let fs_server = FilesystemServer::new(PathBuf::from("/tmp"));
757
758 registry
759 .register("fs".to_string(), fs_server)
760 .await
761 .unwrap();
762
763 let servers = registry.find_tool("fs_read").await.unwrap();
764 assert_eq!(servers.len(), 1);
765 assert_eq!(servers[0], "fs");
766 }
767
768 #[tokio::test]
769 async fn test_get_stats() {
770 let registry = McpRegistry::new();
771 let fs_server = FilesystemServer::new(PathBuf::from("/tmp"));
772
773 registry
774 .register("fs".to_string(), fs_server)
775 .await
776 .unwrap();
777
778 let stats = registry.get_stats().await;
779 assert_eq!(stats.server_count, 1);
780 assert!(stats.total_tools > 0);
781 }
782
783 #[tokio::test]
784 async fn test_clear() {
785 let registry = McpRegistry::new();
786 let fs_server = FilesystemServer::new(PathBuf::from("/tmp"));
787
788 registry
789 .register("fs".to_string(), fs_server)
790 .await
791 .unwrap();
792
793 registry.clear().await;
794
795 let server_ids = registry.list_server_ids().await;
796 assert_eq!(server_ids.len(), 0);
797 }
798}