Skip to main content

aster/tools/
registry.rs

1//! Tool Registry Module
2//!
3//! This module implements the `ToolRegistry` that manages all available tools
4//! in the system. It supports:
5//! - Native tool registration (high priority)
6//! - MCP tool registration (low priority)
7//! - Tool lookup and execution
8//! - Permission checking integration
9//! - Audit logging integration
10//!
11//! Requirements: 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 8.1, 8.2, 11.3, 11.4
12
13use std::collections::HashMap;
14use std::future::Future;
15use std::pin::Pin;
16use std::sync::Arc;
17use std::time::Instant;
18
19use async_trait::async_trait;
20
21use super::base::{PermissionBehavior, Tool};
22use super::context::{ToolContext, ToolDefinition, ToolResult};
23use super::error::ToolError;
24use crate::permission::{
25    AuditLogEntry, AuditLogLevel, AuditLogger, PermissionContext, ToolPermissionManager,
26};
27
28/// Callback type for permission requests that require user confirmation
29///
30/// When a tool's permission check returns `Ask`, this callback is invoked
31/// to get user confirmation before proceeding with execution.
32pub type PermissionRequestCallback =
33    Box<dyn Fn(String, String) -> Pin<Box<dyn Future<Output = bool> + Send>> + Send + Sync>;
34
35/// MCP Tool Wrapper
36///
37/// Wraps an MCP tool to implement the `Tool` trait, allowing MCP tools
38/// to be registered alongside native tools in the registry.
39///
40/// Requirements: 11.1, 11.2
41#[derive(Clone)]
42pub struct McpToolWrapper {
43    /// Tool name
44    name: String,
45    /// Tool description
46    description: String,
47    /// Input schema
48    input_schema: serde_json::Value,
49    /// MCP server name
50    server_name: String,
51}
52
53impl McpToolWrapper {
54    /// Create a new MCP tool wrapper
55    pub fn new(
56        name: impl Into<String>,
57        description: impl Into<String>,
58        input_schema: serde_json::Value,
59        server_name: impl Into<String>,
60    ) -> Self {
61        Self {
62            name: name.into(),
63            description: description.into(),
64            input_schema,
65            server_name: server_name.into(),
66        }
67    }
68
69    /// Get the MCP server name
70    pub fn server_name(&self) -> &str {
71        &self.server_name
72    }
73}
74
75#[async_trait]
76impl Tool for McpToolWrapper {
77    fn name(&self) -> &str {
78        &self.name
79    }
80
81    fn description(&self) -> &str {
82        &self.description
83    }
84
85    fn input_schema(&self) -> serde_json::Value {
86        self.input_schema.clone()
87    }
88
89    async fn execute(
90        &self,
91        _params: serde_json::Value,
92        _context: &ToolContext,
93    ) -> Result<ToolResult, ToolError> {
94        // MCP tool execution is handled externally
95        // This is a placeholder that should be overridden by the actual MCP execution logic
96        Err(ToolError::execution_failed(
97            "MCP tool execution must be handled by the MCP client",
98        ))
99    }
100}
101
102/// Tool Registry
103///
104/// Manages all available tools in the system, including both native tools
105/// and MCP tools. Native tools have higher priority than MCP tools with
106/// the same name.
107///
108/// Requirements: 2.1, 2.2, 2.3
109pub struct ToolRegistry {
110    /// Native tools (high priority)
111    native_tools: HashMap<String, Box<dyn Tool>>,
112    /// MCP tools (low priority)
113    mcp_tools: HashMap<String, McpToolWrapper>,
114    /// Permission manager for checking tool permissions
115    permission_manager: Option<Arc<ToolPermissionManager>>,
116    /// Audit logger for recording tool executions
117    audit_logger: Option<Arc<AuditLogger>>,
118}
119
120impl Default for ToolRegistry {
121    fn default() -> Self {
122        Self::new()
123    }
124}
125
126impl ToolRegistry {
127    /// Create a new empty tool registry
128    pub fn new() -> Self {
129        Self {
130            native_tools: HashMap::new(),
131            mcp_tools: HashMap::new(),
132            permission_manager: None,
133            audit_logger: None,
134        }
135    }
136
137    /// Create a new tool registry with permission manager and audit logger
138    pub fn with_managers(
139        permission_manager: Arc<ToolPermissionManager>,
140        audit_logger: Arc<AuditLogger>,
141    ) -> Self {
142        Self {
143            native_tools: HashMap::new(),
144            mcp_tools: HashMap::new(),
145            permission_manager: Some(permission_manager),
146            audit_logger: Some(audit_logger),
147        }
148    }
149
150    /// Set the permission manager
151    pub fn set_permission_manager(&mut self, manager: Arc<ToolPermissionManager>) {
152        self.permission_manager = Some(manager);
153    }
154
155    /// Set the audit logger
156    pub fn set_audit_logger(&mut self, logger: Arc<AuditLogger>) {
157        self.audit_logger = Some(logger);
158    }
159
160    /// Get the permission manager
161    pub fn permission_manager(&self) -> Option<&Arc<ToolPermissionManager>> {
162        self.permission_manager.as_ref()
163    }
164
165    /// Get the audit logger
166    pub fn audit_logger(&self) -> Option<&Arc<AuditLogger>> {
167        self.audit_logger.as_ref()
168    }
169}
170
171// =============================================================================
172// Registration Methods (Requirements: 2.1, 11.4)
173// =============================================================================
174
175impl ToolRegistry {
176    /// Register a native tool
177    ///
178    /// Native tools have higher priority than MCP tools with the same name.
179    /// If a native tool with the same name already exists, it will be replaced.
180    ///
181    /// # Arguments
182    /// * `tool` - The tool to register
183    ///
184    /// Requirements: 2.1
185    pub fn register(&mut self, tool: Box<dyn Tool>) {
186        let name = tool.name().to_string();
187        self.native_tools.insert(name, tool);
188    }
189
190    /// Register an MCP tool
191    ///
192    /// MCP tools have lower priority than native tools. If a native tool
193    /// with the same name exists, the MCP tool will be shadowed.
194    ///
195    /// # Arguments
196    /// * `name` - The tool name
197    /// * `tool` - The MCP tool wrapper
198    ///
199    /// Requirements: 11.4
200    pub fn register_mcp(&mut self, name: String, tool: McpToolWrapper) {
201        self.mcp_tools.insert(name, tool);
202    }
203
204    /// Unregister a native tool
205    ///
206    /// # Arguments
207    /// * `name` - The name of the tool to unregister
208    ///
209    /// # Returns
210    /// The unregistered tool if it existed
211    pub fn unregister(&mut self, name: &str) -> Option<Box<dyn Tool>> {
212        self.native_tools.remove(name)
213    }
214
215    /// Unregister an MCP tool
216    ///
217    /// # Arguments
218    /// * `name` - The name of the tool to unregister
219    ///
220    /// # Returns
221    /// The unregistered MCP tool wrapper if it existed
222    pub fn unregister_mcp(&mut self, name: &str) -> Option<McpToolWrapper> {
223        self.mcp_tools.remove(name)
224    }
225
226    /// Check if a tool is registered (native or MCP)
227    ///
228    /// # Arguments
229    /// * `name` - The tool name to check
230    ///
231    /// # Returns
232    /// `true` if the tool is registered
233    pub fn contains(&self, name: &str) -> bool {
234        self.native_tools.contains_key(name) || self.mcp_tools.contains_key(name)
235    }
236
237    /// Check if a native tool is registered
238    pub fn contains_native(&self, name: &str) -> bool {
239        self.native_tools.contains_key(name)
240    }
241
242    /// Check if an MCP tool is registered
243    pub fn contains_mcp(&self, name: &str) -> bool {
244        self.mcp_tools.contains_key(name)
245    }
246
247    /// Get the number of registered native tools
248    pub fn native_tool_count(&self) -> usize {
249        self.native_tools.len()
250    }
251
252    /// Get the number of registered MCP tools
253    pub fn mcp_tool_count(&self) -> usize {
254        self.mcp_tools.len()
255    }
256
257    /// Get the total number of registered tools
258    pub fn tool_count(&self) -> usize {
259        // Count unique tool names (native tools shadow MCP tools)
260        let mut names: std::collections::HashSet<&str> =
261            self.native_tools.keys().map(|s| s.as_str()).collect();
262        for name in self.mcp_tools.keys() {
263            names.insert(name.as_str());
264        }
265        names.len()
266    }
267}
268
269// =============================================================================
270// Query Methods (Requirements: 2.2, 2.3, 2.4)
271// =============================================================================
272
273impl ToolRegistry {
274    /// Get a tool by name (native tools have priority)
275    ///
276    /// # Arguments
277    /// * `name` - The tool name to look up
278    ///
279    /// # Returns
280    /// A reference to the tool if found, with native tools taking priority
281    ///
282    /// Requirements: 2.2
283    pub fn get(&self, name: &str) -> Option<&dyn Tool> {
284        // Native tools have priority over MCP tools
285        if let Some(tool) = self.native_tools.get(name) {
286            return Some(tool.as_ref());
287        }
288        if let Some(tool) = self.mcp_tools.get(name) {
289            return Some(tool as &dyn Tool);
290        }
291        None
292    }
293
294    /// Get all registered tools
295    ///
296    /// Returns all tools with native tools taking priority over MCP tools
297    /// with the same name.
298    ///
299    /// # Returns
300    /// A vector of references to all registered tools
301    ///
302    /// Requirements: 2.3
303    pub fn get_all(&self) -> Vec<&dyn Tool> {
304        let mut tools: Vec<&dyn Tool> = Vec::new();
305        let mut seen_names: std::collections::HashSet<&str> = std::collections::HashSet::new();
306
307        // Add native tools first (higher priority)
308        for (name, tool) in &self.native_tools {
309            tools.push(tool.as_ref());
310            seen_names.insert(name.as_str());
311        }
312
313        // Add MCP tools that aren't shadowed by native tools
314        for (name, tool) in &self.mcp_tools {
315            if !seen_names.contains(name.as_str()) {
316                tools.push(tool as &dyn Tool);
317            }
318        }
319
320        tools
321    }
322
323    /// Get all tool definitions for LLM consumption
324    ///
325    /// Returns definitions for all tools, with native tools taking priority
326    /// over MCP tools with the same name.
327    ///
328    /// # Returns
329    /// A vector of tool definitions
330    ///
331    /// Requirements: 2.4
332    pub fn get_definitions(&self) -> Vec<ToolDefinition> {
333        self.get_all()
334            .iter()
335            .map(|tool| tool.get_definition())
336            .collect()
337    }
338
339    /// Get all native tool names
340    pub fn native_tool_names(&self) -> Vec<&str> {
341        self.native_tools.keys().map(|s| s.as_str()).collect()
342    }
343
344    /// Get all MCP tool names
345    pub fn mcp_tool_names(&self) -> Vec<&str> {
346        self.mcp_tools.keys().map(|s| s.as_str()).collect()
347    }
348
349    /// Get all tool names (unique, native tools shadow MCP tools)
350    pub fn tool_names(&self) -> Vec<&str> {
351        let mut names: std::collections::HashSet<&str> =
352            self.native_tools.keys().map(|s| s.as_str()).collect();
353        for name in self.mcp_tools.keys() {
354            names.insert(name.as_str());
355        }
356        names.into_iter().collect()
357    }
358
359    /// Check if a tool is a native tool
360    pub fn is_native(&self, name: &str) -> bool {
361        self.native_tools.contains_key(name)
362    }
363
364    /// Check if a tool is an MCP tool (and not shadowed by a native tool)
365    pub fn is_mcp(&self, name: &str) -> bool {
366        !self.native_tools.contains_key(name) && self.mcp_tools.contains_key(name)
367    }
368}
369
370// =============================================================================
371// Execution Methods (Requirements: 2.5, 2.6, 8.1, 8.2)
372// =============================================================================
373
374impl ToolRegistry {
375    /// Execute a tool by name with permission checking and audit logging
376    ///
377    /// This method:
378    /// 1. Looks up the tool by name
379    /// 2. Performs permission check (if permission manager is configured)
380    /// 3. Handles permission request callback for 'Ask' behavior
381    /// 4. Executes the tool
382    /// 5. Records audit log (if audit logger is configured)
383    ///
384    /// # Arguments
385    /// * `name` - The tool name to execute
386    /// * `params` - The tool parameters
387    /// * `context` - The execution context
388    /// * `on_permission_request` - Optional callback for permission requests
389    ///
390    /// # Returns
391    /// * `Ok(ToolResult)` - The execution result
392    /// * `Err(ToolError)` - If the tool is not found, permission denied, or execution fails
393    ///
394    /// Requirements: 2.5, 2.6, 8.1, 8.2
395    pub async fn execute(
396        &self,
397        name: &str,
398        params: serde_json::Value,
399        context: &ToolContext,
400        on_permission_request: Option<PermissionRequestCallback>,
401    ) -> Result<ToolResult, ToolError> {
402        let start_time = Instant::now();
403
404        // Step 1: Look up the tool
405        let tool = self.get(name).ok_or_else(|| ToolError::not_found(name))?;
406
407        // Step 2: Check tool-level permissions
408        let permission_result = tool.check_permissions(&params, context).await;
409
410        // Handle tool-level permission check result
411        match permission_result.behavior {
412            PermissionBehavior::Deny => {
413                let reason = permission_result
414                    .message
415                    .unwrap_or_else(|| format!("Permission denied for tool '{}'", name));
416
417                // Log permission denial
418                self.log_permission_denied(name, &params, context, &reason, start_time.elapsed());
419
420                return Err(ToolError::permission_denied(reason));
421            }
422            PermissionBehavior::Ask => {
423                // Handle user confirmation request
424                if let Some(callback) = on_permission_request {
425                    let message = permission_result.message.unwrap_or_else(|| {
426                        format!("Tool '{}' requires permission to execute", name)
427                    });
428
429                    let approved = callback(name.to_string(), message.clone()).await;
430
431                    if !approved {
432                        self.log_permission_denied(
433                            name,
434                            &params,
435                            context,
436                            "User denied permission",
437                            start_time.elapsed(),
438                        );
439                        return Err(ToolError::permission_denied("User denied permission"));
440                    }
441                } else {
442                    // No callback provided, deny by default
443                    let reason =
444                        "Permission request requires user confirmation but no callback provided";
445                    self.log_permission_denied(
446                        name,
447                        &params,
448                        context,
449                        reason,
450                        start_time.elapsed(),
451                    );
452                    return Err(ToolError::permission_denied(reason));
453                }
454            }
455            PermissionBehavior::Allow => {
456                // Permission granted, continue
457            }
458        }
459
460        // Step 3: Check system-level permissions (if permission manager is configured)
461        if let Some(ref permission_manager) = self.permission_manager {
462            let perm_context = self.create_permission_context(context);
463            let params_map = self.params_to_hashmap(&params);
464            let perm_result = permission_manager.is_allowed(name, &params_map, &perm_context);
465
466            if !perm_result.allowed {
467                let reason = perm_result
468                    .reason
469                    .unwrap_or_else(|| format!("Permission denied for tool '{}'", name));
470
471                self.log_permission_denied(name, &params, context, &reason, start_time.elapsed());
472
473                return Err(ToolError::permission_denied(reason));
474            }
475        }
476
477        // Step 4: Execute the tool
478        let params_to_use = permission_result.updated_params.unwrap_or(params.clone());
479        let result = tool.execute(params_to_use, context).await;
480
481        // Step 5: Log the execution
482        let duration = start_time.elapsed();
483        match &result {
484            Ok(tool_result) => {
485                self.log_tool_execution(name, &params, context, tool_result, duration);
486            }
487            Err(err) => {
488                self.log_tool_error(name, &params, context, err, duration);
489            }
490        }
491
492        result
493    }
494
495    /// Create a PermissionContext from ToolContext
496    fn create_permission_context(&self, context: &ToolContext) -> PermissionContext {
497        PermissionContext {
498            working_directory: context.working_directory.clone(),
499            session_id: context.session_id.clone(),
500            timestamp: chrono::Utc::now().timestamp(),
501            user: context.user.clone(),
502            environment: context.environment.clone(),
503            metadata: HashMap::new(),
504        }
505    }
506
507    /// Convert JSON params to HashMap for permission checking
508    fn params_to_hashmap(&self, params: &serde_json::Value) -> HashMap<String, serde_json::Value> {
509        match params {
510            serde_json::Value::Object(map) => {
511                map.iter().map(|(k, v)| (k.clone(), v.clone())).collect()
512            }
513            _ => HashMap::new(),
514        }
515    }
516
517    /// Log a permission denial
518    fn log_permission_denied(
519        &self,
520        tool_name: &str,
521        params: &serde_json::Value,
522        context: &ToolContext,
523        reason: &str,
524        duration: std::time::Duration,
525    ) {
526        if let Some(ref logger) = self.audit_logger {
527            let entry = AuditLogEntry::new("permission_denied", tool_name)
528                .with_level(AuditLogLevel::Warn)
529                .with_parameters(self.params_to_hashmap(params))
530                .with_context(self.create_permission_context(context))
531                .with_duration_ms(duration.as_millis() as u64)
532                .add_metadata("reason", serde_json::json!(reason));
533
534            logger.log(entry);
535        }
536    }
537
538    /// Log a successful tool execution
539    fn log_tool_execution(
540        &self,
541        tool_name: &str,
542        params: &serde_json::Value,
543        context: &ToolContext,
544        result: &ToolResult,
545        duration: std::time::Duration,
546    ) {
547        if let Some(ref logger) = self.audit_logger {
548            let level = if result.is_success() {
549                AuditLogLevel::Info
550            } else {
551                AuditLogLevel::Warn
552            };
553
554            let entry = AuditLogEntry::new("tool_execution", tool_name)
555                .with_level(level)
556                .with_parameters(self.params_to_hashmap(params))
557                .with_context(self.create_permission_context(context))
558                .with_duration_ms(duration.as_millis() as u64)
559                .add_metadata("success", serde_json::json!(result.is_success()))
560                .add_metadata(
561                    "output_size",
562                    serde_json::json!(result.output.as_ref().map(|s| s.len()).unwrap_or(0)),
563                );
564
565            logger.log_tool_execution(entry);
566        }
567    }
568
569    /// Log a tool execution error
570    fn log_tool_error(
571        &self,
572        tool_name: &str,
573        params: &serde_json::Value,
574        context: &ToolContext,
575        error: &ToolError,
576        duration: std::time::Duration,
577    ) {
578        if let Some(ref logger) = self.audit_logger {
579            let entry = AuditLogEntry::new("tool_error", tool_name)
580                .with_level(AuditLogLevel::Error)
581                .with_parameters(self.params_to_hashmap(params))
582                .with_context(self.create_permission_context(context))
583                .with_duration_ms(duration.as_millis() as u64)
584                .add_metadata("error", serde_json::json!(error.to_string()));
585
586            logger.log_tool_execution(entry);
587        }
588    }
589}
590
591#[cfg(test)]
592mod tests {
593    use super::*;
594    use crate::tools::PermissionCheckResult;
595    use std::path::PathBuf;
596
597    /// A simple test tool for unit testing
598    struct TestTool {
599        name: String,
600        should_fail: bool,
601        permission_behavior: PermissionBehavior,
602    }
603
604    impl TestTool {
605        fn new(name: &str) -> Self {
606            Self {
607                name: name.to_string(),
608                should_fail: false,
609                permission_behavior: PermissionBehavior::Allow,
610            }
611        }
612
613        fn failing(name: &str) -> Self {
614            Self {
615                name: name.to_string(),
616                should_fail: true,
617                permission_behavior: PermissionBehavior::Allow,
618            }
619        }
620
621        fn with_permission(name: &str, behavior: PermissionBehavior) -> Self {
622            Self {
623                name: name.to_string(),
624                should_fail: false,
625                permission_behavior: behavior,
626            }
627        }
628    }
629
630    #[async_trait]
631    impl Tool for TestTool {
632        fn name(&self) -> &str {
633            &self.name
634        }
635
636        fn description(&self) -> &str {
637            "A test tool for unit testing"
638        }
639
640        fn input_schema(&self) -> serde_json::Value {
641            serde_json::json!({
642                "type": "object",
643                "properties": {
644                    "input": { "type": "string" }
645                },
646                "required": ["input"]
647            })
648        }
649
650        async fn execute(
651            &self,
652            params: serde_json::Value,
653            _context: &ToolContext,
654        ) -> Result<ToolResult, ToolError> {
655            if self.should_fail {
656                return Err(ToolError::execution_failed("Test failure"));
657            }
658
659            let input = params
660                .get("input")
661                .and_then(|v| v.as_str())
662                .unwrap_or("default");
663
664            Ok(ToolResult::success(format!("Processed: {}", input)))
665        }
666
667        async fn check_permissions(
668            &self,
669            _params: &serde_json::Value,
670            _context: &ToolContext,
671        ) -> PermissionCheckResult {
672            match self.permission_behavior {
673                PermissionBehavior::Allow => PermissionCheckResult::allow(),
674                PermissionBehavior::Deny => PermissionCheckResult::deny("Test denial"),
675                PermissionBehavior::Ask => PermissionCheckResult::ask("Test confirmation required"),
676            }
677        }
678    }
679
680    fn create_test_context() -> ToolContext {
681        ToolContext::new(PathBuf::from("/tmp"))
682            .with_session_id("test-session")
683            .with_user("test-user")
684    }
685
686    #[test]
687    fn test_registry_new() {
688        let registry = ToolRegistry::new();
689        assert_eq!(registry.native_tool_count(), 0);
690        assert_eq!(registry.mcp_tool_count(), 0);
691        assert_eq!(registry.tool_count(), 0);
692    }
693
694    #[test]
695    fn test_registry_register_native_tool() {
696        let mut registry = ToolRegistry::new();
697        registry.register(Box::new(TestTool::new("test_tool")));
698
699        assert_eq!(registry.native_tool_count(), 1);
700        assert!(registry.contains("test_tool"));
701        assert!(registry.contains_native("test_tool"));
702        assert!(!registry.contains_mcp("test_tool"));
703    }
704
705    #[test]
706    fn test_registry_register_mcp_tool() {
707        let mut registry = ToolRegistry::new();
708        let mcp_tool = McpToolWrapper::new(
709            "mcp_tool",
710            "An MCP tool",
711            serde_json::json!({}),
712            "test_server",
713        );
714        registry.register_mcp("mcp_tool".to_string(), mcp_tool);
715
716        assert_eq!(registry.mcp_tool_count(), 1);
717        assert!(registry.contains("mcp_tool"));
718        assert!(!registry.contains_native("mcp_tool"));
719        assert!(registry.contains_mcp("mcp_tool"));
720    }
721
722    #[test]
723    fn test_registry_native_priority_over_mcp() {
724        let mut registry = ToolRegistry::new();
725
726        // Register MCP tool first
727        let mcp_tool = McpToolWrapper::new(
728            "shared_tool",
729            "MCP version",
730            serde_json::json!({}),
731            "test_server",
732        );
733        registry.register_mcp("shared_tool".to_string(), mcp_tool);
734
735        // Register native tool with same name
736        registry.register(Box::new(TestTool::new("shared_tool")));
737
738        // Native tool should take priority
739        let tool = registry.get("shared_tool").unwrap();
740        assert_eq!(tool.description(), "A test tool for unit testing");
741        assert!(registry.is_native("shared_tool"));
742        assert!(!registry.is_mcp("shared_tool"));
743    }
744
745    #[test]
746    fn test_registry_get_nonexistent() {
747        let registry = ToolRegistry::new();
748        assert!(registry.get("nonexistent").is_none());
749    }
750
751    #[test]
752    fn test_registry_get_all() {
753        let mut registry = ToolRegistry::new();
754        registry.register(Box::new(TestTool::new("tool1")));
755        registry.register(Box::new(TestTool::new("tool2")));
756
757        let mcp_tool = McpToolWrapper::new(
758            "mcp_tool",
759            "An MCP tool",
760            serde_json::json!({}),
761            "test_server",
762        );
763        registry.register_mcp("mcp_tool".to_string(), mcp_tool);
764
765        let all_tools = registry.get_all();
766        assert_eq!(all_tools.len(), 3);
767    }
768
769    #[test]
770    fn test_registry_get_all_with_shadowing() {
771        let mut registry = ToolRegistry::new();
772
773        // Register MCP tool
774        let mcp_tool = McpToolWrapper::new(
775            "shared_tool",
776            "MCP version",
777            serde_json::json!({}),
778            "test_server",
779        );
780        registry.register_mcp("shared_tool".to_string(), mcp_tool);
781
782        // Register native tool with same name
783        registry.register(Box::new(TestTool::new("shared_tool")));
784
785        // Should only return 1 tool (native shadows MCP)
786        let all_tools = registry.get_all();
787        assert_eq!(all_tools.len(), 1);
788        assert_eq!(all_tools[0].description(), "A test tool for unit testing");
789    }
790
791    #[test]
792    fn test_registry_get_definitions() {
793        let mut registry = ToolRegistry::new();
794        registry.register(Box::new(TestTool::new("tool1")));
795        registry.register(Box::new(TestTool::new("tool2")));
796
797        let definitions = registry.get_definitions();
798        assert_eq!(definitions.len(), 2);
799
800        let names: Vec<&str> = definitions.iter().map(|d| d.name.as_str()).collect();
801        assert!(names.contains(&"tool1"));
802        assert!(names.contains(&"tool2"));
803    }
804
805    #[test]
806    fn test_registry_unregister() {
807        let mut registry = ToolRegistry::new();
808        registry.register(Box::new(TestTool::new("test_tool")));
809
810        assert!(registry.contains("test_tool"));
811
812        let removed = registry.unregister("test_tool");
813        assert!(removed.is_some());
814        assert!(!registry.contains("test_tool"));
815    }
816
817    #[test]
818    fn test_registry_unregister_mcp() {
819        let mut registry = ToolRegistry::new();
820        let mcp_tool = McpToolWrapper::new(
821            "mcp_tool",
822            "An MCP tool",
823            serde_json::json!({}),
824            "test_server",
825        );
826        registry.register_mcp("mcp_tool".to_string(), mcp_tool);
827
828        assert!(registry.contains("mcp_tool"));
829
830        let removed = registry.unregister_mcp("mcp_tool");
831        assert!(removed.is_some());
832        assert!(!registry.contains("mcp_tool"));
833    }
834
835    #[test]
836    fn test_registry_tool_names() {
837        let mut registry = ToolRegistry::new();
838        registry.register(Box::new(TestTool::new("native1")));
839        registry.register(Box::new(TestTool::new("native2")));
840
841        let mcp_tool =
842            McpToolWrapper::new("mcp1", "An MCP tool", serde_json::json!({}), "test_server");
843        registry.register_mcp("mcp1".to_string(), mcp_tool);
844
845        let native_names = registry.native_tool_names();
846        assert_eq!(native_names.len(), 2);
847
848        let mcp_names = registry.mcp_tool_names();
849        assert_eq!(mcp_names.len(), 1);
850
851        let all_names = registry.tool_names();
852        assert_eq!(all_names.len(), 3);
853    }
854
855    #[tokio::test]
856    async fn test_registry_execute_success() {
857        let mut registry = ToolRegistry::new();
858        registry.register(Box::new(TestTool::new("test_tool")));
859
860        let context = create_test_context();
861        let params = serde_json::json!({"input": "hello"});
862
863        let result = registry.execute("test_tool", params, &context, None).await;
864        assert!(result.is_ok());
865
866        let tool_result = result.unwrap();
867        assert!(tool_result.is_success());
868        assert_eq!(tool_result.output, Some("Processed: hello".to_string()));
869    }
870
871    #[tokio::test]
872    async fn test_registry_execute_not_found() {
873        let registry = ToolRegistry::new();
874        let context = create_test_context();
875        let params = serde_json::json!({});
876
877        let result = registry
878            .execute("nonexistent", params, &context, None)
879            .await;
880        assert!(result.is_err());
881        assert!(matches!(result.unwrap_err(), ToolError::NotFound(_)));
882    }
883
884    #[tokio::test]
885    async fn test_registry_execute_tool_failure() {
886        let mut registry = ToolRegistry::new();
887        registry.register(Box::new(TestTool::failing("failing_tool")));
888
889        let context = create_test_context();
890        let params = serde_json::json!({"input": "hello"});
891
892        let result = registry
893            .execute("failing_tool", params, &context, None)
894            .await;
895        assert!(result.is_err());
896        assert!(matches!(result.unwrap_err(), ToolError::ExecutionFailed(_)));
897    }
898
899    #[tokio::test]
900    async fn test_registry_execute_permission_denied() {
901        let mut registry = ToolRegistry::new();
902        registry.register(Box::new(TestTool::with_permission(
903            "denied_tool",
904            PermissionBehavior::Deny,
905        )));
906
907        let context = create_test_context();
908        let params = serde_json::json!({"input": "hello"});
909
910        let result = registry
911            .execute("denied_tool", params, &context, None)
912            .await;
913        assert!(result.is_err());
914        assert!(matches!(
915            result.unwrap_err(),
916            ToolError::PermissionDenied(_)
917        ));
918    }
919
920    #[tokio::test]
921    async fn test_registry_execute_permission_ask_approved() {
922        let mut registry = ToolRegistry::new();
923        registry.register(Box::new(TestTool::with_permission(
924            "ask_tool",
925            PermissionBehavior::Ask,
926        )));
927
928        let context = create_test_context();
929        let params = serde_json::json!({"input": "hello"});
930
931        // Create a callback that approves the request
932        let callback: PermissionRequestCallback =
933            Box::new(|_name, _message| Box::pin(async { true }));
934
935        let result = registry
936            .execute("ask_tool", params, &context, Some(callback))
937            .await;
938        assert!(result.is_ok());
939    }
940
941    #[tokio::test]
942    async fn test_registry_execute_permission_ask_denied() {
943        let mut registry = ToolRegistry::new();
944        registry.register(Box::new(TestTool::with_permission(
945            "ask_tool",
946            PermissionBehavior::Ask,
947        )));
948
949        let context = create_test_context();
950        let params = serde_json::json!({"input": "hello"});
951
952        // Create a callback that denies the request
953        let callback: PermissionRequestCallback =
954            Box::new(|_name, _message| Box::pin(async { false }));
955
956        let result = registry
957            .execute("ask_tool", params, &context, Some(callback))
958            .await;
959        assert!(result.is_err());
960        assert!(matches!(
961            result.unwrap_err(),
962            ToolError::PermissionDenied(_)
963        ));
964    }
965
966    #[tokio::test]
967    async fn test_registry_execute_permission_ask_no_callback() {
968        let mut registry = ToolRegistry::new();
969        registry.register(Box::new(TestTool::with_permission(
970            "ask_tool",
971            PermissionBehavior::Ask,
972        )));
973
974        let context = create_test_context();
975        let params = serde_json::json!({"input": "hello"});
976
977        // No callback provided - should deny
978        let result = registry.execute("ask_tool", params, &context, None).await;
979        assert!(result.is_err());
980        assert!(matches!(
981            result.unwrap_err(),
982            ToolError::PermissionDenied(_)
983        ));
984    }
985
986    #[test]
987    fn test_mcp_tool_wrapper() {
988        let wrapper = McpToolWrapper::new(
989            "test_mcp",
990            "Test MCP tool",
991            serde_json::json!({"type": "object"}),
992            "test_server",
993        );
994
995        assert_eq!(wrapper.name(), "test_mcp");
996        assert_eq!(wrapper.description(), "Test MCP tool");
997        assert_eq!(wrapper.server_name(), "test_server");
998        assert_eq!(wrapper.input_schema()["type"], "object");
999    }
1000
1001    #[test]
1002    fn test_registry_with_managers() {
1003        let permission_manager = Arc::new(ToolPermissionManager::new(None));
1004        let audit_logger = Arc::new(AuditLogger::new(AuditLogLevel::Info));
1005
1006        let registry =
1007            ToolRegistry::with_managers(permission_manager.clone(), audit_logger.clone());
1008
1009        assert!(registry.permission_manager().is_some());
1010        assert!(registry.audit_logger().is_some());
1011    }
1012
1013    #[test]
1014    fn test_registry_set_managers() {
1015        let mut registry = ToolRegistry::new();
1016
1017        assert!(registry.permission_manager().is_none());
1018        assert!(registry.audit_logger().is_none());
1019
1020        let permission_manager = Arc::new(ToolPermissionManager::new(None));
1021        let audit_logger = Arc::new(AuditLogger::new(AuditLogLevel::Info));
1022
1023        registry.set_permission_manager(permission_manager);
1024        registry.set_audit_logger(audit_logger);
1025
1026        assert!(registry.permission_manager().is_some());
1027        assert!(registry.audit_logger().is_some());
1028    }
1029}