Skip to main content

a3s_code_core/tools/
registry.rs

1//! Tool Registry
2//!
3//! Central registry for all tools (built-in and dynamic).
4//! Provides thread-safe registration, lookup, and execution.
5
6use super::types::{Tool, ToolContext, ToolOutput};
7use super::ToolResult;
8use crate::llm::ToolDefinition;
9use anyhow::Result;
10use std::collections::HashMap;
11use std::path::PathBuf;
12use std::sync::{Arc, RwLock};
13
14/// Tool registry for managing all available tools
15pub struct ToolRegistry {
16    tools: RwLock<HashMap<String, Arc<dyn Tool>>>,
17    /// Names of builtin tools that cannot be overridden
18    builtins: RwLock<std::collections::HashSet<String>>,
19    context: ToolContext,
20}
21
22impl ToolRegistry {
23    /// Create a new tool registry
24    pub fn new(workspace: PathBuf) -> Self {
25        Self {
26            tools: RwLock::new(HashMap::new()),
27            builtins: RwLock::new(std::collections::HashSet::new()),
28            context: ToolContext::new(workspace),
29        }
30    }
31
32    /// Register a builtin tool (cannot be overridden by dynamic tools)
33    pub fn register_builtin(&self, tool: Arc<dyn Tool>) {
34        let name = tool.name().to_string();
35        let mut tools = self.tools.write().unwrap();
36        let mut builtins = self.builtins.write().unwrap();
37        tracing::debug!("Registering builtin tool: {}", name);
38        tools.insert(name.clone(), tool);
39        builtins.insert(name);
40    }
41
42    /// Register a tool
43    ///
44    /// If a tool with the same name already exists as a builtin, the registration
45    /// is rejected to prevent shadowing of core tools.
46    pub fn register(&self, tool: Arc<dyn Tool>) {
47        let name = tool.name().to_string();
48        let builtins = self.builtins.read().unwrap();
49        if builtins.contains(&name) {
50            tracing::warn!(
51                "Rejected registration of tool '{}': cannot shadow builtin",
52                name
53            );
54            return;
55        }
56        drop(builtins);
57        let mut tools = self.tools.write().unwrap();
58        tracing::debug!("Registering tool: {}", name);
59        tools.insert(name, tool);
60    }
61
62    /// Unregister a tool by name
63    ///
64    /// Returns true if the tool was found and removed.
65    pub fn unregister(&self, name: &str) -> bool {
66        let mut tools = self.tools.write().unwrap();
67        tracing::debug!("Unregistering tool: {}", name);
68        tools.remove(name).is_some()
69    }
70
71    /// Get a tool by name
72    pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
73        let tools = self.tools.read().unwrap();
74        tools.get(name).cloned()
75    }
76
77    /// Check if a tool exists
78    pub fn contains(&self, name: &str) -> bool {
79        let tools = self.tools.read().unwrap();
80        tools.contains_key(name)
81    }
82
83    /// Get all tool definitions for LLM
84    pub fn definitions(&self) -> Vec<ToolDefinition> {
85        let tools = self.tools.read().unwrap();
86        tools
87            .values()
88            .map(|tool| ToolDefinition {
89                name: tool.name().to_string(),
90                description: tool.description().to_string(),
91                parameters: tool.parameters(),
92            })
93            .collect()
94    }
95
96    /// List all registered tool names
97    pub fn list(&self) -> Vec<String> {
98        let tools = self.tools.read().unwrap();
99        tools.keys().cloned().collect()
100    }
101
102    /// Get the number of registered tools
103    pub fn len(&self) -> usize {
104        let tools = self.tools.read().unwrap();
105        tools.len()
106    }
107
108    /// Check if registry is empty
109    pub fn is_empty(&self) -> bool {
110        self.len() == 0
111    }
112
113    /// Get the tool context
114    pub fn context(&self) -> &ToolContext {
115        &self.context
116    }
117
118    /// Execute a tool by name using the registry's default context
119    pub async fn execute(&self, name: &str, args: &serde_json::Value) -> Result<ToolResult> {
120        self.execute_with_context(name, args, &self.context).await
121    }
122
123    /// Execute a tool by name with an external context
124    pub async fn execute_with_context(
125        &self,
126        name: &str,
127        args: &serde_json::Value,
128        ctx: &ToolContext,
129    ) -> Result<ToolResult> {
130        let span = tracing::info_span!(
131            "a3s.tool.execute",
132            "a3s.tool.name" = %name,
133            "a3s.tool.success" = tracing::field::Empty,
134            "a3s.tool.exit_code" = tracing::field::Empty,
135            "a3s.tool.duration_ms" = tracing::field::Empty,
136        );
137        let _guard = span.enter();
138        let start = std::time::Instant::now();
139
140        let tool = self.get(name);
141
142        let result = match tool {
143            Some(tool) => {
144                let output = tool.execute(args, ctx).await?;
145                Ok(ToolResult {
146                    name: name.to_string(),
147                    output: output.content,
148                    exit_code: if output.success { 0 } else { 1 },
149                    metadata: output.metadata,
150                })
151            }
152            None => Ok(ToolResult::error(name, format!("Unknown tool: {}", name))),
153        };
154
155        if let Ok(ref r) = result {
156            crate::telemetry::record_tool_result(r.exit_code, start.elapsed());
157        }
158
159        result
160    }
161
162    /// Execute a tool and return raw output using the registry's default context
163    pub async fn execute_raw(
164        &self,
165        name: &str,
166        args: &serde_json::Value,
167    ) -> Result<Option<ToolOutput>> {
168        self.execute_raw_with_context(name, args, &self.context)
169            .await
170    }
171
172    /// Execute a tool and return raw output with an external context
173    pub async fn execute_raw_with_context(
174        &self,
175        name: &str,
176        args: &serde_json::Value,
177        ctx: &ToolContext,
178    ) -> Result<Option<ToolOutput>> {
179        let tool = self.get(name);
180
181        match tool {
182            Some(tool) => {
183                let output = tool.execute(args, ctx).await?;
184                Ok(Some(output))
185            }
186            None => Ok(None),
187        }
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194    use async_trait::async_trait;
195
196    struct MockTool {
197        name: String,
198    }
199
200    #[async_trait]
201    impl Tool for MockTool {
202        fn name(&self) -> &str {
203            &self.name
204        }
205
206        fn description(&self) -> &str {
207            "A mock tool for testing"
208        }
209
210        fn parameters(&self) -> serde_json::Value {
211            serde_json::json!({
212                "type": "object",
213                "properties": {},
214                "required": []
215            })
216        }
217
218        async fn execute(
219            &self,
220            _args: &serde_json::Value,
221            _ctx: &ToolContext,
222        ) -> Result<ToolOutput> {
223            Ok(ToolOutput::success("mock output"))
224        }
225    }
226
227    #[test]
228    fn test_registry_register_and_get() {
229        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
230
231        let tool = Arc::new(MockTool {
232            name: "test".to_string(),
233        });
234        registry.register(tool);
235
236        assert!(registry.contains("test"));
237        assert!(!registry.contains("nonexistent"));
238
239        let retrieved = registry.get("test");
240        assert!(retrieved.is_some());
241        assert_eq!(retrieved.unwrap().name(), "test");
242    }
243
244    #[test]
245    fn test_registry_unregister() {
246        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
247
248        let tool = Arc::new(MockTool {
249            name: "test".to_string(),
250        });
251        registry.register(tool);
252
253        assert!(registry.contains("test"));
254        assert!(registry.unregister("test"));
255        assert!(!registry.contains("test"));
256        assert!(!registry.unregister("test")); // Already removed
257    }
258
259    #[test]
260    fn test_registry_definitions() {
261        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
262
263        registry.register(Arc::new(MockTool {
264            name: "tool1".to_string(),
265        }));
266        registry.register(Arc::new(MockTool {
267            name: "tool2".to_string(),
268        }));
269
270        let definitions = registry.definitions();
271        assert_eq!(definitions.len(), 2);
272    }
273
274    #[tokio::test]
275    async fn test_registry_execute() {
276        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
277
278        registry.register(Arc::new(MockTool {
279            name: "test".to_string(),
280        }));
281
282        let result = registry
283            .execute("test", &serde_json::json!({}))
284            .await
285            .unwrap();
286        assert_eq!(result.exit_code, 0);
287        assert_eq!(result.output, "mock output");
288    }
289
290    #[tokio::test]
291    async fn test_registry_execute_unknown() {
292        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
293
294        let result = registry
295            .execute("unknown", &serde_json::json!({}))
296            .await
297            .unwrap();
298        assert_eq!(result.exit_code, 1);
299        assert!(result.output.contains("Unknown tool"));
300    }
301
302    #[tokio::test]
303    async fn test_registry_execute_with_context_success() {
304        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
305        let ctx = ToolContext::new(PathBuf::from("/tmp"));
306
307        registry.register(Arc::new(MockTool {
308            name: "my_tool".to_string(),
309        }));
310
311        let result = registry
312            .execute_with_context("my_tool", &serde_json::json!({}), &ctx)
313            .await
314            .unwrap();
315        assert_eq!(result.name, "my_tool");
316        assert_eq!(result.exit_code, 0);
317        assert_eq!(result.output, "mock output");
318    }
319
320    #[tokio::test]
321    async fn test_registry_execute_with_context_unknown_tool() {
322        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
323        let ctx = ToolContext::new(PathBuf::from("/tmp"));
324
325        let result = registry
326            .execute_with_context("nonexistent", &serde_json::json!({}), &ctx)
327            .await
328            .unwrap();
329        assert_eq!(result.exit_code, 1);
330        assert!(result.output.contains("Unknown tool: nonexistent"));
331    }
332
333    struct FailingTool;
334
335    #[async_trait]
336    impl Tool for FailingTool {
337        fn name(&self) -> &str {
338            "failing"
339        }
340
341        fn description(&self) -> &str {
342            "A tool that returns failure"
343        }
344
345        fn parameters(&self) -> serde_json::Value {
346            serde_json::json!({"type": "object", "properties": {}, "required": []})
347        }
348
349        async fn execute(
350            &self,
351            _args: &serde_json::Value,
352            _ctx: &ToolContext,
353        ) -> Result<ToolOutput> {
354            Ok(ToolOutput::error("something went wrong"))
355        }
356    }
357
358    #[tokio::test]
359    async fn test_registry_execute_failing_tool() {
360        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
361        registry.register(Arc::new(FailingTool));
362
363        let result = registry
364            .execute("failing", &serde_json::json!({}))
365            .await
366            .unwrap();
367        assert_eq!(result.exit_code, 1);
368        assert_eq!(result.output, "something went wrong");
369    }
370
371    #[tokio::test]
372    async fn test_registry_execute_raw_success() {
373        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
374        registry.register(Arc::new(MockTool {
375            name: "raw_test".to_string(),
376        }));
377
378        let output = registry
379            .execute_raw("raw_test", &serde_json::json!({}))
380            .await
381            .unwrap();
382        assert!(output.is_some());
383        let output = output.unwrap();
384        assert!(output.success);
385        assert_eq!(output.content, "mock output");
386    }
387
388    #[tokio::test]
389    async fn test_registry_execute_raw_unknown() {
390        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
391
392        let output = registry
393            .execute_raw("missing", &serde_json::json!({}))
394            .await
395            .unwrap();
396        assert!(output.is_none());
397    }
398
399    #[test]
400    fn test_registry_list() {
401        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
402        registry.register(Arc::new(MockTool {
403            name: "alpha".to_string(),
404        }));
405        registry.register(Arc::new(MockTool {
406            name: "beta".to_string(),
407        }));
408
409        let names = registry.list();
410        assert_eq!(names.len(), 2);
411        assert!(names.contains(&"alpha".to_string()));
412        assert!(names.contains(&"beta".to_string()));
413    }
414
415    #[test]
416    fn test_registry_len_and_is_empty() {
417        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
418        assert!(registry.is_empty());
419        assert_eq!(registry.len(), 0);
420
421        registry.register(Arc::new(MockTool {
422            name: "t".to_string(),
423        }));
424        assert!(!registry.is_empty());
425        assert_eq!(registry.len(), 1);
426    }
427
428    #[test]
429    fn test_registry_replace_tool() {
430        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
431        registry.register(Arc::new(MockTool {
432            name: "dup".to_string(),
433        }));
434        registry.register(Arc::new(MockTool {
435            name: "dup".to_string(),
436        }));
437        // Should still have only 1 tool (replaced)
438        assert_eq!(registry.len(), 1);
439    }
440}