1use super::types::{Tool, ToolContext, ToolOutput};
7use super::ToolResult;
8use crate::llm::ToolDefinition;
9use anyhow::Result;
10use std::collections::HashMap;
11use std::path::PathBuf;
12use std::sync::{Arc, RwLock};
13
14pub struct ToolRegistry {
16 tools: RwLock<HashMap<String, Arc<dyn Tool>>>,
17 builtins: RwLock<std::collections::HashSet<String>>,
19 context: RwLock<ToolContext>,
20}
21
22impl ToolRegistry {
23 pub fn new(workspace: PathBuf) -> Self {
25 Self {
26 tools: RwLock::new(HashMap::new()),
27 builtins: RwLock::new(std::collections::HashSet::new()),
28 context: RwLock::new(ToolContext::new(workspace)),
29 }
30 }
31
32 pub fn register_builtin(&self, tool: Arc<dyn Tool>) {
34 let name = tool.name().to_string();
35 let mut tools = self.tools.write().unwrap();
36 let mut builtins = self.builtins.write().unwrap();
37 tracing::debug!("Registering builtin tool: {}", name);
38 tools.insert(name.clone(), tool);
39 builtins.insert(name);
40 }
41
42 pub fn register(&self, tool: Arc<dyn Tool>) {
47 let name = tool.name().to_string();
48 let builtins = self.builtins.read().unwrap();
49 if builtins.contains(&name) {
50 tracing::warn!(
51 "Rejected registration of tool '{}': cannot shadow builtin",
52 name
53 );
54 return;
55 }
56 drop(builtins);
57 let mut tools = self.tools.write().unwrap();
58 tracing::debug!("Registering tool: {}", name);
59 tools.insert(name, tool);
60 }
61
62 pub fn unregister(&self, name: &str) -> bool {
66 let mut tools = self.tools.write().unwrap();
67 tracing::debug!("Unregistering tool: {}", name);
68 tools.remove(name).is_some()
69 }
70
71 pub fn unregister_by_prefix(&self, prefix: &str) {
73 let mut tools = self.tools.write().unwrap();
74 tools.retain(|name, _| !name.starts_with(prefix));
75 tracing::debug!("Unregistered tools with prefix: {}", prefix);
76 }
77
78 pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
80 let tools = self.tools.read().unwrap();
81 tools.get(name).cloned()
82 }
83
84 pub fn contains(&self, name: &str) -> bool {
86 let tools = self.tools.read().unwrap();
87 tools.contains_key(name)
88 }
89
90 pub fn definitions(&self) -> Vec<ToolDefinition> {
92 let tools = self.tools.read().unwrap();
93 tools
94 .values()
95 .map(|tool| ToolDefinition {
96 name: tool.name().to_string(),
97 description: tool.description().to_string(),
98 parameters: tool.parameters(),
99 })
100 .collect()
101 }
102
103 pub fn list(&self) -> Vec<String> {
105 let tools = self.tools.read().unwrap();
106 tools.keys().cloned().collect()
107 }
108
109 pub fn len(&self) -> usize {
111 let tools = self.tools.read().unwrap();
112 tools.len()
113 }
114
115 pub fn is_empty(&self) -> bool {
117 self.len() == 0
118 }
119
120 pub fn context(&self) -> ToolContext {
122 self.context.read().unwrap().clone()
123 }
124
125 pub fn set_search_config(&self, config: crate::config::SearchConfig) {
127 let mut ctx = self.context.write().unwrap();
128 *ctx = ctx.clone().with_search_config(config);
129 }
130
131 pub fn set_sandbox(&self, sandbox: std::sync::Arc<dyn crate::sandbox::BashSandbox>) {
134 let mut ctx = self.context.write().unwrap();
135 *ctx = ctx.clone().with_sandbox(sandbox);
136 }
137
138 pub async fn execute(&self, name: &str, args: &serde_json::Value) -> Result<ToolResult> {
140 let ctx = self.context();
141 self.execute_with_context(name, args, &ctx).await
142 }
143
144 pub async fn execute_with_context(
146 &self,
147 name: &str,
148 args: &serde_json::Value,
149 ctx: &ToolContext,
150 ) -> Result<ToolResult> {
151 let start = std::time::Instant::now();
152
153 let tool = self.get(name);
154
155 let result = match tool {
156 Some(tool) => {
157 let output = tool.execute(args, ctx).await?;
158 Ok(ToolResult {
159 name: name.to_string(),
160 output: output.content,
161 exit_code: if output.success { 0 } else { 1 },
162 metadata: output.metadata,
163 images: output.images,
164 })
165 }
166 None => Ok(ToolResult::error(name, format!("Unknown tool: {}", name))),
167 };
168
169 if let Ok(ref r) = result {
170 crate::telemetry::record_tool_result(r.exit_code, start.elapsed());
171 }
172
173 result
174 }
175
176 pub async fn execute_raw(
178 &self,
179 name: &str,
180 args: &serde_json::Value,
181 ) -> Result<Option<ToolOutput>> {
182 let ctx = self.context();
183 self.execute_raw_with_context(name, args, &ctx).await
184 }
185
186 pub async fn execute_raw_with_context(
188 &self,
189 name: &str,
190 args: &serde_json::Value,
191 ctx: &ToolContext,
192 ) -> Result<Option<ToolOutput>> {
193 let tool = self.get(name);
194
195 match tool {
196 Some(tool) => {
197 let output = tool.execute(args, ctx).await?;
198 Ok(Some(output))
199 }
200 None => Ok(None),
201 }
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use async_trait::async_trait;
209
210 struct MockTool {
211 name: String,
212 }
213
214 #[async_trait]
215 impl Tool for MockTool {
216 fn name(&self) -> &str {
217 &self.name
218 }
219
220 fn description(&self) -> &str {
221 "A mock tool for testing"
222 }
223
224 fn parameters(&self) -> serde_json::Value {
225 serde_json::json!({
226 "type": "object",
227 "additionalProperties": false,
228 "properties": {},
229 "required": []
230 })
231 }
232
233 async fn execute(
234 &self,
235 _args: &serde_json::Value,
236 _ctx: &ToolContext,
237 ) -> Result<ToolOutput> {
238 Ok(ToolOutput::success("mock output"))
239 }
240 }
241
242 #[test]
243 fn test_registry_register_and_get() {
244 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
245
246 let tool = Arc::new(MockTool {
247 name: "test".to_string(),
248 });
249 registry.register(tool);
250
251 assert!(registry.contains("test"));
252 assert!(!registry.contains("nonexistent"));
253
254 let retrieved = registry.get("test");
255 assert!(retrieved.is_some());
256 assert_eq!(retrieved.unwrap().name(), "test");
257 }
258
259 #[test]
260 fn test_registry_unregister() {
261 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
262
263 let tool = Arc::new(MockTool {
264 name: "test".to_string(),
265 });
266 registry.register(tool);
267
268 assert!(registry.contains("test"));
269 assert!(registry.unregister("test"));
270 assert!(!registry.contains("test"));
271 assert!(!registry.unregister("test")); }
273
274 #[test]
275 fn test_registry_definitions() {
276 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
277
278 registry.register(Arc::new(MockTool {
279 name: "tool1".to_string(),
280 }));
281 registry.register(Arc::new(MockTool {
282 name: "tool2".to_string(),
283 }));
284
285 let definitions = registry.definitions();
286 assert_eq!(definitions.len(), 2);
287 }
288
289 #[tokio::test]
290 async fn test_registry_execute() {
291 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
292
293 registry.register(Arc::new(MockTool {
294 name: "test".to_string(),
295 }));
296
297 let result = registry
298 .execute("test", &serde_json::json!({}))
299 .await
300 .unwrap();
301 assert_eq!(result.exit_code, 0);
302 assert_eq!(result.output, "mock output");
303 }
304
305 #[tokio::test]
306 async fn test_registry_execute_unknown() {
307 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
308
309 let result = registry
310 .execute("unknown", &serde_json::json!({}))
311 .await
312 .unwrap();
313 assert_eq!(result.exit_code, 1);
314 assert!(result.output.contains("Unknown tool"));
315 }
316
317 #[tokio::test]
318 async fn test_registry_execute_with_context_success() {
319 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
320 let ctx = ToolContext::new(PathBuf::from("/tmp"));
321
322 registry.register(Arc::new(MockTool {
323 name: "my_tool".to_string(),
324 }));
325
326 let result = registry
327 .execute_with_context("my_tool", &serde_json::json!({}), &ctx)
328 .await
329 .unwrap();
330 assert_eq!(result.name, "my_tool");
331 assert_eq!(result.exit_code, 0);
332 assert_eq!(result.output, "mock output");
333 }
334
335 #[tokio::test]
336 async fn test_registry_execute_with_context_unknown_tool() {
337 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
338 let ctx = ToolContext::new(PathBuf::from("/tmp"));
339
340 let result = registry
341 .execute_with_context("nonexistent", &serde_json::json!({}), &ctx)
342 .await
343 .unwrap();
344 assert_eq!(result.exit_code, 1);
345 assert!(result.output.contains("Unknown tool: nonexistent"));
346 }
347
348 struct FailingTool;
349
350 #[async_trait]
351 impl Tool for FailingTool {
352 fn name(&self) -> &str {
353 "failing"
354 }
355
356 fn description(&self) -> &str {
357 "A tool that returns failure"
358 }
359
360 fn parameters(&self) -> serde_json::Value {
361 serde_json::json!({
362 "type": "object",
363 "additionalProperties": false,
364 "properties": {},
365 "required": []
366 })
367 }
368
369 async fn execute(
370 &self,
371 _args: &serde_json::Value,
372 _ctx: &ToolContext,
373 ) -> Result<ToolOutput> {
374 Ok(ToolOutput::error("something went wrong"))
375 }
376 }
377
378 #[tokio::test]
379 async fn test_registry_execute_failing_tool() {
380 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
381 registry.register(Arc::new(FailingTool));
382
383 let result = registry
384 .execute("failing", &serde_json::json!({}))
385 .await
386 .unwrap();
387 assert_eq!(result.exit_code, 1);
388 assert_eq!(result.output, "something went wrong");
389 }
390
391 #[tokio::test]
392 async fn test_registry_execute_raw_success() {
393 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
394 registry.register(Arc::new(MockTool {
395 name: "raw_test".to_string(),
396 }));
397
398 let output = registry
399 .execute_raw("raw_test", &serde_json::json!({}))
400 .await
401 .unwrap();
402 assert!(output.is_some());
403 let output = output.unwrap();
404 assert!(output.success);
405 assert_eq!(output.content, "mock output");
406 }
407
408 #[tokio::test]
409 async fn test_registry_execute_raw_unknown() {
410 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
411
412 let output = registry
413 .execute_raw("missing", &serde_json::json!({}))
414 .await
415 .unwrap();
416 assert!(output.is_none());
417 }
418
419 #[test]
420 fn test_registry_list() {
421 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
422 registry.register(Arc::new(MockTool {
423 name: "alpha".to_string(),
424 }));
425 registry.register(Arc::new(MockTool {
426 name: "beta".to_string(),
427 }));
428
429 let names = registry.list();
430 assert_eq!(names.len(), 2);
431 assert!(names.contains(&"alpha".to_string()));
432 assert!(names.contains(&"beta".to_string()));
433 }
434
435 #[test]
436 fn test_registry_len_and_is_empty() {
437 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
438 assert!(registry.is_empty());
439 assert_eq!(registry.len(), 0);
440
441 registry.register(Arc::new(MockTool {
442 name: "t".to_string(),
443 }));
444 assert!(!registry.is_empty());
445 assert_eq!(registry.len(), 1);
446 }
447
448 #[test]
449 fn test_registry_replace_tool() {
450 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
451 registry.register(Arc::new(MockTool {
452 name: "dup".to_string(),
453 }));
454 registry.register(Arc::new(MockTool {
455 name: "dup".to_string(),
456 }));
457 assert_eq!(registry.len(), 1);
459 }
460}