1use super::types::{Tool, ToolContext, ToolOutput};
7use super::ToolResult;
8use crate::llm::ToolDefinition;
9use anyhow::Result;
10use std::collections::HashMap;
11use std::path::PathBuf;
12use std::sync::{Arc, RwLock};
13
14pub struct ToolRegistry {
16 tools: RwLock<HashMap<String, Arc<dyn Tool>>>,
17 builtins: RwLock<std::collections::HashSet<String>>,
19 context: RwLock<ToolContext>,
20}
21
22impl ToolRegistry {
23 pub fn new(workspace: PathBuf) -> Self {
25 Self {
26 tools: RwLock::new(HashMap::new()),
27 builtins: RwLock::new(std::collections::HashSet::new()),
28 context: RwLock::new(ToolContext::new(workspace)),
29 }
30 }
31
32 pub fn register_builtin(&self, tool: Arc<dyn Tool>) {
34 let name = tool.name().to_string();
35 let mut tools = self.tools.write().unwrap();
36 let mut builtins = self.builtins.write().unwrap();
37 tracing::debug!("Registering builtin tool: {}", name);
38 tools.insert(name.clone(), tool);
39 builtins.insert(name);
40 }
41
42 pub fn register(&self, tool: Arc<dyn Tool>) {
47 let name = tool.name().to_string();
48 let builtins = self.builtins.read().unwrap();
49 if builtins.contains(&name) {
50 tracing::warn!(
51 "Rejected registration of tool '{}': cannot shadow builtin",
52 name
53 );
54 return;
55 }
56 drop(builtins);
57 let mut tools = self.tools.write().unwrap();
58 tracing::debug!("Registering tool: {}", name);
59 tools.insert(name, tool);
60 }
61
62 pub fn unregister(&self, name: &str) -> bool {
66 let mut tools = self.tools.write().unwrap();
67 tracing::debug!("Unregistering tool: {}", name);
68 tools.remove(name).is_some()
69 }
70
71 pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
73 let tools = self.tools.read().unwrap();
74 tools.get(name).cloned()
75 }
76
77 pub fn contains(&self, name: &str) -> bool {
79 let tools = self.tools.read().unwrap();
80 tools.contains_key(name)
81 }
82
83 pub fn definitions(&self) -> Vec<ToolDefinition> {
85 let tools = self.tools.read().unwrap();
86 tools
87 .values()
88 .map(|tool| ToolDefinition {
89 name: tool.name().to_string(),
90 description: tool.description().to_string(),
91 parameters: tool.parameters(),
92 })
93 .collect()
94 }
95
96 pub fn list(&self) -> Vec<String> {
98 let tools = self.tools.read().unwrap();
99 tools.keys().cloned().collect()
100 }
101
102 pub fn len(&self) -> usize {
104 let tools = self.tools.read().unwrap();
105 tools.len()
106 }
107
108 pub fn is_empty(&self) -> bool {
110 self.len() == 0
111 }
112
113 pub fn context(&self) -> ToolContext {
115 self.context.read().unwrap().clone()
116 }
117
118 pub fn set_search_config(&self, config: crate::config::SearchConfig) {
120 let mut ctx = self.context.write().unwrap();
121 *ctx = ctx.clone().with_search_config(config);
122 }
123
124 pub fn set_sandbox(&self, sandbox: std::sync::Arc<dyn crate::sandbox::BashSandbox>) {
127 let mut ctx = self.context.write().unwrap();
128 *ctx = ctx.clone().with_sandbox(sandbox);
129 }
130
131 pub async fn execute(&self, name: &str, args: &serde_json::Value) -> Result<ToolResult> {
133 let ctx = self.context();
134 self.execute_with_context(name, args, &ctx).await
135 }
136
137 pub async fn execute_with_context(
139 &self,
140 name: &str,
141 args: &serde_json::Value,
142 ctx: &ToolContext,
143 ) -> Result<ToolResult> {
144 let start = std::time::Instant::now();
145
146 let tool = self.get(name);
147
148 let result = match tool {
149 Some(tool) => {
150 let output = tool.execute(args, ctx).await?;
151 Ok(ToolResult {
152 name: name.to_string(),
153 output: output.content,
154 exit_code: if output.success { 0 } else { 1 },
155 metadata: output.metadata,
156 images: output.images,
157 })
158 }
159 None => Ok(ToolResult::error(name, format!("Unknown tool: {}", name))),
160 };
161
162 if let Ok(ref r) = result {
163 crate::telemetry::record_tool_result(r.exit_code, start.elapsed());
164 }
165
166 result
167 }
168
169 pub async fn execute_raw(
171 &self,
172 name: &str,
173 args: &serde_json::Value,
174 ) -> Result<Option<ToolOutput>> {
175 let ctx = self.context();
176 self.execute_raw_with_context(name, args, &ctx).await
177 }
178
179 pub async fn execute_raw_with_context(
181 &self,
182 name: &str,
183 args: &serde_json::Value,
184 ctx: &ToolContext,
185 ) -> Result<Option<ToolOutput>> {
186 let tool = self.get(name);
187
188 match tool {
189 Some(tool) => {
190 let output = tool.execute(args, ctx).await?;
191 Ok(Some(output))
192 }
193 None => Ok(None),
194 }
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201 use async_trait::async_trait;
202
203 struct MockTool {
204 name: String,
205 }
206
207 #[async_trait]
208 impl Tool for MockTool {
209 fn name(&self) -> &str {
210 &self.name
211 }
212
213 fn description(&self) -> &str {
214 "A mock tool for testing"
215 }
216
217 fn parameters(&self) -> serde_json::Value {
218 serde_json::json!({
219 "type": "object",
220 "properties": {},
221 "required": []
222 })
223 }
224
225 async fn execute(
226 &self,
227 _args: &serde_json::Value,
228 _ctx: &ToolContext,
229 ) -> Result<ToolOutput> {
230 Ok(ToolOutput::success("mock output"))
231 }
232 }
233
234 #[test]
235 fn test_registry_register_and_get() {
236 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
237
238 let tool = Arc::new(MockTool {
239 name: "test".to_string(),
240 });
241 registry.register(tool);
242
243 assert!(registry.contains("test"));
244 assert!(!registry.contains("nonexistent"));
245
246 let retrieved = registry.get("test");
247 assert!(retrieved.is_some());
248 assert_eq!(retrieved.unwrap().name(), "test");
249 }
250
251 #[test]
252 fn test_registry_unregister() {
253 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
254
255 let tool = Arc::new(MockTool {
256 name: "test".to_string(),
257 });
258 registry.register(tool);
259
260 assert!(registry.contains("test"));
261 assert!(registry.unregister("test"));
262 assert!(!registry.contains("test"));
263 assert!(!registry.unregister("test")); }
265
266 #[test]
267 fn test_registry_definitions() {
268 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
269
270 registry.register(Arc::new(MockTool {
271 name: "tool1".to_string(),
272 }));
273 registry.register(Arc::new(MockTool {
274 name: "tool2".to_string(),
275 }));
276
277 let definitions = registry.definitions();
278 assert_eq!(definitions.len(), 2);
279 }
280
281 #[tokio::test]
282 async fn test_registry_execute() {
283 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
284
285 registry.register(Arc::new(MockTool {
286 name: "test".to_string(),
287 }));
288
289 let result = registry
290 .execute("test", &serde_json::json!({}))
291 .await
292 .unwrap();
293 assert_eq!(result.exit_code, 0);
294 assert_eq!(result.output, "mock output");
295 }
296
297 #[tokio::test]
298 async fn test_registry_execute_unknown() {
299 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
300
301 let result = registry
302 .execute("unknown", &serde_json::json!({}))
303 .await
304 .unwrap();
305 assert_eq!(result.exit_code, 1);
306 assert!(result.output.contains("Unknown tool"));
307 }
308
309 #[tokio::test]
310 async fn test_registry_execute_with_context_success() {
311 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
312 let ctx = ToolContext::new(PathBuf::from("/tmp"));
313
314 registry.register(Arc::new(MockTool {
315 name: "my_tool".to_string(),
316 }));
317
318 let result = registry
319 .execute_with_context("my_tool", &serde_json::json!({}), &ctx)
320 .await
321 .unwrap();
322 assert_eq!(result.name, "my_tool");
323 assert_eq!(result.exit_code, 0);
324 assert_eq!(result.output, "mock output");
325 }
326
327 #[tokio::test]
328 async fn test_registry_execute_with_context_unknown_tool() {
329 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
330 let ctx = ToolContext::new(PathBuf::from("/tmp"));
331
332 let result = registry
333 .execute_with_context("nonexistent", &serde_json::json!({}), &ctx)
334 .await
335 .unwrap();
336 assert_eq!(result.exit_code, 1);
337 assert!(result.output.contains("Unknown tool: nonexistent"));
338 }
339
340 struct FailingTool;
341
342 #[async_trait]
343 impl Tool for FailingTool {
344 fn name(&self) -> &str {
345 "failing"
346 }
347
348 fn description(&self) -> &str {
349 "A tool that returns failure"
350 }
351
352 fn parameters(&self) -> serde_json::Value {
353 serde_json::json!({"type": "object", "properties": {}, "required": []})
354 }
355
356 async fn execute(
357 &self,
358 _args: &serde_json::Value,
359 _ctx: &ToolContext,
360 ) -> Result<ToolOutput> {
361 Ok(ToolOutput::error("something went wrong"))
362 }
363 }
364
365 #[tokio::test]
366 async fn test_registry_execute_failing_tool() {
367 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
368 registry.register(Arc::new(FailingTool));
369
370 let result = registry
371 .execute("failing", &serde_json::json!({}))
372 .await
373 .unwrap();
374 assert_eq!(result.exit_code, 1);
375 assert_eq!(result.output, "something went wrong");
376 }
377
378 #[tokio::test]
379 async fn test_registry_execute_raw_success() {
380 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
381 registry.register(Arc::new(MockTool {
382 name: "raw_test".to_string(),
383 }));
384
385 let output = registry
386 .execute_raw("raw_test", &serde_json::json!({}))
387 .await
388 .unwrap();
389 assert!(output.is_some());
390 let output = output.unwrap();
391 assert!(output.success);
392 assert_eq!(output.content, "mock output");
393 }
394
395 #[tokio::test]
396 async fn test_registry_execute_raw_unknown() {
397 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
398
399 let output = registry
400 .execute_raw("missing", &serde_json::json!({}))
401 .await
402 .unwrap();
403 assert!(output.is_none());
404 }
405
406 #[test]
407 fn test_registry_list() {
408 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
409 registry.register(Arc::new(MockTool {
410 name: "alpha".to_string(),
411 }));
412 registry.register(Arc::new(MockTool {
413 name: "beta".to_string(),
414 }));
415
416 let names = registry.list();
417 assert_eq!(names.len(), 2);
418 assert!(names.contains(&"alpha".to_string()));
419 assert!(names.contains(&"beta".to_string()));
420 }
421
422 #[test]
423 fn test_registry_len_and_is_empty() {
424 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
425 assert!(registry.is_empty());
426 assert_eq!(registry.len(), 0);
427
428 registry.register(Arc::new(MockTool {
429 name: "t".to_string(),
430 }));
431 assert!(!registry.is_empty());
432 assert_eq!(registry.len(), 1);
433 }
434
435 #[test]
436 fn test_registry_replace_tool() {
437 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
438 registry.register(Arc::new(MockTool {
439 name: "dup".to_string(),
440 }));
441 registry.register(Arc::new(MockTool {
442 name: "dup".to_string(),
443 }));
444 assert_eq!(registry.len(), 1);
446 }
447}