Skip to main content

a3s_code_core/tools/
registry.rs

1//! Tool Registry
2//!
3//! Central registry for all tools (built-in and dynamic).
4//! Provides thread-safe registration, lookup, and execution.
5
6use super::types::{Tool, ToolContext, ToolOutput};
7use super::ToolResult;
8use crate::llm::ToolDefinition;
9use anyhow::Result;
10use std::collections::HashMap;
11use std::path::PathBuf;
12use std::sync::{Arc, RwLock};
13
14/// Tool registry for managing all available tools
15pub struct ToolRegistry {
16    tools: RwLock<HashMap<String, Arc<dyn Tool>>>,
17    /// Names of builtin tools that cannot be overridden
18    builtins: RwLock<std::collections::HashSet<String>>,
19    context: RwLock<ToolContext>,
20}
21
22impl ToolRegistry {
23    /// Create a new tool registry
24    pub fn new(workspace: PathBuf) -> Self {
25        Self {
26            tools: RwLock::new(HashMap::new()),
27            builtins: RwLock::new(std::collections::HashSet::new()),
28            context: RwLock::new(ToolContext::new(workspace)),
29        }
30    }
31
32    /// Register a builtin tool (cannot be overridden by dynamic tools)
33    pub fn register_builtin(&self, tool: Arc<dyn Tool>) {
34        let name = tool.name().to_string();
35        let mut tools = self.tools.write().unwrap();
36        let mut builtins = self.builtins.write().unwrap();
37        tracing::debug!("Registering builtin tool: {}", name);
38        tools.insert(name.clone(), tool);
39        builtins.insert(name);
40    }
41
42    /// Register a tool
43    ///
44    /// If a tool with the same name already exists as a builtin, the registration
45    /// is rejected to prevent shadowing of core tools.
46    pub fn register(&self, tool: Arc<dyn Tool>) {
47        let name = tool.name().to_string();
48        let builtins = self.builtins.read().unwrap();
49        if builtins.contains(&name) {
50            tracing::warn!(
51                "Rejected registration of tool '{}': cannot shadow builtin",
52                name
53            );
54            return;
55        }
56        drop(builtins);
57        let mut tools = self.tools.write().unwrap();
58        tracing::debug!("Registering tool: {}", name);
59        tools.insert(name, tool);
60    }
61
62    /// Unregister a tool by name
63    ///
64    /// Returns true if the tool was found and removed.
65    pub fn unregister(&self, name: &str) -> bool {
66        let mut tools = self.tools.write().unwrap();
67        tracing::debug!("Unregistering tool: {}", name);
68        tools.remove(name).is_some()
69    }
70
71    /// Unregister all tools whose names start with the given prefix.
72    pub fn unregister_by_prefix(&self, prefix: &str) {
73        let mut tools = self.tools.write().unwrap();
74        tools.retain(|name, _| !name.starts_with(prefix));
75        tracing::debug!("Unregistered tools with prefix: {}", prefix);
76    }
77
78    /// Get a tool by name
79    pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
80        let tools = self.tools.read().unwrap();
81        tools.get(name).cloned()
82    }
83
84    /// Check if a tool exists
85    pub fn contains(&self, name: &str) -> bool {
86        let tools = self.tools.read().unwrap();
87        tools.contains_key(name)
88    }
89
90    /// Get all tool definitions for LLM
91    pub fn definitions(&self) -> Vec<ToolDefinition> {
92        let tools = self.tools.read().unwrap();
93        tools
94            .values()
95            .map(|tool| ToolDefinition {
96                name: tool.name().to_string(),
97                description: tool.description().to_string(),
98                parameters: tool.parameters(),
99            })
100            .collect()
101    }
102
103    /// List all registered tool names
104    pub fn list(&self) -> Vec<String> {
105        let tools = self.tools.read().unwrap();
106        tools.keys().cloned().collect()
107    }
108
109    /// Get the number of registered tools
110    pub fn len(&self) -> usize {
111        let tools = self.tools.read().unwrap();
112        tools.len()
113    }
114
115    /// Check if registry is empty
116    pub fn is_empty(&self) -> bool {
117        self.len() == 0
118    }
119
120    /// Get the tool context
121    pub fn context(&self) -> ToolContext {
122        self.context.read().unwrap().clone()
123    }
124
125    /// Set the search configuration for the tool context
126    pub fn set_search_config(&self, config: crate::config::SearchConfig) {
127        let mut ctx = self.context.write().unwrap();
128        *ctx = ctx.clone().with_search_config(config);
129    }
130
131    /// Set the built-in `agentic_search` configuration for the default tool context.
132    pub fn set_agentic_search_config(&self, config: crate::config::AgenticSearchConfig) {
133        let mut ctx = self.context.write().unwrap();
134        *ctx = ctx.clone().with_agentic_search_config(config);
135    }
136
137    /// Set the built-in `agentic_parse` configuration for the default tool context.
138    pub fn set_agentic_parse_config(&self, config: crate::config::AgenticParseConfig) {
139        let mut ctx = self.context.write().unwrap();
140        *ctx = ctx.clone().with_agentic_parse_config(config);
141    }
142
143    /// Set the built-in document context extraction configuration for the default tool context.
144    pub fn set_document_parser_config(&self, config: crate::config::DocumentParserConfig) {
145        let mut ctx = self.context.write().unwrap();
146        *ctx = ctx.clone().with_document_parser_config(config);
147    }
148
149    /// Set the document parser registry for the default tool context.
150    pub fn set_document_parsers(
151        &self,
152        registry: std::sync::Arc<crate::document_parser::DocumentParserRegistry>,
153    ) {
154        let mut ctx = self.context.write().unwrap();
155        *ctx = ctx.clone().with_document_parsers(registry);
156    }
157
158    /// Set the internal document pipeline registry for the default tool context.
159    pub(crate) fn set_document_pipeline(
160        &self,
161        registry: std::sync::Arc<crate::document_pipeline::DocumentPipelineRegistry>,
162    ) {
163        let mut ctx = self.context.write().unwrap();
164        *ctx = ctx.clone().with_document_pipeline(registry);
165    }
166
167    /// Set a sandbox executor so that `bash` tool calls use the sandbox even
168    /// when executed without an explicit `ToolContext` (i.e., via `execute()`).
169    pub fn set_sandbox(&self, sandbox: std::sync::Arc<dyn crate::sandbox::BashSandbox>) {
170        let mut ctx = self.context.write().unwrap();
171        *ctx = ctx.clone().with_sandbox(sandbox);
172    }
173
174    /// Execute a tool by name using the registry's default context
175    pub async fn execute(&self, name: &str, args: &serde_json::Value) -> Result<ToolResult> {
176        let ctx = self.context();
177        self.execute_with_context(name, args, &ctx).await
178    }
179
180    /// Execute a tool by name with an external context
181    pub async fn execute_with_context(
182        &self,
183        name: &str,
184        args: &serde_json::Value,
185        ctx: &ToolContext,
186    ) -> Result<ToolResult> {
187        let start = std::time::Instant::now();
188
189        let tool = self.get(name);
190
191        let result = match tool {
192            Some(tool) => {
193                let output = tool.execute(args, ctx).await?;
194                Ok(ToolResult {
195                    name: name.to_string(),
196                    output: output.content,
197                    exit_code: if output.success { 0 } else { 1 },
198                    metadata: output.metadata,
199                    images: output.images,
200                })
201            }
202            None => Ok(ToolResult::error(name, format!("Unknown tool: {}", name))),
203        };
204
205        if let Ok(ref r) = result {
206            crate::telemetry::record_tool_result(r.exit_code, start.elapsed());
207        }
208
209        result
210    }
211
212    /// Execute a tool and return raw output using the registry's default context
213    pub async fn execute_raw(
214        &self,
215        name: &str,
216        args: &serde_json::Value,
217    ) -> Result<Option<ToolOutput>> {
218        let ctx = self.context();
219        self.execute_raw_with_context(name, args, &ctx).await
220    }
221
222    /// Execute a tool and return raw output with an external context
223    pub async fn execute_raw_with_context(
224        &self,
225        name: &str,
226        args: &serde_json::Value,
227        ctx: &ToolContext,
228    ) -> Result<Option<ToolOutput>> {
229        let tool = self.get(name);
230
231        match tool {
232            Some(tool) => {
233                let output = tool.execute(args, ctx).await?;
234                Ok(Some(output))
235            }
236            None => Ok(None),
237        }
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244    use async_trait::async_trait;
245
246    struct MockTool {
247        name: String,
248    }
249
250    #[async_trait]
251    impl Tool for MockTool {
252        fn name(&self) -> &str {
253            &self.name
254        }
255
256        fn description(&self) -> &str {
257            "A mock tool for testing"
258        }
259
260        fn parameters(&self) -> serde_json::Value {
261            serde_json::json!({
262                "type": "object",
263                "additionalProperties": false,
264                "properties": {},
265                "required": []
266            })
267        }
268
269        async fn execute(
270            &self,
271            _args: &serde_json::Value,
272            _ctx: &ToolContext,
273        ) -> Result<ToolOutput> {
274            Ok(ToolOutput::success("mock output"))
275        }
276    }
277
278    #[test]
279    fn test_registry_register_and_get() {
280        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
281
282        let tool = Arc::new(MockTool {
283            name: "test".to_string(),
284        });
285        registry.register(tool);
286
287        assert!(registry.contains("test"));
288        assert!(!registry.contains("nonexistent"));
289
290        let retrieved = registry.get("test");
291        assert!(retrieved.is_some());
292        assert_eq!(retrieved.unwrap().name(), "test");
293    }
294
295    #[test]
296    fn test_registry_unregister() {
297        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
298
299        let tool = Arc::new(MockTool {
300            name: "test".to_string(),
301        });
302        registry.register(tool);
303
304        assert!(registry.contains("test"));
305        assert!(registry.unregister("test"));
306        assert!(!registry.contains("test"));
307        assert!(!registry.unregister("test")); // Already removed
308    }
309
310    #[test]
311    fn test_registry_definitions() {
312        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
313
314        registry.register(Arc::new(MockTool {
315            name: "tool1".to_string(),
316        }));
317        registry.register(Arc::new(MockTool {
318            name: "tool2".to_string(),
319        }));
320
321        let definitions = registry.definitions();
322        assert_eq!(definitions.len(), 2);
323    }
324
325    #[test]
326    fn test_registry_set_document_parser_config() {
327        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
328        registry.set_document_parser_config(crate::config::DocumentParserConfig {
329            enabled: true,
330            max_file_size_mb: 48,
331            ocr: None,
332            ..Default::default()
333        });
334
335        assert_eq!(
336            registry
337                .context()
338                .document_parser_config
339                .as_ref()
340                .unwrap()
341                .max_file_size_mb,
342            48
343        );
344    }
345
346    #[tokio::test]
347    async fn test_registry_execute() {
348        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
349
350        registry.register(Arc::new(MockTool {
351            name: "test".to_string(),
352        }));
353
354        let result = registry
355            .execute("test", &serde_json::json!({}))
356            .await
357            .unwrap();
358        assert_eq!(result.exit_code, 0);
359        assert_eq!(result.output, "mock output");
360    }
361
362    #[tokio::test]
363    async fn test_registry_execute_unknown() {
364        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
365
366        let result = registry
367            .execute("unknown", &serde_json::json!({}))
368            .await
369            .unwrap();
370        assert_eq!(result.exit_code, 1);
371        assert!(result.output.contains("Unknown tool"));
372    }
373
374    #[tokio::test]
375    async fn test_registry_execute_with_context_success() {
376        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
377        let ctx = ToolContext::new(PathBuf::from("/tmp"));
378
379        registry.register(Arc::new(MockTool {
380            name: "my_tool".to_string(),
381        }));
382
383        let result = registry
384            .execute_with_context("my_tool", &serde_json::json!({}), &ctx)
385            .await
386            .unwrap();
387        assert_eq!(result.name, "my_tool");
388        assert_eq!(result.exit_code, 0);
389        assert_eq!(result.output, "mock output");
390    }
391
392    #[tokio::test]
393    async fn test_registry_execute_with_context_unknown_tool() {
394        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
395        let ctx = ToolContext::new(PathBuf::from("/tmp"));
396
397        let result = registry
398            .execute_with_context("nonexistent", &serde_json::json!({}), &ctx)
399            .await
400            .unwrap();
401        assert_eq!(result.exit_code, 1);
402        assert!(result.output.contains("Unknown tool: nonexistent"));
403    }
404
405    struct FailingTool;
406
407    #[async_trait]
408    impl Tool for FailingTool {
409        fn name(&self) -> &str {
410            "failing"
411        }
412
413        fn description(&self) -> &str {
414            "A tool that returns failure"
415        }
416
417        fn parameters(&self) -> serde_json::Value {
418            serde_json::json!({
419                "type": "object",
420                "additionalProperties": false,
421                "properties": {},
422                "required": []
423            })
424        }
425
426        async fn execute(
427            &self,
428            _args: &serde_json::Value,
429            _ctx: &ToolContext,
430        ) -> Result<ToolOutput> {
431            Ok(ToolOutput::error("something went wrong"))
432        }
433    }
434
435    #[tokio::test]
436    async fn test_registry_execute_failing_tool() {
437        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
438        registry.register(Arc::new(FailingTool));
439
440        let result = registry
441            .execute("failing", &serde_json::json!({}))
442            .await
443            .unwrap();
444        assert_eq!(result.exit_code, 1);
445        assert_eq!(result.output, "something went wrong");
446    }
447
448    #[tokio::test]
449    async fn test_registry_execute_raw_success() {
450        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
451        registry.register(Arc::new(MockTool {
452            name: "raw_test".to_string(),
453        }));
454
455        let output = registry
456            .execute_raw("raw_test", &serde_json::json!({}))
457            .await
458            .unwrap();
459        assert!(output.is_some());
460        let output = output.unwrap();
461        assert!(output.success);
462        assert_eq!(output.content, "mock output");
463    }
464
465    #[tokio::test]
466    async fn test_registry_execute_raw_unknown() {
467        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
468
469        let output = registry
470            .execute_raw("missing", &serde_json::json!({}))
471            .await
472            .unwrap();
473        assert!(output.is_none());
474    }
475
476    #[test]
477    fn test_registry_list() {
478        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
479        registry.register(Arc::new(MockTool {
480            name: "alpha".to_string(),
481        }));
482        registry.register(Arc::new(MockTool {
483            name: "beta".to_string(),
484        }));
485
486        let names = registry.list();
487        assert_eq!(names.len(), 2);
488        assert!(names.contains(&"alpha".to_string()));
489        assert!(names.contains(&"beta".to_string()));
490    }
491
492    #[test]
493    fn test_registry_len_and_is_empty() {
494        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
495        assert!(registry.is_empty());
496        assert_eq!(registry.len(), 0);
497
498        registry.register(Arc::new(MockTool {
499            name: "t".to_string(),
500        }));
501        assert!(!registry.is_empty());
502        assert_eq!(registry.len(), 1);
503    }
504
505    #[test]
506    fn test_registry_replace_tool() {
507        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
508        registry.register(Arc::new(MockTool {
509            name: "dup".to_string(),
510        }));
511        registry.register(Arc::new(MockTool {
512            name: "dup".to_string(),
513        }));
514        // Should still have only 1 tool (replaced)
515        assert_eq!(registry.len(), 1);
516    }
517}