claude_agent/tools/
registry.rs

1//! Tool registry for managing and executing tools.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6
7use super::ProcessManager;
8use super::access::ToolAccess;
9use super::builder::ToolRegistryBuilder;
10use super::context::ExecutionContext;
11use super::env::ToolExecutionEnv;
12use super::traits::Tool;
13use crate::agent::TaskRegistry;
14use crate::permissions::PermissionPolicy;
15use crate::session::MemoryPersistence;
16use crate::types::{ToolDefinition, ToolOutput, ToolResult};
17use std::path::PathBuf;
18
19#[derive(Clone)]
20pub struct ToolRegistry {
21    tools: HashMap<String, Arc<dyn Tool>>,
22    task_registry: TaskRegistry,
23    env: ToolExecutionEnv,
24}
25
26impl ToolRegistry {
27    pub fn new() -> Self {
28        Self {
29            tools: HashMap::new(),
30            task_registry: TaskRegistry::new(Arc::new(MemoryPersistence::new())),
31            env: ToolExecutionEnv::default(),
32        }
33    }
34
35    pub(crate) fn with_env(task_registry: TaskRegistry, env: ToolExecutionEnv) -> Self {
36        Self {
37            tools: HashMap::new(),
38            task_registry,
39            env,
40        }
41    }
42
43    pub fn builder() -> ToolRegistryBuilder {
44        ToolRegistryBuilder::new()
45    }
46
47    pub fn with_context(context: ExecutionContext) -> Self {
48        Self {
49            tools: HashMap::new(),
50            task_registry: TaskRegistry::new(Arc::new(MemoryPersistence::new())),
51            env: ToolExecutionEnv::new(context),
52        }
53    }
54
55    pub fn default_tools(
56        access: ToolAccess,
57        working_dir: Option<PathBuf>,
58        policy: Option<PermissionPolicy>,
59    ) -> Self {
60        let mut builder = ToolRegistryBuilder::new().access(access);
61        if let Some(dir) = working_dir {
62            builder = builder.working_dir(dir);
63        }
64        if let Some(p) = policy {
65            builder = builder.policy(p);
66        }
67        builder.build()
68    }
69
70    #[inline]
71    pub fn context(&self) -> &ExecutionContext {
72        &self.env.context
73    }
74
75    #[inline]
76    pub fn tool_state(&self) -> Option<&crate::session::session_state::ToolState> {
77        self.env.tool_state.as_ref()
78    }
79
80    #[inline]
81    pub fn process_manager(&self) -> Option<&Arc<ProcessManager>> {
82        self.env.process_manager.as_ref()
83    }
84
85    #[inline]
86    pub fn env(&self) -> &ToolExecutionEnv {
87        &self.env
88    }
89
90    #[inline]
91    pub fn task_registry(&self) -> &TaskRegistry {
92        &self.task_registry
93    }
94
95    pub fn register(&mut self, tool: Arc<dyn Tool>) {
96        self.tools.insert(tool.name().to_string(), tool);
97    }
98
99    #[inline]
100    pub fn get(&self, name: &str) -> Option<&Arc<dyn Tool>> {
101        self.tools.get(name)
102    }
103
104    pub async fn execute(&self, name: &str, input: serde_json::Value) -> ToolResult {
105        let tool = match self.tools.get(name) {
106            Some(t) => t,
107            None => return ToolResult::unknown_tool(name),
108        };
109
110        let decision = self.env.context.check_permission(name, &input);
111        if !decision.is_allowed() {
112            return ToolResult::permission_denied(name, decision.reason);
113        }
114
115        if let Err(e) = self.env.context.validate_security(name, &input) {
116            return ToolResult::security_error(e);
117        }
118
119        let limits = self.env.context.limits_for(name);
120        let timeout_ms = limits.timeout_ms.unwrap_or(120_000);
121
122        let result = tokio::time::timeout(
123            Duration::from_millis(timeout_ms),
124            tool.execute(input, &self.env.context),
125        )
126        .await;
127
128        match result {
129            Ok(tool_result) => self.apply_output_limits(tool_result, &limits),
130            Err(_) => ToolResult::timeout(timeout_ms),
131        }
132    }
133
134    fn apply_output_limits(
135        &self,
136        mut result: ToolResult,
137        limits: &crate::permissions::ToolLimits,
138    ) -> ToolResult {
139        if let Some(max_size) = limits.max_output_size
140            && let ToolOutput::Success(ref content) = result.output
141            && content.len() > max_size
142        {
143            let truncated = format!(
144                "{}...\n(output truncated at {} bytes)",
145                &content[..max_size],
146                max_size
147            );
148            result.output = ToolOutput::Success(truncated);
149        }
150        result
151    }
152
153    pub fn definitions(&self) -> Vec<ToolDefinition> {
154        self.tools.values().map(|t| t.definition()).collect()
155    }
156
157    pub fn names(&self) -> Vec<&str> {
158        self.tools.keys().map(|s| s.as_str()).collect()
159    }
160
161    pub fn contains(&self, name: &str) -> bool {
162        self.tools.contains_key(name)
163    }
164
165    pub fn register_dynamic(&mut self, tool: Arc<dyn Tool>) -> crate::Result<()> {
166        let name = tool.name().to_string();
167        if self.tools.contains_key(&name) {
168            return Err(crate::Error::Config(format!(
169                "Tool already registered: {}",
170                name
171            )));
172        }
173        self.tools.insert(name, tool);
174        Ok(())
175    }
176
177    pub fn register_or_replace(&mut self, tool: Arc<dyn Tool>) -> Option<Arc<dyn Tool>> {
178        let name = tool.name().to_string();
179        self.tools.insert(name, tool)
180    }
181
182    pub fn unregister(&mut self, name: &str) -> Option<Arc<dyn Tool>> {
183        self.tools.remove(name)
184    }
185}
186
187impl Default for ToolRegistry {
188    fn default() -> Self {
189        Self::new()
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196    use crate::tools::access::ToolAccess;
197
198    #[test]
199    fn test_tool_output() {
200        assert!(!ToolOutput::success("ok").is_error());
201        assert!(ToolOutput::error("fail").is_error());
202        assert!(!ToolOutput::empty().is_error());
203    }
204
205    #[test]
206    fn test_default_tools_count() {
207        let registry = ToolRegistry::default_tools(ToolAccess::All, None, None);
208        assert!(registry.contains("Read"));
209        assert!(registry.contains("Write"));
210        assert!(registry.contains("Edit"));
211        assert!(registry.contains("Glob"));
212        assert!(registry.contains("Grep"));
213        assert!(registry.contains("Bash"));
214        assert!(registry.contains("KillShell"));
215        assert!(registry.contains("Task"));
216        assert!(registry.contains("TaskOutput"));
217        assert!(registry.contains("TodoWrite"));
218        assert!(registry.contains("Plan"));
219        assert!(registry.contains("Skill"));
220    }
221
222    #[test]
223    fn test_tool_access_filtering() {
224        let registry = ToolRegistry::default_tools(ToolAccess::only(["Read", "Write"]), None, None);
225        assert!(registry.contains("Read"));
226        assert!(registry.contains("Write"));
227        assert!(!registry.contains("Bash"));
228    }
229
230    #[test]
231    fn test_register_dynamic() {
232        let mut registry = ToolRegistry::new();
233        let tool: Arc<dyn Tool> = Arc::new(crate::tools::ReadTool);
234
235        assert!(registry.register_dynamic(tool.clone()).is_ok());
236        assert!(registry.contains("Read"));
237
238        let result = registry.register_dynamic(tool);
239        assert!(result.is_err());
240    }
241
242    #[test]
243    fn test_register_or_replace() {
244        let mut registry = ToolRegistry::new();
245        let tool1: Arc<dyn Tool> = Arc::new(crate::tools::ReadTool);
246        let tool2: Arc<dyn Tool> = Arc::new(crate::tools::ReadTool);
247
248        let old = registry.register_or_replace(tool1);
249        assert!(old.is_none());
250
251        let old = registry.register_or_replace(tool2);
252        assert!(old.is_some());
253    }
254
255    #[test]
256    fn test_unregister() {
257        let mut registry = ToolRegistry::new();
258        let tool: Arc<dyn Tool> = Arc::new(crate::tools::ReadTool);
259
260        registry.register(tool);
261        assert!(registry.contains("Read"));
262
263        let removed = registry.unregister("Read");
264        assert!(removed.is_some());
265        assert!(!registry.contains("Read"));
266
267        let removed = registry.unregister("NonExistent");
268        assert!(removed.is_none());
269    }
270}