Skip to main content

a3s_code_core/tools/
registry.rs

1//! Tool Registry
2//!
3//! Central registry for all tools (built-in and dynamic).
4//! Provides thread-safe registration, lookup, and execution.
5
6use super::artifacts::{ArtifactStore, ArtifactStoreLimits, ToolArtifact};
7use super::types::{Tool, ToolContext, ToolOutput};
8use super::ToolResult;
9use super::{
10    merge_tool_output_artifact_metadata, truncate_tool_output_with_artifact, ToolOutputArtifact,
11};
12use crate::llm::ToolDefinition;
13use crate::trace::{InMemoryTraceSink, TraceEvent, TraceSink};
14use anyhow::Result;
15use std::collections::HashMap;
16use std::path::PathBuf;
17use std::sync::{Arc, RwLock};
18
19/// Tool registry for managing all available tools
20pub struct ToolRegistry {
21    tools: RwLock<HashMap<String, Arc<dyn Tool>>>,
22    /// Names of builtin tools that cannot be overridden
23    builtins: RwLock<std::collections::HashSet<String>>,
24    context: RwLock<ToolContext>,
25    artifact_store: ArtifactStore,
26    trace_sink: RwLock<Arc<dyn TraceSink>>,
27}
28
29impl ToolRegistry {
30    /// Create a new tool registry
31    pub fn new(workspace: PathBuf) -> Self {
32        Self::with_artifact_limits(workspace, ArtifactStoreLimits::default())
33    }
34
35    /// Create a new tool registry with custom artifact retention limits.
36    pub fn with_artifact_limits(workspace: PathBuf, artifact_limits: ArtifactStoreLimits) -> Self {
37        Self {
38            tools: RwLock::new(HashMap::new()),
39            builtins: RwLock::new(std::collections::HashSet::new()),
40            context: RwLock::new(ToolContext::new(workspace)),
41            artifact_store: ArtifactStore::with_limits(artifact_limits),
42            trace_sink: RwLock::new(Arc::new(InMemoryTraceSink::default())),
43        }
44    }
45
46    /// Register a builtin tool (cannot be overridden by dynamic tools)
47    pub fn register_builtin(&self, tool: Arc<dyn Tool>) {
48        let name = tool.name().to_string();
49        let mut tools = self.tools.write().unwrap();
50        let mut builtins = self.builtins.write().unwrap();
51        tracing::debug!("Registering builtin tool: {}", name);
52        tools.insert(name.clone(), tool);
53        builtins.insert(name);
54    }
55
56    /// Register a tool
57    ///
58    /// If a tool with the same name already exists as a builtin, the registration
59    /// is rejected to prevent shadowing of core tools.
60    pub fn register(&self, tool: Arc<dyn Tool>) {
61        let name = tool.name().to_string();
62        let builtins = self.builtins.read().unwrap();
63        if builtins.contains(&name) {
64            tracing::warn!(
65                "Rejected registration of tool '{}': cannot shadow builtin",
66                name
67            );
68            return;
69        }
70        drop(builtins);
71        let mut tools = self.tools.write().unwrap();
72        tracing::debug!("Registering tool: {}", name);
73        tools.insert(name, tool);
74    }
75
76    /// Unregister a tool by name
77    ///
78    /// Returns true if the tool was found and removed.
79    pub fn unregister(&self, name: &str) -> bool {
80        let mut tools = self.tools.write().unwrap();
81        tracing::debug!("Unregistering tool: {}", name);
82        tools.remove(name).is_some()
83    }
84
85    /// Unregister all tools whose names start with the given prefix.
86    pub fn unregister_by_prefix(&self, prefix: &str) {
87        let mut tools = self.tools.write().unwrap();
88        tools.retain(|name, _| !name.starts_with(prefix));
89        tracing::debug!("Unregistered tools with prefix: {}", prefix);
90    }
91
92    /// Get a tool by name
93    pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
94        let tools = self.tools.read().unwrap();
95        tools.get(name).cloned()
96    }
97
98    /// Check if a tool exists
99    pub fn contains(&self, name: &str) -> bool {
100        let tools = self.tools.read().unwrap();
101        tools.contains_key(name)
102    }
103
104    /// Get all tool definitions for LLM
105    pub fn definitions(&self) -> Vec<ToolDefinition> {
106        let tools = self.tools.read().unwrap();
107        tools
108            .values()
109            .map(|tool| ToolDefinition {
110                name: tool.name().to_string(),
111                description: tool.description().to_string(),
112                parameters: tool.parameters(),
113            })
114            .collect()
115    }
116
117    /// List all registered tool names
118    pub fn list(&self) -> Vec<String> {
119        let tools = self.tools.read().unwrap();
120        tools.keys().cloned().collect()
121    }
122
123    /// Get the number of registered tools
124    pub fn len(&self) -> usize {
125        let tools = self.tools.read().unwrap();
126        tools.len()
127    }
128
129    /// Check if registry is empty
130    pub fn is_empty(&self) -> bool {
131        self.len() == 0
132    }
133
134    /// Get the tool context
135    pub fn context(&self) -> ToolContext {
136        self.context.read().unwrap().clone()
137    }
138
139    /// Return a clone of the registry's artifact store handle.
140    pub fn artifact_store(&self) -> ArtifactStore {
141        self.artifact_store.clone()
142    }
143
144    /// Get a stored tool artifact by URI.
145    pub fn get_artifact(&self, artifact_uri: &str) -> Option<ToolArtifact> {
146        self.artifact_store.get(artifact_uri)
147    }
148
149    /// Replace the trace sink used for compact tool/program execution events.
150    pub fn set_trace_sink(&self, sink: Arc<dyn TraceSink>) {
151        *self.trace_sink.write().unwrap() = sink;
152    }
153
154    /// Return the current trace sink.
155    pub fn trace_sink(&self) -> Arc<dyn TraceSink> {
156        Arc::clone(&self.trace_sink.read().unwrap())
157    }
158
159    /// Set the search configuration for the tool context
160    pub fn set_search_config(&self, config: crate::config::SearchConfig) {
161        let mut ctx = self.context.write().unwrap();
162        *ctx = ctx.clone().with_search_config(config);
163    }
164
165    /// Set a sandbox executor so that `bash` tool calls use the sandbox even
166    /// when executed without an explicit `ToolContext` (i.e., via `execute()`).
167    pub fn set_sandbox(&self, sandbox: std::sync::Arc<dyn crate::sandbox::BashSandbox>) {
168        let mut ctx = self.context.write().unwrap();
169        *ctx = ctx.clone().with_sandbox(sandbox);
170    }
171
172    /// Execute a tool by name using the registry's default context
173    pub async fn execute(&self, name: &str, args: &serde_json::Value) -> Result<ToolResult> {
174        let ctx = self.context();
175        self.execute_with_context(name, args, &ctx).await
176    }
177
178    /// Execute a tool by name with an external context
179    pub async fn execute_with_context(
180        &self,
181        name: &str,
182        args: &serde_json::Value,
183        ctx: &ToolContext,
184    ) -> Result<ToolResult> {
185        let start = std::time::Instant::now();
186
187        let tool = self.get(name);
188
189        let result = match tool {
190            Some(tool) => {
191                let mut output = tool.execute(args, ctx).await?;
192                let original_content = output.content.clone();
193                let truncated = truncate_tool_output_with_artifact(name, &output.content);
194                output.content = truncated.content;
195                if let Some(artifact) = truncated.artifact {
196                    self.store_tool_artifact(name, &original_content, &artifact);
197                    output.metadata = Some(merge_tool_output_artifact_metadata(
198                        output.metadata,
199                        &artifact,
200                    ));
201                }
202                Ok(ToolResult {
203                    name: name.to_string(),
204                    output: output.content,
205                    exit_code: if output.success { 0 } else { 1 },
206                    metadata: output.metadata,
207                    images: output.images,
208                })
209            }
210            None => Ok(ToolResult::error(name, format!("Unknown tool: {}", name))),
211        };
212
213        if let Ok(ref r) = result {
214            crate::telemetry::record_tool_result(r.exit_code, start.elapsed());
215            self.record_trace_event(name, r, start.elapsed());
216        }
217
218        result
219    }
220
221    /// Execute a tool and return raw output using the registry's default context
222    pub async fn execute_raw(
223        &self,
224        name: &str,
225        args: &serde_json::Value,
226    ) -> Result<Option<ToolOutput>> {
227        let ctx = self.context();
228        self.execute_raw_with_context(name, args, &ctx).await
229    }
230
231    /// Execute a tool and return raw output with an external context
232    pub async fn execute_raw_with_context(
233        &self,
234        name: &str,
235        args: &serde_json::Value,
236        ctx: &ToolContext,
237    ) -> Result<Option<ToolOutput>> {
238        let tool = self.get(name);
239
240        match tool {
241            Some(tool) => {
242                let mut output = tool.execute(args, ctx).await?;
243                let original_content = output.content.clone();
244                let truncated = truncate_tool_output_with_artifact(name, &output.content);
245                output.content = truncated.content;
246                if let Some(artifact) = truncated.artifact {
247                    self.store_tool_artifact(name, &original_content, &artifact);
248                    output.metadata = Some(merge_tool_output_artifact_metadata(
249                        output.metadata,
250                        &artifact,
251                    ));
252                }
253                Ok(Some(output))
254            }
255            None => Ok(None),
256        }
257    }
258
259    fn store_tool_artifact(&self, tool_name: &str, content: &str, artifact: &ToolOutputArtifact) {
260        self.artifact_store.put(ToolArtifact {
261            artifact_id: artifact.artifact_id.clone(),
262            artifact_uri: artifact.artifact_uri.clone(),
263            tool_name: tool_name.to_string(),
264            content: content.to_string(),
265            original_bytes: artifact.original_bytes,
266            shown_bytes: artifact.shown_bytes,
267        });
268    }
269
270    fn record_trace_event(&self, name: &str, result: &ToolResult, duration: std::time::Duration) {
271        let sink = self.trace_sink();
272        sink.record(TraceEvent::tool_execution(
273            name,
274            result.exit_code == 0,
275            result.exit_code,
276            duration,
277            result.output.len(),
278            result.metadata.as_ref(),
279        ));
280
281        if name == "program" {
282            sink.record(TraceEvent::program_execution(
283                name,
284                result.exit_code == 0,
285                result.exit_code,
286                duration,
287                result.output.len(),
288                result.metadata.as_ref(),
289            ));
290        }
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297    use crate::trace::{InMemoryTraceSink, TraceEventKind};
298    use async_trait::async_trait;
299
300    struct MockTool {
301        name: String,
302    }
303
304    #[async_trait]
305    impl Tool for MockTool {
306        fn name(&self) -> &str {
307            &self.name
308        }
309
310        fn description(&self) -> &str {
311            "A mock tool for testing"
312        }
313
314        fn parameters(&self) -> serde_json::Value {
315            serde_json::json!({
316                "type": "object",
317                "additionalProperties": false,
318                "properties": {},
319                "required": []
320            })
321        }
322
323        async fn execute(
324            &self,
325            _args: &serde_json::Value,
326            _ctx: &ToolContext,
327        ) -> Result<ToolOutput> {
328            Ok(ToolOutput::success("mock output"))
329        }
330    }
331
332    #[test]
333    fn test_registry_register_and_get() {
334        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
335
336        let tool = Arc::new(MockTool {
337            name: "test".to_string(),
338        });
339        registry.register(tool);
340
341        assert!(registry.contains("test"));
342        assert!(!registry.contains("nonexistent"));
343
344        let retrieved = registry.get("test");
345        assert!(retrieved.is_some());
346        assert_eq!(retrieved.unwrap().name(), "test");
347    }
348
349    #[test]
350    fn test_registry_unregister() {
351        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
352
353        let tool = Arc::new(MockTool {
354            name: "test".to_string(),
355        });
356        registry.register(tool);
357
358        assert!(registry.contains("test"));
359        assert!(registry.unregister("test"));
360        assert!(!registry.contains("test"));
361        assert!(!registry.unregister("test")); // Already removed
362    }
363
364    #[test]
365    fn test_registry_definitions() {
366        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
367
368        registry.register(Arc::new(MockTool {
369            name: "tool1".to_string(),
370        }));
371        registry.register(Arc::new(MockTool {
372            name: "tool2".to_string(),
373        }));
374
375        let definitions = registry.definitions();
376        assert_eq!(definitions.len(), 2);
377    }
378
379    #[tokio::test]
380    async fn test_registry_execute() {
381        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
382
383        registry.register(Arc::new(MockTool {
384            name: "test".to_string(),
385        }));
386
387        let result = registry
388            .execute("test", &serde_json::json!({}))
389            .await
390            .unwrap();
391        assert_eq!(result.exit_code, 0);
392        assert_eq!(result.output, "mock output");
393    }
394
395    #[tokio::test]
396    async fn test_registry_execute_unknown() {
397        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
398
399        let result = registry
400            .execute("unknown", &serde_json::json!({}))
401            .await
402            .unwrap();
403        assert_eq!(result.exit_code, 1);
404        assert!(result.output.contains("Unknown tool"));
405    }
406
407    #[tokio::test]
408    async fn test_registry_execute_with_context_success() {
409        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
410        let ctx = ToolContext::new(PathBuf::from("/tmp"));
411        let trace_sink = InMemoryTraceSink::default();
412        registry.set_trace_sink(Arc::new(trace_sink.clone()));
413
414        registry.register(Arc::new(MockTool {
415            name: "my_tool".to_string(),
416        }));
417
418        let result = registry
419            .execute_with_context("my_tool", &serde_json::json!({}), &ctx)
420            .await
421            .unwrap();
422        assert_eq!(result.name, "my_tool");
423        assert_eq!(result.exit_code, 0);
424        assert_eq!(result.output, "mock output");
425
426        let events = trace_sink.events();
427        assert_eq!(events.len(), 1);
428        assert_eq!(events[0].kind, TraceEventKind::ToolExecution);
429        assert_eq!(events[0].name, "my_tool");
430        assert!(events[0].success);
431        assert_eq!(events[0].output_bytes, "mock output".len());
432    }
433
434    #[tokio::test]
435    async fn test_registry_execute_with_context_unknown_tool() {
436        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
437        let ctx = ToolContext::new(PathBuf::from("/tmp"));
438
439        let result = registry
440            .execute_with_context("nonexistent", &serde_json::json!({}), &ctx)
441            .await
442            .unwrap();
443        assert_eq!(result.exit_code, 1);
444        assert!(result.output.contains("Unknown tool: nonexistent"));
445    }
446
447    struct FailingTool;
448
449    #[async_trait]
450    impl Tool for FailingTool {
451        fn name(&self) -> &str {
452            "failing"
453        }
454
455        fn description(&self) -> &str {
456            "A tool that returns failure"
457        }
458
459        fn parameters(&self) -> serde_json::Value {
460            serde_json::json!({
461                "type": "object",
462                "additionalProperties": false,
463                "properties": {},
464                "required": []
465            })
466        }
467
468        async fn execute(
469            &self,
470            _args: &serde_json::Value,
471            _ctx: &ToolContext,
472        ) -> Result<ToolOutput> {
473            Ok(ToolOutput::error("something went wrong"))
474        }
475    }
476
477    #[tokio::test]
478    async fn test_registry_execute_failing_tool() {
479        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
480        registry.register(Arc::new(FailingTool));
481
482        let result = registry
483            .execute("failing", &serde_json::json!({}))
484            .await
485            .unwrap();
486        assert_eq!(result.exit_code, 1);
487        assert_eq!(result.output, "something went wrong");
488    }
489
490    struct LargeOutputTool;
491
492    #[async_trait]
493    impl Tool for LargeOutputTool {
494        fn name(&self) -> &str {
495            "large_output"
496        }
497
498        fn description(&self) -> &str {
499            "A tool that returns more than the maximum output size"
500        }
501
502        fn parameters(&self) -> serde_json::Value {
503            serde_json::json!({
504                "type": "object",
505                "additionalProperties": false,
506                "properties": {},
507                "required": []
508            })
509        }
510
511        async fn execute(
512            &self,
513            _args: &serde_json::Value,
514            _ctx: &ToolContext,
515        ) -> Result<ToolOutput> {
516            Ok(ToolOutput::success(
517                "x".repeat(super::super::MAX_OUTPUT_SIZE + 1),
518            ))
519        }
520    }
521
522    #[tokio::test]
523    async fn test_registry_truncates_large_tool_output() {
524        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
525        let trace_sink = InMemoryTraceSink::default();
526        registry.set_trace_sink(Arc::new(trace_sink.clone()));
527        registry.register(Arc::new(LargeOutputTool));
528
529        let result = registry
530            .execute("large_output", &serde_json::json!({}))
531            .await
532            .unwrap();
533
534        assert_eq!(result.exit_code, 0);
535        assert!(result.output.contains("[tool output truncated:"));
536        assert!(result
537            .output
538            .contains("Full output artifact: a3s://tool-output/large_output/"));
539        assert!(result.output.len() < super::super::MAX_OUTPUT_SIZE + 512);
540        let metadata = result.metadata.expect("artifact metadata");
541        assert_eq!(
542            metadata["artifact"]["original_bytes"],
543            serde_json::json!(super::super::MAX_OUTPUT_SIZE + 1)
544        );
545        assert_eq!(
546            metadata["artifact"]["shown_bytes"],
547            serde_json::json!(super::super::MAX_OUTPUT_SIZE)
548        );
549        assert!(metadata["artifact"]["artifact_id"]
550            .as_str()
551            .unwrap()
552            .starts_with("tool-output:large_output:"));
553        assert!(metadata["artifact"]["artifact_uri"]
554            .as_str()
555            .unwrap()
556            .starts_with("a3s://tool-output/large_output/"));
557
558        let artifact_uri = metadata["artifact"]["artifact_uri"].as_str().unwrap();
559        let artifact = registry
560            .get_artifact(artifact_uri)
561            .expect("full output artifact");
562        assert_eq!(artifact.tool_name, "large_output");
563        assert_eq!(artifact.original_bytes, super::super::MAX_OUTPUT_SIZE + 1);
564        assert_eq!(artifact.shown_bytes, super::super::MAX_OUTPUT_SIZE);
565        assert_eq!(
566            artifact.content,
567            "x".repeat(super::super::MAX_OUTPUT_SIZE + 1)
568        );
569
570        let events = trace_sink.events();
571        assert_eq!(events.len(), 1);
572        assert_eq!(events[0].artifact_uris, vec![artifact_uri]);
573    }
574
575    #[tokio::test]
576    async fn test_registry_execute_raw_success() {
577        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
578        registry.register(Arc::new(MockTool {
579            name: "raw_test".to_string(),
580        }));
581
582        let output = registry
583            .execute_raw("raw_test", &serde_json::json!({}))
584            .await
585            .unwrap();
586        assert!(output.is_some());
587        let output = output.unwrap();
588        assert!(output.success);
589        assert_eq!(output.content, "mock output");
590    }
591
592    #[tokio::test]
593    async fn test_registry_execute_raw_stores_truncated_artifact() {
594        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
595        registry.register(Arc::new(LargeOutputTool));
596
597        let output = registry
598            .execute_raw("large_output", &serde_json::json!({}))
599            .await
600            .unwrap()
601            .expect("raw output");
602
603        assert!(output.content.contains("[tool output truncated:"));
604        let metadata = output.metadata.expect("artifact metadata");
605        let artifact_uri = metadata["artifact"]["artifact_uri"].as_str().unwrap();
606        let artifact = registry
607            .get_artifact(artifact_uri)
608            .expect("full output artifact");
609        assert_eq!(artifact.tool_name, "large_output");
610        assert_eq!(artifact.content.len(), super::super::MAX_OUTPUT_SIZE + 1);
611    }
612
613    #[tokio::test]
614    async fn test_registry_execute_raw_unknown() {
615        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
616
617        let output = registry
618            .execute_raw("missing", &serde_json::json!({}))
619            .await
620            .unwrap();
621        assert!(output.is_none());
622    }
623
624    #[test]
625    fn test_registry_list() {
626        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
627        registry.register(Arc::new(MockTool {
628            name: "alpha".to_string(),
629        }));
630        registry.register(Arc::new(MockTool {
631            name: "beta".to_string(),
632        }));
633
634        let names = registry.list();
635        assert_eq!(names.len(), 2);
636        assert!(names.contains(&"alpha".to_string()));
637        assert!(names.contains(&"beta".to_string()));
638    }
639
640    #[test]
641    fn test_registry_len_and_is_empty() {
642        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
643        assert!(registry.is_empty());
644        assert_eq!(registry.len(), 0);
645
646        registry.register(Arc::new(MockTool {
647            name: "t".to_string(),
648        }));
649        assert!(!registry.is_empty());
650        assert_eq!(registry.len(), 1);
651    }
652
653    #[test]
654    fn test_registry_replace_tool() {
655        let registry = ToolRegistry::new(PathBuf::from("/tmp"));
656        registry.register(Arc::new(MockTool {
657            name: "dup".to_string(),
658        }));
659        registry.register(Arc::new(MockTool {
660            name: "dup".to_string(),
661        }));
662        // Should still have only 1 tool (replaced)
663        assert_eq!(registry.len(), 1);
664    }
665}