1use super::types::{Tool, ToolContext, ToolOutput};
7use super::ToolResult;
8use crate::llm::ToolDefinition;
9use anyhow::Result;
10use std::collections::HashMap;
11use std::path::PathBuf;
12use std::sync::{Arc, RwLock};
13
14pub struct ToolRegistry {
16 tools: RwLock<HashMap<String, Arc<dyn Tool>>>,
17 builtins: RwLock<std::collections::HashSet<String>>,
19 context: RwLock<ToolContext>,
20}
21
22impl ToolRegistry {
23 pub fn new(workspace: PathBuf) -> Self {
25 Self {
26 tools: RwLock::new(HashMap::new()),
27 builtins: RwLock::new(std::collections::HashSet::new()),
28 context: RwLock::new(ToolContext::new(workspace)),
29 }
30 }
31
32 pub fn register_builtin(&self, tool: Arc<dyn Tool>) {
34 let name = tool.name().to_string();
35 let mut tools = self.tools.write().unwrap();
36 let mut builtins = self.builtins.write().unwrap();
37 tracing::debug!("Registering builtin tool: {}", name);
38 tools.insert(name.clone(), tool);
39 builtins.insert(name);
40 }
41
42 pub fn register(&self, tool: Arc<dyn Tool>) {
47 let name = tool.name().to_string();
48 let builtins = self.builtins.read().unwrap();
49 if builtins.contains(&name) {
50 tracing::warn!(
51 "Rejected registration of tool '{}': cannot shadow builtin",
52 name
53 );
54 return;
55 }
56 drop(builtins);
57 let mut tools = self.tools.write().unwrap();
58 tracing::debug!("Registering tool: {}", name);
59 tools.insert(name, tool);
60 }
61
62 pub fn unregister(&self, name: &str) -> bool {
66 let mut tools = self.tools.write().unwrap();
67 tracing::debug!("Unregistering tool: {}", name);
68 tools.remove(name).is_some()
69 }
70
71 pub fn unregister_by_prefix(&self, prefix: &str) {
73 let mut tools = self.tools.write().unwrap();
74 tools.retain(|name, _| !name.starts_with(prefix));
75 tracing::debug!("Unregistered tools with prefix: {}", prefix);
76 }
77
78 pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
80 let tools = self.tools.read().unwrap();
81 tools.get(name).cloned()
82 }
83
84 pub fn contains(&self, name: &str) -> bool {
86 let tools = self.tools.read().unwrap();
87 tools.contains_key(name)
88 }
89
90 pub fn definitions(&self) -> Vec<ToolDefinition> {
92 let tools = self.tools.read().unwrap();
93 tools
94 .values()
95 .map(|tool| ToolDefinition {
96 name: tool.name().to_string(),
97 description: tool.description().to_string(),
98 parameters: tool.parameters(),
99 })
100 .collect()
101 }
102
103 pub fn list(&self) -> Vec<String> {
105 let tools = self.tools.read().unwrap();
106 tools.keys().cloned().collect()
107 }
108
109 pub fn len(&self) -> usize {
111 let tools = self.tools.read().unwrap();
112 tools.len()
113 }
114
115 pub fn is_empty(&self) -> bool {
117 self.len() == 0
118 }
119
120 pub fn context(&self) -> ToolContext {
122 self.context.read().unwrap().clone()
123 }
124
125 pub fn set_search_config(&self, config: crate::config::SearchConfig) {
127 let mut ctx = self.context.write().unwrap();
128 *ctx = ctx.clone().with_search_config(config);
129 }
130
131 pub fn set_agentic_search_config(&self, config: crate::config::AgenticSearchConfig) {
133 let mut ctx = self.context.write().unwrap();
134 *ctx = ctx.clone().with_agentic_search_config(config);
135 }
136
137 pub fn set_agentic_parse_config(&self, config: crate::config::AgenticParseConfig) {
139 let mut ctx = self.context.write().unwrap();
140 *ctx = ctx.clone().with_agentic_parse_config(config);
141 }
142
143 pub fn set_default_parser_config(&self, config: crate::config::DefaultParserConfig) {
145 let mut ctx = self.context.write().unwrap();
146 *ctx = ctx.clone().with_default_parser_config(config);
147 }
148
149 pub fn set_document_parsers(
151 &self,
152 registry: std::sync::Arc<crate::document_parser::DocumentParserRegistry>,
153 ) {
154 let mut ctx = self.context.write().unwrap();
155 *ctx = ctx.clone().with_document_parsers(registry);
156 }
157
158 pub fn set_sandbox(&self, sandbox: std::sync::Arc<dyn crate::sandbox::BashSandbox>) {
161 let mut ctx = self.context.write().unwrap();
162 *ctx = ctx.clone().with_sandbox(sandbox);
163 }
164
165 pub async fn execute(&self, name: &str, args: &serde_json::Value) -> Result<ToolResult> {
167 let ctx = self.context();
168 self.execute_with_context(name, args, &ctx).await
169 }
170
171 pub async fn execute_with_context(
173 &self,
174 name: &str,
175 args: &serde_json::Value,
176 ctx: &ToolContext,
177 ) -> Result<ToolResult> {
178 let start = std::time::Instant::now();
179
180 let tool = self.get(name);
181
182 let result = match tool {
183 Some(tool) => {
184 let output = tool.execute(args, ctx).await?;
185 Ok(ToolResult {
186 name: name.to_string(),
187 output: output.content,
188 exit_code: if output.success { 0 } else { 1 },
189 metadata: output.metadata,
190 images: output.images,
191 })
192 }
193 None => Ok(ToolResult::error(name, format!("Unknown tool: {}", name))),
194 };
195
196 if let Ok(ref r) = result {
197 crate::telemetry::record_tool_result(r.exit_code, start.elapsed());
198 }
199
200 result
201 }
202
203 pub async fn execute_raw(
205 &self,
206 name: &str,
207 args: &serde_json::Value,
208 ) -> Result<Option<ToolOutput>> {
209 let ctx = self.context();
210 self.execute_raw_with_context(name, args, &ctx).await
211 }
212
213 pub async fn execute_raw_with_context(
215 &self,
216 name: &str,
217 args: &serde_json::Value,
218 ctx: &ToolContext,
219 ) -> Result<Option<ToolOutput>> {
220 let tool = self.get(name);
221
222 match tool {
223 Some(tool) => {
224 let output = tool.execute(args, ctx).await?;
225 Ok(Some(output))
226 }
227 None => Ok(None),
228 }
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use async_trait::async_trait;
236
237 struct MockTool {
238 name: String,
239 }
240
241 #[async_trait]
242 impl Tool for MockTool {
243 fn name(&self) -> &str {
244 &self.name
245 }
246
247 fn description(&self) -> &str {
248 "A mock tool for testing"
249 }
250
251 fn parameters(&self) -> serde_json::Value {
252 serde_json::json!({
253 "type": "object",
254 "additionalProperties": false,
255 "properties": {},
256 "required": []
257 })
258 }
259
260 async fn execute(
261 &self,
262 _args: &serde_json::Value,
263 _ctx: &ToolContext,
264 ) -> Result<ToolOutput> {
265 Ok(ToolOutput::success("mock output"))
266 }
267 }
268
269 #[test]
270 fn test_registry_register_and_get() {
271 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
272
273 let tool = Arc::new(MockTool {
274 name: "test".to_string(),
275 });
276 registry.register(tool);
277
278 assert!(registry.contains("test"));
279 assert!(!registry.contains("nonexistent"));
280
281 let retrieved = registry.get("test");
282 assert!(retrieved.is_some());
283 assert_eq!(retrieved.unwrap().name(), "test");
284 }
285
286 #[test]
287 fn test_registry_unregister() {
288 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
289
290 let tool = Arc::new(MockTool {
291 name: "test".to_string(),
292 });
293 registry.register(tool);
294
295 assert!(registry.contains("test"));
296 assert!(registry.unregister("test"));
297 assert!(!registry.contains("test"));
298 assert!(!registry.unregister("test")); }
300
301 #[test]
302 fn test_registry_definitions() {
303 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
304
305 registry.register(Arc::new(MockTool {
306 name: "tool1".to_string(),
307 }));
308 registry.register(Arc::new(MockTool {
309 name: "tool2".to_string(),
310 }));
311
312 let definitions = registry.definitions();
313 assert_eq!(definitions.len(), 2);
314 }
315
316 #[tokio::test]
317 async fn test_registry_execute() {
318 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
319
320 registry.register(Arc::new(MockTool {
321 name: "test".to_string(),
322 }));
323
324 let result = registry
325 .execute("test", &serde_json::json!({}))
326 .await
327 .unwrap();
328 assert_eq!(result.exit_code, 0);
329 assert_eq!(result.output, "mock output");
330 }
331
332 #[tokio::test]
333 async fn test_registry_execute_unknown() {
334 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
335
336 let result = registry
337 .execute("unknown", &serde_json::json!({}))
338 .await
339 .unwrap();
340 assert_eq!(result.exit_code, 1);
341 assert!(result.output.contains("Unknown tool"));
342 }
343
344 #[tokio::test]
345 async fn test_registry_execute_with_context_success() {
346 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
347 let ctx = ToolContext::new(PathBuf::from("/tmp"));
348
349 registry.register(Arc::new(MockTool {
350 name: "my_tool".to_string(),
351 }));
352
353 let result = registry
354 .execute_with_context("my_tool", &serde_json::json!({}), &ctx)
355 .await
356 .unwrap();
357 assert_eq!(result.name, "my_tool");
358 assert_eq!(result.exit_code, 0);
359 assert_eq!(result.output, "mock output");
360 }
361
362 #[tokio::test]
363 async fn test_registry_execute_with_context_unknown_tool() {
364 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
365 let ctx = ToolContext::new(PathBuf::from("/tmp"));
366
367 let result = registry
368 .execute_with_context("nonexistent", &serde_json::json!({}), &ctx)
369 .await
370 .unwrap();
371 assert_eq!(result.exit_code, 1);
372 assert!(result.output.contains("Unknown tool: nonexistent"));
373 }
374
375 struct FailingTool;
376
377 #[async_trait]
378 impl Tool for FailingTool {
379 fn name(&self) -> &str {
380 "failing"
381 }
382
383 fn description(&self) -> &str {
384 "A tool that returns failure"
385 }
386
387 fn parameters(&self) -> serde_json::Value {
388 serde_json::json!({
389 "type": "object",
390 "additionalProperties": false,
391 "properties": {},
392 "required": []
393 })
394 }
395
396 async fn execute(
397 &self,
398 _args: &serde_json::Value,
399 _ctx: &ToolContext,
400 ) -> Result<ToolOutput> {
401 Ok(ToolOutput::error("something went wrong"))
402 }
403 }
404
405 #[tokio::test]
406 async fn test_registry_execute_failing_tool() {
407 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
408 registry.register(Arc::new(FailingTool));
409
410 let result = registry
411 .execute("failing", &serde_json::json!({}))
412 .await
413 .unwrap();
414 assert_eq!(result.exit_code, 1);
415 assert_eq!(result.output, "something went wrong");
416 }
417
418 #[tokio::test]
419 async fn test_registry_execute_raw_success() {
420 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
421 registry.register(Arc::new(MockTool {
422 name: "raw_test".to_string(),
423 }));
424
425 let output = registry
426 .execute_raw("raw_test", &serde_json::json!({}))
427 .await
428 .unwrap();
429 assert!(output.is_some());
430 let output = output.unwrap();
431 assert!(output.success);
432 assert_eq!(output.content, "mock output");
433 }
434
435 #[tokio::test]
436 async fn test_registry_execute_raw_unknown() {
437 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
438
439 let output = registry
440 .execute_raw("missing", &serde_json::json!({}))
441 .await
442 .unwrap();
443 assert!(output.is_none());
444 }
445
446 #[test]
447 fn test_registry_list() {
448 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
449 registry.register(Arc::new(MockTool {
450 name: "alpha".to_string(),
451 }));
452 registry.register(Arc::new(MockTool {
453 name: "beta".to_string(),
454 }));
455
456 let names = registry.list();
457 assert_eq!(names.len(), 2);
458 assert!(names.contains(&"alpha".to_string()));
459 assert!(names.contains(&"beta".to_string()));
460 }
461
462 #[test]
463 fn test_registry_len_and_is_empty() {
464 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
465 assert!(registry.is_empty());
466 assert_eq!(registry.len(), 0);
467
468 registry.register(Arc::new(MockTool {
469 name: "t".to_string(),
470 }));
471 assert!(!registry.is_empty());
472 assert_eq!(registry.len(), 1);
473 }
474
475 #[test]
476 fn test_registry_replace_tool() {
477 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
478 registry.register(Arc::new(MockTool {
479 name: "dup".to_string(),
480 }));
481 registry.register(Arc::new(MockTool {
482 name: "dup".to_string(),
483 }));
484 assert_eq!(registry.len(), 1);
486 }
487}