1use super::artifacts::{ArtifactStore, ArtifactStoreLimits, ToolArtifact};
7use super::types::{Tool, ToolContext, ToolOutput};
8use super::ToolResult;
9use super::{
10 merge_tool_output_artifact_metadata, truncate_tool_output_with_artifact, ToolOutputArtifact,
11};
12use crate::llm::ToolDefinition;
13use crate::trace::{InMemoryTraceSink, TraceEvent, TraceSink};
14use anyhow::Result;
15use std::collections::HashMap;
16use std::path::PathBuf;
17use std::sync::{Arc, RwLock};
18
19pub struct ToolRegistry {
21 tools: RwLock<HashMap<String, Arc<dyn Tool>>>,
22 builtins: RwLock<std::collections::HashSet<String>>,
24 context: RwLock<ToolContext>,
25 artifact_store: ArtifactStore,
26 trace_sink: RwLock<Arc<dyn TraceSink>>,
27}
28
29impl ToolRegistry {
30 pub fn new(workspace: PathBuf) -> Self {
32 Self::with_artifact_limits(workspace, ArtifactStoreLimits::default())
33 }
34
35 pub fn with_artifact_limits(workspace: PathBuf, artifact_limits: ArtifactStoreLimits) -> Self {
37 Self::with_artifact_limits_and_workspace_services(
38 workspace.clone(),
39 artifact_limits,
40 crate::workspace::WorkspaceServices::local(workspace),
41 )
42 }
43
44 pub fn with_artifact_limits_and_workspace_services(
46 workspace: PathBuf,
47 artifact_limits: ArtifactStoreLimits,
48 workspace_services: Arc<crate::workspace::WorkspaceServices>,
49 ) -> Self {
50 let context = ToolContext::new(workspace).with_workspace_services(workspace_services);
51 Self {
52 tools: RwLock::new(HashMap::new()),
53 builtins: RwLock::new(std::collections::HashSet::new()),
54 context: RwLock::new(context),
55 artifact_store: ArtifactStore::with_limits(artifact_limits),
56 trace_sink: RwLock::new(Arc::new(InMemoryTraceSink::default())),
57 }
58 }
59
60 pub fn register_builtin(&self, tool: Arc<dyn Tool>) {
62 let name = tool.name().to_string();
63 let mut tools = self.tools.write().unwrap();
64 let mut builtins = self.builtins.write().unwrap();
65 tracing::debug!("Registering builtin tool: {}", name);
66 tools.insert(name.clone(), tool);
67 builtins.insert(name);
68 }
69
70 pub fn register(&self, tool: Arc<dyn Tool>) {
75 let name = tool.name().to_string();
76 let builtins = self.builtins.read().unwrap();
77 if builtins.contains(&name) {
78 tracing::warn!(
79 "Rejected registration of tool '{}': cannot shadow builtin",
80 name
81 );
82 return;
83 }
84 drop(builtins);
85 let mut tools = self.tools.write().unwrap();
86 tracing::debug!("Registering tool: {}", name);
87 tools.insert(name, tool);
88 }
89
90 pub fn unregister(&self, name: &str) -> bool {
94 let mut tools = self.tools.write().unwrap();
95 tracing::debug!("Unregistering tool: {}", name);
96 tools.remove(name).is_some()
97 }
98
99 pub fn unregister_by_prefix(&self, prefix: &str) {
101 let mut tools = self.tools.write().unwrap();
102 tools.retain(|name, _| !name.starts_with(prefix));
103 tracing::debug!("Unregistered tools with prefix: {}", prefix);
104 }
105
106 pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
108 let tools = self.tools.read().unwrap();
109 tools.get(name).cloned()
110 }
111
112 pub fn contains(&self, name: &str) -> bool {
114 let tools = self.tools.read().unwrap();
115 tools.contains_key(name)
116 }
117
118 pub fn definitions(&self) -> Vec<ToolDefinition> {
120 let tools = self.tools.read().unwrap();
121 tools
122 .values()
123 .map(|tool| ToolDefinition {
124 name: tool.name().to_string(),
125 description: tool.description().to_string(),
126 parameters: tool.parameters(),
127 })
128 .collect()
129 }
130
131 pub fn list(&self) -> Vec<String> {
133 let tools = self.tools.read().unwrap();
134 tools.keys().cloned().collect()
135 }
136
137 pub fn len(&self) -> usize {
139 let tools = self.tools.read().unwrap();
140 tools.len()
141 }
142
143 pub fn is_empty(&self) -> bool {
145 self.len() == 0
146 }
147
148 pub fn context(&self) -> ToolContext {
150 self.context.read().unwrap().clone()
151 }
152
153 pub fn artifact_store(&self) -> ArtifactStore {
155 self.artifact_store.clone()
156 }
157
158 pub fn get_artifact(&self, artifact_uri: &str) -> Option<ToolArtifact> {
160 self.artifact_store.get(artifact_uri)
161 }
162
163 pub fn set_trace_sink(&self, sink: Arc<dyn TraceSink>) {
165 *self.trace_sink.write().unwrap() = sink;
166 }
167
168 pub fn trace_sink(&self) -> Arc<dyn TraceSink> {
170 Arc::clone(&self.trace_sink.read().unwrap())
171 }
172
173 pub fn set_search_config(&self, config: crate::config::SearchConfig) {
175 let mut ctx = self.context.write().unwrap();
176 *ctx = ctx.clone().with_search_config(config);
177 }
178
179 pub fn set_sandbox(&self, sandbox: std::sync::Arc<dyn crate::sandbox::BashSandbox>) {
182 let mut ctx = self.context.write().unwrap();
183 *ctx = ctx.clone().with_sandbox(sandbox);
184 }
185
186 pub fn set_command_env(&self, env: Arc<HashMap<String, String>>) {
189 let mut ctx = self.context.write().unwrap();
190 *ctx = ctx.clone().with_command_env(env);
191 }
192
193 pub async fn execute(&self, name: &str, args: &serde_json::Value) -> Result<ToolResult> {
195 let ctx = self.context();
196 self.execute_with_context(name, args, &ctx).await
197 }
198
199 pub async fn execute_with_context(
201 &self,
202 name: &str,
203 args: &serde_json::Value,
204 ctx: &ToolContext,
205 ) -> Result<ToolResult> {
206 let start = std::time::Instant::now();
207
208 let tool = self.get(name);
209
210 let result = match tool {
211 Some(tool) => {
212 let mut output = tool.execute(args, ctx).await?;
213 let original_content = output.content.clone();
214 let truncated = truncate_tool_output_with_artifact(name, &output.content);
215 output.content = truncated.content;
216 if let Some(artifact) = truncated.artifact {
217 self.store_tool_artifact(name, &original_content, &artifact);
218 output.metadata = Some(merge_tool_output_artifact_metadata(
219 output.metadata,
220 &artifact,
221 ));
222 }
223 Ok(ToolResult {
224 name: name.to_string(),
225 output: output.content,
226 exit_code: if output.success { 0 } else { 1 },
227 metadata: output.metadata,
228 images: output.images,
229 })
230 }
231 None => Ok(ToolResult::error(name, format!("Unknown tool: {}", name))),
232 };
233
234 if let Ok(ref r) = result {
235 crate::telemetry::record_tool_result(r.exit_code, start.elapsed());
236 self.record_trace_event(name, r, start.elapsed());
237 }
238
239 result
240 }
241
242 pub async fn execute_raw(
244 &self,
245 name: &str,
246 args: &serde_json::Value,
247 ) -> Result<Option<ToolOutput>> {
248 let ctx = self.context();
249 self.execute_raw_with_context(name, args, &ctx).await
250 }
251
252 pub async fn execute_raw_with_context(
254 &self,
255 name: &str,
256 args: &serde_json::Value,
257 ctx: &ToolContext,
258 ) -> Result<Option<ToolOutput>> {
259 let tool = self.get(name);
260
261 match tool {
262 Some(tool) => {
263 let mut output = tool.execute(args, ctx).await?;
264 let original_content = output.content.clone();
265 let truncated = truncate_tool_output_with_artifact(name, &output.content);
266 output.content = truncated.content;
267 if let Some(artifact) = truncated.artifact {
268 self.store_tool_artifact(name, &original_content, &artifact);
269 output.metadata = Some(merge_tool_output_artifact_metadata(
270 output.metadata,
271 &artifact,
272 ));
273 }
274 Ok(Some(output))
275 }
276 None => Ok(None),
277 }
278 }
279
280 fn store_tool_artifact(&self, tool_name: &str, content: &str, artifact: &ToolOutputArtifact) {
281 self.artifact_store.put(ToolArtifact {
282 artifact_id: artifact.artifact_id.clone(),
283 artifact_uri: artifact.artifact_uri.clone(),
284 tool_name: tool_name.to_string(),
285 content: content.to_string(),
286 original_bytes: artifact.original_bytes,
287 shown_bytes: artifact.shown_bytes,
288 });
289 }
290
291 fn record_trace_event(&self, name: &str, result: &ToolResult, duration: std::time::Duration) {
292 let sink = self.trace_sink();
293 sink.record(TraceEvent::tool_execution(
294 name,
295 result.exit_code == 0,
296 result.exit_code,
297 duration,
298 result.output.len(),
299 result.metadata.as_ref(),
300 ));
301
302 if name == "program" {
303 sink.record(TraceEvent::program_execution(
304 name,
305 result.exit_code == 0,
306 result.exit_code,
307 duration,
308 result.output.len(),
309 result.metadata.as_ref(),
310 ));
311 }
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318 use crate::trace::{InMemoryTraceSink, TraceEventKind};
319 use async_trait::async_trait;
320
321 struct MockTool {
322 name: String,
323 }
324
325 #[async_trait]
326 impl Tool for MockTool {
327 fn name(&self) -> &str {
328 &self.name
329 }
330
331 fn description(&self) -> &str {
332 "A mock tool for testing"
333 }
334
335 fn parameters(&self) -> serde_json::Value {
336 serde_json::json!({
337 "type": "object",
338 "additionalProperties": false,
339 "properties": {},
340 "required": []
341 })
342 }
343
344 async fn execute(
345 &self,
346 _args: &serde_json::Value,
347 _ctx: &ToolContext,
348 ) -> Result<ToolOutput> {
349 Ok(ToolOutput::success("mock output"))
350 }
351 }
352
353 #[test]
354 fn test_registry_register_and_get() {
355 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
356
357 let tool = Arc::new(MockTool {
358 name: "test".to_string(),
359 });
360 registry.register(tool);
361
362 assert!(registry.contains("test"));
363 assert!(!registry.contains("nonexistent"));
364
365 let retrieved = registry.get("test");
366 assert!(retrieved.is_some());
367 assert_eq!(retrieved.unwrap().name(), "test");
368 }
369
370 #[test]
371 fn test_registry_unregister() {
372 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
373
374 let tool = Arc::new(MockTool {
375 name: "test".to_string(),
376 });
377 registry.register(tool);
378
379 assert!(registry.contains("test"));
380 assert!(registry.unregister("test"));
381 assert!(!registry.contains("test"));
382 assert!(!registry.unregister("test")); }
384
385 #[test]
386 fn test_registry_definitions() {
387 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
388
389 registry.register(Arc::new(MockTool {
390 name: "tool1".to_string(),
391 }));
392 registry.register(Arc::new(MockTool {
393 name: "tool2".to_string(),
394 }));
395
396 let definitions = registry.definitions();
397 assert_eq!(definitions.len(), 2);
398 }
399
400 #[tokio::test]
401 async fn test_registry_execute() {
402 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
403
404 registry.register(Arc::new(MockTool {
405 name: "test".to_string(),
406 }));
407
408 let result = registry
409 .execute("test", &serde_json::json!({}))
410 .await
411 .unwrap();
412 assert_eq!(result.exit_code, 0);
413 assert_eq!(result.output, "mock output");
414 }
415
416 #[tokio::test]
417 async fn test_registry_execute_unknown() {
418 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
419
420 let result = registry
421 .execute("unknown", &serde_json::json!({}))
422 .await
423 .unwrap();
424 assert_eq!(result.exit_code, 1);
425 assert!(result.output.contains("Unknown tool"));
426 }
427
428 #[tokio::test]
429 async fn test_registry_execute_with_context_success() {
430 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
431 let ctx = ToolContext::new(PathBuf::from("/tmp"));
432 let trace_sink = InMemoryTraceSink::default();
433 registry.set_trace_sink(Arc::new(trace_sink.clone()));
434
435 registry.register(Arc::new(MockTool {
436 name: "my_tool".to_string(),
437 }));
438
439 let result = registry
440 .execute_with_context("my_tool", &serde_json::json!({}), &ctx)
441 .await
442 .unwrap();
443 assert_eq!(result.name, "my_tool");
444 assert_eq!(result.exit_code, 0);
445 assert_eq!(result.output, "mock output");
446
447 let events = trace_sink.events();
448 assert_eq!(events.len(), 1);
449 assert_eq!(events[0].kind, TraceEventKind::ToolExecution);
450 assert_eq!(events[0].name, "my_tool");
451 assert!(events[0].success);
452 assert_eq!(events[0].output_bytes, "mock output".len());
453 }
454
455 #[tokio::test]
456 async fn test_registry_execute_with_context_unknown_tool() {
457 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
458 let ctx = ToolContext::new(PathBuf::from("/tmp"));
459
460 let result = registry
461 .execute_with_context("nonexistent", &serde_json::json!({}), &ctx)
462 .await
463 .unwrap();
464 assert_eq!(result.exit_code, 1);
465 assert!(result.output.contains("Unknown tool: nonexistent"));
466 }
467
468 struct FailingTool;
469
470 #[async_trait]
471 impl Tool for FailingTool {
472 fn name(&self) -> &str {
473 "failing"
474 }
475
476 fn description(&self) -> &str {
477 "A tool that returns failure"
478 }
479
480 fn parameters(&self) -> serde_json::Value {
481 serde_json::json!({
482 "type": "object",
483 "additionalProperties": false,
484 "properties": {},
485 "required": []
486 })
487 }
488
489 async fn execute(
490 &self,
491 _args: &serde_json::Value,
492 _ctx: &ToolContext,
493 ) -> Result<ToolOutput> {
494 Ok(ToolOutput::error("something went wrong"))
495 }
496 }
497
498 #[tokio::test]
499 async fn test_registry_execute_failing_tool() {
500 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
501 registry.register(Arc::new(FailingTool));
502
503 let result = registry
504 .execute("failing", &serde_json::json!({}))
505 .await
506 .unwrap();
507 assert_eq!(result.exit_code, 1);
508 assert_eq!(result.output, "something went wrong");
509 }
510
511 struct LargeOutputTool;
512
513 #[async_trait]
514 impl Tool for LargeOutputTool {
515 fn name(&self) -> &str {
516 "large_output"
517 }
518
519 fn description(&self) -> &str {
520 "A tool that returns more than the maximum output size"
521 }
522
523 fn parameters(&self) -> serde_json::Value {
524 serde_json::json!({
525 "type": "object",
526 "additionalProperties": false,
527 "properties": {},
528 "required": []
529 })
530 }
531
532 async fn execute(
533 &self,
534 _args: &serde_json::Value,
535 _ctx: &ToolContext,
536 ) -> Result<ToolOutput> {
537 Ok(ToolOutput::success(
538 "x".repeat(super::super::MAX_OUTPUT_SIZE + 1),
539 ))
540 }
541 }
542
543 #[tokio::test]
544 async fn test_registry_truncates_large_tool_output() {
545 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
546 let trace_sink = InMemoryTraceSink::default();
547 registry.set_trace_sink(Arc::new(trace_sink.clone()));
548 registry.register(Arc::new(LargeOutputTool));
549
550 let result = registry
551 .execute("large_output", &serde_json::json!({}))
552 .await
553 .unwrap();
554
555 assert_eq!(result.exit_code, 0);
556 assert!(result.output.contains("[tool output truncated:"));
557 assert!(result
558 .output
559 .contains("Full output artifact: a3s://tool-output/large_output/"));
560 assert!(result.output.len() < super::super::MAX_OUTPUT_SIZE + 512);
561 let metadata = result.metadata.expect("artifact metadata");
562 assert_eq!(
563 metadata["artifact"]["original_bytes"],
564 serde_json::json!(super::super::MAX_OUTPUT_SIZE + 1)
565 );
566 assert_eq!(
567 metadata["artifact"]["shown_bytes"],
568 serde_json::json!(super::super::MAX_OUTPUT_SIZE)
569 );
570 assert!(metadata["artifact"]["artifact_id"]
571 .as_str()
572 .unwrap()
573 .starts_with("tool-output:large_output:"));
574 assert!(metadata["artifact"]["artifact_uri"]
575 .as_str()
576 .unwrap()
577 .starts_with("a3s://tool-output/large_output/"));
578
579 let artifact_uri = metadata["artifact"]["artifact_uri"].as_str().unwrap();
580 let artifact = registry
581 .get_artifact(artifact_uri)
582 .expect("full output artifact");
583 assert_eq!(artifact.tool_name, "large_output");
584 assert_eq!(artifact.original_bytes, super::super::MAX_OUTPUT_SIZE + 1);
585 assert_eq!(artifact.shown_bytes, super::super::MAX_OUTPUT_SIZE);
586 assert_eq!(
587 artifact.content,
588 "x".repeat(super::super::MAX_OUTPUT_SIZE + 1)
589 );
590
591 let events = trace_sink.events();
592 assert_eq!(events.len(), 1);
593 assert_eq!(events[0].artifact_uris, vec![artifact_uri]);
594 }
595
596 #[tokio::test]
597 async fn test_registry_execute_raw_success() {
598 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
599 registry.register(Arc::new(MockTool {
600 name: "raw_test".to_string(),
601 }));
602
603 let output = registry
604 .execute_raw("raw_test", &serde_json::json!({}))
605 .await
606 .unwrap();
607 assert!(output.is_some());
608 let output = output.unwrap();
609 assert!(output.success);
610 assert_eq!(output.content, "mock output");
611 }
612
613 #[tokio::test]
614 async fn test_registry_execute_raw_stores_truncated_artifact() {
615 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
616 registry.register(Arc::new(LargeOutputTool));
617
618 let output = registry
619 .execute_raw("large_output", &serde_json::json!({}))
620 .await
621 .unwrap()
622 .expect("raw output");
623
624 assert!(output.content.contains("[tool output truncated:"));
625 let metadata = output.metadata.expect("artifact metadata");
626 let artifact_uri = metadata["artifact"]["artifact_uri"].as_str().unwrap();
627 let artifact = registry
628 .get_artifact(artifact_uri)
629 .expect("full output artifact");
630 assert_eq!(artifact.tool_name, "large_output");
631 assert_eq!(artifact.content.len(), super::super::MAX_OUTPUT_SIZE + 1);
632 }
633
634 #[tokio::test]
635 async fn test_registry_execute_raw_unknown() {
636 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
637
638 let output = registry
639 .execute_raw("missing", &serde_json::json!({}))
640 .await
641 .unwrap();
642 assert!(output.is_none());
643 }
644
645 #[test]
646 fn test_registry_list() {
647 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
648 registry.register(Arc::new(MockTool {
649 name: "alpha".to_string(),
650 }));
651 registry.register(Arc::new(MockTool {
652 name: "beta".to_string(),
653 }));
654
655 let names = registry.list();
656 assert_eq!(names.len(), 2);
657 assert!(names.contains(&"alpha".to_string()));
658 assert!(names.contains(&"beta".to_string()));
659 }
660
661 #[test]
662 fn test_registry_len_and_is_empty() {
663 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
664 assert!(registry.is_empty());
665 assert_eq!(registry.len(), 0);
666
667 registry.register(Arc::new(MockTool {
668 name: "t".to_string(),
669 }));
670 assert!(!registry.is_empty());
671 assert_eq!(registry.len(), 1);
672 }
673
674 #[test]
675 fn test_registry_replace_tool() {
676 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
677 registry.register(Arc::new(MockTool {
678 name: "dup".to_string(),
679 }));
680 registry.register(Arc::new(MockTool {
681 name: "dup".to_string(),
682 }));
683 assert_eq!(registry.len(), 1);
685 }
686}