1use super::types::{Tool, ToolContext, ToolOutput};
7use super::ToolResult;
8use crate::llm::ToolDefinition;
9use anyhow::Result;
10use std::collections::HashMap;
11use std::path::PathBuf;
12use std::sync::{Arc, RwLock};
13
14pub struct ToolRegistry {
16 tools: RwLock<HashMap<String, Arc<dyn Tool>>>,
17 builtins: RwLock<std::collections::HashSet<String>>,
19 context: RwLock<ToolContext>,
20}
21
22impl ToolRegistry {
23 pub fn new(workspace: PathBuf) -> Self {
25 Self {
26 tools: RwLock::new(HashMap::new()),
27 builtins: RwLock::new(std::collections::HashSet::new()),
28 context: RwLock::new(ToolContext::new(workspace)),
29 }
30 }
31
32 pub fn register_builtin(&self, tool: Arc<dyn Tool>) {
34 let name = tool.name().to_string();
35 let mut tools = self.tools.write().unwrap();
36 let mut builtins = self.builtins.write().unwrap();
37 tracing::debug!("Registering builtin tool: {}", name);
38 tools.insert(name.clone(), tool);
39 builtins.insert(name);
40 }
41
42 pub fn register(&self, tool: Arc<dyn Tool>) {
47 let name = tool.name().to_string();
48 let builtins = self.builtins.read().unwrap();
49 if builtins.contains(&name) {
50 tracing::warn!(
51 "Rejected registration of tool '{}': cannot shadow builtin",
52 name
53 );
54 return;
55 }
56 drop(builtins);
57 let mut tools = self.tools.write().unwrap();
58 tracing::debug!("Registering tool: {}", name);
59 tools.insert(name, tool);
60 }
61
62 pub fn unregister(&self, name: &str) -> bool {
66 let mut tools = self.tools.write().unwrap();
67 tracing::debug!("Unregistering tool: {}", name);
68 tools.remove(name).is_some()
69 }
70
71 pub fn unregister_by_prefix(&self, prefix: &str) {
73 let mut tools = self.tools.write().unwrap();
74 tools.retain(|name, _| !name.starts_with(prefix));
75 tracing::debug!("Unregistered tools with prefix: {}", prefix);
76 }
77
78 pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
80 let tools = self.tools.read().unwrap();
81 tools.get(name).cloned()
82 }
83
84 pub fn contains(&self, name: &str) -> bool {
86 let tools = self.tools.read().unwrap();
87 tools.contains_key(name)
88 }
89
90 pub fn definitions(&self) -> Vec<ToolDefinition> {
92 let tools = self.tools.read().unwrap();
93 tools
94 .values()
95 .map(|tool| ToolDefinition {
96 name: tool.name().to_string(),
97 description: tool.description().to_string(),
98 parameters: tool.parameters(),
99 })
100 .collect()
101 }
102
103 pub fn list(&self) -> Vec<String> {
105 let tools = self.tools.read().unwrap();
106 tools.keys().cloned().collect()
107 }
108
109 pub fn len(&self) -> usize {
111 let tools = self.tools.read().unwrap();
112 tools.len()
113 }
114
115 pub fn is_empty(&self) -> bool {
117 self.len() == 0
118 }
119
120 pub fn context(&self) -> ToolContext {
122 self.context.read().unwrap().clone()
123 }
124
125 pub fn set_search_config(&self, config: crate::config::SearchConfig) {
127 let mut ctx = self.context.write().unwrap();
128 *ctx = ctx.clone().with_search_config(config);
129 }
130
131 pub fn set_sandbox(&self, sandbox: std::sync::Arc<dyn crate::sandbox::BashSandbox>) {
134 let mut ctx = self.context.write().unwrap();
135 *ctx = ctx.clone().with_sandbox(sandbox);
136 }
137
138 pub async fn execute(&self, name: &str, args: &serde_json::Value) -> Result<ToolResult> {
140 let ctx = self.context();
141 self.execute_with_context(name, args, &ctx).await
142 }
143
144 pub async fn execute_with_context(
146 &self,
147 name: &str,
148 args: &serde_json::Value,
149 ctx: &ToolContext,
150 ) -> Result<ToolResult> {
151 let start = std::time::Instant::now();
152
153 let tool = self.get(name);
154
155 let result = match tool {
156 Some(tool) => {
157 let output = tool.execute(args, ctx).await?;
158 Ok(ToolResult {
159 name: name.to_string(),
160 output: output.content,
161 exit_code: if output.success { 0 } else { 1 },
162 metadata: output.metadata,
163 images: output.images,
164 })
165 }
166 None => Ok(ToolResult::error(name, format!("Unknown tool: {}", name))),
167 };
168
169 if let Ok(ref r) = result {
170 crate::telemetry::record_tool_result(r.exit_code, start.elapsed());
171 }
172
173 result
174 }
175
176 pub async fn execute_raw(
178 &self,
179 name: &str,
180 args: &serde_json::Value,
181 ) -> Result<Option<ToolOutput>> {
182 let ctx = self.context();
183 self.execute_raw_with_context(name, args, &ctx).await
184 }
185
186 pub async fn execute_raw_with_context(
188 &self,
189 name: &str,
190 args: &serde_json::Value,
191 ctx: &ToolContext,
192 ) -> Result<Option<ToolOutput>> {
193 let tool = self.get(name);
194
195 match tool {
196 Some(tool) => {
197 let output = tool.execute(args, ctx).await?;
198 Ok(Some(output))
199 }
200 None => Ok(None),
201 }
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use async_trait::async_trait;
209
210 struct MockTool {
211 name: String,
212 }
213
214 #[async_trait]
215 impl Tool for MockTool {
216 fn name(&self) -> &str {
217 &self.name
218 }
219
220 fn description(&self) -> &str {
221 "A mock tool for testing"
222 }
223
224 fn parameters(&self) -> serde_json::Value {
225 serde_json::json!({
226 "type": "object",
227 "properties": {},
228 "required": []
229 })
230 }
231
232 async fn execute(
233 &self,
234 _args: &serde_json::Value,
235 _ctx: &ToolContext,
236 ) -> Result<ToolOutput> {
237 Ok(ToolOutput::success("mock output"))
238 }
239 }
240
241 #[test]
242 fn test_registry_register_and_get() {
243 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
244
245 let tool = Arc::new(MockTool {
246 name: "test".to_string(),
247 });
248 registry.register(tool);
249
250 assert!(registry.contains("test"));
251 assert!(!registry.contains("nonexistent"));
252
253 let retrieved = registry.get("test");
254 assert!(retrieved.is_some());
255 assert_eq!(retrieved.unwrap().name(), "test");
256 }
257
258 #[test]
259 fn test_registry_unregister() {
260 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
261
262 let tool = Arc::new(MockTool {
263 name: "test".to_string(),
264 });
265 registry.register(tool);
266
267 assert!(registry.contains("test"));
268 assert!(registry.unregister("test"));
269 assert!(!registry.contains("test"));
270 assert!(!registry.unregister("test")); }
272
273 #[test]
274 fn test_registry_definitions() {
275 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
276
277 registry.register(Arc::new(MockTool {
278 name: "tool1".to_string(),
279 }));
280 registry.register(Arc::new(MockTool {
281 name: "tool2".to_string(),
282 }));
283
284 let definitions = registry.definitions();
285 assert_eq!(definitions.len(), 2);
286 }
287
288 #[tokio::test]
289 async fn test_registry_execute() {
290 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
291
292 registry.register(Arc::new(MockTool {
293 name: "test".to_string(),
294 }));
295
296 let result = registry
297 .execute("test", &serde_json::json!({}))
298 .await
299 .unwrap();
300 assert_eq!(result.exit_code, 0);
301 assert_eq!(result.output, "mock output");
302 }
303
304 #[tokio::test]
305 async fn test_registry_execute_unknown() {
306 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
307
308 let result = registry
309 .execute("unknown", &serde_json::json!({}))
310 .await
311 .unwrap();
312 assert_eq!(result.exit_code, 1);
313 assert!(result.output.contains("Unknown tool"));
314 }
315
316 #[tokio::test]
317 async fn test_registry_execute_with_context_success() {
318 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
319 let ctx = ToolContext::new(PathBuf::from("/tmp"));
320
321 registry.register(Arc::new(MockTool {
322 name: "my_tool".to_string(),
323 }));
324
325 let result = registry
326 .execute_with_context("my_tool", &serde_json::json!({}), &ctx)
327 .await
328 .unwrap();
329 assert_eq!(result.name, "my_tool");
330 assert_eq!(result.exit_code, 0);
331 assert_eq!(result.output, "mock output");
332 }
333
334 #[tokio::test]
335 async fn test_registry_execute_with_context_unknown_tool() {
336 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
337 let ctx = ToolContext::new(PathBuf::from("/tmp"));
338
339 let result = registry
340 .execute_with_context("nonexistent", &serde_json::json!({}), &ctx)
341 .await
342 .unwrap();
343 assert_eq!(result.exit_code, 1);
344 assert!(result.output.contains("Unknown tool: nonexistent"));
345 }
346
347 struct FailingTool;
348
349 #[async_trait]
350 impl Tool for FailingTool {
351 fn name(&self) -> &str {
352 "failing"
353 }
354
355 fn description(&self) -> &str {
356 "A tool that returns failure"
357 }
358
359 fn parameters(&self) -> serde_json::Value {
360 serde_json::json!({"type": "object", "properties": {}, "required": []})
361 }
362
363 async fn execute(
364 &self,
365 _args: &serde_json::Value,
366 _ctx: &ToolContext,
367 ) -> Result<ToolOutput> {
368 Ok(ToolOutput::error("something went wrong"))
369 }
370 }
371
372 #[tokio::test]
373 async fn test_registry_execute_failing_tool() {
374 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
375 registry.register(Arc::new(FailingTool));
376
377 let result = registry
378 .execute("failing", &serde_json::json!({}))
379 .await
380 .unwrap();
381 assert_eq!(result.exit_code, 1);
382 assert_eq!(result.output, "something went wrong");
383 }
384
385 #[tokio::test]
386 async fn test_registry_execute_raw_success() {
387 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
388 registry.register(Arc::new(MockTool {
389 name: "raw_test".to_string(),
390 }));
391
392 let output = registry
393 .execute_raw("raw_test", &serde_json::json!({}))
394 .await
395 .unwrap();
396 assert!(output.is_some());
397 let output = output.unwrap();
398 assert!(output.success);
399 assert_eq!(output.content, "mock output");
400 }
401
402 #[tokio::test]
403 async fn test_registry_execute_raw_unknown() {
404 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
405
406 let output = registry
407 .execute_raw("missing", &serde_json::json!({}))
408 .await
409 .unwrap();
410 assert!(output.is_none());
411 }
412
413 #[test]
414 fn test_registry_list() {
415 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
416 registry.register(Arc::new(MockTool {
417 name: "alpha".to_string(),
418 }));
419 registry.register(Arc::new(MockTool {
420 name: "beta".to_string(),
421 }));
422
423 let names = registry.list();
424 assert_eq!(names.len(), 2);
425 assert!(names.contains(&"alpha".to_string()));
426 assert!(names.contains(&"beta".to_string()));
427 }
428
429 #[test]
430 fn test_registry_len_and_is_empty() {
431 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
432 assert!(registry.is_empty());
433 assert_eq!(registry.len(), 0);
434
435 registry.register(Arc::new(MockTool {
436 name: "t".to_string(),
437 }));
438 assert!(!registry.is_empty());
439 assert_eq!(registry.len(), 1);
440 }
441
442 #[test]
443 fn test_registry_replace_tool() {
444 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
445 registry.register(Arc::new(MockTool {
446 name: "dup".to_string(),
447 }));
448 registry.register(Arc::new(MockTool {
449 name: "dup".to_string(),
450 }));
451 assert_eq!(registry.len(), 1);
453 }
454}