1use super::types::{Tool, ToolContext, ToolOutput};
7use super::ToolResult;
8use crate::llm::ToolDefinition;
9use anyhow::Result;
10use std::collections::HashMap;
11use std::path::PathBuf;
12use std::sync::{Arc, RwLock};
13
14pub struct ToolRegistry {
16 tools: RwLock<HashMap<String, Arc<dyn Tool>>>,
17 builtins: RwLock<std::collections::HashSet<String>>,
19 context: RwLock<ToolContext>,
20}
21
22impl ToolRegistry {
23 pub fn new(workspace: PathBuf) -> Self {
25 Self {
26 tools: RwLock::new(HashMap::new()),
27 builtins: RwLock::new(std::collections::HashSet::new()),
28 context: RwLock::new(ToolContext::new(workspace)),
29 }
30 }
31
32 pub fn register_builtin(&self, tool: Arc<dyn Tool>) {
34 let name = tool.name().to_string();
35 let mut tools = self.tools.write().unwrap();
36 let mut builtins = self.builtins.write().unwrap();
37 tracing::debug!("Registering builtin tool: {}", name);
38 tools.insert(name.clone(), tool);
39 builtins.insert(name);
40 }
41
42 pub fn register(&self, tool: Arc<dyn Tool>) {
47 let name = tool.name().to_string();
48 let builtins = self.builtins.read().unwrap();
49 if builtins.contains(&name) {
50 tracing::warn!(
51 "Rejected registration of tool '{}': cannot shadow builtin",
52 name
53 );
54 return;
55 }
56 drop(builtins);
57 let mut tools = self.tools.write().unwrap();
58 tracing::debug!("Registering tool: {}", name);
59 tools.insert(name, tool);
60 }
61
62 pub fn unregister(&self, name: &str) -> bool {
66 let mut tools = self.tools.write().unwrap();
67 tracing::debug!("Unregistering tool: {}", name);
68 tools.remove(name).is_some()
69 }
70
71 pub fn unregister_by_prefix(&self, prefix: &str) {
73 let mut tools = self.tools.write().unwrap();
74 tools.retain(|name, _| !name.starts_with(prefix));
75 tracing::debug!("Unregistered tools with prefix: {}", prefix);
76 }
77
78 pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
80 let tools = self.tools.read().unwrap();
81 tools.get(name).cloned()
82 }
83
84 pub fn contains(&self, name: &str) -> bool {
86 let tools = self.tools.read().unwrap();
87 tools.contains_key(name)
88 }
89
90 pub fn definitions(&self) -> Vec<ToolDefinition> {
92 let tools = self.tools.read().unwrap();
93 tools
94 .values()
95 .map(|tool| ToolDefinition {
96 name: tool.name().to_string(),
97 description: tool.description().to_string(),
98 parameters: tool.parameters(),
99 })
100 .collect()
101 }
102
103 pub fn list(&self) -> Vec<String> {
105 let tools = self.tools.read().unwrap();
106 tools.keys().cloned().collect()
107 }
108
109 pub fn len(&self) -> usize {
111 let tools = self.tools.read().unwrap();
112 tools.len()
113 }
114
115 pub fn is_empty(&self) -> bool {
117 self.len() == 0
118 }
119
120 pub fn context(&self) -> ToolContext {
122 self.context.read().unwrap().clone()
123 }
124
125 pub fn set_search_config(&self, config: crate::config::SearchConfig) {
127 let mut ctx = self.context.write().unwrap();
128 *ctx = ctx.clone().with_search_config(config);
129 }
130
131 pub fn set_agentic_search_config(&self, config: crate::config::AgenticSearchConfig) {
133 let mut ctx = self.context.write().unwrap();
134 *ctx = ctx.clone().with_agentic_search_config(config);
135 }
136
137 pub fn set_agentic_parse_config(&self, config: crate::config::AgenticParseConfig) {
139 let mut ctx = self.context.write().unwrap();
140 *ctx = ctx.clone().with_agentic_parse_config(config);
141 }
142
143 pub fn set_document_parser_config(&self, config: crate::config::DocumentParserConfig) {
145 let mut ctx = self.context.write().unwrap();
146 *ctx = ctx.clone().with_document_parser_config(config);
147 }
148
149 pub fn set_document_parsers(
151 &self,
152 registry: std::sync::Arc<crate::document_parser::DocumentParserRegistry>,
153 ) {
154 let mut ctx = self.context.write().unwrap();
155 *ctx = ctx.clone().with_document_parsers(registry);
156 }
157
158 pub(crate) fn set_document_pipeline(
160 &self,
161 registry: std::sync::Arc<crate::document_pipeline::DocumentPipelineRegistry>,
162 ) {
163 let mut ctx = self.context.write().unwrap();
164 *ctx = ctx.clone().with_document_pipeline(registry);
165 }
166
167 pub fn set_sandbox(&self, sandbox: std::sync::Arc<dyn crate::sandbox::BashSandbox>) {
170 let mut ctx = self.context.write().unwrap();
171 *ctx = ctx.clone().with_sandbox(sandbox);
172 }
173
174 pub async fn execute(&self, name: &str, args: &serde_json::Value) -> Result<ToolResult> {
176 let ctx = self.context();
177 self.execute_with_context(name, args, &ctx).await
178 }
179
180 pub async fn execute_with_context(
182 &self,
183 name: &str,
184 args: &serde_json::Value,
185 ctx: &ToolContext,
186 ) -> Result<ToolResult> {
187 let start = std::time::Instant::now();
188
189 let tool = self.get(name);
190
191 let result = match tool {
192 Some(tool) => {
193 let output = tool.execute(args, ctx).await?;
194 Ok(ToolResult {
195 name: name.to_string(),
196 output: output.content,
197 exit_code: if output.success { 0 } else { 1 },
198 metadata: output.metadata,
199 images: output.images,
200 })
201 }
202 None => Ok(ToolResult::error(name, format!("Unknown tool: {}", name))),
203 };
204
205 if let Ok(ref r) = result {
206 crate::telemetry::record_tool_result(r.exit_code, start.elapsed());
207 }
208
209 result
210 }
211
212 pub async fn execute_raw(
214 &self,
215 name: &str,
216 args: &serde_json::Value,
217 ) -> Result<Option<ToolOutput>> {
218 let ctx = self.context();
219 self.execute_raw_with_context(name, args, &ctx).await
220 }
221
222 pub async fn execute_raw_with_context(
224 &self,
225 name: &str,
226 args: &serde_json::Value,
227 ctx: &ToolContext,
228 ) -> Result<Option<ToolOutput>> {
229 let tool = self.get(name);
230
231 match tool {
232 Some(tool) => {
233 let output = tool.execute(args, ctx).await?;
234 Ok(Some(output))
235 }
236 None => Ok(None),
237 }
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244 use async_trait::async_trait;
245
246 struct MockTool {
247 name: String,
248 }
249
250 #[async_trait]
251 impl Tool for MockTool {
252 fn name(&self) -> &str {
253 &self.name
254 }
255
256 fn description(&self) -> &str {
257 "A mock tool for testing"
258 }
259
260 fn parameters(&self) -> serde_json::Value {
261 serde_json::json!({
262 "type": "object",
263 "additionalProperties": false,
264 "properties": {},
265 "required": []
266 })
267 }
268
269 async fn execute(
270 &self,
271 _args: &serde_json::Value,
272 _ctx: &ToolContext,
273 ) -> Result<ToolOutput> {
274 Ok(ToolOutput::success("mock output"))
275 }
276 }
277
278 #[test]
279 fn test_registry_register_and_get() {
280 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
281
282 let tool = Arc::new(MockTool {
283 name: "test".to_string(),
284 });
285 registry.register(tool);
286
287 assert!(registry.contains("test"));
288 assert!(!registry.contains("nonexistent"));
289
290 let retrieved = registry.get("test");
291 assert!(retrieved.is_some());
292 assert_eq!(retrieved.unwrap().name(), "test");
293 }
294
295 #[test]
296 fn test_registry_unregister() {
297 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
298
299 let tool = Arc::new(MockTool {
300 name: "test".to_string(),
301 });
302 registry.register(tool);
303
304 assert!(registry.contains("test"));
305 assert!(registry.unregister("test"));
306 assert!(!registry.contains("test"));
307 assert!(!registry.unregister("test")); }
309
310 #[test]
311 fn test_registry_definitions() {
312 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
313
314 registry.register(Arc::new(MockTool {
315 name: "tool1".to_string(),
316 }));
317 registry.register(Arc::new(MockTool {
318 name: "tool2".to_string(),
319 }));
320
321 let definitions = registry.definitions();
322 assert_eq!(definitions.len(), 2);
323 }
324
325 #[test]
326 fn test_registry_set_document_parser_config() {
327 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
328 registry.set_document_parser_config(crate::config::DocumentParserConfig {
329 enabled: true,
330 max_file_size_mb: 48,
331 ocr: None,
332 ..Default::default()
333 });
334
335 assert_eq!(
336 registry
337 .context()
338 .document_parser_config
339 .as_ref()
340 .unwrap()
341 .max_file_size_mb,
342 48
343 );
344 }
345
346 #[tokio::test]
347 async fn test_registry_execute() {
348 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
349
350 registry.register(Arc::new(MockTool {
351 name: "test".to_string(),
352 }));
353
354 let result = registry
355 .execute("test", &serde_json::json!({}))
356 .await
357 .unwrap();
358 assert_eq!(result.exit_code, 0);
359 assert_eq!(result.output, "mock output");
360 }
361
362 #[tokio::test]
363 async fn test_registry_execute_unknown() {
364 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
365
366 let result = registry
367 .execute("unknown", &serde_json::json!({}))
368 .await
369 .unwrap();
370 assert_eq!(result.exit_code, 1);
371 assert!(result.output.contains("Unknown tool"));
372 }
373
374 #[tokio::test]
375 async fn test_registry_execute_with_context_success() {
376 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
377 let ctx = ToolContext::new(PathBuf::from("/tmp"));
378
379 registry.register(Arc::new(MockTool {
380 name: "my_tool".to_string(),
381 }));
382
383 let result = registry
384 .execute_with_context("my_tool", &serde_json::json!({}), &ctx)
385 .await
386 .unwrap();
387 assert_eq!(result.name, "my_tool");
388 assert_eq!(result.exit_code, 0);
389 assert_eq!(result.output, "mock output");
390 }
391
392 #[tokio::test]
393 async fn test_registry_execute_with_context_unknown_tool() {
394 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
395 let ctx = ToolContext::new(PathBuf::from("/tmp"));
396
397 let result = registry
398 .execute_with_context("nonexistent", &serde_json::json!({}), &ctx)
399 .await
400 .unwrap();
401 assert_eq!(result.exit_code, 1);
402 assert!(result.output.contains("Unknown tool: nonexistent"));
403 }
404
405 struct FailingTool;
406
407 #[async_trait]
408 impl Tool for FailingTool {
409 fn name(&self) -> &str {
410 "failing"
411 }
412
413 fn description(&self) -> &str {
414 "A tool that returns failure"
415 }
416
417 fn parameters(&self) -> serde_json::Value {
418 serde_json::json!({
419 "type": "object",
420 "additionalProperties": false,
421 "properties": {},
422 "required": []
423 })
424 }
425
426 async fn execute(
427 &self,
428 _args: &serde_json::Value,
429 _ctx: &ToolContext,
430 ) -> Result<ToolOutput> {
431 Ok(ToolOutput::error("something went wrong"))
432 }
433 }
434
435 #[tokio::test]
436 async fn test_registry_execute_failing_tool() {
437 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
438 registry.register(Arc::new(FailingTool));
439
440 let result = registry
441 .execute("failing", &serde_json::json!({}))
442 .await
443 .unwrap();
444 assert_eq!(result.exit_code, 1);
445 assert_eq!(result.output, "something went wrong");
446 }
447
448 #[tokio::test]
449 async fn test_registry_execute_raw_success() {
450 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
451 registry.register(Arc::new(MockTool {
452 name: "raw_test".to_string(),
453 }));
454
455 let output = registry
456 .execute_raw("raw_test", &serde_json::json!({}))
457 .await
458 .unwrap();
459 assert!(output.is_some());
460 let output = output.unwrap();
461 assert!(output.success);
462 assert_eq!(output.content, "mock output");
463 }
464
465 #[tokio::test]
466 async fn test_registry_execute_raw_unknown() {
467 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
468
469 let output = registry
470 .execute_raw("missing", &serde_json::json!({}))
471 .await
472 .unwrap();
473 assert!(output.is_none());
474 }
475
476 #[test]
477 fn test_registry_list() {
478 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
479 registry.register(Arc::new(MockTool {
480 name: "alpha".to_string(),
481 }));
482 registry.register(Arc::new(MockTool {
483 name: "beta".to_string(),
484 }));
485
486 let names = registry.list();
487 assert_eq!(names.len(), 2);
488 assert!(names.contains(&"alpha".to_string()));
489 assert!(names.contains(&"beta".to_string()));
490 }
491
492 #[test]
493 fn test_registry_len_and_is_empty() {
494 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
495 assert!(registry.is_empty());
496 assert_eq!(registry.len(), 0);
497
498 registry.register(Arc::new(MockTool {
499 name: "t".to_string(),
500 }));
501 assert!(!registry.is_empty());
502 assert_eq!(registry.len(), 1);
503 }
504
505 #[test]
506 fn test_registry_replace_tool() {
507 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
508 registry.register(Arc::new(MockTool {
509 name: "dup".to_string(),
510 }));
511 registry.register(Arc::new(MockTool {
512 name: "dup".to_string(),
513 }));
514 assert_eq!(registry.len(), 1);
516 }
517}