Skip to main content

a3s_code_core/tools/
registry.rs

1//! Tool Registry
2//!
3//! Central registry for all tools (built-in and dynamic).
4//! Provides thread-safe registration, lookup, and execution.
5
6use super::types::{Tool, ToolContext, ToolOutput};
7use super::ToolResult;
8use crate::llm::ToolDefinition;
9use anyhow::Result;
10use std::collections::HashMap;
11use std::path::PathBuf;
12use std::sync::{Arc, RwLock};
13
14/// Tool registry for managing all available tools
15pub struct ToolRegistry {
16    tools: RwLock<HashMap<String, Arc<dyn Tool>>>,
17    /// Names of builtin tools that cannot be overridden
18    builtins: RwLock<std::collections::HashSet<String>>,
19    context: RwLock<ToolContext>,
20}
21
22impl ToolRegistry {
23    /// Create a new tool registry
24    pub fn new(workspace: PathBuf) -> Self {
25        Self {
26            tools: RwLock::new(HashMap::new()),
27            builtins: RwLock::new(std::collections::HashSet::new()),
28            context: RwLock::new(ToolContext::new(workspace)),
29        }
30    }
31
32    /// Register a builtin tool (cannot be overridden by dynamic tools)
33    pub fn register_builtin(&self, tool: Arc<dyn Tool>) {
34        let name = tool.name().to_string();
35        let mut tools = self.tools.write().unwrap();
36        let mut builtins = self.builtins.write().unwrap();
37        tracing::debug!("Registering builtin tool: {}", name);
38        tools.insert(name.clone(), tool);
39        builtins.insert(name);
40    }
41
42    /// Register a tool
43    ///
44    /// If a tool with the same name already exists as a builtin, the registration
45    /// is rejected to prevent shadowing of core tools.
46    pub fn register(&self, tool: Arc<dyn Tool>) {
47        let name = tool.name().to_string();
48        let builtins = self.builtins.read().unwrap();
49        if builtins.contains(&name) {
50            tracing::warn!(
51                "Rejected registration of tool '{}': cannot shadow builtin",
52                name
53            );
54            return;
55        }
56        drop(builtins);
57        let mut tools = self.tools.write().unwrap();
58        tracing::debug!("Registering tool: {}", name);
59        tools.insert(name, tool);
60    }
61
62    /// Unregister a tool by name
63    ///
64    /// Returns true if the tool was found and removed.
65    pub fn unregister(&self, name: &str) -> bool {
66        let mut tools = self.tools.write().unwrap();
67        tracing::debug!("Unregistering tool: {}", name);
68        tools.remove(name).is_some()
69    }
70
71    /// Get a tool by name
72    pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
73        let tools = self.tools.read().unwrap();
74        tools.get(name).cloned()
75    }
76
77    /// Check if a tool exists
78    pub fn contains(&self, name: &str) -> bool {
79        let tools = self.tools.read().unwrap();
80        tools.contains_key(name)
81    }
82
83    /// Get all tool definitions for LLM
84    pub fn definitions(&self) -> Vec<ToolDefinition> {
85        let tools = self.tools.read().unwrap();
86        tools
87            .values()
88            .map(|tool| ToolDefinition {
89                name: tool.name().to_string(),
90                description: tool.description().to_string(),
91                parameters: tool.parameters(),
92            })
93            .collect()
94    }
95
96    /// List all registered tool names
97    pub fn list(&self) -> Vec<String> {
98        let tools = self.tools.read().unwrap();
99        tools.keys().cloned().collect()
100    }
101
102    /// Get the number of registered tools
103    pub fn len(&self) -> usize {
104        let tools = self.tools.read().unwrap();
105        tools.len()
106    }
107
108    /// Check if registry is empty
109    pub fn is_empty(&self) -> bool {
110        self.len() == 0
111    }
112
113    /// Get the tool context
114    pub fn context(&self) -> ToolContext {
115        self.context.read().unwrap().clone()
116    }
117
118    /// Set the search configuration for the tool context
119    pub fn set_search_config(&self, config: crate::config::SearchConfig) {
120        let mut ctx = self.context.write().unwrap();
121        *ctx = ctx.clone().with_search_config(config);
122    }
123
124    /// Set a sandbox executor so that `bash` tool calls use the sandbox even
125    /// when executed without an explicit `ToolContext` (i.e., via `execute()`).
126    pub fn set_sandbox(&self, sandbox: std::sync::Arc<dyn crate::sandbox::BashSandbox>) {
127        let mut ctx = self.context.write().unwrap();
128        *ctx = ctx.clone().with_sandbox(sandbox);
129    }
130
131    /// Execute a tool by name using the registry's default context
132    pub async fn execute(&self, name: &str, args: &serde_json::Value) -> Result<ToolResult> {
133        let ctx = self.context();
134        self.execute_with_context(name, args, &ctx).await
135    }
136
137    /// Execute a tool by name with an external context
138    pub async fn execute_with_context(
139        &self,
140        name: &str,
141        args: &serde_json::Value,
142        ctx: &ToolContext,
143    ) -> Result<ToolResult> {
144        let start = std::time::Instant::now();
145
146        let tool = self.get(name);
147
148        let result = match tool {
149            Some(tool) => {
150                let output = tool.execute(args, ctx).await?;
151                Ok(ToolResult {
152                    name: name.to_string(),
153                    output: output.content,
154                    exit_code: if output.success { 0 } else { 1 },
155                    metadata: output.metadata,
156                    images: output.images,
157                })
158            }
159            None => Ok(ToolResult::error(name, format!("Unknown tool: {}", name))),
160        };
161
162        if let Ok(ref r) = result {
163            crate::telemetry::record_tool_result(r.exit_code, start.elapsed());
164        }
165
166        result
167    }
168
169    /// Execute a tool and return raw output using the registry's default context
170    pub async fn execute_raw(
171        &self,
172        name: &str,
173        args: &serde_json::Value,
174    ) -> Result<Option<ToolOutput>> {
175        let ctx = self.context();
176        self.execute_raw_with_context(name, args, &ctx).await
177    }
178
179    /// Execute a tool and return raw output with an external context
180    pub async fn execute_raw_with_context(
181        &self,
182        name: &str,
183        args: &serde_json::Value,
184        ctx: &ToolContext,
185    ) -> Result<Option<ToolOutput>> {
186        let tool = self.get(name);
187
188        match tool {
189            Some(tool) => {
190                let output = tool.execute(args, ctx).await?;
191                Ok(Some(output))
192            }
193            None => Ok(None),
194        }
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201    use async_trait::async_trait;
202
203    struct MockTool {
204        name: String,
205    }
206
207    #[async_trait]
208    impl Tool for MockTool {
209        fn name(&self) -> &str {
210            &self.name
211        }
212
213        fn description(&self) -> &str {
214            "A mock tool for testing"
215        }
216
217        fn parameters(&self) -> serde_json::Value {
218            serde_json::json!({
219                "type": "object",
220                "properties": {},
221                "required": []
222            })
223        }
224
225        async fn execute(
226            &self,
227            _args: &serde_json::Value,
228            _ctx: &ToolContext,
229        ) -> Result<ToolOutput> {
230            Ok(ToolOutput::success("mock output"))
231        }
232    }
233
234    #[test]
235    fn test_registry_register_and_get() {
236        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
237
238        let tool = Arc::new(MockTool {
239            name: "test".to_string(),
240        });
241        registry.register(tool);
242
243        assert!(registry.contains("test"));
244        assert!(!registry.contains("nonexistent"));
245
246        let retrieved = registry.get("test");
247        assert!(retrieved.is_some());
248        assert_eq!(retrieved.unwrap().name(), "test");
249    }
250
251    #[test]
252    fn test_registry_unregister() {
253        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
254
255        let tool = Arc::new(MockTool {
256            name: "test".to_string(),
257        });
258        registry.register(tool);
259
260        assert!(registry.contains("test"));
261        assert!(registry.unregister("test"));
262        assert!(!registry.contains("test"));
263        assert!(!registry.unregister("test")); // Already removed
264    }
265
266    #[test]
267    fn test_registry_definitions() {
268        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
269
270        registry.register(Arc::new(MockTool {
271            name: "tool1".to_string(),
272        }));
273        registry.register(Arc::new(MockTool {
274            name: "tool2".to_string(),
275        }));
276
277        let definitions = registry.definitions();
278        assert_eq!(definitions.len(), 2);
279    }
280
281    #[tokio::test]
282    async fn test_registry_execute() {
283        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
284
285        registry.register(Arc::new(MockTool {
286            name: "test".to_string(),
287        }));
288
289        let result = registry
290            .execute("test", &serde_json::json!({}))
291            .await
292            .unwrap();
293        assert_eq!(result.exit_code, 0);
294        assert_eq!(result.output, "mock output");
295    }
296
297    #[tokio::test]
298    async fn test_registry_execute_unknown() {
299        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
300
301        let result = registry
302            .execute("unknown", &serde_json::json!({}))
303            .await
304            .unwrap();
305        assert_eq!(result.exit_code, 1);
306        assert!(result.output.contains("Unknown tool"));
307    }
308
309    #[tokio::test]
310    async fn test_registry_execute_with_context_success() {
311        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
312        let ctx = ToolContext::new(PathBuf::from("/tmp"));
313
314        registry.register(Arc::new(MockTool {
315            name: "my_tool".to_string(),
316        }));
317
318        let result = registry
319            .execute_with_context("my_tool", &serde_json::json!({}), &ctx)
320            .await
321            .unwrap();
322        assert_eq!(result.name, "my_tool");
323        assert_eq!(result.exit_code, 0);
324        assert_eq!(result.output, "mock output");
325    }
326
327    #[tokio::test]
328    async fn test_registry_execute_with_context_unknown_tool() {
329        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
330        let ctx = ToolContext::new(PathBuf::from("/tmp"));
331
332        let result = registry
333            .execute_with_context("nonexistent", &serde_json::json!({}), &ctx)
334            .await
335            .unwrap();
336        assert_eq!(result.exit_code, 1);
337        assert!(result.output.contains("Unknown tool: nonexistent"));
338    }
339
340    struct FailingTool;
341
342    #[async_trait]
343    impl Tool for FailingTool {
344        fn name(&self) -> &str {
345            "failing"
346        }
347
348        fn description(&self) -> &str {
349            "A tool that returns failure"
350        }
351
352        fn parameters(&self) -> serde_json::Value {
353            serde_json::json!({"type": "object", "properties": {}, "required": []})
354        }
355
356        async fn execute(
357            &self,
358            _args: &serde_json::Value,
359            _ctx: &ToolContext,
360        ) -> Result<ToolOutput> {
361            Ok(ToolOutput::error("something went wrong"))
362        }
363    }
364
365    #[tokio::test]
366    async fn test_registry_execute_failing_tool() {
367        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
368        registry.register(Arc::new(FailingTool));
369
370        let result = registry
371            .execute("failing", &serde_json::json!({}))
372            .await
373            .unwrap();
374        assert_eq!(result.exit_code, 1);
375        assert_eq!(result.output, "something went wrong");
376    }
377
378    #[tokio::test]
379    async fn test_registry_execute_raw_success() {
380        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
381        registry.register(Arc::new(MockTool {
382            name: "raw_test".to_string(),
383        }));
384
385        let output = registry
386            .execute_raw("raw_test", &serde_json::json!({}))
387            .await
388            .unwrap();
389        assert!(output.is_some());
390        let output = output.unwrap();
391        assert!(output.success);
392        assert_eq!(output.content, "mock output");
393    }
394
395    #[tokio::test]
396    async fn test_registry_execute_raw_unknown() {
397        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
398
399        let output = registry
400            .execute_raw("missing", &serde_json::json!({}))
401            .await
402            .unwrap();
403        assert!(output.is_none());
404    }
405
406    #[test]
407    fn test_registry_list() {
408        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
409        registry.register(Arc::new(MockTool {
410            name: "alpha".to_string(),
411        }));
412        registry.register(Arc::new(MockTool {
413            name: "beta".to_string(),
414        }));
415
416        let names = registry.list();
417        assert_eq!(names.len(), 2);
418        assert!(names.contains(&"alpha".to_string()));
419        assert!(names.contains(&"beta".to_string()));
420    }
421
422    #[test]
423    fn test_registry_len_and_is_empty() {
424        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
425        assert!(registry.is_empty());
426        assert_eq!(registry.len(), 0);
427
428        registry.register(Arc::new(MockTool {
429            name: "t".to_string(),
430        }));
431        assert!(!registry.is_empty());
432        assert_eq!(registry.len(), 1);
433    }
434
435    #[test]
436    fn test_registry_replace_tool() {
437        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
438        registry.register(Arc::new(MockTool {
439            name: "dup".to_string(),
440        }));
441        registry.register(Arc::new(MockTool {
442            name: "dup".to_string(),
443        }));
444        // Should still have only 1 tool (replaced)
445        assert_eq!(registry.len(), 1);
446    }
447}