Skip to main content

a3s_code_core/tools/
registry.rs

1//! Tool Registry
2//!
3//! Central registry for all tools (built-in and dynamic).
4//! Provides thread-safe registration, lookup, and execution.
5
6use super::types::{Tool, ToolContext, ToolOutput};
7use super::ToolResult;
8use crate::llm::ToolDefinition;
9use anyhow::Result;
10use std::collections::HashMap;
11use std::path::PathBuf;
12use std::sync::{Arc, RwLock};
13
14/// Tool registry for managing all available tools
15pub struct ToolRegistry {
16    tools: RwLock<HashMap<String, Arc<dyn Tool>>>,
17    /// Names of builtin tools that cannot be overridden
18    builtins: RwLock<std::collections::HashSet<String>>,
19    context: RwLock<ToolContext>,
20}
21
22impl ToolRegistry {
23    /// Create a new tool registry
24    pub fn new(workspace: PathBuf) -> Self {
25        Self {
26            tools: RwLock::new(HashMap::new()),
27            builtins: RwLock::new(std::collections::HashSet::new()),
28            context: RwLock::new(ToolContext::new(workspace)),
29        }
30    }
31
32    /// Register a builtin tool (cannot be overridden by dynamic tools)
33    pub fn register_builtin(&self, tool: Arc<dyn Tool>) {
34        let name = tool.name().to_string();
35        let mut tools = self.tools.write().unwrap();
36        let mut builtins = self.builtins.write().unwrap();
37        tracing::debug!("Registering builtin tool: {}", name);
38        tools.insert(name.clone(), tool);
39        builtins.insert(name);
40    }
41
42    /// Register a tool
43    ///
44    /// If a tool with the same name already exists as a builtin, the registration
45    /// is rejected to prevent shadowing of core tools.
46    pub fn register(&self, tool: Arc<dyn Tool>) {
47        let name = tool.name().to_string();
48        let builtins = self.builtins.read().unwrap();
49        if builtins.contains(&name) {
50            tracing::warn!(
51                "Rejected registration of tool '{}': cannot shadow builtin",
52                name
53            );
54            return;
55        }
56        drop(builtins);
57        let mut tools = self.tools.write().unwrap();
58        tracing::debug!("Registering tool: {}", name);
59        tools.insert(name, tool);
60    }
61
62    /// Unregister a tool by name
63    ///
64    /// Returns true if the tool was found and removed.
65    pub fn unregister(&self, name: &str) -> bool {
66        let mut tools = self.tools.write().unwrap();
67        tracing::debug!("Unregistering tool: {}", name);
68        tools.remove(name).is_some()
69    }
70
71    /// Unregister all tools whose names start with the given prefix.
72    pub fn unregister_by_prefix(&self, prefix: &str) {
73        let mut tools = self.tools.write().unwrap();
74        tools.retain(|name, _| !name.starts_with(prefix));
75        tracing::debug!("Unregistered tools with prefix: {}", prefix);
76    }
77
78    /// Get a tool by name
79    pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
80        let tools = self.tools.read().unwrap();
81        tools.get(name).cloned()
82    }
83
84    /// Check if a tool exists
85    pub fn contains(&self, name: &str) -> bool {
86        let tools = self.tools.read().unwrap();
87        tools.contains_key(name)
88    }
89
90    /// Get all tool definitions for LLM
91    pub fn definitions(&self) -> Vec<ToolDefinition> {
92        let tools = self.tools.read().unwrap();
93        tools
94            .values()
95            .map(|tool| ToolDefinition {
96                name: tool.name().to_string(),
97                description: tool.description().to_string(),
98                parameters: tool.parameters(),
99            })
100            .collect()
101    }
102
103    /// List all registered tool names
104    pub fn list(&self) -> Vec<String> {
105        let tools = self.tools.read().unwrap();
106        tools.keys().cloned().collect()
107    }
108
109    /// Get the number of registered tools
110    pub fn len(&self) -> usize {
111        let tools = self.tools.read().unwrap();
112        tools.len()
113    }
114
115    /// Check if registry is empty
116    pub fn is_empty(&self) -> bool {
117        self.len() == 0
118    }
119
120    /// Get the tool context
121    pub fn context(&self) -> ToolContext {
122        self.context.read().unwrap().clone()
123    }
124
125    /// Set the search configuration for the tool context
126    pub fn set_search_config(&self, config: crate::config::SearchConfig) {
127        let mut ctx = self.context.write().unwrap();
128        *ctx = ctx.clone().with_search_config(config);
129    }
130
131    /// Set a sandbox executor so that `bash` tool calls use the sandbox even
132    /// when executed without an explicit `ToolContext` (i.e., via `execute()`).
133    pub fn set_sandbox(&self, sandbox: std::sync::Arc<dyn crate::sandbox::BashSandbox>) {
134        let mut ctx = self.context.write().unwrap();
135        *ctx = ctx.clone().with_sandbox(sandbox);
136    }
137
138    /// Execute a tool by name using the registry's default context
139    pub async fn execute(&self, name: &str, args: &serde_json::Value) -> Result<ToolResult> {
140        let ctx = self.context();
141        self.execute_with_context(name, args, &ctx).await
142    }
143
144    /// Execute a tool by name with an external context
145    pub async fn execute_with_context(
146        &self,
147        name: &str,
148        args: &serde_json::Value,
149        ctx: &ToolContext,
150    ) -> Result<ToolResult> {
151        let start = std::time::Instant::now();
152
153        let tool = self.get(name);
154
155        let result = match tool {
156            Some(tool) => {
157                let output = tool.execute(args, ctx).await?;
158                Ok(ToolResult {
159                    name: name.to_string(),
160                    output: output.content,
161                    exit_code: if output.success { 0 } else { 1 },
162                    metadata: output.metadata,
163                    images: output.images,
164                })
165            }
166            None => Ok(ToolResult::error(name, format!("Unknown tool: {}", name))),
167        };
168
169        if let Ok(ref r) = result {
170            crate::telemetry::record_tool_result(r.exit_code, start.elapsed());
171        }
172
173        result
174    }
175
176    /// Execute a tool and return raw output using the registry's default context
177    pub async fn execute_raw(
178        &self,
179        name: &str,
180        args: &serde_json::Value,
181    ) -> Result<Option<ToolOutput>> {
182        let ctx = self.context();
183        self.execute_raw_with_context(name, args, &ctx).await
184    }
185
186    /// Execute a tool and return raw output with an external context
187    pub async fn execute_raw_with_context(
188        &self,
189        name: &str,
190        args: &serde_json::Value,
191        ctx: &ToolContext,
192    ) -> Result<Option<ToolOutput>> {
193        let tool = self.get(name);
194
195        match tool {
196            Some(tool) => {
197                let output = tool.execute(args, ctx).await?;
198                Ok(Some(output))
199            }
200            None => Ok(None),
201        }
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use async_trait::async_trait;
209
210    struct MockTool {
211        name: String,
212    }
213
214    #[async_trait]
215    impl Tool for MockTool {
216        fn name(&self) -> &str {
217            &self.name
218        }
219
220        fn description(&self) -> &str {
221            "A mock tool for testing"
222        }
223
224        fn parameters(&self) -> serde_json::Value {
225            serde_json::json!({
226                "type": "object",
227                "additionalProperties": false,
228                "properties": {},
229                "required": []
230            })
231        }
232
233        async fn execute(
234            &self,
235            _args: &serde_json::Value,
236            _ctx: &ToolContext,
237        ) -> Result<ToolOutput> {
238            Ok(ToolOutput::success("mock output"))
239        }
240    }
241
242    #[test]
243    fn test_registry_register_and_get() {
244        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
245
246        let tool = Arc::new(MockTool {
247            name: "test".to_string(),
248        });
249        registry.register(tool);
250
251        assert!(registry.contains("test"));
252        assert!(!registry.contains("nonexistent"));
253
254        let retrieved = registry.get("test");
255        assert!(retrieved.is_some());
256        assert_eq!(retrieved.unwrap().name(), "test");
257    }
258
259    #[test]
260    fn test_registry_unregister() {
261        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
262
263        let tool = Arc::new(MockTool {
264            name: "test".to_string(),
265        });
266        registry.register(tool);
267
268        assert!(registry.contains("test"));
269        assert!(registry.unregister("test"));
270        assert!(!registry.contains("test"));
271        assert!(!registry.unregister("test")); // Already removed
272    }
273
274    #[test]
275    fn test_registry_definitions() {
276        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
277
278        registry.register(Arc::new(MockTool {
279            name: "tool1".to_string(),
280        }));
281        registry.register(Arc::new(MockTool {
282            name: "tool2".to_string(),
283        }));
284
285        let definitions = registry.definitions();
286        assert_eq!(definitions.len(), 2);
287    }
288
289    #[tokio::test]
290    async fn test_registry_execute() {
291        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
292
293        registry.register(Arc::new(MockTool {
294            name: "test".to_string(),
295        }));
296
297        let result = registry
298            .execute("test", &serde_json::json!({}))
299            .await
300            .unwrap();
301        assert_eq!(result.exit_code, 0);
302        assert_eq!(result.output, "mock output");
303    }
304
305    #[tokio::test]
306    async fn test_registry_execute_unknown() {
307        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
308
309        let result = registry
310            .execute("unknown", &serde_json::json!({}))
311            .await
312            .unwrap();
313        assert_eq!(result.exit_code, 1);
314        assert!(result.output.contains("Unknown tool"));
315    }
316
317    #[tokio::test]
318    async fn test_registry_execute_with_context_success() {
319        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
320        let ctx = ToolContext::new(PathBuf::from("/tmp"));
321
322        registry.register(Arc::new(MockTool {
323            name: "my_tool".to_string(),
324        }));
325
326        let result = registry
327            .execute_with_context("my_tool", &serde_json::json!({}), &ctx)
328            .await
329            .unwrap();
330        assert_eq!(result.name, "my_tool");
331        assert_eq!(result.exit_code, 0);
332        assert_eq!(result.output, "mock output");
333    }
334
335    #[tokio::test]
336    async fn test_registry_execute_with_context_unknown_tool() {
337        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
338        let ctx = ToolContext::new(PathBuf::from("/tmp"));
339
340        let result = registry
341            .execute_with_context("nonexistent", &serde_json::json!({}), &ctx)
342            .await
343            .unwrap();
344        assert_eq!(result.exit_code, 1);
345        assert!(result.output.contains("Unknown tool: nonexistent"));
346    }
347
348    struct FailingTool;
349
350    #[async_trait]
351    impl Tool for FailingTool {
352        fn name(&self) -> &str {
353            "failing"
354        }
355
356        fn description(&self) -> &str {
357            "A tool that returns failure"
358        }
359
360        fn parameters(&self) -> serde_json::Value {
361            serde_json::json!({
362                "type": "object",
363                "additionalProperties": false,
364                "properties": {},
365                "required": []
366            })
367        }
368
369        async fn execute(
370            &self,
371            _args: &serde_json::Value,
372            _ctx: &ToolContext,
373        ) -> Result<ToolOutput> {
374            Ok(ToolOutput::error("something went wrong"))
375        }
376    }
377
378    #[tokio::test]
379    async fn test_registry_execute_failing_tool() {
380        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
381        registry.register(Arc::new(FailingTool));
382
383        let result = registry
384            .execute("failing", &serde_json::json!({}))
385            .await
386            .unwrap();
387        assert_eq!(result.exit_code, 1);
388        assert_eq!(result.output, "something went wrong");
389    }
390
391    #[tokio::test]
392    async fn test_registry_execute_raw_success() {
393        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
394        registry.register(Arc::new(MockTool {
395            name: "raw_test".to_string(),
396        }));
397
398        let output = registry
399            .execute_raw("raw_test", &serde_json::json!({}))
400            .await
401            .unwrap();
402        assert!(output.is_some());
403        let output = output.unwrap();
404        assert!(output.success);
405        assert_eq!(output.content, "mock output");
406    }
407
408    #[tokio::test]
409    async fn test_registry_execute_raw_unknown() {
410        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
411
412        let output = registry
413            .execute_raw("missing", &serde_json::json!({}))
414            .await
415            .unwrap();
416        assert!(output.is_none());
417    }
418
419    #[test]
420    fn test_registry_list() {
421        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
422        registry.register(Arc::new(MockTool {
423            name: "alpha".to_string(),
424        }));
425        registry.register(Arc::new(MockTool {
426            name: "beta".to_string(),
427        }));
428
429        let names = registry.list();
430        assert_eq!(names.len(), 2);
431        assert!(names.contains(&"alpha".to_string()));
432        assert!(names.contains(&"beta".to_string()));
433    }
434
435    #[test]
436    fn test_registry_len_and_is_empty() {
437        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
438        assert!(registry.is_empty());
439        assert_eq!(registry.len(), 0);
440
441        registry.register(Arc::new(MockTool {
442            name: "t".to_string(),
443        }));
444        assert!(!registry.is_empty());
445        assert_eq!(registry.len(), 1);
446    }
447
448    #[test]
449    fn test_registry_replace_tool() {
450        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
451        registry.register(Arc::new(MockTool {
452            name: "dup".to_string(),
453        }));
454        registry.register(Arc::new(MockTool {
455            name: "dup".to_string(),
456        }));
457        // Should still have only 1 tool (replaced)
458        assert_eq!(registry.len(), 1);
459    }
460}