Skip to main content

agent_diva_tooling/
registry.rs

1//! Tool registry.
2
3use crate::Tool;
4use agent_diva_core::error_context::{find_problematic_chars, ErrorContext};
5use serde_json::Value;
6use std::collections::HashMap;
7use std::sync::Arc;
8use tracing::{error, warn};
9
10const ERROR_HINT: &str = "\n\n[Analyze the error above and try a different approach.]";
11
12/// Registry of available tools.
13pub struct ToolRegistry {
14    tools: HashMap<String, Arc<dyn Tool>>,
15}
16
17impl ToolRegistry {
18    /// Create a new tool registry.
19    pub fn new() -> Self {
20        Self {
21            tools: HashMap::new(),
22        }
23    }
24
25    /// Register a tool.
26    pub fn register(&mut self, tool: Arc<dyn Tool>) {
27        let name = tool.name().to_string();
28        self.tools.insert(name, tool);
29    }
30
31    /// Unregister a tool by name.
32    pub fn unregister(&mut self, name: &str) {
33        self.tools.remove(name);
34    }
35
36    /// Get a tool by name.
37    pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
38        self.tools.get(name).cloned()
39    }
40
41    /// Check if a tool is registered.
42    pub fn has(&self, name: &str) -> bool {
43        self.tools.contains_key(name)
44    }
45
46    /// Get all tool definitions in OpenAI format.
47    pub fn get_definitions(&self) -> Vec<Value> {
48        self.tools.values().map(|tool| tool.to_schema()).collect()
49    }
50
51    /// Execute a tool by name with given parameters.
52    pub async fn execute(&self, name: &str, params: Value) -> String {
53        let tool = match self.tools.get(name) {
54            Some(tool) => tool,
55            None => {
56                let ctx = ErrorContext::new("tool_lookup", format!("Tool '{}' not found", name))
57                    .with_metadata("tool_name", name.to_string())
58                    .with_metadata("available_tools", self.tool_names().join(", "));
59                warn!("{}", ctx.to_detailed_string());
60                return format!("Error: Tool '{}' not found{}", name, ERROR_HINT);
61            }
62        };
63
64        let errors = tool.validate_params(&params);
65        if !errors.is_empty() {
66            let params_str = serde_json::to_string(&params).unwrap_or_default();
67            let problems = find_problematic_chars(&params_str);
68            let ctx = ErrorContext::new("tool_validation", errors.join("; "))
69                .with_content(&params_str)
70                .with_metadata("tool_name", name.to_string());
71            let ctx_str = ctx.to_detailed_string();
72            if problems.is_empty() {
73                warn!("{}", ctx_str);
74            } else {
75                warn!(
76                    "{}\n  Problematic characters found:\n    - {}",
77                    ctx_str,
78                    problems.join("\n    - ")
79                );
80            }
81            return format!(
82                "Error: Invalid parameters for tool '{}': {}{}",
83                name,
84                errors.join("; "),
85                ERROR_HINT,
86            );
87        }
88
89        match tool.execute(params.clone()).await {
90            Ok(result) => {
91                if result.starts_with("Error") {
92                    let params_str = serde_json::to_string(&params).unwrap_or_default();
93                    let ctx = ErrorContext::new("tool_execution", &result)
94                        .with_content(&params_str)
95                        .with_metadata("tool_name", name.to_string());
96                    warn!("{}", ctx.to_detailed_string());
97                    format!("{}{}", result, ERROR_HINT)
98                } else {
99                    result
100                }
101            }
102            Err(e) => {
103                let params_str = serde_json::to_string(&params).unwrap_or_default();
104                let problems = find_problematic_chars(&params_str);
105                let ctx = ErrorContext::new("tool_execution", e.to_string())
106                    .with_content(&params_str)
107                    .with_metadata("tool_name", name.to_string());
108                let ctx_str = ctx.to_detailed_string();
109                if problems.is_empty() {
110                    error!("{}", ctx_str);
111                } else {
112                    error!(
113                        "{}\n  Problematic characters found:\n    - {}",
114                        ctx_str,
115                        problems.join("\n    - ")
116                    );
117                }
118                format!("Error executing {}: {}{}", name, e, ERROR_HINT)
119            }
120        }
121    }
122
123    /// Get list of registered tool names.
124    pub fn tool_names(&self) -> Vec<String> {
125        self.tools.keys().cloned().collect()
126    }
127
128    /// Get number of registered tools.
129    pub fn len(&self) -> usize {
130        self.tools.len()
131    }
132
133    /// Check if registry is empty.
134    pub fn is_empty(&self) -> bool {
135        self.tools.is_empty()
136    }
137}
138
139impl Default for ToolRegistry {
140    fn default() -> Self {
141        Self::new()
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148    use async_trait::async_trait;
149
150    struct MockTool;
151
152    #[async_trait]
153    impl Tool for MockTool {
154        fn name(&self) -> &str {
155            "mock"
156        }
157
158        fn description(&self) -> &str {
159            "A mock tool"
160        }
161
162        fn parameters(&self) -> Value {
163            serde_json::json!({
164                "type": "object",
165                "properties": {},
166                "required": []
167            })
168        }
169
170        async fn execute(&self, _args: Value) -> crate::Result<String> {
171            Ok("mock result".to_string())
172        }
173    }
174
175    #[test]
176    fn test_register_tool() {
177        let mut registry = ToolRegistry::new();
178        registry.register(Arc::new(MockTool));
179        assert_eq!(registry.len(), 1);
180        assert!(registry.has("mock"));
181    }
182
183    #[test]
184    fn test_unregister_tool() {
185        let mut registry = ToolRegistry::new();
186        registry.register(Arc::new(MockTool));
187        registry.unregister("mock");
188        assert_eq!(registry.len(), 0);
189        assert!(!registry.has("mock"));
190    }
191
192    #[tokio::test]
193    async fn test_execute_tool() {
194        let mut registry = ToolRegistry::new();
195        registry.register(Arc::new(MockTool));
196        let result = registry.execute("mock", serde_json::json!({})).await;
197        assert_eq!(result, "mock result");
198    }
199
200    #[tokio::test]
201    async fn test_execute_unknown_tool() {
202        let registry = ToolRegistry::new();
203        let result = registry.execute("nonexistent", serde_json::json!({})).await;
204        assert!(result.contains("Tool 'nonexistent' not found"));
205        assert!(result.contains("[Analyze the error above"));
206    }
207}