Skip to main content

a3s_code_core/tools/
registry.rs

1//! Tool Registry
2//!
3//! Central registry for all tools (built-in and dynamic).
4//! Provides thread-safe registration, lookup, and execution.
5
6use super::artifacts::{ArtifactStore, ArtifactStoreLimits, ToolArtifact};
7use super::types::{Tool, ToolContext, ToolOutput};
8use super::ToolResult;
9use super::{
10    merge_tool_output_artifact_metadata, truncate_tool_output_with_artifact, ToolOutputArtifact,
11};
12use crate::llm::ToolDefinition;
13use crate::trace::{InMemoryTraceSink, TraceEvent, TraceSink};
14use anyhow::Result;
15use std::collections::HashMap;
16use std::path::PathBuf;
17use std::sync::{Arc, RwLock};
18
19/// Tool registry for managing all available tools
20pub struct ToolRegistry {
21    tools: RwLock<HashMap<String, Arc<dyn Tool>>>,
22    /// Names of builtin tools that cannot be overridden
23    builtins: RwLock<std::collections::HashSet<String>>,
24    context: RwLock<ToolContext>,
25    artifact_store: ArtifactStore,
26    trace_sink: RwLock<Arc<dyn TraceSink>>,
27}
28
29impl ToolRegistry {
30    /// Create a new tool registry
31    pub fn new(workspace: PathBuf) -> Self {
32        Self::with_artifact_limits(workspace, ArtifactStoreLimits::default())
33    }
34
35    /// Create a new tool registry with custom artifact retention limits.
36    pub fn with_artifact_limits(workspace: PathBuf, artifact_limits: ArtifactStoreLimits) -> Self {
37        Self::with_artifact_limits_and_workspace_services(
38            workspace.clone(),
39            artifact_limits,
40            crate::workspace::WorkspaceServices::local(workspace),
41        )
42    }
43
44    /// Create a new tool registry with custom artifact limits and workspace backend.
45    pub fn with_artifact_limits_and_workspace_services(
46        workspace: PathBuf,
47        artifact_limits: ArtifactStoreLimits,
48        workspace_services: Arc<crate::workspace::WorkspaceServices>,
49    ) -> Self {
50        let context = ToolContext::new(workspace).with_workspace_services(workspace_services);
51        Self {
52            tools: RwLock::new(HashMap::new()),
53            builtins: RwLock::new(std::collections::HashSet::new()),
54            context: RwLock::new(context),
55            artifact_store: ArtifactStore::with_limits(artifact_limits),
56            trace_sink: RwLock::new(Arc::new(InMemoryTraceSink::default())),
57        }
58    }
59
60    /// Register a builtin tool (cannot be overridden by dynamic tools)
61    pub fn register_builtin(&self, tool: Arc<dyn Tool>) {
62        let name = tool.name().to_string();
63        let mut tools = self.tools.write().unwrap();
64        let mut builtins = self.builtins.write().unwrap();
65        tracing::debug!("Registering builtin tool: {}", name);
66        tools.insert(name.clone(), tool);
67        builtins.insert(name);
68    }
69
70    /// Register a tool
71    ///
72    /// If a tool with the same name already exists as a builtin, the registration
73    /// is rejected to prevent shadowing of core tools.
74    pub fn register(&self, tool: Arc<dyn Tool>) {
75        let name = tool.name().to_string();
76        let builtins = self.builtins.read().unwrap();
77        if builtins.contains(&name) {
78            tracing::warn!(
79                "Rejected registration of tool '{}': cannot shadow builtin",
80                name
81            );
82            return;
83        }
84        drop(builtins);
85        let mut tools = self.tools.write().unwrap();
86        tracing::debug!("Registering tool: {}", name);
87        tools.insert(name, tool);
88    }
89
90    /// Unregister a tool by name
91    ///
92    /// Returns true if the tool was found and removed.
93    pub fn unregister(&self, name: &str) -> bool {
94        let mut tools = self.tools.write().unwrap();
95        tracing::debug!("Unregistering tool: {}", name);
96        tools.remove(name).is_some()
97    }
98
99    /// Unregister all tools whose names start with the given prefix.
100    pub fn unregister_by_prefix(&self, prefix: &str) {
101        let mut tools = self.tools.write().unwrap();
102        tools.retain(|name, _| !name.starts_with(prefix));
103        tracing::debug!("Unregistered tools with prefix: {}", prefix);
104    }
105
106    /// Get a tool by name
107    pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
108        let tools = self.tools.read().unwrap();
109        tools.get(name).cloned()
110    }
111
112    /// Check if a tool exists
113    pub fn contains(&self, name: &str) -> bool {
114        let tools = self.tools.read().unwrap();
115        tools.contains_key(name)
116    }
117
118    /// Get all tool definitions for LLM
119    pub fn definitions(&self) -> Vec<ToolDefinition> {
120        let tools = self.tools.read().unwrap();
121        tools
122            .values()
123            .map(|tool| ToolDefinition {
124                name: tool.name().to_string(),
125                description: tool.description().to_string(),
126                parameters: tool.parameters(),
127            })
128            .collect()
129    }
130
131    /// List all registered tool names
132    pub fn list(&self) -> Vec<String> {
133        let tools = self.tools.read().unwrap();
134        tools.keys().cloned().collect()
135    }
136
137    /// Get the number of registered tools
138    pub fn len(&self) -> usize {
139        let tools = self.tools.read().unwrap();
140        tools.len()
141    }
142
143    /// Check if registry is empty
144    pub fn is_empty(&self) -> bool {
145        self.len() == 0
146    }
147
148    /// Get the tool context
149    pub fn context(&self) -> ToolContext {
150        self.context.read().unwrap().clone()
151    }
152
153    /// Return a clone of the registry's artifact store handle.
154    pub fn artifact_store(&self) -> ArtifactStore {
155        self.artifact_store.clone()
156    }
157
158    /// Get a stored tool artifact by URI.
159    pub fn get_artifact(&self, artifact_uri: &str) -> Option<ToolArtifact> {
160        self.artifact_store.get(artifact_uri)
161    }
162
163    /// Replace the trace sink used for compact tool/program execution events.
164    pub fn set_trace_sink(&self, sink: Arc<dyn TraceSink>) {
165        *self.trace_sink.write().unwrap() = sink;
166    }
167
168    /// Return the current trace sink.
169    pub fn trace_sink(&self) -> Arc<dyn TraceSink> {
170        Arc::clone(&self.trace_sink.read().unwrap())
171    }
172
173    /// Set the search configuration for the tool context
174    pub fn set_search_config(&self, config: crate::config::SearchConfig) {
175        let mut ctx = self.context.write().unwrap();
176        *ctx = ctx.clone().with_search_config(config);
177    }
178
179    /// Set a sandbox executor so that `bash` tool calls use the sandbox even
180    /// when executed without an explicit `ToolContext` (i.e., via `execute()`).
181    pub fn set_sandbox(&self, sandbox: std::sync::Arc<dyn crate::sandbox::BashSandbox>) {
182        let mut ctx = self.context.write().unwrap();
183        *ctx = ctx.clone().with_sandbox(sandbox);
184    }
185
186    /// Set environment overrides used by subprocess-backed tools when executed
187    /// without an explicit context.
188    pub fn set_command_env(&self, env: Arc<HashMap<String, String>>) {
189        let mut ctx = self.context.write().unwrap();
190        *ctx = ctx.clone().with_command_env(env);
191    }
192
193    /// Execute a tool by name using the registry's default context
194    pub async fn execute(&self, name: &str, args: &serde_json::Value) -> Result<ToolResult> {
195        let ctx = self.context();
196        self.execute_with_context(name, args, &ctx).await
197    }
198
199    /// Execute a tool by name with an external context
200    pub async fn execute_with_context(
201        &self,
202        name: &str,
203        args: &serde_json::Value,
204        ctx: &ToolContext,
205    ) -> Result<ToolResult> {
206        let start = std::time::Instant::now();
207
208        let tool = self.get(name);
209
210        let result = match tool {
211            Some(tool) => {
212                let mut output = tool.execute(args, ctx).await?;
213                let original_content = output.content.clone();
214                let truncated = truncate_tool_output_with_artifact(name, &output.content);
215                output.content = truncated.content;
216                if let Some(artifact) = truncated.artifact {
217                    self.store_tool_artifact(name, &original_content, &artifact);
218                    output.metadata = Some(merge_tool_output_artifact_metadata(
219                        output.metadata,
220                        &artifact,
221                    ));
222                }
223                Ok(ToolResult {
224                    name: name.to_string(),
225                    output: output.content,
226                    exit_code: if output.success { 0 } else { 1 },
227                    metadata: output.metadata,
228                    images: output.images,
229                })
230            }
231            None => Ok(ToolResult::error(name, format!("Unknown tool: {}", name))),
232        };
233
234        if let Ok(ref r) = result {
235            crate::telemetry::record_tool_result(r.exit_code, start.elapsed());
236            self.record_trace_event(name, r, start.elapsed());
237        }
238
239        result
240    }
241
242    /// Execute a tool and return raw output using the registry's default context
243    pub async fn execute_raw(
244        &self,
245        name: &str,
246        args: &serde_json::Value,
247    ) -> Result<Option<ToolOutput>> {
248        let ctx = self.context();
249        self.execute_raw_with_context(name, args, &ctx).await
250    }
251
252    /// Execute a tool and return raw output with an external context
253    pub async fn execute_raw_with_context(
254        &self,
255        name: &str,
256        args: &serde_json::Value,
257        ctx: &ToolContext,
258    ) -> Result<Option<ToolOutput>> {
259        let tool = self.get(name);
260
261        match tool {
262            Some(tool) => {
263                let mut output = tool.execute(args, ctx).await?;
264                let original_content = output.content.clone();
265                let truncated = truncate_tool_output_with_artifact(name, &output.content);
266                output.content = truncated.content;
267                if let Some(artifact) = truncated.artifact {
268                    self.store_tool_artifact(name, &original_content, &artifact);
269                    output.metadata = Some(merge_tool_output_artifact_metadata(
270                        output.metadata,
271                        &artifact,
272                    ));
273                }
274                Ok(Some(output))
275            }
276            None => Ok(None),
277        }
278    }
279
280    fn store_tool_artifact(&self, tool_name: &str, content: &str, artifact: &ToolOutputArtifact) {
281        self.artifact_store.put(ToolArtifact {
282            artifact_id: artifact.artifact_id.clone(),
283            artifact_uri: artifact.artifact_uri.clone(),
284            tool_name: tool_name.to_string(),
285            content: content.to_string(),
286            original_bytes: artifact.original_bytes,
287            shown_bytes: artifact.shown_bytes,
288        });
289    }
290
291    fn record_trace_event(&self, name: &str, result: &ToolResult, duration: std::time::Duration) {
292        let sink = self.trace_sink();
293        sink.record(TraceEvent::tool_execution(
294            name,
295            result.exit_code == 0,
296            result.exit_code,
297            duration,
298            result.output.len(),
299            result.metadata.as_ref(),
300        ));
301
302        if name == "program" {
303            sink.record(TraceEvent::program_execution(
304                name,
305                result.exit_code == 0,
306                result.exit_code,
307                duration,
308                result.output.len(),
309                result.metadata.as_ref(),
310            ));
311        }
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318    use crate::trace::{InMemoryTraceSink, TraceEventKind};
319    use async_trait::async_trait;
320
321    struct MockTool {
322        name: String,
323    }
324
325    #[async_trait]
326    impl Tool for MockTool {
327        fn name(&self) -> &str {
328            &self.name
329        }
330
331        fn description(&self) -> &str {
332            "A mock tool for testing"
333        }
334
335        fn parameters(&self) -> serde_json::Value {
336            serde_json::json!({
337                "type": "object",
338                "additionalProperties": false,
339                "properties": {},
340                "required": []
341            })
342        }
343
344        async fn execute(
345            &self,
346            _args: &serde_json::Value,
347            _ctx: &ToolContext,
348        ) -> Result<ToolOutput> {
349            Ok(ToolOutput::success("mock output"))
350        }
351    }
352
353    #[test]
354    fn test_registry_register_and_get() {
355        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
356
357        let tool = Arc::new(MockTool {
358            name: "test".to_string(),
359        });
360        registry.register(tool);
361
362        assert!(registry.contains("test"));
363        assert!(!registry.contains("nonexistent"));
364
365        let retrieved = registry.get("test");
366        assert!(retrieved.is_some());
367        assert_eq!(retrieved.unwrap().name(), "test");
368    }
369
370    #[test]
371    fn test_registry_unregister() {
372        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
373
374        let tool = Arc::new(MockTool {
375            name: "test".to_string(),
376        });
377        registry.register(tool);
378
379        assert!(registry.contains("test"));
380        assert!(registry.unregister("test"));
381        assert!(!registry.contains("test"));
382        assert!(!registry.unregister("test")); // Already removed
383    }
384
385    #[test]
386    fn test_registry_definitions() {
387        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
388
389        registry.register(Arc::new(MockTool {
390            name: "tool1".to_string(),
391        }));
392        registry.register(Arc::new(MockTool {
393            name: "tool2".to_string(),
394        }));
395
396        let definitions = registry.definitions();
397        assert_eq!(definitions.len(), 2);
398    }
399
400    #[tokio::test]
401    async fn test_registry_execute() {
402        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
403
404        registry.register(Arc::new(MockTool {
405            name: "test".to_string(),
406        }));
407
408        let result = registry
409            .execute("test", &serde_json::json!({}))
410            .await
411            .unwrap();
412        assert_eq!(result.exit_code, 0);
413        assert_eq!(result.output, "mock output");
414    }
415
416    #[tokio::test]
417    async fn test_registry_execute_unknown() {
418        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
419
420        let result = registry
421            .execute("unknown", &serde_json::json!({}))
422            .await
423            .unwrap();
424        assert_eq!(result.exit_code, 1);
425        assert!(result.output.contains("Unknown tool"));
426    }
427
428    #[tokio::test]
429    async fn test_registry_execute_with_context_success() {
430        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
431        let ctx = ToolContext::new(PathBuf::from("/tmp"));
432        let trace_sink = InMemoryTraceSink::default();
433        registry.set_trace_sink(Arc::new(trace_sink.clone()));
434
435        registry.register(Arc::new(MockTool {
436            name: "my_tool".to_string(),
437        }));
438
439        let result = registry
440            .execute_with_context("my_tool", &serde_json::json!({}), &ctx)
441            .await
442            .unwrap();
443        assert_eq!(result.name, "my_tool");
444        assert_eq!(result.exit_code, 0);
445        assert_eq!(result.output, "mock output");
446
447        let events = trace_sink.events();
448        assert_eq!(events.len(), 1);
449        assert_eq!(events[0].kind, TraceEventKind::ToolExecution);
450        assert_eq!(events[0].name, "my_tool");
451        assert!(events[0].success);
452        assert_eq!(events[0].output_bytes, "mock output".len());
453    }
454
455    #[tokio::test]
456    async fn test_registry_execute_with_context_unknown_tool() {
457        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
458        let ctx = ToolContext::new(PathBuf::from("/tmp"));
459
460        let result = registry
461            .execute_with_context("nonexistent", &serde_json::json!({}), &ctx)
462            .await
463            .unwrap();
464        assert_eq!(result.exit_code, 1);
465        assert!(result.output.contains("Unknown tool: nonexistent"));
466    }
467
468    struct FailingTool;
469
470    #[async_trait]
471    impl Tool for FailingTool {
472        fn name(&self) -> &str {
473            "failing"
474        }
475
476        fn description(&self) -> &str {
477            "A tool that returns failure"
478        }
479
480        fn parameters(&self) -> serde_json::Value {
481            serde_json::json!({
482                "type": "object",
483                "additionalProperties": false,
484                "properties": {},
485                "required": []
486            })
487        }
488
489        async fn execute(
490            &self,
491            _args: &serde_json::Value,
492            _ctx: &ToolContext,
493        ) -> Result<ToolOutput> {
494            Ok(ToolOutput::error("something went wrong"))
495        }
496    }
497
498    #[tokio::test]
499    async fn test_registry_execute_failing_tool() {
500        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
501        registry.register(Arc::new(FailingTool));
502
503        let result = registry
504            .execute("failing", &serde_json::json!({}))
505            .await
506            .unwrap();
507        assert_eq!(result.exit_code, 1);
508        assert_eq!(result.output, "something went wrong");
509    }
510
511    struct LargeOutputTool;
512
513    #[async_trait]
514    impl Tool for LargeOutputTool {
515        fn name(&self) -> &str {
516            "large_output"
517        }
518
519        fn description(&self) -> &str {
520            "A tool that returns more than the maximum output size"
521        }
522
523        fn parameters(&self) -> serde_json::Value {
524            serde_json::json!({
525                "type": "object",
526                "additionalProperties": false,
527                "properties": {},
528                "required": []
529            })
530        }
531
532        async fn execute(
533            &self,
534            _args: &serde_json::Value,
535            _ctx: &ToolContext,
536        ) -> Result<ToolOutput> {
537            Ok(ToolOutput::success(
538                "x".repeat(super::super::MAX_OUTPUT_SIZE + 1),
539            ))
540        }
541    }
542
543    #[tokio::test]
544    async fn test_registry_truncates_large_tool_output() {
545        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
546        let trace_sink = InMemoryTraceSink::default();
547        registry.set_trace_sink(Arc::new(trace_sink.clone()));
548        registry.register(Arc::new(LargeOutputTool));
549
550        let result = registry
551            .execute("large_output", &serde_json::json!({}))
552            .await
553            .unwrap();
554
555        assert_eq!(result.exit_code, 0);
556        assert!(result.output.contains("[tool output truncated:"));
557        assert!(result
558            .output
559            .contains("Full output artifact: a3s://tool-output/large_output/"));
560        assert!(result.output.len() < super::super::MAX_OUTPUT_SIZE + 512);
561        let metadata = result.metadata.expect("artifact metadata");
562        assert_eq!(
563            metadata["artifact"]["original_bytes"],
564            serde_json::json!(super::super::MAX_OUTPUT_SIZE + 1)
565        );
566        assert_eq!(
567            metadata["artifact"]["shown_bytes"],
568            serde_json::json!(super::super::MAX_OUTPUT_SIZE)
569        );
570        assert!(metadata["artifact"]["artifact_id"]
571            .as_str()
572            .unwrap()
573            .starts_with("tool-output:large_output:"));
574        assert!(metadata["artifact"]["artifact_uri"]
575            .as_str()
576            .unwrap()
577            .starts_with("a3s://tool-output/large_output/"));
578
579        let artifact_uri = metadata["artifact"]["artifact_uri"].as_str().unwrap();
580        let artifact = registry
581            .get_artifact(artifact_uri)
582            .expect("full output artifact");
583        assert_eq!(artifact.tool_name, "large_output");
584        assert_eq!(artifact.original_bytes, super::super::MAX_OUTPUT_SIZE + 1);
585        assert_eq!(artifact.shown_bytes, super::super::MAX_OUTPUT_SIZE);
586        assert_eq!(
587            artifact.content,
588            "x".repeat(super::super::MAX_OUTPUT_SIZE + 1)
589        );
590
591        let events = trace_sink.events();
592        assert_eq!(events.len(), 1);
593        assert_eq!(events[0].artifact_uris, vec![artifact_uri]);
594    }
595
596    #[tokio::test]
597    async fn test_registry_execute_raw_success() {
598        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
599        registry.register(Arc::new(MockTool {
600            name: "raw_test".to_string(),
601        }));
602
603        let output = registry
604            .execute_raw("raw_test", &serde_json::json!({}))
605            .await
606            .unwrap();
607        assert!(output.is_some());
608        let output = output.unwrap();
609        assert!(output.success);
610        assert_eq!(output.content, "mock output");
611    }
612
613    #[tokio::test]
614    async fn test_registry_execute_raw_stores_truncated_artifact() {
615        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
616        registry.register(Arc::new(LargeOutputTool));
617
618        let output = registry
619            .execute_raw("large_output", &serde_json::json!({}))
620            .await
621            .unwrap()
622            .expect("raw output");
623
624        assert!(output.content.contains("[tool output truncated:"));
625        let metadata = output.metadata.expect("artifact metadata");
626        let artifact_uri = metadata["artifact"]["artifact_uri"].as_str().unwrap();
627        let artifact = registry
628            .get_artifact(artifact_uri)
629            .expect("full output artifact");
630        assert_eq!(artifact.tool_name, "large_output");
631        assert_eq!(artifact.content.len(), super::super::MAX_OUTPUT_SIZE + 1);
632    }
633
634    #[tokio::test]
635    async fn test_registry_execute_raw_unknown() {
636        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
637
638        let output = registry
639            .execute_raw("missing", &serde_json::json!({}))
640            .await
641            .unwrap();
642        assert!(output.is_none());
643    }
644
645    #[test]
646    fn test_registry_list() {
647        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
648        registry.register(Arc::new(MockTool {
649            name: "alpha".to_string(),
650        }));
651        registry.register(Arc::new(MockTool {
652            name: "beta".to_string(),
653        }));
654
655        let names = registry.list();
656        assert_eq!(names.len(), 2);
657        assert!(names.contains(&"alpha".to_string()));
658        assert!(names.contains(&"beta".to_string()));
659    }
660
661    #[test]
662    fn test_registry_len_and_is_empty() {
663        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
664        assert!(registry.is_empty());
665        assert_eq!(registry.len(), 0);
666
667        registry.register(Arc::new(MockTool {
668            name: "t".to_string(),
669        }));
670        assert!(!registry.is_empty());
671        assert_eq!(registry.len(), 1);
672    }
673
674    #[test]
675    fn test_registry_replace_tool() {
676        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
677        registry.register(Arc::new(MockTool {
678            name: "dup".to_string(),
679        }));
680        registry.register(Arc::new(MockTool {
681            name: "dup".to_string(),
682        }));
683        // Should still have only 1 tool (replaced)
684        assert_eq!(registry.len(), 1);
685    }
686}