1use crate::core::error::{McpError, McpResult};
8use crate::core::tool::Tool;
9use crate::core::tool_metadata::{
10 CategoryFilter, DeprecationSeverity, EnhancedToolMetadata, ToolBehaviorHints,
11};
12use chrono::Utc;
13use std::collections::HashMap;
14use std::time::Duration;
15
16pub struct ToolRegistry {
18 tools: HashMap<String, Tool>,
20 global_stats: GlobalToolStats,
22}
23
24#[derive(Debug, Clone)]
26pub struct GlobalToolStats {
27 pub total_tools: usize,
29 pub deprecated_tools: usize,
31 pub disabled_tools: usize,
33 pub total_executions: u64,
35 pub total_successes: u64,
37 pub overall_success_rate: f64,
39 pub most_used_tool: Option<String>,
41 pub most_reliable_tool: Option<String>,
43}
44
45impl Default for GlobalToolStats {
46 fn default() -> Self {
47 Self {
48 total_tools: 0,
49 deprecated_tools: 0,
50 disabled_tools: 0,
51 total_executions: 0,
52 total_successes: 0,
53 overall_success_rate: 0.0,
54 most_used_tool: None,
55 most_reliable_tool: None,
56 }
57 }
58}
59
60#[derive(Debug, Clone)]
62pub struct DiscoveryResult {
63 pub name: String,
65 pub match_score: f64,
67 pub recommendation_reason: String,
69 pub metadata: EnhancedToolMetadata,
71 pub is_deprecated: bool,
73 pub is_enabled: bool,
75}
76
77#[derive(Debug, Clone, Default)]
79pub struct DiscoveryCriteria {
80 pub category_filter: Option<CategoryFilter>,
82 pub required_hints: ToolBehaviorHints,
84 pub preferred_hints: ToolBehaviorHints,
86 pub exclude_deprecated: bool,
88 pub exclude_disabled: bool,
90 pub min_success_rate: Option<f64>,
92 pub max_execution_time: Option<Duration>,
94 pub text_search: Option<String>,
96 pub min_executions: Option<u64>,
98}
99
100impl Default for ToolRegistry {
101 fn default() -> Self {
102 Self::new()
103 }
104}
105
106impl ToolRegistry {
107 pub fn new() -> Self {
109 Self {
110 tools: HashMap::new(),
111 global_stats: GlobalToolStats::default(),
112 }
113 }
114
115 pub fn register_tool(&mut self, tool: Tool) -> McpResult<()> {
117 let name = tool.info.name.clone();
118
119 if self.tools.contains_key(&name) {
120 return Err(McpError::validation(format!(
121 "Tool '{name}' is already registered"
122 )));
123 }
124
125 self.tools.insert(name, tool);
126 self.update_global_stats();
127 Ok(())
128 }
129
130 pub fn unregister_tool(&mut self, name: &str) -> McpResult<Tool> {
132 let tool = self
133 .tools
134 .remove(name)
135 .ok_or_else(|| McpError::validation(format!("Tool '{name}' not found")))?;
136
137 self.update_global_stats();
138 Ok(tool)
139 }
140
141 pub fn get_tool(&self, name: &str) -> Option<&Tool> {
143 self.tools.get(name)
144 }
145
146 pub fn get_tool_mut(&mut self, name: &str) -> Option<&mut Tool> {
148 self.tools.get_mut(name)
149 }
150
151 pub fn list_tool_names(&self) -> Vec<String> {
153 self.tools.keys().cloned().collect()
154 }
155
156 pub fn discover_tools(&self, criteria: &DiscoveryCriteria) -> Vec<DiscoveryResult> {
158 let mut results = Vec::new();
159
160 for (name, tool) in &self.tools {
161 if let Some(result) = self.evaluate_tool_match(name, tool, criteria) {
162 results.push(result);
163 }
164 }
165
166 results.sort_by(|a, b| {
168 b.match_score
169 .partial_cmp(&a.match_score)
170 .unwrap_or(std::cmp::Ordering::Equal)
171 });
172
173 results
174 }
175
176 pub fn get_tools_by_category(&self, filter: &CategoryFilter) -> Vec<String> {
178 self.tools
179 .iter()
180 .filter(|(_, tool)| tool.matches_category_filter(filter))
181 .map(|(name, _)| name.clone())
182 .collect()
183 }
184
185 pub fn get_deprecated_tools(&self) -> Vec<String> {
187 self.tools
188 .iter()
189 .filter(|(_, tool)| tool.is_deprecated())
190 .map(|(name, _)| name.clone())
191 .collect()
192 }
193
194 pub fn get_disabled_tools(&self) -> Vec<String> {
196 self.tools
197 .iter()
198 .filter(|(_, tool)| !tool.is_enabled())
199 .map(|(name, _)| name.clone())
200 .collect()
201 }
202
203 pub fn get_performance_report(
205 &self,
206 ) -> HashMap<String, crate::core::tool_metadata::ToolPerformanceMetrics> {
207 self.tools
208 .iter()
209 .map(|(name, tool)| (name.clone(), tool.performance_metrics()))
210 .collect()
211 }
212
213 pub fn get_global_stats(&self) -> &GlobalToolStats {
215 &self.global_stats
216 }
217
218 pub fn recommend_tool(
220 &self,
221 use_case: &str,
222 criteria: &DiscoveryCriteria,
223 ) -> Option<DiscoveryResult> {
224 let mut enhanced_criteria = criteria.clone();
225
226 enhanced_criteria.text_search = Some(use_case.to_string());
228
229 let results = self.discover_tools(&enhanced_criteria);
230 results.into_iter().next()
231 }
232
233 pub fn cleanup_deprecated_tools(&mut self, policy: &DeprecationCleanupPolicy) -> Vec<String> {
235 let mut removed_tools = Vec::new();
236 let current_time = Utc::now();
237
238 let tools_to_remove: Vec<String> = self
239 .tools
240 .iter()
241 .filter(|(_, tool)| {
242 if let Some(ref deprecation) = tool.enhanced_metadata.deprecation {
243 if !deprecation.deprecated {
244 return false;
245 }
246
247 if matches!(deprecation.severity, DeprecationSeverity::Critical) {
249 return true;
250 }
251
252 if let Some(removal_date) = deprecation.removal_date {
254 if current_time >= removal_date {
255 return true;
256 }
257 }
258
259 if let Some(deprecated_date) = deprecation.deprecated_date {
261 let age = current_time.signed_duration_since(deprecated_date);
262 if age.num_days() > policy.max_deprecated_days as i64 {
263 return true;
264 }
265 }
266 }
267 false
268 })
269 .map(|(name, _)| name.clone())
270 .collect();
271
272 for name in tools_to_remove {
273 if self.tools.remove(&name).is_some() {
274 removed_tools.push(name);
275 }
276 }
277
278 if !removed_tools.is_empty() {
279 self.update_global_stats();
280 }
281
282 removed_tools
283 }
284
285 fn update_global_stats(&mut self) {
287 let mut stats = GlobalToolStats {
288 total_tools: self.tools.len(),
289 ..Default::default()
290 };
291
292 let mut max_executions = 0u64;
293 let mut max_success_rate = 0.0f64;
294 let mut most_used = None;
295 let mut most_reliable = None;
296
297 for (name, tool) in &self.tools {
298 let metrics = tool.performance_metrics();
299
300 if tool.is_deprecated() {
301 stats.deprecated_tools += 1;
302 }
303
304 if !tool.is_enabled() {
305 stats.disabled_tools += 1;
306 }
307
308 stats.total_executions += metrics.execution_count;
309 stats.total_successes += metrics.success_count;
310
311 if metrics.execution_count > max_executions {
313 max_executions = metrics.execution_count;
314 most_used = Some(name.clone());
315 }
316
317 if metrics.execution_count >= 5 && metrics.success_rate > max_success_rate {
319 max_success_rate = metrics.success_rate;
320 most_reliable = Some(name.clone());
321 }
322 }
323
324 if stats.total_executions > 0 {
325 stats.overall_success_rate =
326 (stats.total_successes as f64 / stats.total_executions as f64) * 100.0;
327 }
328
329 stats.most_used_tool = most_used;
330 stats.most_reliable_tool = most_reliable;
331 self.global_stats = stats;
332 }
333
334 fn evaluate_tool_match(
336 &self,
337 name: &str,
338 tool: &Tool,
339 criteria: &DiscoveryCriteria,
340 ) -> Option<DiscoveryResult> {
341 let mut score = 0.0f64;
342 let mut reasons = Vec::new();
343
344 if criteria.exclude_deprecated && tool.is_deprecated() {
346 return None;
347 }
348
349 if criteria.exclude_disabled && !tool.is_enabled() {
350 return None;
351 }
352
353 let metrics = tool.performance_metrics();
354
355 if let Some(min_rate) = criteria.min_success_rate {
357 if metrics.execution_count > 0 && metrics.success_rate < min_rate * 100.0 {
358 return None;
359 }
360 }
361
362 if let Some(max_time) = criteria.max_execution_time {
364 if metrics.execution_count > 0 && metrics.average_execution_time > max_time {
365 return None;
366 }
367 }
368
369 if let Some(min_execs) = criteria.min_executions {
371 if metrics.execution_count < min_execs {
372 return None;
373 }
374 }
375
376 if let Some(ref filter) = criteria.category_filter {
378 if tool.matches_category_filter(filter) {
379 score += 0.3;
380 reasons.push("matches category criteria".to_string());
381 } else {
382 return None;
383 }
384 }
385
386 if let Some(ref search_text) = criteria.text_search {
388 let search_lower = search_text.to_lowercase();
389 let name_match = name.to_lowercase().contains(&search_lower);
390 let desc_match = tool
391 .info
392 .description
393 .as_ref()
394 .map(|d| d.to_lowercase().contains(&search_lower))
395 .unwrap_or(false);
396
397 if name_match || desc_match {
398 score += if name_match { 0.4 } else { 0.2 };
399 reasons.push("matches text search".to_string());
400 } else {
401 return None;
403 }
404 }
405
406 let hints = tool.behavior_hints();
408
409 if criteria.required_hints.read_only.unwrap_or(false) && !hints.read_only.unwrap_or(false) {
411 return None;
412 }
413 if criteria.required_hints.idempotent.unwrap_or(false) && !hints.idempotent.unwrap_or(false)
414 {
415 return None;
416 }
417 if criteria.required_hints.cacheable.unwrap_or(false) && !hints.cacheable.unwrap_or(false) {
418 return None;
419 }
420 if criteria.required_hints.destructive.unwrap_or(false)
421 && !hints.destructive.unwrap_or(false)
422 {
423 return None;
424 }
425 if criteria.required_hints.requires_auth.unwrap_or(false)
426 && !hints.requires_auth.unwrap_or(false)
427 {
428 return None;
429 }
430
431 if criteria.required_hints.read_only.unwrap_or(false) && hints.read_only.unwrap_or(false) {
433 score += 0.2;
434 reasons.push("read-only as required".to_string());
435 }
436 if criteria.required_hints.idempotent.unwrap_or(false) && hints.idempotent.unwrap_or(false)
437 {
438 score += 0.2;
439 reasons.push("idempotent as required".to_string());
440 }
441 if criteria.required_hints.cacheable.unwrap_or(false) && hints.cacheable.unwrap_or(false) {
442 score += 0.15;
443 reasons.push("cacheable as required".to_string());
444 }
445
446 if criteria.preferred_hints.read_only.unwrap_or(false) && hints.read_only.unwrap_or(false) {
448 score += 0.1;
449 reasons.push("preferred: read-only".to_string());
450 }
451 if criteria.preferred_hints.idempotent.unwrap_or(false) && hints.idempotent.unwrap_or(false)
452 {
453 score += 0.1;
454 reasons.push("preferred: idempotent".to_string());
455 }
456
457 if metrics.execution_count > 0 {
459 let success_bonus = (metrics.success_rate / 100.0) * 0.2;
461 score += success_bonus;
462
463 let usage_bonus = (metrics.execution_count as f64).ln() * 0.05;
465 score += usage_bonus.min(0.15);
466
467 if metrics.success_rate > 95.0 {
468 reasons.push("high reliability".to_string());
469 }
470 if metrics.execution_count > 100 {
471 reasons.push("well-tested".to_string());
472 }
473 }
474
475 if tool.is_deprecated() {
477 score *= 0.5;
478 reasons.push("deprecated (reduced score)".to_string());
479 }
480
481 if !tool.is_enabled() {
483 score *= 0.1;
484 reasons.push("disabled (reduced score)".to_string());
485 }
486
487 Some(DiscoveryResult {
488 name: name.to_string(),
489 match_score: score.min(1.0),
490 recommendation_reason: reasons.join(", "),
491 metadata: tool.enhanced_metadata.clone(),
492 is_deprecated: tool.is_deprecated(),
493 is_enabled: tool.is_enabled(),
494 })
495 }
496}
497
498#[derive(Debug, Clone)]
500pub struct DeprecationCleanupPolicy {
501 pub max_deprecated_days: u32,
503 pub remove_critical_immediately: bool,
505}
506
507impl Default for DeprecationCleanupPolicy {
508 fn default() -> Self {
509 Self {
510 max_deprecated_days: 90,
511 remove_critical_immediately: true,
512 }
513 }
514}
515
516#[cfg(test)]
517mod tests {
518 use super::*;
519 use crate::core::tool::{ToolBuilder, ToolHandler};
520 use crate::core::tool_metadata::*;
521 use async_trait::async_trait;
522 use serde_json::Value;
523 use std::collections::HashMap;
524
525 struct MockHandler {
526 result: String,
527 }
528
529 #[async_trait]
530 impl ToolHandler for MockHandler {
531 async fn call(
532 &self,
533 _args: HashMap<String, Value>,
534 ) -> McpResult<crate::protocol::types::ToolResult> {
535 Ok(crate::protocol::types::ToolResult {
536 content: vec![crate::protocol::types::ContentBlock::Text {
537 text: self.result.clone(),
538 annotations: None,
539 meta: None,
540 }],
541 is_error: None,
542 structured_content: None,
543 meta: None,
544 })
545 }
546 }
547
548 #[test]
549 fn test_tool_registry_basic_operations() {
550 let mut registry = ToolRegistry::new();
551
552 let tool = ToolBuilder::new("test_tool")
553 .description("A test tool")
554 .build(MockHandler {
555 result: "test".to_string(),
556 })
557 .unwrap();
558
559 registry.register_tool(tool).unwrap();
561 assert_eq!(registry.list_tool_names().len(), 1);
562 assert!(registry.get_tool("test_tool").is_some());
563
564 let duplicate_tool = ToolBuilder::new("test_tool")
566 .build(MockHandler {
567 result: "duplicate".to_string(),
568 })
569 .unwrap();
570 assert!(registry.register_tool(duplicate_tool).is_err());
571
572 let removed = registry.unregister_tool("test_tool").unwrap();
574 assert_eq!(removed.info.name, "test_tool");
575 assert_eq!(registry.list_tool_names().len(), 0);
576 }
577
578 #[test]
579 fn test_tool_discovery_by_category() {
580 let mut registry = ToolRegistry::new();
581
582 let file_tool = ToolBuilder::new("file_reader")
584 .category_simple("file".to_string(), Some("read".to_string()))
585 .tag("filesystem".to_string())
586 .build(MockHandler {
587 result: "file".to_string(),
588 })
589 .unwrap();
590
591 let network_tool = ToolBuilder::new("http_client")
592 .category_simple("network".to_string(), Some("http".to_string()))
593 .tag("client".to_string())
594 .build(MockHandler {
595 result: "network".to_string(),
596 })
597 .unwrap();
598
599 registry.register_tool(file_tool).unwrap();
600 registry.register_tool(network_tool).unwrap();
601
602 let file_filter = CategoryFilter::new().with_primary("file".to_string());
604 let file_tools = registry.get_tools_by_category(&file_filter);
605 assert_eq!(file_tools.len(), 1);
606 assert!(file_tools.contains(&"file_reader".to_string()));
607
608 let network_filter = CategoryFilter::new().with_primary("network".to_string());
609 let network_tools = registry.get_tools_by_category(&network_filter);
610 assert_eq!(network_tools.len(), 1);
611 assert!(network_tools.contains(&"http_client".to_string()));
612 }
613
614 #[test]
615 fn test_tool_discovery_criteria() {
616 let mut registry = ToolRegistry::new();
617
618 let read_only_tool = ToolBuilder::new("reader")
620 .description("Reads data")
621 .read_only()
622 .idempotent()
623 .cacheable()
624 .build(MockHandler {
625 result: "read".to_string(),
626 })
627 .unwrap();
628
629 let destructive_tool = ToolBuilder::new("deleter")
630 .description("Deletes data")
631 .destructive()
632 .build(MockHandler {
633 result: "delete".to_string(),
634 })
635 .unwrap();
636
637 let deprecated_tool = ToolBuilder::new("old_tool")
638 .description("Old tool")
639 .deprecated_simple("Use new_tool instead")
640 .build(MockHandler {
641 result: "old".to_string(),
642 })
643 .unwrap();
644
645 registry.register_tool(read_only_tool).unwrap();
646 registry.register_tool(destructive_tool).unwrap();
647 registry.register_tool(deprecated_tool).unwrap();
648
649 let criteria = DiscoveryCriteria {
651 required_hints: ToolBehaviorHints::new().read_only(),
652 exclude_deprecated: false,
653 exclude_disabled: false,
654 ..Default::default()
655 };
656
657 let results = registry.discover_tools(&criteria);
658 assert_eq!(results.len(), 1);
659 assert_eq!(results[0].name, "reader");
660
661 let criteria = DiscoveryCriteria {
663 exclude_deprecated: true,
664 ..Default::default()
665 };
666
667 let results = registry.discover_tools(&criteria);
668 assert_eq!(results.len(), 2); assert!(!results.iter().any(|r| r.name == "old_tool"));
670
671 let criteria = DiscoveryCriteria {
673 text_search: Some("delete".to_string()),
674 exclude_deprecated: false,
675 ..Default::default()
676 };
677
678 let results = registry.discover_tools(&criteria);
679 assert_eq!(results.len(), 1);
680 assert_eq!(results[0].name, "deleter");
681 }
682
683 #[test]
684 fn test_global_statistics() {
685 let mut registry = ToolRegistry::new();
686
687 let tool1 = ToolBuilder::new("tool1")
688 .build(MockHandler {
689 result: "1".to_string(),
690 })
691 .unwrap();
692
693 let tool2 = ToolBuilder::new("tool2")
694 .deprecated_simple("Old tool")
695 .build(MockHandler {
696 result: "2".to_string(),
697 })
698 .unwrap();
699
700 registry.register_tool(tool1).unwrap();
701 registry.register_tool(tool2).unwrap();
702
703 let stats = registry.get_global_stats();
704 assert_eq!(stats.total_tools, 2);
705 assert_eq!(stats.deprecated_tools, 1);
706 assert_eq!(stats.disabled_tools, 0);
707 }
708
709 #[test]
710 fn test_tool_recommendation() {
711 let mut registry = ToolRegistry::new();
712
713 let file_tool = ToolBuilder::new("file_processor")
714 .description("Processes files efficiently")
715 .category_simple("file".to_string(), Some("process".to_string()))
716 .read_only()
717 .build(MockHandler {
718 result: "processed".to_string(),
719 })
720 .unwrap();
721
722 let network_tool = ToolBuilder::new("network_handler")
723 .description("Handles network requests")
724 .category_simple("network".to_string(), None)
725 .build(MockHandler {
726 result: "handled".to_string(),
727 })
728 .unwrap();
729
730 registry.register_tool(file_tool).unwrap();
731 registry.register_tool(network_tool).unwrap();
732
733 let criteria = DiscoveryCriteria::default();
735 let recommendation = registry.recommend_tool("file", &criteria);
736
737 assert!(recommendation.is_some());
738 let result = recommendation.unwrap();
739 assert_eq!(result.name, "file_processor");
740 assert!(result.match_score > 0.0);
741 assert!(result.recommendation_reason.contains("matches text search"));
742 }
743
744 #[test]
745 fn test_deprecation_cleanup() {
746 let mut registry = ToolRegistry::new();
747
748 let normal_tool = ToolBuilder::new("normal")
750 .build(MockHandler {
751 result: "normal".to_string(),
752 })
753 .unwrap();
754
755 let deprecated_tool = ToolBuilder::new("deprecated")
756 .deprecated(
757 ToolDeprecation::new("Old version".to_string())
758 .with_severity(DeprecationSeverity::Low),
759 )
760 .build(MockHandler {
761 result: "deprecated".to_string(),
762 })
763 .unwrap();
764
765 let critical_tool = ToolBuilder::new("critical")
766 .deprecated(
767 ToolDeprecation::new("Security issue".to_string())
768 .with_severity(DeprecationSeverity::Critical),
769 )
770 .build(MockHandler {
771 result: "critical".to_string(),
772 })
773 .unwrap();
774
775 registry.register_tool(normal_tool).unwrap();
776 registry.register_tool(deprecated_tool).unwrap();
777 registry.register_tool(critical_tool).unwrap();
778
779 assert_eq!(registry.list_tool_names().len(), 3);
780
781 let policy = DeprecationCleanupPolicy::default();
783 let removed = registry.cleanup_deprecated_tools(&policy);
784
785 assert_eq!(removed.len(), 1);
786 assert!(removed.contains(&"critical".to_string()));
787 assert_eq!(registry.list_tool_names().len(), 2);
788 }
789}