mcp_protocol_sdk/core/
tool_discovery.rs

1//! Tool Discovery and Management System
2//!
3//! This module provides advanced tool discovery, filtering, and management capabilities
4//! based on the enhanced metadata system. It allows for intelligent tool selection,
5//! categorization, performance monitoring, and lifecycle management.
6
7use crate::core::error::{McpError, McpResult};
8use crate::core::tool::Tool;
9use crate::core::tool_metadata::{
10    CategoryFilter, DeprecationSeverity, EnhancedToolMetadata, ToolBehaviorHints,
11};
12use chrono::Utc;
13use std::collections::HashMap;
14use std::time::Duration;
15
16/// Tool discovery and management system
17pub struct ToolRegistry {
18    /// Registered tools indexed by name
19    tools: HashMap<String, Tool>,
20    /// Tool execution statistics
21    global_stats: GlobalToolStats,
22}
23
24/// Global statistics across all tools
25#[derive(Debug, Clone)]
26pub struct GlobalToolStats {
27    /// Total number of registered tools
28    pub total_tools: usize,
29    /// Number of deprecated tools
30    pub deprecated_tools: usize,
31    /// Number of disabled tools
32    pub disabled_tools: usize,
33    /// Total executions across all tools
34    pub total_executions: u64,
35    /// Total successful executions
36    pub total_successes: u64,
37    /// Overall success rate
38    pub overall_success_rate: f64,
39    /// Most frequently used tool
40    pub most_used_tool: Option<String>,
41    /// Most reliable tool (highest success rate)
42    pub most_reliable_tool: Option<String>,
43}
44
45impl Default for GlobalToolStats {
46    fn default() -> Self {
47        Self {
48            total_tools: 0,
49            deprecated_tools: 0,
50            disabled_tools: 0,
51            total_executions: 0,
52            total_successes: 0,
53            overall_success_rate: 0.0,
54            most_used_tool: None,
55            most_reliable_tool: None,
56        }
57    }
58}
59
60/// Tool discovery result with ranking information
61#[derive(Debug, Clone)]
62pub struct DiscoveryResult {
63    /// Tool name
64    pub name: String,
65    /// Match score (0.0 to 1.0, higher is better)
66    pub match_score: f64,
67    /// Reason for recommendation
68    pub recommendation_reason: String,
69    /// Tool metadata snapshot
70    pub metadata: EnhancedToolMetadata,
71    /// Whether tool is deprecated
72    pub is_deprecated: bool,
73    /// Whether tool is enabled
74    pub is_enabled: bool,
75}
76
77/// Tool discovery criteria
78#[derive(Debug, Clone, Default)]
79pub struct DiscoveryCriteria {
80    /// Category filter
81    pub category_filter: Option<CategoryFilter>,
82    /// Required behavior hints
83    pub required_hints: ToolBehaviorHints,
84    /// Preferred behavior hints (for ranking)
85    pub preferred_hints: ToolBehaviorHints,
86    /// Exclude deprecated tools
87    pub exclude_deprecated: bool,
88    /// Exclude disabled tools
89    pub exclude_disabled: bool,
90    /// Minimum success rate (0.0 to 1.0)
91    pub min_success_rate: Option<f64>,
92    /// Maximum average execution time
93    pub max_execution_time: Option<Duration>,
94    /// Text search in name/description
95    pub text_search: Option<String>,
96    /// Minimum number of executions (for reliability filtering)
97    pub min_executions: Option<u64>,
98}
99
100impl Default for ToolRegistry {
101    fn default() -> Self {
102        Self::new()
103    }
104}
105
106impl ToolRegistry {
107    /// Create a new tool registry
108    pub fn new() -> Self {
109        Self {
110            tools: HashMap::new(),
111            global_stats: GlobalToolStats::default(),
112        }
113    }
114
115    /// Register a tool in the registry
116    pub fn register_tool(&mut self, tool: Tool) -> McpResult<()> {
117        let name = tool.info.name.clone();
118
119        if self.tools.contains_key(&name) {
120            return Err(McpError::validation(format!(
121                "Tool '{name}' is already registered"
122            )));
123        }
124
125        self.tools.insert(name, tool);
126        self.update_global_stats();
127        Ok(())
128    }
129
130    /// Unregister a tool from the registry
131    pub fn unregister_tool(&mut self, name: &str) -> McpResult<Tool> {
132        let tool = self
133            .tools
134            .remove(name)
135            .ok_or_else(|| McpError::validation(format!("Tool '{name}' not found")))?;
136
137        self.update_global_stats();
138        Ok(tool)
139    }
140
141    /// Get a tool by name
142    pub fn get_tool(&self, name: &str) -> Option<&Tool> {
143        self.tools.get(name)
144    }
145
146    /// Get a mutable reference to a tool by name
147    pub fn get_tool_mut(&mut self, name: &str) -> Option<&mut Tool> {
148        self.tools.get_mut(name)
149    }
150
151    /// List all tool names
152    pub fn list_tool_names(&self) -> Vec<String> {
153        self.tools.keys().cloned().collect()
154    }
155
156    /// Discover tools based on criteria
157    pub fn discover_tools(&self, criteria: &DiscoveryCriteria) -> Vec<DiscoveryResult> {
158        let mut results = Vec::new();
159
160        for (name, tool) in &self.tools {
161            if let Some(result) = self.evaluate_tool_match(name, tool, criteria) {
162                results.push(result);
163            }
164        }
165
166        // Sort by match score (descending)
167        results.sort_by(|a, b| {
168            b.match_score
169                .partial_cmp(&a.match_score)
170                .unwrap_or(std::cmp::Ordering::Equal)
171        });
172
173        results
174    }
175
176    /// Get tools by category
177    pub fn get_tools_by_category(&self, filter: &CategoryFilter) -> Vec<String> {
178        self.tools
179            .iter()
180            .filter(|(_, tool)| tool.matches_category_filter(filter))
181            .map(|(name, _)| name.clone())
182            .collect()
183    }
184
185    /// Get deprecated tools
186    pub fn get_deprecated_tools(&self) -> Vec<String> {
187        self.tools
188            .iter()
189            .filter(|(_, tool)| tool.is_deprecated())
190            .map(|(name, _)| name.clone())
191            .collect()
192    }
193
194    /// Get disabled tools
195    pub fn get_disabled_tools(&self) -> Vec<String> {
196        self.tools
197            .iter()
198            .filter(|(_, tool)| !tool.is_enabled())
199            .map(|(name, _)| name.clone())
200            .collect()
201    }
202
203    /// Get performance report for all tools
204    pub fn get_performance_report(
205        &self,
206    ) -> HashMap<String, crate::core::tool_metadata::ToolPerformanceMetrics> {
207        self.tools
208            .iter()
209            .map(|(name, tool)| (name.clone(), tool.performance_metrics()))
210            .collect()
211    }
212
213    /// Get global statistics
214    pub fn get_global_stats(&self) -> &GlobalToolStats {
215        &self.global_stats
216    }
217
218    /// Recommend best tool for a specific use case
219    pub fn recommend_tool(
220        &self,
221        use_case: &str,
222        criteria: &DiscoveryCriteria,
223    ) -> Option<DiscoveryResult> {
224        let mut enhanced_criteria = criteria.clone();
225
226        // Add text search based on use case
227        enhanced_criteria.text_search = Some(use_case.to_string());
228
229        let results = self.discover_tools(&enhanced_criteria);
230        results.into_iter().next()
231    }
232
233    /// Clean up deprecated tools based on policy
234    pub fn cleanup_deprecated_tools(&mut self, policy: &DeprecationCleanupPolicy) -> Vec<String> {
235        let mut removed_tools = Vec::new();
236        let current_time = Utc::now();
237
238        let tools_to_remove: Vec<String> = self
239            .tools
240            .iter()
241            .filter(|(_, tool)| {
242                if let Some(ref deprecation) = tool.enhanced_metadata.deprecation {
243                    if !deprecation.deprecated {
244                        return false;
245                    }
246
247                    // Check severity-based removal
248                    if matches!(deprecation.severity, DeprecationSeverity::Critical) {
249                        return true;
250                    }
251
252                    // Check time-based removal
253                    if let Some(removal_date) = deprecation.removal_date {
254                        if current_time >= removal_date {
255                            return true;
256                        }
257                    }
258
259                    // Check age-based removal
260                    if let Some(deprecated_date) = deprecation.deprecated_date {
261                        let age = current_time.signed_duration_since(deprecated_date);
262                        if age.num_days() > policy.max_deprecated_days as i64 {
263                            return true;
264                        }
265                    }
266                }
267                false
268            })
269            .map(|(name, _)| name.clone())
270            .collect();
271
272        for name in tools_to_remove {
273            if self.tools.remove(&name).is_some() {
274                removed_tools.push(name);
275            }
276        }
277
278        if !removed_tools.is_empty() {
279            self.update_global_stats();
280        }
281
282        removed_tools
283    }
284
285    /// Update global statistics
286    fn update_global_stats(&mut self) {
287        let mut stats = GlobalToolStats {
288            total_tools: self.tools.len(),
289            ..Default::default()
290        };
291
292        let mut max_executions = 0u64;
293        let mut max_success_rate = 0.0f64;
294        let mut most_used = None;
295        let mut most_reliable = None;
296
297        for (name, tool) in &self.tools {
298            let metrics = tool.performance_metrics();
299
300            if tool.is_deprecated() {
301                stats.deprecated_tools += 1;
302            }
303
304            if !tool.is_enabled() {
305                stats.disabled_tools += 1;
306            }
307
308            stats.total_executions += metrics.execution_count;
309            stats.total_successes += metrics.success_count;
310
311            // Track most used tool
312            if metrics.execution_count > max_executions {
313                max_executions = metrics.execution_count;
314                most_used = Some(name.clone());
315            }
316
317            // Track most reliable tool (with minimum executions)
318            if metrics.execution_count >= 5 && metrics.success_rate > max_success_rate {
319                max_success_rate = metrics.success_rate;
320                most_reliable = Some(name.clone());
321            }
322        }
323
324        if stats.total_executions > 0 {
325            stats.overall_success_rate =
326                (stats.total_successes as f64 / stats.total_executions as f64) * 100.0;
327        }
328
329        stats.most_used_tool = most_used;
330        stats.most_reliable_tool = most_reliable;
331        self.global_stats = stats;
332    }
333
334    /// Evaluate how well a tool matches the discovery criteria
335    fn evaluate_tool_match(
336        &self,
337        name: &str,
338        tool: &Tool,
339        criteria: &DiscoveryCriteria,
340    ) -> Option<DiscoveryResult> {
341        let mut score = 0.0f64;
342        let mut reasons = Vec::new();
343
344        // Filter out tools that don't meet basic criteria
345        if criteria.exclude_deprecated && tool.is_deprecated() {
346            return None;
347        }
348
349        if criteria.exclude_disabled && !tool.is_enabled() {
350            return None;
351        }
352
353        let metrics = tool.performance_metrics();
354
355        // Filter by minimum success rate
356        if let Some(min_rate) = criteria.min_success_rate {
357            if metrics.execution_count > 0 && metrics.success_rate < min_rate * 100.0 {
358                return None;
359            }
360        }
361
362        // Filter by maximum execution time
363        if let Some(max_time) = criteria.max_execution_time {
364            if metrics.execution_count > 0 && metrics.average_execution_time > max_time {
365                return None;
366            }
367        }
368
369        // Filter by minimum executions
370        if let Some(min_execs) = criteria.min_executions {
371            if metrics.execution_count < min_execs {
372                return None;
373            }
374        }
375
376        // Category matching
377        if let Some(ref filter) = criteria.category_filter {
378            if tool.matches_category_filter(filter) {
379                score += 0.3;
380                reasons.push("matches category criteria".to_string());
381            } else {
382                return None;
383            }
384        }
385
386        // Text search matching
387        if let Some(ref search_text) = criteria.text_search {
388            let search_lower = search_text.to_lowercase();
389            let name_match = name.to_lowercase().contains(&search_lower);
390            let desc_match = tool
391                .info
392                .description
393                .as_ref()
394                .map(|d| d.to_lowercase().contains(&search_lower))
395                .unwrap_or(false);
396
397            if name_match || desc_match {
398                score += if name_match { 0.4 } else { 0.2 };
399                reasons.push("matches text search".to_string());
400            } else {
401                // If text search is specified but doesn't match, exclude this tool
402                return None;
403            }
404        }
405
406        // Behavior hints matching - check required hints first
407        let hints = tool.behavior_hints();
408
409        // Filter out tools that don't meet required hints
410        if criteria.required_hints.read_only.unwrap_or(false) && !hints.read_only.unwrap_or(false) {
411            return None;
412        }
413        if criteria.required_hints.idempotent.unwrap_or(false) && !hints.idempotent.unwrap_or(false)
414        {
415            return None;
416        }
417        if criteria.required_hints.cacheable.unwrap_or(false) && !hints.cacheable.unwrap_or(false) {
418            return None;
419        }
420        if criteria.required_hints.destructive.unwrap_or(false)
421            && !hints.destructive.unwrap_or(false)
422        {
423            return None;
424        }
425        if criteria.required_hints.requires_auth.unwrap_or(false)
426            && !hints.requires_auth.unwrap_or(false)
427        {
428            return None;
429        }
430
431        // Add score bonuses for meeting required hints
432        if criteria.required_hints.read_only.unwrap_or(false) && hints.read_only.unwrap_or(false) {
433            score += 0.2;
434            reasons.push("read-only as required".to_string());
435        }
436        if criteria.required_hints.idempotent.unwrap_or(false) && hints.idempotent.unwrap_or(false)
437        {
438            score += 0.2;
439            reasons.push("idempotent as required".to_string());
440        }
441        if criteria.required_hints.cacheable.unwrap_or(false) && hints.cacheable.unwrap_or(false) {
442            score += 0.15;
443            reasons.push("cacheable as required".to_string());
444        }
445
446        // Preferred hints bonus
447        if criteria.preferred_hints.read_only.unwrap_or(false) && hints.read_only.unwrap_or(false) {
448            score += 0.1;
449            reasons.push("preferred: read-only".to_string());
450        }
451        if criteria.preferred_hints.idempotent.unwrap_or(false) && hints.idempotent.unwrap_or(false)
452        {
453            score += 0.1;
454            reasons.push("preferred: idempotent".to_string());
455        }
456
457        // Performance-based scoring
458        if metrics.execution_count > 0 {
459            // Success rate bonus
460            let success_bonus = (metrics.success_rate / 100.0) * 0.2;
461            score += success_bonus;
462
463            // Usage frequency bonus (logarithmic scale)
464            let usage_bonus = (metrics.execution_count as f64).ln() * 0.05;
465            score += usage_bonus.min(0.15);
466
467            if metrics.success_rate > 95.0 {
468                reasons.push("high reliability".to_string());
469            }
470            if metrics.execution_count > 100 {
471                reasons.push("well-tested".to_string());
472            }
473        }
474
475        // Deprecation penalty
476        if tool.is_deprecated() {
477            score *= 0.5;
478            reasons.push("deprecated (reduced score)".to_string());
479        }
480
481        // Disabled penalty
482        if !tool.is_enabled() {
483            score *= 0.1;
484            reasons.push("disabled (reduced score)".to_string());
485        }
486
487        Some(DiscoveryResult {
488            name: name.to_string(),
489            match_score: score.min(1.0),
490            recommendation_reason: reasons.join(", "),
491            metadata: tool.enhanced_metadata.clone(),
492            is_deprecated: tool.is_deprecated(),
493            is_enabled: tool.is_enabled(),
494        })
495    }
496}
497
498/// Policy for cleaning up deprecated tools
499#[derive(Debug, Clone)]
500pub struct DeprecationCleanupPolicy {
501    /// Maximum number of days to keep deprecated tools
502    pub max_deprecated_days: u32,
503    /// Remove tools marked as critical immediately
504    pub remove_critical_immediately: bool,
505}
506
507impl Default for DeprecationCleanupPolicy {
508    fn default() -> Self {
509        Self {
510            max_deprecated_days: 90,
511            remove_critical_immediately: true,
512        }
513    }
514}
515
516#[cfg(test)]
517mod tests {
518    use super::*;
519    use crate::core::tool::{ToolBuilder, ToolHandler};
520    use crate::core::tool_metadata::*;
521    use async_trait::async_trait;
522    use serde_json::Value;
523    use std::collections::HashMap;
524
525    struct MockHandler {
526        result: String,
527    }
528
529    #[async_trait]
530    impl ToolHandler for MockHandler {
531        async fn call(
532            &self,
533            _args: HashMap<String, Value>,
534        ) -> McpResult<crate::protocol::types::ToolResult> {
535            Ok(crate::protocol::types::ToolResult {
536                content: vec![crate::protocol::types::ContentBlock::Text {
537                    text: self.result.clone(),
538                    annotations: None,
539                    meta: None,
540                }],
541                is_error: None,
542                structured_content: None,
543                meta: None,
544            })
545        }
546    }
547
548    #[test]
549    fn test_tool_registry_basic_operations() {
550        let mut registry = ToolRegistry::new();
551
552        let tool = ToolBuilder::new("test_tool")
553            .description("A test tool")
554            .build(MockHandler {
555                result: "test".to_string(),
556            })
557            .unwrap();
558
559        // Register tool
560        registry.register_tool(tool).unwrap();
561        assert_eq!(registry.list_tool_names().len(), 1);
562        assert!(registry.get_tool("test_tool").is_some());
563
564        // Try to register duplicate - should fail
565        let duplicate_tool = ToolBuilder::new("test_tool")
566            .build(MockHandler {
567                result: "duplicate".to_string(),
568            })
569            .unwrap();
570        assert!(registry.register_tool(duplicate_tool).is_err());
571
572        // Unregister tool
573        let removed = registry.unregister_tool("test_tool").unwrap();
574        assert_eq!(removed.info.name, "test_tool");
575        assert_eq!(registry.list_tool_names().len(), 0);
576    }
577
578    #[test]
579    fn test_tool_discovery_by_category() {
580        let mut registry = ToolRegistry::new();
581
582        // Add tools with different categories
583        let file_tool = ToolBuilder::new("file_reader")
584            .category_simple("file".to_string(), Some("read".to_string()))
585            .tag("filesystem".to_string())
586            .build(MockHandler {
587                result: "file".to_string(),
588            })
589            .unwrap();
590
591        let network_tool = ToolBuilder::new("http_client")
592            .category_simple("network".to_string(), Some("http".to_string()))
593            .tag("client".to_string())
594            .build(MockHandler {
595                result: "network".to_string(),
596            })
597            .unwrap();
598
599        registry.register_tool(file_tool).unwrap();
600        registry.register_tool(network_tool).unwrap();
601
602        // Test category filtering
603        let file_filter = CategoryFilter::new().with_primary("file".to_string());
604        let file_tools = registry.get_tools_by_category(&file_filter);
605        assert_eq!(file_tools.len(), 1);
606        assert!(file_tools.contains(&"file_reader".to_string()));
607
608        let network_filter = CategoryFilter::new().with_primary("network".to_string());
609        let network_tools = registry.get_tools_by_category(&network_filter);
610        assert_eq!(network_tools.len(), 1);
611        assert!(network_tools.contains(&"http_client".to_string()));
612    }
613
614    #[test]
615    fn test_tool_discovery_criteria() {
616        let mut registry = ToolRegistry::new();
617
618        // Add tools with different characteristics
619        let read_only_tool = ToolBuilder::new("reader")
620            .description("Reads data")
621            .read_only()
622            .idempotent()
623            .cacheable()
624            .build(MockHandler {
625                result: "read".to_string(),
626            })
627            .unwrap();
628
629        let destructive_tool = ToolBuilder::new("deleter")
630            .description("Deletes data")
631            .destructive()
632            .build(MockHandler {
633                result: "delete".to_string(),
634            })
635            .unwrap();
636
637        let deprecated_tool = ToolBuilder::new("old_tool")
638            .description("Old tool")
639            .deprecated_simple("Use new_tool instead")
640            .build(MockHandler {
641                result: "old".to_string(),
642            })
643            .unwrap();
644
645        registry.register_tool(read_only_tool).unwrap();
646        registry.register_tool(destructive_tool).unwrap();
647        registry.register_tool(deprecated_tool).unwrap();
648
649        // Test discovery with read-only requirement
650        let criteria = DiscoveryCriteria {
651            required_hints: ToolBehaviorHints::new().read_only(),
652            exclude_deprecated: false,
653            exclude_disabled: false,
654            ..Default::default()
655        };
656
657        let results = registry.discover_tools(&criteria);
658        assert_eq!(results.len(), 1);
659        assert_eq!(results[0].name, "reader");
660
661        // Test discovery excluding deprecated
662        let criteria = DiscoveryCriteria {
663            exclude_deprecated: true,
664            ..Default::default()
665        };
666
667        let results = registry.discover_tools(&criteria);
668        assert_eq!(results.len(), 2); // Should exclude deprecated tool
669        assert!(!results.iter().any(|r| r.name == "old_tool"));
670
671        // Test text search
672        let criteria = DiscoveryCriteria {
673            text_search: Some("delete".to_string()),
674            exclude_deprecated: false,
675            ..Default::default()
676        };
677
678        let results = registry.discover_tools(&criteria);
679        assert_eq!(results.len(), 1);
680        assert_eq!(results[0].name, "deleter");
681    }
682
683    #[test]
684    fn test_global_statistics() {
685        let mut registry = ToolRegistry::new();
686
687        let tool1 = ToolBuilder::new("tool1")
688            .build(MockHandler {
689                result: "1".to_string(),
690            })
691            .unwrap();
692
693        let tool2 = ToolBuilder::new("tool2")
694            .deprecated_simple("Old tool")
695            .build(MockHandler {
696                result: "2".to_string(),
697            })
698            .unwrap();
699
700        registry.register_tool(tool1).unwrap();
701        registry.register_tool(tool2).unwrap();
702
703        let stats = registry.get_global_stats();
704        assert_eq!(stats.total_tools, 2);
705        assert_eq!(stats.deprecated_tools, 1);
706        assert_eq!(stats.disabled_tools, 0);
707    }
708
709    #[test]
710    fn test_tool_recommendation() {
711        let mut registry = ToolRegistry::new();
712
713        let file_tool = ToolBuilder::new("file_processor")
714            .description("Processes files efficiently")
715            .category_simple("file".to_string(), Some("process".to_string()))
716            .read_only()
717            .build(MockHandler {
718                result: "processed".to_string(),
719            })
720            .unwrap();
721
722        let network_tool = ToolBuilder::new("network_handler")
723            .description("Handles network requests")
724            .category_simple("network".to_string(), None)
725            .build(MockHandler {
726                result: "handled".to_string(),
727            })
728            .unwrap();
729
730        registry.register_tool(file_tool).unwrap();
731        registry.register_tool(network_tool).unwrap();
732
733        // Recommend tool for file processing
734        let criteria = DiscoveryCriteria::default();
735        let recommendation = registry.recommend_tool("file", &criteria);
736
737        assert!(recommendation.is_some());
738        let result = recommendation.unwrap();
739        assert_eq!(result.name, "file_processor");
740        assert!(result.match_score > 0.0);
741        assert!(result.recommendation_reason.contains("matches text search"));
742    }
743
744    #[test]
745    fn test_deprecation_cleanup() {
746        let mut registry = ToolRegistry::new();
747
748        // Add tools with different deprecation states
749        let normal_tool = ToolBuilder::new("normal")
750            .build(MockHandler {
751                result: "normal".to_string(),
752            })
753            .unwrap();
754
755        let deprecated_tool = ToolBuilder::new("deprecated")
756            .deprecated(
757                ToolDeprecation::new("Old version".to_string())
758                    .with_severity(DeprecationSeverity::Low),
759            )
760            .build(MockHandler {
761                result: "deprecated".to_string(),
762            })
763            .unwrap();
764
765        let critical_tool = ToolBuilder::new("critical")
766            .deprecated(
767                ToolDeprecation::new("Security issue".to_string())
768                    .with_severity(DeprecationSeverity::Critical),
769            )
770            .build(MockHandler {
771                result: "critical".to_string(),
772            })
773            .unwrap();
774
775        registry.register_tool(normal_tool).unwrap();
776        registry.register_tool(deprecated_tool).unwrap();
777        registry.register_tool(critical_tool).unwrap();
778
779        assert_eq!(registry.list_tool_names().len(), 3);
780
781        // Clean up with default policy (should remove critical tools)
782        let policy = DeprecationCleanupPolicy::default();
783        let removed = registry.cleanup_deprecated_tools(&policy);
784
785        assert_eq!(removed.len(), 1);
786        assert!(removed.contains(&"critical".to_string()));
787        assert_eq!(registry.list_tool_names().len(), 2);
788    }
789}