1use super::artifacts::{ArtifactStore, ArtifactStoreLimits, ToolArtifact};
7use super::types::{Tool, ToolContext, ToolOutput};
8use super::ToolResult;
9use super::{
10 merge_tool_output_artifact_metadata, truncate_tool_output_with_artifact, ToolOutputArtifact,
11};
12use crate::llm::ToolDefinition;
13use crate::trace::{InMemoryTraceSink, TraceEvent, TraceSink};
14use anyhow::Result;
15use std::collections::HashMap;
16use std::path::PathBuf;
17use std::sync::{Arc, RwLock};
18
19pub struct ToolRegistry {
21 tools: RwLock<HashMap<String, Arc<dyn Tool>>>,
22 builtins: RwLock<std::collections::HashSet<String>>,
24 context: RwLock<ToolContext>,
25 artifact_store: ArtifactStore,
26 trace_sink: RwLock<Arc<dyn TraceSink>>,
27}
28
29impl ToolRegistry {
30 pub fn new(workspace: PathBuf) -> Self {
32 Self::with_artifact_limits(workspace, ArtifactStoreLimits::default())
33 }
34
35 pub fn with_artifact_limits(workspace: PathBuf, artifact_limits: ArtifactStoreLimits) -> Self {
37 Self::with_artifact_limits_and_workspace_services(
38 workspace.clone(),
39 artifact_limits,
40 crate::workspace::WorkspaceServices::local(workspace),
41 )
42 }
43
44 pub fn with_artifact_limits_and_workspace_services(
46 workspace: PathBuf,
47 artifact_limits: ArtifactStoreLimits,
48 workspace_services: Arc<crate::workspace::WorkspaceServices>,
49 ) -> Self {
50 let context = ToolContext::new(workspace).with_workspace_services(workspace_services);
51 Self {
52 tools: RwLock::new(HashMap::new()),
53 builtins: RwLock::new(std::collections::HashSet::new()),
54 context: RwLock::new(context),
55 artifact_store: ArtifactStore::with_limits(artifact_limits),
56 trace_sink: RwLock::new(Arc::new(InMemoryTraceSink::default())),
57 }
58 }
59
60 pub fn register_builtin(&self, tool: Arc<dyn Tool>) {
62 let name = tool.name().to_string();
63 let mut tools = self.tools.write().unwrap();
64 let mut builtins = self.builtins.write().unwrap();
65 tracing::debug!("Registering builtin tool: {}", name);
66 tools.insert(name.clone(), tool);
67 builtins.insert(name);
68 }
69
70 pub fn register(&self, tool: Arc<dyn Tool>) {
75 let name = tool.name().to_string();
76 let builtins = self.builtins.read().unwrap();
77 if builtins.contains(&name) {
78 tracing::warn!(
79 "Rejected registration of tool '{}': cannot shadow builtin",
80 name
81 );
82 return;
83 }
84 drop(builtins);
85 let mut tools = self.tools.write().unwrap();
86 tracing::debug!("Registering tool: {}", name);
87 tools.insert(name, tool);
88 }
89
90 pub fn unregister(&self, name: &str) -> bool {
94 let mut tools = self.tools.write().unwrap();
95 tracing::debug!("Unregistering tool: {}", name);
96 tools.remove(name).is_some()
97 }
98
99 pub fn unregister_by_prefix(&self, prefix: &str) {
101 let mut tools = self.tools.write().unwrap();
102 tools.retain(|name, _| !name.starts_with(prefix));
103 tracing::debug!("Unregistered tools with prefix: {}", prefix);
104 }
105
106 pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
108 let tools = self.tools.read().unwrap();
109 tools.get(name).cloned()
110 }
111
112 pub fn contains(&self, name: &str) -> bool {
114 let tools = self.tools.read().unwrap();
115 tools.contains_key(name)
116 }
117
118 pub fn definitions(&self) -> Vec<ToolDefinition> {
120 let tools = self.tools.read().unwrap();
121 tools
122 .values()
123 .map(|tool| ToolDefinition {
124 name: tool.name().to_string(),
125 description: tool.description().to_string(),
126 parameters: tool.parameters(),
127 })
128 .collect()
129 }
130
131 pub fn list(&self) -> Vec<String> {
133 let tools = self.tools.read().unwrap();
134 tools.keys().cloned().collect()
135 }
136
137 pub fn len(&self) -> usize {
139 let tools = self.tools.read().unwrap();
140 tools.len()
141 }
142
143 pub fn is_empty(&self) -> bool {
145 self.len() == 0
146 }
147
148 pub fn context(&self) -> ToolContext {
150 self.context.read().unwrap().clone()
151 }
152
153 pub fn artifact_store(&self) -> ArtifactStore {
155 self.artifact_store.clone()
156 }
157
158 pub fn get_artifact(&self, artifact_uri: &str) -> Option<ToolArtifact> {
160 self.artifact_store.get(artifact_uri)
161 }
162
163 pub fn set_trace_sink(&self, sink: Arc<dyn TraceSink>) {
165 *self.trace_sink.write().unwrap() = sink;
166 }
167
168 pub fn trace_sink(&self) -> Arc<dyn TraceSink> {
170 Arc::clone(&self.trace_sink.read().unwrap())
171 }
172
173 pub fn set_search_config(&self, config: crate::config::SearchConfig) {
175 let mut ctx = self.context.write().unwrap();
176 *ctx = ctx.clone().with_search_config(config);
177 }
178
179 pub fn set_sandbox(&self, sandbox: std::sync::Arc<dyn crate::sandbox::BashSandbox>) {
182 let mut ctx = self.context.write().unwrap();
183 *ctx = ctx.clone().with_sandbox(sandbox);
184 }
185
186 pub fn set_command_env(&self, env: Arc<HashMap<String, String>>) {
189 let mut ctx = self.context.write().unwrap();
190 *ctx = ctx.clone().with_command_env(env);
191 }
192
193 pub async fn execute(&self, name: &str, args: &serde_json::Value) -> Result<ToolResult> {
195 let ctx = self.context();
196 self.execute_with_context(name, args, &ctx).await
197 }
198
199 pub async fn execute_with_context(
201 &self,
202 name: &str,
203 args: &serde_json::Value,
204 ctx: &ToolContext,
205 ) -> Result<ToolResult> {
206 let start = std::time::Instant::now();
207
208 let tool = self.get(name);
209
210 let result = match tool {
211 Some(tool) => {
212 let mut output = tool.execute(args, ctx).await?;
213 let original_content = output.content.clone();
214 let truncated = truncate_tool_output_with_artifact(name, &output.content);
215 output.content = truncated.content;
216 if let Some(artifact) = truncated.artifact {
217 self.store_tool_artifact(name, &original_content, &artifact);
218 output.metadata = Some(merge_tool_output_artifact_metadata(
219 output.metadata,
220 &artifact,
221 ));
222 }
223 Ok(ToolResult {
224 name: name.to_string(),
225 output: output.content,
226 exit_code: if output.success { 0 } else { 1 },
227 metadata: output.metadata,
228 images: output.images,
229 error_kind: output.error_kind,
230 })
231 }
232 None => Ok(ToolResult::error(name, format!("Unknown tool: {}", name))),
233 };
234
235 if let Ok(ref r) = result {
236 crate::telemetry::record_tool_result(r.exit_code, start.elapsed());
237 self.record_trace_event(name, r, start.elapsed());
238 }
239
240 result
241 }
242
243 pub async fn execute_raw(
245 &self,
246 name: &str,
247 args: &serde_json::Value,
248 ) -> Result<Option<ToolOutput>> {
249 let ctx = self.context();
250 self.execute_raw_with_context(name, args, &ctx).await
251 }
252
253 pub async fn execute_raw_with_context(
255 &self,
256 name: &str,
257 args: &serde_json::Value,
258 ctx: &ToolContext,
259 ) -> Result<Option<ToolOutput>> {
260 let tool = self.get(name);
261
262 match tool {
263 Some(tool) => {
264 let mut output = tool.execute(args, ctx).await?;
265 let original_content = output.content.clone();
266 let truncated = truncate_tool_output_with_artifact(name, &output.content);
267 output.content = truncated.content;
268 if let Some(artifact) = truncated.artifact {
269 self.store_tool_artifact(name, &original_content, &artifact);
270 output.metadata = Some(merge_tool_output_artifact_metadata(
271 output.metadata,
272 &artifact,
273 ));
274 }
275 Ok(Some(output))
276 }
277 None => Ok(None),
278 }
279 }
280
281 fn store_tool_artifact(&self, tool_name: &str, content: &str, artifact: &ToolOutputArtifact) {
282 self.artifact_store.put(ToolArtifact {
283 artifact_id: artifact.artifact_id.clone(),
284 artifact_uri: artifact.artifact_uri.clone(),
285 tool_name: tool_name.to_string(),
286 content: content.to_string(),
287 original_bytes: artifact.original_bytes,
288 shown_bytes: artifact.shown_bytes,
289 });
290 }
291
292 fn record_trace_event(&self, name: &str, result: &ToolResult, duration: std::time::Duration) {
293 let sink = self.trace_sink();
294 sink.record(TraceEvent::tool_execution(
295 name,
296 result.exit_code == 0,
297 result.exit_code,
298 duration,
299 result.output.len(),
300 result.metadata.as_ref(),
301 ));
302
303 if name == "program" {
304 sink.record(TraceEvent::program_execution(
305 name,
306 result.exit_code == 0,
307 result.exit_code,
308 duration,
309 result.output.len(),
310 result.metadata.as_ref(),
311 ));
312 }
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319 use crate::trace::{InMemoryTraceSink, TraceEventKind};
320 use async_trait::async_trait;
321
322 struct MockTool {
323 name: String,
324 }
325
326 #[async_trait]
327 impl Tool for MockTool {
328 fn name(&self) -> &str {
329 &self.name
330 }
331
332 fn description(&self) -> &str {
333 "A mock tool for testing"
334 }
335
336 fn parameters(&self) -> serde_json::Value {
337 serde_json::json!({
338 "type": "object",
339 "additionalProperties": false,
340 "properties": {},
341 "required": []
342 })
343 }
344
345 async fn execute(
346 &self,
347 _args: &serde_json::Value,
348 _ctx: &ToolContext,
349 ) -> Result<ToolOutput> {
350 Ok(ToolOutput::success("mock output"))
351 }
352 }
353
354 #[test]
355 fn test_registry_register_and_get() {
356 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
357
358 let tool = Arc::new(MockTool {
359 name: "test".to_string(),
360 });
361 registry.register(tool);
362
363 assert!(registry.contains("test"));
364 assert!(!registry.contains("nonexistent"));
365
366 let retrieved = registry.get("test");
367 assert!(retrieved.is_some());
368 assert_eq!(retrieved.unwrap().name(), "test");
369 }
370
371 #[test]
372 fn test_registry_unregister() {
373 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
374
375 let tool = Arc::new(MockTool {
376 name: "test".to_string(),
377 });
378 registry.register(tool);
379
380 assert!(registry.contains("test"));
381 assert!(registry.unregister("test"));
382 assert!(!registry.contains("test"));
383 assert!(!registry.unregister("test")); }
385
386 #[test]
387 fn test_registry_definitions() {
388 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
389
390 registry.register(Arc::new(MockTool {
391 name: "tool1".to_string(),
392 }));
393 registry.register(Arc::new(MockTool {
394 name: "tool2".to_string(),
395 }));
396
397 let definitions = registry.definitions();
398 assert_eq!(definitions.len(), 2);
399 }
400
401 #[tokio::test]
402 async fn test_registry_execute() {
403 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
404
405 registry.register(Arc::new(MockTool {
406 name: "test".to_string(),
407 }));
408
409 let result = registry
410 .execute("test", &serde_json::json!({}))
411 .await
412 .unwrap();
413 assert_eq!(result.exit_code, 0);
414 assert_eq!(result.output, "mock output");
415 }
416
417 #[tokio::test]
418 async fn test_registry_execute_unknown() {
419 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
420
421 let result = registry
422 .execute("unknown", &serde_json::json!({}))
423 .await
424 .unwrap();
425 assert_eq!(result.exit_code, 1);
426 assert!(result.output.contains("Unknown tool"));
427 }
428
429 #[tokio::test]
430 async fn test_registry_execute_with_context_success() {
431 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
432 let ctx = ToolContext::new(PathBuf::from("/tmp"));
433 let trace_sink = InMemoryTraceSink::default();
434 registry.set_trace_sink(Arc::new(trace_sink.clone()));
435
436 registry.register(Arc::new(MockTool {
437 name: "my_tool".to_string(),
438 }));
439
440 let result = registry
441 .execute_with_context("my_tool", &serde_json::json!({}), &ctx)
442 .await
443 .unwrap();
444 assert_eq!(result.name, "my_tool");
445 assert_eq!(result.exit_code, 0);
446 assert_eq!(result.output, "mock output");
447
448 let events = trace_sink.events();
449 assert_eq!(events.len(), 1);
450 assert_eq!(events[0].kind, TraceEventKind::ToolExecution);
451 assert_eq!(events[0].name, "my_tool");
452 assert!(events[0].success);
453 assert_eq!(events[0].output_bytes, "mock output".len());
454 }
455
456 #[tokio::test]
457 async fn test_registry_execute_with_context_unknown_tool() {
458 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
459 let ctx = ToolContext::new(PathBuf::from("/tmp"));
460
461 let result = registry
462 .execute_with_context("nonexistent", &serde_json::json!({}), &ctx)
463 .await
464 .unwrap();
465 assert_eq!(result.exit_code, 1);
466 assert!(result.output.contains("Unknown tool: nonexistent"));
467 }
468
469 struct FailingTool;
470
471 #[async_trait]
472 impl Tool for FailingTool {
473 fn name(&self) -> &str {
474 "failing"
475 }
476
477 fn description(&self) -> &str {
478 "A tool that returns failure"
479 }
480
481 fn parameters(&self) -> serde_json::Value {
482 serde_json::json!({
483 "type": "object",
484 "additionalProperties": false,
485 "properties": {},
486 "required": []
487 })
488 }
489
490 async fn execute(
491 &self,
492 _args: &serde_json::Value,
493 _ctx: &ToolContext,
494 ) -> Result<ToolOutput> {
495 Ok(ToolOutput::error("something went wrong"))
496 }
497 }
498
499 #[tokio::test]
500 async fn test_registry_execute_failing_tool() {
501 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
502 registry.register(Arc::new(FailingTool));
503
504 let result = registry
505 .execute("failing", &serde_json::json!({}))
506 .await
507 .unwrap();
508 assert_eq!(result.exit_code, 1);
509 assert_eq!(result.output, "something went wrong");
510 }
511
512 struct LargeOutputTool;
513
514 #[async_trait]
515 impl Tool for LargeOutputTool {
516 fn name(&self) -> &str {
517 "large_output"
518 }
519
520 fn description(&self) -> &str {
521 "A tool that returns more than the maximum output size"
522 }
523
524 fn parameters(&self) -> serde_json::Value {
525 serde_json::json!({
526 "type": "object",
527 "additionalProperties": false,
528 "properties": {},
529 "required": []
530 })
531 }
532
533 async fn execute(
534 &self,
535 _args: &serde_json::Value,
536 _ctx: &ToolContext,
537 ) -> Result<ToolOutput> {
538 Ok(ToolOutput::success(
539 "x".repeat(super::super::MAX_OUTPUT_SIZE + 1),
540 ))
541 }
542 }
543
544 #[tokio::test]
545 async fn test_registry_truncates_large_tool_output() {
546 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
547 let trace_sink = InMemoryTraceSink::default();
548 registry.set_trace_sink(Arc::new(trace_sink.clone()));
549 registry.register(Arc::new(LargeOutputTool));
550
551 let result = registry
552 .execute("large_output", &serde_json::json!({}))
553 .await
554 .unwrap();
555
556 assert_eq!(result.exit_code, 0);
557 assert!(result.output.contains("[tool output truncated:"));
558 assert!(result
559 .output
560 .contains("Full output artifact: a3s://tool-output/large_output/"));
561 assert!(result.output.len() < super::super::MAX_OUTPUT_SIZE + 512);
562 let metadata = result.metadata.expect("artifact metadata");
563 assert_eq!(
564 metadata["artifact"]["original_bytes"],
565 serde_json::json!(super::super::MAX_OUTPUT_SIZE + 1)
566 );
567 assert_eq!(
568 metadata["artifact"]["shown_bytes"],
569 serde_json::json!(super::super::MAX_OUTPUT_SIZE)
570 );
571 assert!(metadata["artifact"]["artifact_id"]
572 .as_str()
573 .unwrap()
574 .starts_with("tool-output:large_output:"));
575 assert!(metadata["artifact"]["artifact_uri"]
576 .as_str()
577 .unwrap()
578 .starts_with("a3s://tool-output/large_output/"));
579
580 let artifact_uri = metadata["artifact"]["artifact_uri"].as_str().unwrap();
581 let artifact = registry
582 .get_artifact(artifact_uri)
583 .expect("full output artifact");
584 assert_eq!(artifact.tool_name, "large_output");
585 assert_eq!(artifact.original_bytes, super::super::MAX_OUTPUT_SIZE + 1);
586 assert_eq!(artifact.shown_bytes, super::super::MAX_OUTPUT_SIZE);
587 assert_eq!(
588 artifact.content,
589 "x".repeat(super::super::MAX_OUTPUT_SIZE + 1)
590 );
591
592 let events = trace_sink.events();
593 assert_eq!(events.len(), 1);
594 assert_eq!(events[0].artifact_uris, vec![artifact_uri]);
595 }
596
597 #[tokio::test]
598 async fn test_registry_execute_raw_success() {
599 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
600 registry.register(Arc::new(MockTool {
601 name: "raw_test".to_string(),
602 }));
603
604 let output = registry
605 .execute_raw("raw_test", &serde_json::json!({}))
606 .await
607 .unwrap();
608 assert!(output.is_some());
609 let output = output.unwrap();
610 assert!(output.success);
611 assert_eq!(output.content, "mock output");
612 }
613
614 #[tokio::test]
615 async fn test_registry_execute_raw_stores_truncated_artifact() {
616 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
617 registry.register(Arc::new(LargeOutputTool));
618
619 let output = registry
620 .execute_raw("large_output", &serde_json::json!({}))
621 .await
622 .unwrap()
623 .expect("raw output");
624
625 assert!(output.content.contains("[tool output truncated:"));
626 let metadata = output.metadata.expect("artifact metadata");
627 let artifact_uri = metadata["artifact"]["artifact_uri"].as_str().unwrap();
628 let artifact = registry
629 .get_artifact(artifact_uri)
630 .expect("full output artifact");
631 assert_eq!(artifact.tool_name, "large_output");
632 assert_eq!(artifact.content.len(), super::super::MAX_OUTPUT_SIZE + 1);
633 }
634
635 #[tokio::test]
636 async fn test_registry_execute_raw_unknown() {
637 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
638
639 let output = registry
640 .execute_raw("missing", &serde_json::json!({}))
641 .await
642 .unwrap();
643 assert!(output.is_none());
644 }
645
646 #[test]
647 fn test_registry_list() {
648 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
649 registry.register(Arc::new(MockTool {
650 name: "alpha".to_string(),
651 }));
652 registry.register(Arc::new(MockTool {
653 name: "beta".to_string(),
654 }));
655
656 let names = registry.list();
657 assert_eq!(names.len(), 2);
658 assert!(names.contains(&"alpha".to_string()));
659 assert!(names.contains(&"beta".to_string()));
660 }
661
662 #[test]
663 fn test_registry_len_and_is_empty() {
664 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
665 assert!(registry.is_empty());
666 assert_eq!(registry.len(), 0);
667
668 registry.register(Arc::new(MockTool {
669 name: "t".to_string(),
670 }));
671 assert!(!registry.is_empty());
672 assert_eq!(registry.len(), 1);
673 }
674
675 #[test]
676 fn test_registry_replace_tool() {
677 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
678 registry.register(Arc::new(MockTool {
679 name: "dup".to_string(),
680 }));
681 registry.register(Arc::new(MockTool {
682 name: "dup".to_string(),
683 }));
684 assert_eq!(registry.len(), 1);
686 }
687}