claude_agent/tools/
registry.rs1use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6
7use super::ProcessManager;
8use super::access::ToolAccess;
9use super::builder::ToolRegistryBuilder;
10use super::context::ExecutionContext;
11use super::env::ToolExecutionEnv;
12use super::traits::Tool;
13use crate::agent::TaskRegistry;
14use crate::permissions::PermissionPolicy;
15use crate::session::MemoryPersistence;
16use crate::types::{ToolDefinition, ToolOutput, ToolResult};
17use std::path::PathBuf;
18
19#[derive(Clone)]
20pub struct ToolRegistry {
21 tools: HashMap<String, Arc<dyn Tool>>,
22 task_registry: TaskRegistry,
23 env: ToolExecutionEnv,
24}
25
26impl ToolRegistry {
27 pub fn new() -> Self {
28 Self {
29 tools: HashMap::new(),
30 task_registry: TaskRegistry::new(Arc::new(MemoryPersistence::new())),
31 env: ToolExecutionEnv::default(),
32 }
33 }
34
35 pub(crate) fn with_env(task_registry: TaskRegistry, env: ToolExecutionEnv) -> Self {
36 Self {
37 tools: HashMap::new(),
38 task_registry,
39 env,
40 }
41 }
42
43 pub fn builder() -> ToolRegistryBuilder {
44 ToolRegistryBuilder::new()
45 }
46
47 pub fn with_context(context: ExecutionContext) -> Self {
48 Self {
49 tools: HashMap::new(),
50 task_registry: TaskRegistry::new(Arc::new(MemoryPersistence::new())),
51 env: ToolExecutionEnv::new(context),
52 }
53 }
54
55 pub fn default_tools(
56 access: ToolAccess,
57 working_dir: Option<PathBuf>,
58 policy: Option<PermissionPolicy>,
59 ) -> Self {
60 let mut builder = ToolRegistryBuilder::new().access(access);
61 if let Some(dir) = working_dir {
62 builder = builder.working_dir(dir);
63 }
64 if let Some(p) = policy {
65 builder = builder.policy(p);
66 }
67 builder.build()
68 }
69
70 #[inline]
71 pub fn context(&self) -> &ExecutionContext {
72 &self.env.context
73 }
74
75 #[inline]
76 pub fn tool_state(&self) -> Option<&crate::session::session_state::ToolState> {
77 self.env.tool_state.as_ref()
78 }
79
80 #[inline]
81 pub fn process_manager(&self) -> Option<&Arc<ProcessManager>> {
82 self.env.process_manager.as_ref()
83 }
84
85 #[inline]
86 pub fn env(&self) -> &ToolExecutionEnv {
87 &self.env
88 }
89
90 #[inline]
91 pub fn task_registry(&self) -> &TaskRegistry {
92 &self.task_registry
93 }
94
95 pub fn register(&mut self, tool: Arc<dyn Tool>) {
96 self.tools.insert(tool.name().to_string(), tool);
97 }
98
99 #[inline]
100 pub fn get(&self, name: &str) -> Option<&Arc<dyn Tool>> {
101 self.tools.get(name)
102 }
103
104 pub async fn execute(&self, name: &str, input: serde_json::Value) -> ToolResult {
105 let tool = match self.tools.get(name) {
106 Some(t) => t,
107 None => return ToolResult::unknown_tool(name),
108 };
109
110 let decision = self.env.context.check_permission(name, &input);
111 if !decision.is_allowed() {
112 return ToolResult::permission_denied(name, decision.reason);
113 }
114
115 if let Err(e) = self.env.context.validate_security(name, &input) {
116 return ToolResult::security_error(e);
117 }
118
119 let limits = self.env.context.limits_for(name);
120 let timeout_ms = limits.timeout_ms.unwrap_or(120_000);
121
122 let result = tokio::time::timeout(
123 Duration::from_millis(timeout_ms),
124 tool.execute(input, &self.env.context),
125 )
126 .await;
127
128 match result {
129 Ok(tool_result) => self.apply_output_limits(tool_result, &limits),
130 Err(_) => ToolResult::timeout(timeout_ms),
131 }
132 }
133
134 fn apply_output_limits(
135 &self,
136 mut result: ToolResult,
137 limits: &crate::permissions::ToolLimits,
138 ) -> ToolResult {
139 if let Some(max_size) = limits.max_output_size
140 && let ToolOutput::Success(ref content) = result.output
141 && content.len() > max_size
142 {
143 let truncated = format!(
144 "{}...\n(output truncated at {} bytes)",
145 &content[..max_size],
146 max_size
147 );
148 result.output = ToolOutput::Success(truncated);
149 }
150 result
151 }
152
153 pub fn definitions(&self) -> Vec<ToolDefinition> {
154 self.tools.values().map(|t| t.definition()).collect()
155 }
156
157 pub fn names(&self) -> Vec<&str> {
158 self.tools.keys().map(|s| s.as_str()).collect()
159 }
160
161 pub fn contains(&self, name: &str) -> bool {
162 self.tools.contains_key(name)
163 }
164
165 pub fn register_dynamic(&mut self, tool: Arc<dyn Tool>) -> crate::Result<()> {
166 let name = tool.name().to_string();
167 if self.tools.contains_key(&name) {
168 return Err(crate::Error::Config(format!(
169 "Tool already registered: {}",
170 name
171 )));
172 }
173 self.tools.insert(name, tool);
174 Ok(())
175 }
176
177 pub fn register_or_replace(&mut self, tool: Arc<dyn Tool>) -> Option<Arc<dyn Tool>> {
178 let name = tool.name().to_string();
179 self.tools.insert(name, tool)
180 }
181
182 pub fn unregister(&mut self, name: &str) -> Option<Arc<dyn Tool>> {
183 self.tools.remove(name)
184 }
185}
186
187impl Default for ToolRegistry {
188 fn default() -> Self {
189 Self::new()
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196 use crate::tools::access::ToolAccess;
197
198 #[test]
199 fn test_tool_output() {
200 assert!(!ToolOutput::success("ok").is_error());
201 assert!(ToolOutput::error("fail").is_error());
202 assert!(!ToolOutput::empty().is_error());
203 }
204
205 #[test]
206 fn test_default_tools_count() {
207 let registry = ToolRegistry::default_tools(ToolAccess::All, None, None);
208 assert!(registry.contains("Read"));
209 assert!(registry.contains("Write"));
210 assert!(registry.contains("Edit"));
211 assert!(registry.contains("Glob"));
212 assert!(registry.contains("Grep"));
213 assert!(registry.contains("Bash"));
214 assert!(registry.contains("KillShell"));
215 assert!(registry.contains("Task"));
216 assert!(registry.contains("TaskOutput"));
217 assert!(registry.contains("TodoWrite"));
218 assert!(registry.contains("Plan"));
219 assert!(registry.contains("Skill"));
220 }
221
222 #[test]
223 fn test_tool_access_filtering() {
224 let registry = ToolRegistry::default_tools(ToolAccess::only(["Read", "Write"]), None, None);
225 assert!(registry.contains("Read"));
226 assert!(registry.contains("Write"));
227 assert!(!registry.contains("Bash"));
228 }
229
230 #[test]
231 fn test_register_dynamic() {
232 let mut registry = ToolRegistry::new();
233 let tool: Arc<dyn Tool> = Arc::new(crate::tools::ReadTool);
234
235 assert!(registry.register_dynamic(tool.clone()).is_ok());
236 assert!(registry.contains("Read"));
237
238 let result = registry.register_dynamic(tool);
239 assert!(result.is_err());
240 }
241
242 #[test]
243 fn test_register_or_replace() {
244 let mut registry = ToolRegistry::new();
245 let tool1: Arc<dyn Tool> = Arc::new(crate::tools::ReadTool);
246 let tool2: Arc<dyn Tool> = Arc::new(crate::tools::ReadTool);
247
248 let old = registry.register_or_replace(tool1);
249 assert!(old.is_none());
250
251 let old = registry.register_or_replace(tool2);
252 assert!(old.is_some());
253 }
254
255 #[test]
256 fn test_unregister() {
257 let mut registry = ToolRegistry::new();
258 let tool: Arc<dyn Tool> = Arc::new(crate::tools::ReadTool);
259
260 registry.register(tool);
261 assert!(registry.contains("Read"));
262
263 let removed = registry.unregister("Read");
264 assert!(removed.is_some());
265 assert!(!registry.contains("Read"));
266
267 let removed = registry.unregister("NonExistent");
268 assert!(removed.is_none());
269 }
270}