1use super::artifacts::{ArtifactStore, ArtifactStoreLimits, ToolArtifact};
7use super::types::{Tool, ToolContext, ToolOutput};
8use super::ToolResult;
9use super::{
10 merge_tool_output_artifact_metadata, truncate_tool_output_with_artifact, ToolOutputArtifact,
11};
12use crate::llm::ToolDefinition;
13use crate::trace::{InMemoryTraceSink, TraceEvent, TraceSink};
14use anyhow::Result;
15use std::collections::HashMap;
16use std::path::PathBuf;
17use std::sync::{Arc, RwLock};
18
19pub struct ToolRegistry {
21 tools: RwLock<HashMap<String, Arc<dyn Tool>>>,
22 builtins: RwLock<std::collections::HashSet<String>>,
24 context: RwLock<ToolContext>,
25 artifact_store: ArtifactStore,
26 trace_sink: RwLock<Arc<dyn TraceSink>>,
27}
28
29impl ToolRegistry {
30 pub fn new(workspace: PathBuf) -> Self {
32 Self::with_artifact_limits(workspace, ArtifactStoreLimits::default())
33 }
34
35 pub fn with_artifact_limits(workspace: PathBuf, artifact_limits: ArtifactStoreLimits) -> Self {
37 Self {
38 tools: RwLock::new(HashMap::new()),
39 builtins: RwLock::new(std::collections::HashSet::new()),
40 context: RwLock::new(ToolContext::new(workspace)),
41 artifact_store: ArtifactStore::with_limits(artifact_limits),
42 trace_sink: RwLock::new(Arc::new(InMemoryTraceSink::default())),
43 }
44 }
45
46 pub fn register_builtin(&self, tool: Arc<dyn Tool>) {
48 let name = tool.name().to_string();
49 let mut tools = self.tools.write().unwrap();
50 let mut builtins = self.builtins.write().unwrap();
51 tracing::debug!("Registering builtin tool: {}", name);
52 tools.insert(name.clone(), tool);
53 builtins.insert(name);
54 }
55
56 pub fn register(&self, tool: Arc<dyn Tool>) {
61 let name = tool.name().to_string();
62 let builtins = self.builtins.read().unwrap();
63 if builtins.contains(&name) {
64 tracing::warn!(
65 "Rejected registration of tool '{}': cannot shadow builtin",
66 name
67 );
68 return;
69 }
70 drop(builtins);
71 let mut tools = self.tools.write().unwrap();
72 tracing::debug!("Registering tool: {}", name);
73 tools.insert(name, tool);
74 }
75
76 pub fn unregister(&self, name: &str) -> bool {
80 let mut tools = self.tools.write().unwrap();
81 tracing::debug!("Unregistering tool: {}", name);
82 tools.remove(name).is_some()
83 }
84
85 pub fn unregister_by_prefix(&self, prefix: &str) {
87 let mut tools = self.tools.write().unwrap();
88 tools.retain(|name, _| !name.starts_with(prefix));
89 tracing::debug!("Unregistered tools with prefix: {}", prefix);
90 }
91
92 pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
94 let tools = self.tools.read().unwrap();
95 tools.get(name).cloned()
96 }
97
98 pub fn contains(&self, name: &str) -> bool {
100 let tools = self.tools.read().unwrap();
101 tools.contains_key(name)
102 }
103
104 pub fn definitions(&self) -> Vec<ToolDefinition> {
106 let tools = self.tools.read().unwrap();
107 tools
108 .values()
109 .map(|tool| ToolDefinition {
110 name: tool.name().to_string(),
111 description: tool.description().to_string(),
112 parameters: tool.parameters(),
113 })
114 .collect()
115 }
116
117 pub fn list(&self) -> Vec<String> {
119 let tools = self.tools.read().unwrap();
120 tools.keys().cloned().collect()
121 }
122
123 pub fn len(&self) -> usize {
125 let tools = self.tools.read().unwrap();
126 tools.len()
127 }
128
129 pub fn is_empty(&self) -> bool {
131 self.len() == 0
132 }
133
134 pub fn context(&self) -> ToolContext {
136 self.context.read().unwrap().clone()
137 }
138
139 pub fn artifact_store(&self) -> ArtifactStore {
141 self.artifact_store.clone()
142 }
143
144 pub fn get_artifact(&self, artifact_uri: &str) -> Option<ToolArtifact> {
146 self.artifact_store.get(artifact_uri)
147 }
148
149 pub fn set_trace_sink(&self, sink: Arc<dyn TraceSink>) {
151 *self.trace_sink.write().unwrap() = sink;
152 }
153
154 pub fn trace_sink(&self) -> Arc<dyn TraceSink> {
156 Arc::clone(&self.trace_sink.read().unwrap())
157 }
158
159 pub fn set_search_config(&self, config: crate::config::SearchConfig) {
161 let mut ctx = self.context.write().unwrap();
162 *ctx = ctx.clone().with_search_config(config);
163 }
164
165 pub fn set_sandbox(&self, sandbox: std::sync::Arc<dyn crate::sandbox::BashSandbox>) {
168 let mut ctx = self.context.write().unwrap();
169 *ctx = ctx.clone().with_sandbox(sandbox);
170 }
171
172 pub async fn execute(&self, name: &str, args: &serde_json::Value) -> Result<ToolResult> {
174 let ctx = self.context();
175 self.execute_with_context(name, args, &ctx).await
176 }
177
178 pub async fn execute_with_context(
180 &self,
181 name: &str,
182 args: &serde_json::Value,
183 ctx: &ToolContext,
184 ) -> Result<ToolResult> {
185 let start = std::time::Instant::now();
186
187 let tool = self.get(name);
188
189 let result = match tool {
190 Some(tool) => {
191 let mut output = tool.execute(args, ctx).await?;
192 let original_content = output.content.clone();
193 let truncated = truncate_tool_output_with_artifact(name, &output.content);
194 output.content = truncated.content;
195 if let Some(artifact) = truncated.artifact {
196 self.store_tool_artifact(name, &original_content, &artifact);
197 output.metadata = Some(merge_tool_output_artifact_metadata(
198 output.metadata,
199 &artifact,
200 ));
201 }
202 Ok(ToolResult {
203 name: name.to_string(),
204 output: output.content,
205 exit_code: if output.success { 0 } else { 1 },
206 metadata: output.metadata,
207 images: output.images,
208 })
209 }
210 None => Ok(ToolResult::error(name, format!("Unknown tool: {}", name))),
211 };
212
213 if let Ok(ref r) = result {
214 crate::telemetry::record_tool_result(r.exit_code, start.elapsed());
215 self.record_trace_event(name, r, start.elapsed());
216 }
217
218 result
219 }
220
221 pub async fn execute_raw(
223 &self,
224 name: &str,
225 args: &serde_json::Value,
226 ) -> Result<Option<ToolOutput>> {
227 let ctx = self.context();
228 self.execute_raw_with_context(name, args, &ctx).await
229 }
230
231 pub async fn execute_raw_with_context(
233 &self,
234 name: &str,
235 args: &serde_json::Value,
236 ctx: &ToolContext,
237 ) -> Result<Option<ToolOutput>> {
238 let tool = self.get(name);
239
240 match tool {
241 Some(tool) => {
242 let mut output = tool.execute(args, ctx).await?;
243 let original_content = output.content.clone();
244 let truncated = truncate_tool_output_with_artifact(name, &output.content);
245 output.content = truncated.content;
246 if let Some(artifact) = truncated.artifact {
247 self.store_tool_artifact(name, &original_content, &artifact);
248 output.metadata = Some(merge_tool_output_artifact_metadata(
249 output.metadata,
250 &artifact,
251 ));
252 }
253 Ok(Some(output))
254 }
255 None => Ok(None),
256 }
257 }
258
259 fn store_tool_artifact(&self, tool_name: &str, content: &str, artifact: &ToolOutputArtifact) {
260 self.artifact_store.put(ToolArtifact {
261 artifact_id: artifact.artifact_id.clone(),
262 artifact_uri: artifact.artifact_uri.clone(),
263 tool_name: tool_name.to_string(),
264 content: content.to_string(),
265 original_bytes: artifact.original_bytes,
266 shown_bytes: artifact.shown_bytes,
267 });
268 }
269
270 fn record_trace_event(&self, name: &str, result: &ToolResult, duration: std::time::Duration) {
271 let sink = self.trace_sink();
272 sink.record(TraceEvent::tool_execution(
273 name,
274 result.exit_code == 0,
275 result.exit_code,
276 duration,
277 result.output.len(),
278 result.metadata.as_ref(),
279 ));
280
281 if name == "program" {
282 sink.record(TraceEvent::program_execution(
283 name,
284 result.exit_code == 0,
285 result.exit_code,
286 duration,
287 result.output.len(),
288 result.metadata.as_ref(),
289 ));
290 }
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297 use crate::trace::{InMemoryTraceSink, TraceEventKind};
298 use async_trait::async_trait;
299
300 struct MockTool {
301 name: String,
302 }
303
304 #[async_trait]
305 impl Tool for MockTool {
306 fn name(&self) -> &str {
307 &self.name
308 }
309
310 fn description(&self) -> &str {
311 "A mock tool for testing"
312 }
313
314 fn parameters(&self) -> serde_json::Value {
315 serde_json::json!({
316 "type": "object",
317 "additionalProperties": false,
318 "properties": {},
319 "required": []
320 })
321 }
322
323 async fn execute(
324 &self,
325 _args: &serde_json::Value,
326 _ctx: &ToolContext,
327 ) -> Result<ToolOutput> {
328 Ok(ToolOutput::success("mock output"))
329 }
330 }
331
332 #[test]
333 fn test_registry_register_and_get() {
334 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
335
336 let tool = Arc::new(MockTool {
337 name: "test".to_string(),
338 });
339 registry.register(tool);
340
341 assert!(registry.contains("test"));
342 assert!(!registry.contains("nonexistent"));
343
344 let retrieved = registry.get("test");
345 assert!(retrieved.is_some());
346 assert_eq!(retrieved.unwrap().name(), "test");
347 }
348
349 #[test]
350 fn test_registry_unregister() {
351 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
352
353 let tool = Arc::new(MockTool {
354 name: "test".to_string(),
355 });
356 registry.register(tool);
357
358 assert!(registry.contains("test"));
359 assert!(registry.unregister("test"));
360 assert!(!registry.contains("test"));
361 assert!(!registry.unregister("test")); }
363
364 #[test]
365 fn test_registry_definitions() {
366 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
367
368 registry.register(Arc::new(MockTool {
369 name: "tool1".to_string(),
370 }));
371 registry.register(Arc::new(MockTool {
372 name: "tool2".to_string(),
373 }));
374
375 let definitions = registry.definitions();
376 assert_eq!(definitions.len(), 2);
377 }
378
379 #[tokio::test]
380 async fn test_registry_execute() {
381 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
382
383 registry.register(Arc::new(MockTool {
384 name: "test".to_string(),
385 }));
386
387 let result = registry
388 .execute("test", &serde_json::json!({}))
389 .await
390 .unwrap();
391 assert_eq!(result.exit_code, 0);
392 assert_eq!(result.output, "mock output");
393 }
394
395 #[tokio::test]
396 async fn test_registry_execute_unknown() {
397 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
398
399 let result = registry
400 .execute("unknown", &serde_json::json!({}))
401 .await
402 .unwrap();
403 assert_eq!(result.exit_code, 1);
404 assert!(result.output.contains("Unknown tool"));
405 }
406
407 #[tokio::test]
408 async fn test_registry_execute_with_context_success() {
409 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
410 let ctx = ToolContext::new(PathBuf::from("/tmp"));
411 let trace_sink = InMemoryTraceSink::default();
412 registry.set_trace_sink(Arc::new(trace_sink.clone()));
413
414 registry.register(Arc::new(MockTool {
415 name: "my_tool".to_string(),
416 }));
417
418 let result = registry
419 .execute_with_context("my_tool", &serde_json::json!({}), &ctx)
420 .await
421 .unwrap();
422 assert_eq!(result.name, "my_tool");
423 assert_eq!(result.exit_code, 0);
424 assert_eq!(result.output, "mock output");
425
426 let events = trace_sink.events();
427 assert_eq!(events.len(), 1);
428 assert_eq!(events[0].kind, TraceEventKind::ToolExecution);
429 assert_eq!(events[0].name, "my_tool");
430 assert!(events[0].success);
431 assert_eq!(events[0].output_bytes, "mock output".len());
432 }
433
434 #[tokio::test]
435 async fn test_registry_execute_with_context_unknown_tool() {
436 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
437 let ctx = ToolContext::new(PathBuf::from("/tmp"));
438
439 let result = registry
440 .execute_with_context("nonexistent", &serde_json::json!({}), &ctx)
441 .await
442 .unwrap();
443 assert_eq!(result.exit_code, 1);
444 assert!(result.output.contains("Unknown tool: nonexistent"));
445 }
446
447 struct FailingTool;
448
449 #[async_trait]
450 impl Tool for FailingTool {
451 fn name(&self) -> &str {
452 "failing"
453 }
454
455 fn description(&self) -> &str {
456 "A tool that returns failure"
457 }
458
459 fn parameters(&self) -> serde_json::Value {
460 serde_json::json!({
461 "type": "object",
462 "additionalProperties": false,
463 "properties": {},
464 "required": []
465 })
466 }
467
468 async fn execute(
469 &self,
470 _args: &serde_json::Value,
471 _ctx: &ToolContext,
472 ) -> Result<ToolOutput> {
473 Ok(ToolOutput::error("something went wrong"))
474 }
475 }
476
477 #[tokio::test]
478 async fn test_registry_execute_failing_tool() {
479 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
480 registry.register(Arc::new(FailingTool));
481
482 let result = registry
483 .execute("failing", &serde_json::json!({}))
484 .await
485 .unwrap();
486 assert_eq!(result.exit_code, 1);
487 assert_eq!(result.output, "something went wrong");
488 }
489
490 struct LargeOutputTool;
491
492 #[async_trait]
493 impl Tool for LargeOutputTool {
494 fn name(&self) -> &str {
495 "large_output"
496 }
497
498 fn description(&self) -> &str {
499 "A tool that returns more than the maximum output size"
500 }
501
502 fn parameters(&self) -> serde_json::Value {
503 serde_json::json!({
504 "type": "object",
505 "additionalProperties": false,
506 "properties": {},
507 "required": []
508 })
509 }
510
511 async fn execute(
512 &self,
513 _args: &serde_json::Value,
514 _ctx: &ToolContext,
515 ) -> Result<ToolOutput> {
516 Ok(ToolOutput::success(
517 "x".repeat(super::super::MAX_OUTPUT_SIZE + 1),
518 ))
519 }
520 }
521
522 #[tokio::test]
523 async fn test_registry_truncates_large_tool_output() {
524 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
525 let trace_sink = InMemoryTraceSink::default();
526 registry.set_trace_sink(Arc::new(trace_sink.clone()));
527 registry.register(Arc::new(LargeOutputTool));
528
529 let result = registry
530 .execute("large_output", &serde_json::json!({}))
531 .await
532 .unwrap();
533
534 assert_eq!(result.exit_code, 0);
535 assert!(result.output.contains("[tool output truncated:"));
536 assert!(result
537 .output
538 .contains("Full output artifact: a3s://tool-output/large_output/"));
539 assert!(result.output.len() < super::super::MAX_OUTPUT_SIZE + 512);
540 let metadata = result.metadata.expect("artifact metadata");
541 assert_eq!(
542 metadata["artifact"]["original_bytes"],
543 serde_json::json!(super::super::MAX_OUTPUT_SIZE + 1)
544 );
545 assert_eq!(
546 metadata["artifact"]["shown_bytes"],
547 serde_json::json!(super::super::MAX_OUTPUT_SIZE)
548 );
549 assert!(metadata["artifact"]["artifact_id"]
550 .as_str()
551 .unwrap()
552 .starts_with("tool-output:large_output:"));
553 assert!(metadata["artifact"]["artifact_uri"]
554 .as_str()
555 .unwrap()
556 .starts_with("a3s://tool-output/large_output/"));
557
558 let artifact_uri = metadata["artifact"]["artifact_uri"].as_str().unwrap();
559 let artifact = registry
560 .get_artifact(artifact_uri)
561 .expect("full output artifact");
562 assert_eq!(artifact.tool_name, "large_output");
563 assert_eq!(artifact.original_bytes, super::super::MAX_OUTPUT_SIZE + 1);
564 assert_eq!(artifact.shown_bytes, super::super::MAX_OUTPUT_SIZE);
565 assert_eq!(
566 artifact.content,
567 "x".repeat(super::super::MAX_OUTPUT_SIZE + 1)
568 );
569
570 let events = trace_sink.events();
571 assert_eq!(events.len(), 1);
572 assert_eq!(events[0].artifact_uris, vec![artifact_uri]);
573 }
574
575 #[tokio::test]
576 async fn test_registry_execute_raw_success() {
577 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
578 registry.register(Arc::new(MockTool {
579 name: "raw_test".to_string(),
580 }));
581
582 let output = registry
583 .execute_raw("raw_test", &serde_json::json!({}))
584 .await
585 .unwrap();
586 assert!(output.is_some());
587 let output = output.unwrap();
588 assert!(output.success);
589 assert_eq!(output.content, "mock output");
590 }
591
592 #[tokio::test]
593 async fn test_registry_execute_raw_stores_truncated_artifact() {
594 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
595 registry.register(Arc::new(LargeOutputTool));
596
597 let output = registry
598 .execute_raw("large_output", &serde_json::json!({}))
599 .await
600 .unwrap()
601 .expect("raw output");
602
603 assert!(output.content.contains("[tool output truncated:"));
604 let metadata = output.metadata.expect("artifact metadata");
605 let artifact_uri = metadata["artifact"]["artifact_uri"].as_str().unwrap();
606 let artifact = registry
607 .get_artifact(artifact_uri)
608 .expect("full output artifact");
609 assert_eq!(artifact.tool_name, "large_output");
610 assert_eq!(artifact.content.len(), super::super::MAX_OUTPUT_SIZE + 1);
611 }
612
613 #[tokio::test]
614 async fn test_registry_execute_raw_unknown() {
615 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
616
617 let output = registry
618 .execute_raw("missing", &serde_json::json!({}))
619 .await
620 .unwrap();
621 assert!(output.is_none());
622 }
623
624 #[test]
625 fn test_registry_list() {
626 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
627 registry.register(Arc::new(MockTool {
628 name: "alpha".to_string(),
629 }));
630 registry.register(Arc::new(MockTool {
631 name: "beta".to_string(),
632 }));
633
634 let names = registry.list();
635 assert_eq!(names.len(), 2);
636 assert!(names.contains(&"alpha".to_string()));
637 assert!(names.contains(&"beta".to_string()));
638 }
639
640 #[test]
641 fn test_registry_len_and_is_empty() {
642 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
643 assert!(registry.is_empty());
644 assert_eq!(registry.len(), 0);
645
646 registry.register(Arc::new(MockTool {
647 name: "t".to_string(),
648 }));
649 assert!(!registry.is_empty());
650 assert_eq!(registry.len(), 1);
651 }
652
653 #[test]
654 fn test_registry_replace_tool() {
655 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
656 registry.register(Arc::new(MockTool {
657 name: "dup".to_string(),
658 }));
659 registry.register(Arc::new(MockTool {
660 name: "dup".to_string(),
661 }));
662 assert_eq!(registry.len(), 1);
664 }
665}