Skip to main content

a3s_code_core/tools/
registry.rs

1//! Tool Registry
2//!
3//! Central registry for all tools (built-in and dynamic).
4//! Provides thread-safe registration, lookup, and execution.
5
6use super::types::{Tool, ToolContext, ToolOutput};
7use super::ToolResult;
8use crate::llm::ToolDefinition;
9use anyhow::Result;
10use std::collections::HashMap;
11use std::path::PathBuf;
12use std::sync::{Arc, RwLock};
13
14/// Tool registry for managing all available tools
15pub struct ToolRegistry {
16    tools: RwLock<HashMap<String, Arc<dyn Tool>>>,
17    /// Names of builtin tools that cannot be overridden
18    builtins: RwLock<std::collections::HashSet<String>>,
19    context: ToolContext,
20}
21
22impl ToolRegistry {
23    /// Create a new tool registry
24    pub fn new(workspace: PathBuf) -> Self {
25        Self {
26            tools: RwLock::new(HashMap::new()),
27            builtins: RwLock::new(std::collections::HashSet::new()),
28            context: ToolContext::new(workspace),
29        }
30    }
31
32    /// Register a builtin tool (cannot be overridden by dynamic tools)
33    pub fn register_builtin(&self, tool: Arc<dyn Tool>) {
34        let name = tool.name().to_string();
35        let mut tools = self.tools.write().unwrap();
36        let mut builtins = self.builtins.write().unwrap();
37        tracing::debug!("Registering builtin tool: {}", name);
38        tools.insert(name.clone(), tool);
39        builtins.insert(name);
40    }
41
42    /// Register a tool
43    ///
44    /// If a tool with the same name already exists as a builtin, the registration
45    /// is rejected to prevent shadowing of core tools.
46    pub fn register(&self, tool: Arc<dyn Tool>) {
47        let name = tool.name().to_string();
48        let builtins = self.builtins.read().unwrap();
49        if builtins.contains(&name) {
50            tracing::warn!(
51                "Rejected registration of tool '{}': cannot shadow builtin",
52                name
53            );
54            return;
55        }
56        drop(builtins);
57        let mut tools = self.tools.write().unwrap();
58        tracing::debug!("Registering tool: {}", name);
59        tools.insert(name, tool);
60    }
61
62    /// Unregister a tool by name
63    ///
64    /// Returns true if the tool was found and removed.
65    pub fn unregister(&self, name: &str) -> bool {
66        let mut tools = self.tools.write().unwrap();
67        tracing::debug!("Unregistering tool: {}", name);
68        tools.remove(name).is_some()
69    }
70
71    /// Get a tool by name
72    pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
73        let tools = self.tools.read().unwrap();
74        tools.get(name).cloned()
75    }
76
77    /// Check if a tool exists
78    pub fn contains(&self, name: &str) -> bool {
79        let tools = self.tools.read().unwrap();
80        tools.contains_key(name)
81    }
82
83    /// Get all tool definitions for LLM
84    pub fn definitions(&self) -> Vec<ToolDefinition> {
85        let tools = self.tools.read().unwrap();
86        tools
87            .values()
88            .map(|tool| ToolDefinition {
89                name: tool.name().to_string(),
90                description: tool.description().to_string(),
91                parameters: tool.parameters(),
92            })
93            .collect()
94    }
95
96    /// List all registered tool names
97    pub fn list(&self) -> Vec<String> {
98        let tools = self.tools.read().unwrap();
99        tools.keys().cloned().collect()
100    }
101
102    /// Get the number of registered tools
103    pub fn len(&self) -> usize {
104        let tools = self.tools.read().unwrap();
105        tools.len()
106    }
107
108    /// Check if registry is empty
109    pub fn is_empty(&self) -> bool {
110        self.len() == 0
111    }
112
113    /// Get the tool context
114    pub fn context(&self) -> &ToolContext {
115        &self.context
116    }
117
118    /// Execute a tool by name using the registry's default context
119    pub async fn execute(&self, name: &str, args: &serde_json::Value) -> Result<ToolResult> {
120        self.execute_with_context(name, args, &self.context).await
121    }
122
123    /// Execute a tool by name with an external context
124    pub async fn execute_with_context(
125        &self,
126        name: &str,
127        args: &serde_json::Value,
128        ctx: &ToolContext,
129    ) -> Result<ToolResult> {
130        let start = std::time::Instant::now();
131
132        let tool = self.get(name);
133
134        let result = match tool {
135            Some(tool) => {
136                let output = tool.execute(args, ctx).await?;
137                Ok(ToolResult {
138                    name: name.to_string(),
139                    output: output.content,
140                    exit_code: if output.success { 0 } else { 1 },
141                    metadata: output.metadata,
142                })
143            }
144            None => Ok(ToolResult::error(name, format!("Unknown tool: {}", name))),
145        };
146
147        if let Ok(ref r) = result {
148            crate::telemetry::record_tool_result(r.exit_code, start.elapsed());
149        }
150
151        result
152    }
153
154    /// Execute a tool and return raw output using the registry's default context
155    pub async fn execute_raw(
156        &self,
157        name: &str,
158        args: &serde_json::Value,
159    ) -> Result<Option<ToolOutput>> {
160        self.execute_raw_with_context(name, args, &self.context)
161            .await
162    }
163
164    /// Execute a tool and return raw output with an external context
165    pub async fn execute_raw_with_context(
166        &self,
167        name: &str,
168        args: &serde_json::Value,
169        ctx: &ToolContext,
170    ) -> Result<Option<ToolOutput>> {
171        let tool = self.get(name);
172
173        match tool {
174            Some(tool) => {
175                let output = tool.execute(args, ctx).await?;
176                Ok(Some(output))
177            }
178            None => Ok(None),
179        }
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186    use async_trait::async_trait;
187
188    struct MockTool {
189        name: String,
190    }
191
192    #[async_trait]
193    impl Tool for MockTool {
194        fn name(&self) -> &str {
195            &self.name
196        }
197
198        fn description(&self) -> &str {
199            "A mock tool for testing"
200        }
201
202        fn parameters(&self) -> serde_json::Value {
203            serde_json::json!({
204                "type": "object",
205                "properties": {},
206                "required": []
207            })
208        }
209
210        async fn execute(
211            &self,
212            _args: &serde_json::Value,
213            _ctx: &ToolContext,
214        ) -> Result<ToolOutput> {
215            Ok(ToolOutput::success("mock output"))
216        }
217    }
218
219    #[test]
220    fn test_registry_register_and_get() {
221        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
222
223        let tool = Arc::new(MockTool {
224            name: "test".to_string(),
225        });
226        registry.register(tool);
227
228        assert!(registry.contains("test"));
229        assert!(!registry.contains("nonexistent"));
230
231        let retrieved = registry.get("test");
232        assert!(retrieved.is_some());
233        assert_eq!(retrieved.unwrap().name(), "test");
234    }
235
236    #[test]
237    fn test_registry_unregister() {
238        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
239
240        let tool = Arc::new(MockTool {
241            name: "test".to_string(),
242        });
243        registry.register(tool);
244
245        assert!(registry.contains("test"));
246        assert!(registry.unregister("test"));
247        assert!(!registry.contains("test"));
248        assert!(!registry.unregister("test")); // Already removed
249    }
250
251    #[test]
252    fn test_registry_definitions() {
253        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
254
255        registry.register(Arc::new(MockTool {
256            name: "tool1".to_string(),
257        }));
258        registry.register(Arc::new(MockTool {
259            name: "tool2".to_string(),
260        }));
261
262        let definitions = registry.definitions();
263        assert_eq!(definitions.len(), 2);
264    }
265
266    #[tokio::test]
267    async fn test_registry_execute() {
268        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
269
270        registry.register(Arc::new(MockTool {
271            name: "test".to_string(),
272        }));
273
274        let result = registry
275            .execute("test", &serde_json::json!({}))
276            .await
277            .unwrap();
278        assert_eq!(result.exit_code, 0);
279        assert_eq!(result.output, "mock output");
280    }
281
282    #[tokio::test]
283    async fn test_registry_execute_unknown() {
284        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
285
286        let result = registry
287            .execute("unknown", &serde_json::json!({}))
288            .await
289            .unwrap();
290        assert_eq!(result.exit_code, 1);
291        assert!(result.output.contains("Unknown tool"));
292    }
293
294    #[tokio::test]
295    async fn test_registry_execute_with_context_success() {
296        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
297        let ctx = ToolContext::new(PathBuf::from("/tmp"));
298
299        registry.register(Arc::new(MockTool {
300            name: "my_tool".to_string(),
301        }));
302
303        let result = registry
304            .execute_with_context("my_tool", &serde_json::json!({}), &ctx)
305            .await
306            .unwrap();
307        assert_eq!(result.name, "my_tool");
308        assert_eq!(result.exit_code, 0);
309        assert_eq!(result.output, "mock output");
310    }
311
312    #[tokio::test]
313    async fn test_registry_execute_with_context_unknown_tool() {
314        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
315        let ctx = ToolContext::new(PathBuf::from("/tmp"));
316
317        let result = registry
318            .execute_with_context("nonexistent", &serde_json::json!({}), &ctx)
319            .await
320            .unwrap();
321        assert_eq!(result.exit_code, 1);
322        assert!(result.output.contains("Unknown tool: nonexistent"));
323    }
324
325    struct FailingTool;
326
327    #[async_trait]
328    impl Tool for FailingTool {
329        fn name(&self) -> &str {
330            "failing"
331        }
332
333        fn description(&self) -> &str {
334            "A tool that returns failure"
335        }
336
337        fn parameters(&self) -> serde_json::Value {
338            serde_json::json!({"type": "object", "properties": {}, "required": []})
339        }
340
341        async fn execute(
342            &self,
343            _args: &serde_json::Value,
344            _ctx: &ToolContext,
345        ) -> Result<ToolOutput> {
346            Ok(ToolOutput::error("something went wrong"))
347        }
348    }
349
350    #[tokio::test]
351    async fn test_registry_execute_failing_tool() {
352        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
353        registry.register(Arc::new(FailingTool));
354
355        let result = registry
356            .execute("failing", &serde_json::json!({}))
357            .await
358            .unwrap();
359        assert_eq!(result.exit_code, 1);
360        assert_eq!(result.output, "something went wrong");
361    }
362
363    #[tokio::test]
364    async fn test_registry_execute_raw_success() {
365        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
366        registry.register(Arc::new(MockTool {
367            name: "raw_test".to_string(),
368        }));
369
370        let output = registry
371            .execute_raw("raw_test", &serde_json::json!({}))
372            .await
373            .unwrap();
374        assert!(output.is_some());
375        let output = output.unwrap();
376        assert!(output.success);
377        assert_eq!(output.content, "mock output");
378    }
379
380    #[tokio::test]
381    async fn test_registry_execute_raw_unknown() {
382        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
383
384        let output = registry
385            .execute_raw("missing", &serde_json::json!({}))
386            .await
387            .unwrap();
388        assert!(output.is_none());
389    }
390
391    #[test]
392    fn test_registry_list() {
393        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
394        registry.register(Arc::new(MockTool {
395            name: "alpha".to_string(),
396        }));
397        registry.register(Arc::new(MockTool {
398            name: "beta".to_string(),
399        }));
400
401        let names = registry.list();
402        assert_eq!(names.len(), 2);
403        assert!(names.contains(&"alpha".to_string()));
404        assert!(names.contains(&"beta".to_string()));
405    }
406
407    #[test]
408    fn test_registry_len_and_is_empty() {
409        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
410        assert!(registry.is_empty());
411        assert_eq!(registry.len(), 0);
412
413        registry.register(Arc::new(MockTool {
414            name: "t".to_string(),
415        }));
416        assert!(!registry.is_empty());
417        assert_eq!(registry.len(), 1);
418    }
419
420    #[test]
421    fn test_registry_replace_tool() {
422        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
423        registry.register(Arc::new(MockTool {
424            name: "dup".to_string(),
425        }));
426        registry.register(Arc::new(MockTool {
427            name: "dup".to_string(),
428        }));
429        // Should still have only 1 tool (replaced)
430        assert_eq!(registry.len(), 1);
431    }
432}