1use super::types::{Tool, ToolContext, ToolOutput};
7use super::ToolResult;
8use crate::llm::ToolDefinition;
9use anyhow::Result;
10use std::collections::HashMap;
11use std::path::PathBuf;
12use std::sync::{Arc, RwLock};
13
14pub struct ToolRegistry {
16 tools: RwLock<HashMap<String, Arc<dyn Tool>>>,
17 builtins: RwLock<std::collections::HashSet<String>>,
19 context: ToolContext,
20}
21
22impl ToolRegistry {
23 pub fn new(workspace: PathBuf) -> Self {
25 Self {
26 tools: RwLock::new(HashMap::new()),
27 builtins: RwLock::new(std::collections::HashSet::new()),
28 context: ToolContext::new(workspace),
29 }
30 }
31
32 pub fn register_builtin(&self, tool: Arc<dyn Tool>) {
34 let name = tool.name().to_string();
35 let mut tools = self.tools.write().unwrap();
36 let mut builtins = self.builtins.write().unwrap();
37 tracing::debug!("Registering builtin tool: {}", name);
38 tools.insert(name.clone(), tool);
39 builtins.insert(name);
40 }
41
42 pub fn register(&self, tool: Arc<dyn Tool>) {
47 let name = tool.name().to_string();
48 let builtins = self.builtins.read().unwrap();
49 if builtins.contains(&name) {
50 tracing::warn!(
51 "Rejected registration of tool '{}': cannot shadow builtin",
52 name
53 );
54 return;
55 }
56 drop(builtins);
57 let mut tools = self.tools.write().unwrap();
58 tracing::debug!("Registering tool: {}", name);
59 tools.insert(name, tool);
60 }
61
62 pub fn unregister(&self, name: &str) -> bool {
66 let mut tools = self.tools.write().unwrap();
67 tracing::debug!("Unregistering tool: {}", name);
68 tools.remove(name).is_some()
69 }
70
71 pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
73 let tools = self.tools.read().unwrap();
74 tools.get(name).cloned()
75 }
76
77 pub fn contains(&self, name: &str) -> bool {
79 let tools = self.tools.read().unwrap();
80 tools.contains_key(name)
81 }
82
83 pub fn definitions(&self) -> Vec<ToolDefinition> {
85 let tools = self.tools.read().unwrap();
86 tools
87 .values()
88 .map(|tool| ToolDefinition {
89 name: tool.name().to_string(),
90 description: tool.description().to_string(),
91 parameters: tool.parameters(),
92 })
93 .collect()
94 }
95
96 pub fn list(&self) -> Vec<String> {
98 let tools = self.tools.read().unwrap();
99 tools.keys().cloned().collect()
100 }
101
102 pub fn len(&self) -> usize {
104 let tools = self.tools.read().unwrap();
105 tools.len()
106 }
107
108 pub fn is_empty(&self) -> bool {
110 self.len() == 0
111 }
112
113 pub fn context(&self) -> &ToolContext {
115 &self.context
116 }
117
118 pub async fn execute(&self, name: &str, args: &serde_json::Value) -> Result<ToolResult> {
120 self.execute_with_context(name, args, &self.context).await
121 }
122
123 pub async fn execute_with_context(
125 &self,
126 name: &str,
127 args: &serde_json::Value,
128 ctx: &ToolContext,
129 ) -> Result<ToolResult> {
130 let span = tracing::info_span!(
131 "a3s.tool.execute",
132 "a3s.tool.name" = %name,
133 "a3s.tool.success" = tracing::field::Empty,
134 "a3s.tool.exit_code" = tracing::field::Empty,
135 "a3s.tool.duration_ms" = tracing::field::Empty,
136 );
137 let _guard = span.enter();
138 let start = std::time::Instant::now();
139
140 let tool = self.get(name);
141
142 let result = match tool {
143 Some(tool) => {
144 let output = tool.execute(args, ctx).await?;
145 Ok(ToolResult {
146 name: name.to_string(),
147 output: output.content,
148 exit_code: if output.success { 0 } else { 1 },
149 metadata: output.metadata,
150 })
151 }
152 None => Ok(ToolResult::error(name, format!("Unknown tool: {}", name))),
153 };
154
155 if let Ok(ref r) = result {
156 crate::telemetry::record_tool_result(r.exit_code, start.elapsed());
157 }
158
159 result
160 }
161
162 pub async fn execute_raw(
164 &self,
165 name: &str,
166 args: &serde_json::Value,
167 ) -> Result<Option<ToolOutput>> {
168 self.execute_raw_with_context(name, args, &self.context)
169 .await
170 }
171
172 pub async fn execute_raw_with_context(
174 &self,
175 name: &str,
176 args: &serde_json::Value,
177 ctx: &ToolContext,
178 ) -> Result<Option<ToolOutput>> {
179 let tool = self.get(name);
180
181 match tool {
182 Some(tool) => {
183 let output = tool.execute(args, ctx).await?;
184 Ok(Some(output))
185 }
186 None => Ok(None),
187 }
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194 use async_trait::async_trait;
195
196 struct MockTool {
197 name: String,
198 }
199
200 #[async_trait]
201 impl Tool for MockTool {
202 fn name(&self) -> &str {
203 &self.name
204 }
205
206 fn description(&self) -> &str {
207 "A mock tool for testing"
208 }
209
210 fn parameters(&self) -> serde_json::Value {
211 serde_json::json!({
212 "type": "object",
213 "properties": {},
214 "required": []
215 })
216 }
217
218 async fn execute(
219 &self,
220 _args: &serde_json::Value,
221 _ctx: &ToolContext,
222 ) -> Result<ToolOutput> {
223 Ok(ToolOutput::success("mock output"))
224 }
225 }
226
227 #[test]
228 fn test_registry_register_and_get() {
229 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
230
231 let tool = Arc::new(MockTool {
232 name: "test".to_string(),
233 });
234 registry.register(tool);
235
236 assert!(registry.contains("test"));
237 assert!(!registry.contains("nonexistent"));
238
239 let retrieved = registry.get("test");
240 assert!(retrieved.is_some());
241 assert_eq!(retrieved.unwrap().name(), "test");
242 }
243
244 #[test]
245 fn test_registry_unregister() {
246 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
247
248 let tool = Arc::new(MockTool {
249 name: "test".to_string(),
250 });
251 registry.register(tool);
252
253 assert!(registry.contains("test"));
254 assert!(registry.unregister("test"));
255 assert!(!registry.contains("test"));
256 assert!(!registry.unregister("test")); }
258
259 #[test]
260 fn test_registry_definitions() {
261 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
262
263 registry.register(Arc::new(MockTool {
264 name: "tool1".to_string(),
265 }));
266 registry.register(Arc::new(MockTool {
267 name: "tool2".to_string(),
268 }));
269
270 let definitions = registry.definitions();
271 assert_eq!(definitions.len(), 2);
272 }
273
274 #[tokio::test]
275 async fn test_registry_execute() {
276 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
277
278 registry.register(Arc::new(MockTool {
279 name: "test".to_string(),
280 }));
281
282 let result = registry
283 .execute("test", &serde_json::json!({}))
284 .await
285 .unwrap();
286 assert_eq!(result.exit_code, 0);
287 assert_eq!(result.output, "mock output");
288 }
289
290 #[tokio::test]
291 async fn test_registry_execute_unknown() {
292 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
293
294 let result = registry
295 .execute("unknown", &serde_json::json!({}))
296 .await
297 .unwrap();
298 assert_eq!(result.exit_code, 1);
299 assert!(result.output.contains("Unknown tool"));
300 }
301
302 #[tokio::test]
303 async fn test_registry_execute_with_context_success() {
304 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
305 let ctx = ToolContext::new(PathBuf::from("/tmp"));
306
307 registry.register(Arc::new(MockTool {
308 name: "my_tool".to_string(),
309 }));
310
311 let result = registry
312 .execute_with_context("my_tool", &serde_json::json!({}), &ctx)
313 .await
314 .unwrap();
315 assert_eq!(result.name, "my_tool");
316 assert_eq!(result.exit_code, 0);
317 assert_eq!(result.output, "mock output");
318 }
319
320 #[tokio::test]
321 async fn test_registry_execute_with_context_unknown_tool() {
322 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
323 let ctx = ToolContext::new(PathBuf::from("/tmp"));
324
325 let result = registry
326 .execute_with_context("nonexistent", &serde_json::json!({}), &ctx)
327 .await
328 .unwrap();
329 assert_eq!(result.exit_code, 1);
330 assert!(result.output.contains("Unknown tool: nonexistent"));
331 }
332
333 struct FailingTool;
334
335 #[async_trait]
336 impl Tool for FailingTool {
337 fn name(&self) -> &str {
338 "failing"
339 }
340
341 fn description(&self) -> &str {
342 "A tool that returns failure"
343 }
344
345 fn parameters(&self) -> serde_json::Value {
346 serde_json::json!({"type": "object", "properties": {}, "required": []})
347 }
348
349 async fn execute(
350 &self,
351 _args: &serde_json::Value,
352 _ctx: &ToolContext,
353 ) -> Result<ToolOutput> {
354 Ok(ToolOutput::error("something went wrong"))
355 }
356 }
357
358 #[tokio::test]
359 async fn test_registry_execute_failing_tool() {
360 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
361 registry.register(Arc::new(FailingTool));
362
363 let result = registry
364 .execute("failing", &serde_json::json!({}))
365 .await
366 .unwrap();
367 assert_eq!(result.exit_code, 1);
368 assert_eq!(result.output, "something went wrong");
369 }
370
371 #[tokio::test]
372 async fn test_registry_execute_raw_success() {
373 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
374 registry.register(Arc::new(MockTool {
375 name: "raw_test".to_string(),
376 }));
377
378 let output = registry
379 .execute_raw("raw_test", &serde_json::json!({}))
380 .await
381 .unwrap();
382 assert!(output.is_some());
383 let output = output.unwrap();
384 assert!(output.success);
385 assert_eq!(output.content, "mock output");
386 }
387
388 #[tokio::test]
389 async fn test_registry_execute_raw_unknown() {
390 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
391
392 let output = registry
393 .execute_raw("missing", &serde_json::json!({}))
394 .await
395 .unwrap();
396 assert!(output.is_none());
397 }
398
399 #[test]
400 fn test_registry_list() {
401 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
402 registry.register(Arc::new(MockTool {
403 name: "alpha".to_string(),
404 }));
405 registry.register(Arc::new(MockTool {
406 name: "beta".to_string(),
407 }));
408
409 let names = registry.list();
410 assert_eq!(names.len(), 2);
411 assert!(names.contains(&"alpha".to_string()));
412 assert!(names.contains(&"beta".to_string()));
413 }
414
415 #[test]
416 fn test_registry_len_and_is_empty() {
417 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
418 assert!(registry.is_empty());
419 assert_eq!(registry.len(), 0);
420
421 registry.register(Arc::new(MockTool {
422 name: "t".to_string(),
423 }));
424 assert!(!registry.is_empty());
425 assert_eq!(registry.len(), 1);
426 }
427
428 #[test]
429 fn test_registry_replace_tool() {
430 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
431 registry.register(Arc::new(MockTool {
432 name: "dup".to_string(),
433 }));
434 registry.register(Arc::new(MockTool {
435 name: "dup".to_string(),
436 }));
437 assert_eq!(registry.len(), 1);
439 }
440}