Skip to main content

a3s_code_core/tools/
registry.rs

1//! Tool Registry
2//!
3//! Central registry for all tools (built-in and dynamic).
4//! Provides thread-safe registration, lookup, and execution.
5
6use super::artifacts::{ArtifactStore, ArtifactStoreLimits, ToolArtifact};
7use super::types::{Tool, ToolContext, ToolOutput};
8use super::ToolResult;
9use super::{
10    merge_tool_output_artifact_metadata, truncate_tool_output_with_artifact, ToolOutputArtifact,
11};
12use crate::llm::ToolDefinition;
13use crate::trace::{InMemoryTraceSink, TraceEvent, TraceSink};
14use anyhow::Result;
15use std::collections::HashMap;
16use std::path::PathBuf;
17use std::sync::{Arc, RwLock};
18
19/// Tool registry for managing all available tools
20pub struct ToolRegistry {
21    tools: RwLock<HashMap<String, Arc<dyn Tool>>>,
22    /// Names of builtin tools that cannot be overridden
23    builtins: RwLock<std::collections::HashSet<String>>,
24    context: RwLock<ToolContext>,
25    artifact_store: ArtifactStore,
26    trace_sink: RwLock<Arc<dyn TraceSink>>,
27}
28
29impl ToolRegistry {
30    /// Create a new tool registry
31    pub fn new(workspace: PathBuf) -> Self {
32        Self::with_artifact_limits(workspace, ArtifactStoreLimits::default())
33    }
34
35    /// Create a new tool registry with custom artifact retention limits.
36    pub fn with_artifact_limits(workspace: PathBuf, artifact_limits: ArtifactStoreLimits) -> Self {
37        Self::with_artifact_limits_and_workspace_services(
38            workspace.clone(),
39            artifact_limits,
40            crate::workspace::WorkspaceServices::local(workspace),
41        )
42    }
43
44    /// Create a new tool registry with custom artifact limits and workspace backend.
45    pub fn with_artifact_limits_and_workspace_services(
46        workspace: PathBuf,
47        artifact_limits: ArtifactStoreLimits,
48        workspace_services: Arc<crate::workspace::WorkspaceServices>,
49    ) -> Self {
50        let context = ToolContext::new(workspace).with_workspace_services(workspace_services);
51        Self {
52            tools: RwLock::new(HashMap::new()),
53            builtins: RwLock::new(std::collections::HashSet::new()),
54            context: RwLock::new(context),
55            artifact_store: ArtifactStore::with_limits(artifact_limits),
56            trace_sink: RwLock::new(Arc::new(InMemoryTraceSink::default())),
57        }
58    }
59
60    /// Register a builtin tool (cannot be overridden by dynamic tools)
61    pub fn register_builtin(&self, tool: Arc<dyn Tool>) {
62        let name = tool.name().to_string();
63        let mut tools = self.tools.write().unwrap();
64        let mut builtins = self.builtins.write().unwrap();
65        tracing::debug!("Registering builtin tool: {}", name);
66        tools.insert(name.clone(), tool);
67        builtins.insert(name);
68    }
69
70    /// Register a tool
71    ///
72    /// If a tool with the same name already exists as a builtin, the registration
73    /// is rejected to prevent shadowing of core tools.
74    pub fn register(&self, tool: Arc<dyn Tool>) {
75        let name = tool.name().to_string();
76        let builtins = self.builtins.read().unwrap();
77        if builtins.contains(&name) {
78            tracing::warn!(
79                "Rejected registration of tool '{}': cannot shadow builtin",
80                name
81            );
82            return;
83        }
84        drop(builtins);
85        let mut tools = self.tools.write().unwrap();
86        tracing::debug!("Registering tool: {}", name);
87        tools.insert(name, tool);
88    }
89
90    /// Unregister a tool by name
91    ///
92    /// Returns true if the tool was found and removed.
93    pub fn unregister(&self, name: &str) -> bool {
94        let mut tools = self.tools.write().unwrap();
95        tracing::debug!("Unregistering tool: {}", name);
96        tools.remove(name).is_some()
97    }
98
99    /// Unregister all tools whose names start with the given prefix.
100    pub fn unregister_by_prefix(&self, prefix: &str) {
101        let mut tools = self.tools.write().unwrap();
102        tools.retain(|name, _| !name.starts_with(prefix));
103        tracing::debug!("Unregistered tools with prefix: {}", prefix);
104    }
105
106    /// Get a tool by name
107    pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
108        let tools = self.tools.read().unwrap();
109        tools.get(name).cloned()
110    }
111
112    /// Check if a tool exists
113    pub fn contains(&self, name: &str) -> bool {
114        let tools = self.tools.read().unwrap();
115        tools.contains_key(name)
116    }
117
118    /// Get all tool definitions for LLM
119    pub fn definitions(&self) -> Vec<ToolDefinition> {
120        let tools = self.tools.read().unwrap();
121        tools
122            .values()
123            .map(|tool| ToolDefinition {
124                name: tool.name().to_string(),
125                description: tool.description().to_string(),
126                parameters: tool.parameters(),
127            })
128            .collect()
129    }
130
131    /// List all registered tool names
132    pub fn list(&self) -> Vec<String> {
133        let tools = self.tools.read().unwrap();
134        tools.keys().cloned().collect()
135    }
136
137    /// Get the number of registered tools
138    pub fn len(&self) -> usize {
139        let tools = self.tools.read().unwrap();
140        tools.len()
141    }
142
143    /// Check if registry is empty
144    pub fn is_empty(&self) -> bool {
145        self.len() == 0
146    }
147
148    /// Get the tool context
149    pub fn context(&self) -> ToolContext {
150        self.context.read().unwrap().clone()
151    }
152
153    /// Return a clone of the registry's artifact store handle.
154    pub fn artifact_store(&self) -> ArtifactStore {
155        self.artifact_store.clone()
156    }
157
158    /// Get a stored tool artifact by URI.
159    pub fn get_artifact(&self, artifact_uri: &str) -> Option<ToolArtifact> {
160        self.artifact_store.get(artifact_uri)
161    }
162
163    /// Replace the trace sink used for compact tool/program execution events.
164    pub fn set_trace_sink(&self, sink: Arc<dyn TraceSink>) {
165        *self.trace_sink.write().unwrap() = sink;
166    }
167
168    /// Return the current trace sink.
169    pub fn trace_sink(&self) -> Arc<dyn TraceSink> {
170        Arc::clone(&self.trace_sink.read().unwrap())
171    }
172
173    /// Set the search configuration for the tool context
174    pub fn set_search_config(&self, config: crate::config::SearchConfig) {
175        let mut ctx = self.context.write().unwrap();
176        *ctx = ctx.clone().with_search_config(config);
177    }
178
179    /// Set a sandbox executor so that `bash` tool calls use the sandbox even
180    /// when executed without an explicit `ToolContext` (i.e., via `execute()`).
181    pub fn set_sandbox(&self, sandbox: std::sync::Arc<dyn crate::sandbox::BashSandbox>) {
182        let mut ctx = self.context.write().unwrap();
183        *ctx = ctx.clone().with_sandbox(sandbox);
184    }
185
186    /// Set environment overrides used by subprocess-backed tools when executed
187    /// without an explicit context.
188    pub fn set_command_env(&self, env: Arc<HashMap<String, String>>) {
189        let mut ctx = self.context.write().unwrap();
190        *ctx = ctx.clone().with_command_env(env);
191    }
192
193    /// Execute a tool by name using the registry's default context
194    pub async fn execute(&self, name: &str, args: &serde_json::Value) -> Result<ToolResult> {
195        let ctx = self.context();
196        self.execute_with_context(name, args, &ctx).await
197    }
198
199    /// Execute a tool by name with an external context
200    pub async fn execute_with_context(
201        &self,
202        name: &str,
203        args: &serde_json::Value,
204        ctx: &ToolContext,
205    ) -> Result<ToolResult> {
206        let start = std::time::Instant::now();
207
208        let tool = self.get(name);
209
210        let result = match tool {
211            Some(tool) => {
212                let mut output = tool.execute(args, ctx).await?;
213                let original_content = output.content.clone();
214                let truncated = truncate_tool_output_with_artifact(name, &output.content);
215                output.content = truncated.content;
216                if let Some(artifact) = truncated.artifact {
217                    self.store_tool_artifact(name, &original_content, &artifact);
218                    output.metadata = Some(merge_tool_output_artifact_metadata(
219                        output.metadata,
220                        &artifact,
221                    ));
222                }
223                Ok(ToolResult {
224                    name: name.to_string(),
225                    output: output.content,
226                    exit_code: if output.success { 0 } else { 1 },
227                    metadata: output.metadata,
228                    images: output.images,
229                    error_kind: output.error_kind,
230                })
231            }
232            None => Ok(ToolResult::error(name, format!("Unknown tool: {}", name))),
233        };
234
235        if let Ok(ref r) = result {
236            crate::telemetry::record_tool_result(r.exit_code, start.elapsed());
237            self.record_trace_event(name, r, start.elapsed());
238        }
239
240        result
241    }
242
243    /// Execute a tool and return raw output using the registry's default context
244    pub async fn execute_raw(
245        &self,
246        name: &str,
247        args: &serde_json::Value,
248    ) -> Result<Option<ToolOutput>> {
249        let ctx = self.context();
250        self.execute_raw_with_context(name, args, &ctx).await
251    }
252
253    /// Execute a tool and return raw output with an external context
254    pub async fn execute_raw_with_context(
255        &self,
256        name: &str,
257        args: &serde_json::Value,
258        ctx: &ToolContext,
259    ) -> Result<Option<ToolOutput>> {
260        let tool = self.get(name);
261
262        match tool {
263            Some(tool) => {
264                let mut output = tool.execute(args, ctx).await?;
265                let original_content = output.content.clone();
266                let truncated = truncate_tool_output_with_artifact(name, &output.content);
267                output.content = truncated.content;
268                if let Some(artifact) = truncated.artifact {
269                    self.store_tool_artifact(name, &original_content, &artifact);
270                    output.metadata = Some(merge_tool_output_artifact_metadata(
271                        output.metadata,
272                        &artifact,
273                    ));
274                }
275                Ok(Some(output))
276            }
277            None => Ok(None),
278        }
279    }
280
281    fn store_tool_artifact(&self, tool_name: &str, content: &str, artifact: &ToolOutputArtifact) {
282        self.artifact_store.put(ToolArtifact {
283            artifact_id: artifact.artifact_id.clone(),
284            artifact_uri: artifact.artifact_uri.clone(),
285            tool_name: tool_name.to_string(),
286            content: content.to_string(),
287            original_bytes: artifact.original_bytes,
288            shown_bytes: artifact.shown_bytes,
289        });
290    }
291
292    fn record_trace_event(&self, name: &str, result: &ToolResult, duration: std::time::Duration) {
293        let sink = self.trace_sink();
294        sink.record(TraceEvent::tool_execution(
295            name,
296            result.exit_code == 0,
297            result.exit_code,
298            duration,
299            result.output.len(),
300            result.metadata.as_ref(),
301        ));
302
303        if name == "program" {
304            sink.record(TraceEvent::program_execution(
305                name,
306                result.exit_code == 0,
307                result.exit_code,
308                duration,
309                result.output.len(),
310                result.metadata.as_ref(),
311            ));
312        }
313    }
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319    use crate::trace::{InMemoryTraceSink, TraceEventKind};
320    use async_trait::async_trait;
321
322    struct MockTool {
323        name: String,
324    }
325
326    #[async_trait]
327    impl Tool for MockTool {
328        fn name(&self) -> &str {
329            &self.name
330        }
331
332        fn description(&self) -> &str {
333            "A mock tool for testing"
334        }
335
336        fn parameters(&self) -> serde_json::Value {
337            serde_json::json!({
338                "type": "object",
339                "additionalProperties": false,
340                "properties": {},
341                "required": []
342            })
343        }
344
345        async fn execute(
346            &self,
347            _args: &serde_json::Value,
348            _ctx: &ToolContext,
349        ) -> Result<ToolOutput> {
350            Ok(ToolOutput::success("mock output"))
351        }
352    }
353
354    #[test]
355    fn test_registry_register_and_get() {
356        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
357
358        let tool = Arc::new(MockTool {
359            name: "test".to_string(),
360        });
361        registry.register(tool);
362
363        assert!(registry.contains("test"));
364        assert!(!registry.contains("nonexistent"));
365
366        let retrieved = registry.get("test");
367        assert!(retrieved.is_some());
368        assert_eq!(retrieved.unwrap().name(), "test");
369    }
370
371    #[test]
372    fn test_registry_unregister() {
373        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
374
375        let tool = Arc::new(MockTool {
376            name: "test".to_string(),
377        });
378        registry.register(tool);
379
380        assert!(registry.contains("test"));
381        assert!(registry.unregister("test"));
382        assert!(!registry.contains("test"));
383        assert!(!registry.unregister("test")); // Already removed
384    }
385
386    #[test]
387    fn test_registry_definitions() {
388        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
389
390        registry.register(Arc::new(MockTool {
391            name: "tool1".to_string(),
392        }));
393        registry.register(Arc::new(MockTool {
394            name: "tool2".to_string(),
395        }));
396
397        let definitions = registry.definitions();
398        assert_eq!(definitions.len(), 2);
399    }
400
401    #[tokio::test]
402    async fn test_registry_execute() {
403        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
404
405        registry.register(Arc::new(MockTool {
406            name: "test".to_string(),
407        }));
408
409        let result = registry
410            .execute("test", &serde_json::json!({}))
411            .await
412            .unwrap();
413        assert_eq!(result.exit_code, 0);
414        assert_eq!(result.output, "mock output");
415    }
416
417    #[tokio::test]
418    async fn test_registry_execute_unknown() {
419        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
420
421        let result = registry
422            .execute("unknown", &serde_json::json!({}))
423            .await
424            .unwrap();
425        assert_eq!(result.exit_code, 1);
426        assert!(result.output.contains("Unknown tool"));
427    }
428
429    #[tokio::test]
430    async fn test_registry_execute_with_context_success() {
431        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
432        let ctx = ToolContext::new(PathBuf::from("/tmp"));
433        let trace_sink = InMemoryTraceSink::default();
434        registry.set_trace_sink(Arc::new(trace_sink.clone()));
435
436        registry.register(Arc::new(MockTool {
437            name: "my_tool".to_string(),
438        }));
439
440        let result = registry
441            .execute_with_context("my_tool", &serde_json::json!({}), &ctx)
442            .await
443            .unwrap();
444        assert_eq!(result.name, "my_tool");
445        assert_eq!(result.exit_code, 0);
446        assert_eq!(result.output, "mock output");
447
448        let events = trace_sink.events();
449        assert_eq!(events.len(), 1);
450        assert_eq!(events[0].kind, TraceEventKind::ToolExecution);
451        assert_eq!(events[0].name, "my_tool");
452        assert!(events[0].success);
453        assert_eq!(events[0].output_bytes, "mock output".len());
454    }
455
456    #[tokio::test]
457    async fn test_registry_execute_with_context_unknown_tool() {
458        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
459        let ctx = ToolContext::new(PathBuf::from("/tmp"));
460
461        let result = registry
462            .execute_with_context("nonexistent", &serde_json::json!({}), &ctx)
463            .await
464            .unwrap();
465        assert_eq!(result.exit_code, 1);
466        assert!(result.output.contains("Unknown tool: nonexistent"));
467    }
468
469    struct FailingTool;
470
471    #[async_trait]
472    impl Tool for FailingTool {
473        fn name(&self) -> &str {
474            "failing"
475        }
476
477        fn description(&self) -> &str {
478            "A tool that returns failure"
479        }
480
481        fn parameters(&self) -> serde_json::Value {
482            serde_json::json!({
483                "type": "object",
484                "additionalProperties": false,
485                "properties": {},
486                "required": []
487            })
488        }
489
490        async fn execute(
491            &self,
492            _args: &serde_json::Value,
493            _ctx: &ToolContext,
494        ) -> Result<ToolOutput> {
495            Ok(ToolOutput::error("something went wrong"))
496        }
497    }
498
499    #[tokio::test]
500    async fn test_registry_execute_failing_tool() {
501        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
502        registry.register(Arc::new(FailingTool));
503
504        let result = registry
505            .execute("failing", &serde_json::json!({}))
506            .await
507            .unwrap();
508        assert_eq!(result.exit_code, 1);
509        assert_eq!(result.output, "something went wrong");
510    }
511
512    struct LargeOutputTool;
513
514    #[async_trait]
515    impl Tool for LargeOutputTool {
516        fn name(&self) -> &str {
517            "large_output"
518        }
519
520        fn description(&self) -> &str {
521            "A tool that returns more than the maximum output size"
522        }
523
524        fn parameters(&self) -> serde_json::Value {
525            serde_json::json!({
526                "type": "object",
527                "additionalProperties": false,
528                "properties": {},
529                "required": []
530            })
531        }
532
533        async fn execute(
534            &self,
535            _args: &serde_json::Value,
536            _ctx: &ToolContext,
537        ) -> Result<ToolOutput> {
538            Ok(ToolOutput::success(
539                "x".repeat(super::super::MAX_OUTPUT_SIZE + 1),
540            ))
541        }
542    }
543
544    #[tokio::test]
545    async fn test_registry_truncates_large_tool_output() {
546        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
547        let trace_sink = InMemoryTraceSink::default();
548        registry.set_trace_sink(Arc::new(trace_sink.clone()));
549        registry.register(Arc::new(LargeOutputTool));
550
551        let result = registry
552            .execute("large_output", &serde_json::json!({}))
553            .await
554            .unwrap();
555
556        assert_eq!(result.exit_code, 0);
557        assert!(result.output.contains("[tool output truncated:"));
558        assert!(result
559            .output
560            .contains("Full output artifact: a3s://tool-output/large_output/"));
561        assert!(result.output.len() < super::super::MAX_OUTPUT_SIZE + 512);
562        let metadata = result.metadata.expect("artifact metadata");
563        assert_eq!(
564            metadata["artifact"]["original_bytes"],
565            serde_json::json!(super::super::MAX_OUTPUT_SIZE + 1)
566        );
567        assert_eq!(
568            metadata["artifact"]["shown_bytes"],
569            serde_json::json!(super::super::MAX_OUTPUT_SIZE)
570        );
571        assert!(metadata["artifact"]["artifact_id"]
572            .as_str()
573            .unwrap()
574            .starts_with("tool-output:large_output:"));
575        assert!(metadata["artifact"]["artifact_uri"]
576            .as_str()
577            .unwrap()
578            .starts_with("a3s://tool-output/large_output/"));
579
580        let artifact_uri = metadata["artifact"]["artifact_uri"].as_str().unwrap();
581        let artifact = registry
582            .get_artifact(artifact_uri)
583            .expect("full output artifact");
584        assert_eq!(artifact.tool_name, "large_output");
585        assert_eq!(artifact.original_bytes, super::super::MAX_OUTPUT_SIZE + 1);
586        assert_eq!(artifact.shown_bytes, super::super::MAX_OUTPUT_SIZE);
587        assert_eq!(
588            artifact.content,
589            "x".repeat(super::super::MAX_OUTPUT_SIZE + 1)
590        );
591
592        let events = trace_sink.events();
593        assert_eq!(events.len(), 1);
594        assert_eq!(events[0].artifact_uris, vec![artifact_uri]);
595    }
596
597    #[tokio::test]
598    async fn test_registry_execute_raw_success() {
599        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
600        registry.register(Arc::new(MockTool {
601            name: "raw_test".to_string(),
602        }));
603
604        let output = registry
605            .execute_raw("raw_test", &serde_json::json!({}))
606            .await
607            .unwrap();
608        assert!(output.is_some());
609        let output = output.unwrap();
610        assert!(output.success);
611        assert_eq!(output.content, "mock output");
612    }
613
614    #[tokio::test]
615    async fn test_registry_execute_raw_stores_truncated_artifact() {
616        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
617        registry.register(Arc::new(LargeOutputTool));
618
619        let output = registry
620            .execute_raw("large_output", &serde_json::json!({}))
621            .await
622            .unwrap()
623            .expect("raw output");
624
625        assert!(output.content.contains("[tool output truncated:"));
626        let metadata = output.metadata.expect("artifact metadata");
627        let artifact_uri = metadata["artifact"]["artifact_uri"].as_str().unwrap();
628        let artifact = registry
629            .get_artifact(artifact_uri)
630            .expect("full output artifact");
631        assert_eq!(artifact.tool_name, "large_output");
632        assert_eq!(artifact.content.len(), super::super::MAX_OUTPUT_SIZE + 1);
633    }
634
635    #[tokio::test]
636    async fn test_registry_execute_raw_unknown() {
637        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
638
639        let output = registry
640            .execute_raw("missing", &serde_json::json!({}))
641            .await
642            .unwrap();
643        assert!(output.is_none());
644    }
645
646    #[test]
647    fn test_registry_list() {
648        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
649        registry.register(Arc::new(MockTool {
650            name: "alpha".to_string(),
651        }));
652        registry.register(Arc::new(MockTool {
653            name: "beta".to_string(),
654        }));
655
656        let names = registry.list();
657        assert_eq!(names.len(), 2);
658        assert!(names.contains(&"alpha".to_string()));
659        assert!(names.contains(&"beta".to_string()));
660    }
661
662    #[test]
663    fn test_registry_len_and_is_empty() {
664        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
665        assert!(registry.is_empty());
666        assert_eq!(registry.len(), 0);
667
668        registry.register(Arc::new(MockTool {
669            name: "t".to_string(),
670        }));
671        assert!(!registry.is_empty());
672        assert_eq!(registry.len(), 1);
673    }
674
675    #[test]
676    fn test_registry_replace_tool() {
677        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
678        registry.register(Arc::new(MockTool {
679            name: "dup".to_string(),
680        }));
681        registry.register(Arc::new(MockTool {
682            name: "dup".to_string(),
683        }));
684        // Should still have only 1 tool (replaced)
685        assert_eq!(registry.len(), 1);
686    }
687}