agent_diva_tooling/
registry.rs1use crate::Tool;
4use agent_diva_core::error_context::{find_problematic_chars, ErrorContext};
5use serde_json::Value;
6use std::collections::HashMap;
7use std::sync::Arc;
8use tracing::{error, warn};
9
10const ERROR_HINT: &str = "\n\n[Analyze the error above and try a different approach.]";
11
12pub struct ToolRegistry {
14 tools: HashMap<String, Arc<dyn Tool>>,
15}
16
17impl ToolRegistry {
18 pub fn new() -> Self {
20 Self {
21 tools: HashMap::new(),
22 }
23 }
24
25 pub fn register(&mut self, tool: Arc<dyn Tool>) {
27 let name = tool.name().to_string();
28 self.tools.insert(name, tool);
29 }
30
31 pub fn unregister(&mut self, name: &str) {
33 self.tools.remove(name);
34 }
35
36 pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
38 self.tools.get(name).cloned()
39 }
40
41 pub fn has(&self, name: &str) -> bool {
43 self.tools.contains_key(name)
44 }
45
46 pub fn get_definitions(&self) -> Vec<Value> {
48 self.tools.values().map(|tool| tool.to_schema()).collect()
49 }
50
51 pub async fn execute(&self, name: &str, params: Value) -> String {
53 let tool = match self.tools.get(name) {
54 Some(tool) => tool,
55 None => {
56 let ctx = ErrorContext::new("tool_lookup", format!("Tool '{}' not found", name))
57 .with_metadata("tool_name", name.to_string())
58 .with_metadata("available_tools", self.tool_names().join(", "));
59 warn!("{}", ctx.to_detailed_string());
60 return format!("Error: Tool '{}' not found{}", name, ERROR_HINT);
61 }
62 };
63
64 let errors = tool.validate_params(¶ms);
65 if !errors.is_empty() {
66 let params_str = serde_json::to_string(¶ms).unwrap_or_default();
67 let problems = find_problematic_chars(¶ms_str);
68 let ctx = ErrorContext::new("tool_validation", errors.join("; "))
69 .with_content(¶ms_str)
70 .with_metadata("tool_name", name.to_string());
71 let ctx_str = ctx.to_detailed_string();
72 if problems.is_empty() {
73 warn!("{}", ctx_str);
74 } else {
75 warn!(
76 "{}\n Problematic characters found:\n - {}",
77 ctx_str,
78 problems.join("\n - ")
79 );
80 }
81 return format!(
82 "Error: Invalid parameters for tool '{}': {}{}",
83 name,
84 errors.join("; "),
85 ERROR_HINT,
86 );
87 }
88
89 match tool.execute(params.clone()).await {
90 Ok(result) => {
91 if result.starts_with("Error") {
92 let params_str = serde_json::to_string(¶ms).unwrap_or_default();
93 let ctx = ErrorContext::new("tool_execution", &result)
94 .with_content(¶ms_str)
95 .with_metadata("tool_name", name.to_string());
96 warn!("{}", ctx.to_detailed_string());
97 format!("{}{}", result, ERROR_HINT)
98 } else {
99 result
100 }
101 }
102 Err(e) => {
103 let params_str = serde_json::to_string(¶ms).unwrap_or_default();
104 let problems = find_problematic_chars(¶ms_str);
105 let ctx = ErrorContext::new("tool_execution", e.to_string())
106 .with_content(¶ms_str)
107 .with_metadata("tool_name", name.to_string());
108 let ctx_str = ctx.to_detailed_string();
109 if problems.is_empty() {
110 error!("{}", ctx_str);
111 } else {
112 error!(
113 "{}\n Problematic characters found:\n - {}",
114 ctx_str,
115 problems.join("\n - ")
116 );
117 }
118 format!("Error executing {}: {}{}", name, e, ERROR_HINT)
119 }
120 }
121 }
122
123 pub fn tool_names(&self) -> Vec<String> {
125 self.tools.keys().cloned().collect()
126 }
127
128 pub fn len(&self) -> usize {
130 self.tools.len()
131 }
132
133 pub fn is_empty(&self) -> bool {
135 self.tools.is_empty()
136 }
137}
138
139impl Default for ToolRegistry {
140 fn default() -> Self {
141 Self::new()
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148 use async_trait::async_trait;
149
150 struct MockTool;
151
152 #[async_trait]
153 impl Tool for MockTool {
154 fn name(&self) -> &str {
155 "mock"
156 }
157
158 fn description(&self) -> &str {
159 "A mock tool"
160 }
161
162 fn parameters(&self) -> Value {
163 serde_json::json!({
164 "type": "object",
165 "properties": {},
166 "required": []
167 })
168 }
169
170 async fn execute(&self, _args: Value) -> crate::Result<String> {
171 Ok("mock result".to_string())
172 }
173 }
174
175 #[test]
176 fn test_register_tool() {
177 let mut registry = ToolRegistry::new();
178 registry.register(Arc::new(MockTool));
179 assert_eq!(registry.len(), 1);
180 assert!(registry.has("mock"));
181 }
182
183 #[test]
184 fn test_unregister_tool() {
185 let mut registry = ToolRegistry::new();
186 registry.register(Arc::new(MockTool));
187 registry.unregister("mock");
188 assert_eq!(registry.len(), 0);
189 assert!(!registry.has("mock"));
190 }
191
192 #[tokio::test]
193 async fn test_execute_tool() {
194 let mut registry = ToolRegistry::new();
195 registry.register(Arc::new(MockTool));
196 let result = registry.execute("mock", serde_json::json!({})).await;
197 assert_eq!(result, "mock result");
198 }
199
200 #[tokio::test]
201 async fn test_execute_unknown_tool() {
202 let registry = ToolRegistry::new();
203 let result = registry.execute("nonexistent", serde_json::json!({})).await;
204 assert!(result.contains("Tool 'nonexistent' not found"));
205 assert!(result.contains("[Analyze the error above"));
206 }
207}