1use super::types::{Tool, ToolContext, ToolOutput};
7use super::ToolResult;
8use crate::llm::ToolDefinition;
9use anyhow::Result;
10use std::collections::HashMap;
11use std::path::PathBuf;
12use std::sync::{Arc, RwLock};
13
14pub struct ToolRegistry {
16 tools: RwLock<HashMap<String, Arc<dyn Tool>>>,
17 builtins: RwLock<std::collections::HashSet<String>>,
19 context: ToolContext,
20}
21
22impl ToolRegistry {
23 pub fn new(workspace: PathBuf) -> Self {
25 Self {
26 tools: RwLock::new(HashMap::new()),
27 builtins: RwLock::new(std::collections::HashSet::new()),
28 context: ToolContext::new(workspace),
29 }
30 }
31
32 pub fn register_builtin(&self, tool: Arc<dyn Tool>) {
34 let name = tool.name().to_string();
35 let mut tools = self.tools.write().unwrap();
36 let mut builtins = self.builtins.write().unwrap();
37 tracing::debug!("Registering builtin tool: {}", name);
38 tools.insert(name.clone(), tool);
39 builtins.insert(name);
40 }
41
42 pub fn register(&self, tool: Arc<dyn Tool>) {
47 let name = tool.name().to_string();
48 let builtins = self.builtins.read().unwrap();
49 if builtins.contains(&name) {
50 tracing::warn!(
51 "Rejected registration of tool '{}': cannot shadow builtin",
52 name
53 );
54 return;
55 }
56 drop(builtins);
57 let mut tools = self.tools.write().unwrap();
58 tracing::debug!("Registering tool: {}", name);
59 tools.insert(name, tool);
60 }
61
62 pub fn unregister(&self, name: &str) -> bool {
66 let mut tools = self.tools.write().unwrap();
67 tracing::debug!("Unregistering tool: {}", name);
68 tools.remove(name).is_some()
69 }
70
71 pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
73 let tools = self.tools.read().unwrap();
74 tools.get(name).cloned()
75 }
76
77 pub fn contains(&self, name: &str) -> bool {
79 let tools = self.tools.read().unwrap();
80 tools.contains_key(name)
81 }
82
83 pub fn definitions(&self) -> Vec<ToolDefinition> {
85 let tools = self.tools.read().unwrap();
86 tools
87 .values()
88 .map(|tool| ToolDefinition {
89 name: tool.name().to_string(),
90 description: tool.description().to_string(),
91 parameters: tool.parameters(),
92 })
93 .collect()
94 }
95
96 pub fn list(&self) -> Vec<String> {
98 let tools = self.tools.read().unwrap();
99 tools.keys().cloned().collect()
100 }
101
102 pub fn len(&self) -> usize {
104 let tools = self.tools.read().unwrap();
105 tools.len()
106 }
107
108 pub fn is_empty(&self) -> bool {
110 self.len() == 0
111 }
112
113 pub fn context(&self) -> &ToolContext {
115 &self.context
116 }
117
118 pub async fn execute(&self, name: &str, args: &serde_json::Value) -> Result<ToolResult> {
120 self.execute_with_context(name, args, &self.context).await
121 }
122
123 pub async fn execute_with_context(
125 &self,
126 name: &str,
127 args: &serde_json::Value,
128 ctx: &ToolContext,
129 ) -> Result<ToolResult> {
130 let start = std::time::Instant::now();
131
132 let tool = self.get(name);
133
134 let result = match tool {
135 Some(tool) => {
136 let output = tool.execute(args, ctx).await?;
137 Ok(ToolResult {
138 name: name.to_string(),
139 output: output.content,
140 exit_code: if output.success { 0 } else { 1 },
141 metadata: output.metadata,
142 })
143 }
144 None => Ok(ToolResult::error(name, format!("Unknown tool: {}", name))),
145 };
146
147 if let Ok(ref r) = result {
148 crate::telemetry::record_tool_result(r.exit_code, start.elapsed());
149 }
150
151 result
152 }
153
154 pub async fn execute_raw(
156 &self,
157 name: &str,
158 args: &serde_json::Value,
159 ) -> Result<Option<ToolOutput>> {
160 self.execute_raw_with_context(name, args, &self.context)
161 .await
162 }
163
164 pub async fn execute_raw_with_context(
166 &self,
167 name: &str,
168 args: &serde_json::Value,
169 ctx: &ToolContext,
170 ) -> Result<Option<ToolOutput>> {
171 let tool = self.get(name);
172
173 match tool {
174 Some(tool) => {
175 let output = tool.execute(args, ctx).await?;
176 Ok(Some(output))
177 }
178 None => Ok(None),
179 }
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186 use async_trait::async_trait;
187
188 struct MockTool {
189 name: String,
190 }
191
192 #[async_trait]
193 impl Tool for MockTool {
194 fn name(&self) -> &str {
195 &self.name
196 }
197
198 fn description(&self) -> &str {
199 "A mock tool for testing"
200 }
201
202 fn parameters(&self) -> serde_json::Value {
203 serde_json::json!({
204 "type": "object",
205 "properties": {},
206 "required": []
207 })
208 }
209
210 async fn execute(
211 &self,
212 _args: &serde_json::Value,
213 _ctx: &ToolContext,
214 ) -> Result<ToolOutput> {
215 Ok(ToolOutput::success("mock output"))
216 }
217 }
218
219 #[test]
220 fn test_registry_register_and_get() {
221 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
222
223 let tool = Arc::new(MockTool {
224 name: "test".to_string(),
225 });
226 registry.register(tool);
227
228 assert!(registry.contains("test"));
229 assert!(!registry.contains("nonexistent"));
230
231 let retrieved = registry.get("test");
232 assert!(retrieved.is_some());
233 assert_eq!(retrieved.unwrap().name(), "test");
234 }
235
236 #[test]
237 fn test_registry_unregister() {
238 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
239
240 let tool = Arc::new(MockTool {
241 name: "test".to_string(),
242 });
243 registry.register(tool);
244
245 assert!(registry.contains("test"));
246 assert!(registry.unregister("test"));
247 assert!(!registry.contains("test"));
248 assert!(!registry.unregister("test")); }
250
251 #[test]
252 fn test_registry_definitions() {
253 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
254
255 registry.register(Arc::new(MockTool {
256 name: "tool1".to_string(),
257 }));
258 registry.register(Arc::new(MockTool {
259 name: "tool2".to_string(),
260 }));
261
262 let definitions = registry.definitions();
263 assert_eq!(definitions.len(), 2);
264 }
265
266 #[tokio::test]
267 async fn test_registry_execute() {
268 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
269
270 registry.register(Arc::new(MockTool {
271 name: "test".to_string(),
272 }));
273
274 let result = registry
275 .execute("test", &serde_json::json!({}))
276 .await
277 .unwrap();
278 assert_eq!(result.exit_code, 0);
279 assert_eq!(result.output, "mock output");
280 }
281
282 #[tokio::test]
283 async fn test_registry_execute_unknown() {
284 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
285
286 let result = registry
287 .execute("unknown", &serde_json::json!({}))
288 .await
289 .unwrap();
290 assert_eq!(result.exit_code, 1);
291 assert!(result.output.contains("Unknown tool"));
292 }
293
294 #[tokio::test]
295 async fn test_registry_execute_with_context_success() {
296 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
297 let ctx = ToolContext::new(PathBuf::from("/tmp"));
298
299 registry.register(Arc::new(MockTool {
300 name: "my_tool".to_string(),
301 }));
302
303 let result = registry
304 .execute_with_context("my_tool", &serde_json::json!({}), &ctx)
305 .await
306 .unwrap();
307 assert_eq!(result.name, "my_tool");
308 assert_eq!(result.exit_code, 0);
309 assert_eq!(result.output, "mock output");
310 }
311
312 #[tokio::test]
313 async fn test_registry_execute_with_context_unknown_tool() {
314 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
315 let ctx = ToolContext::new(PathBuf::from("/tmp"));
316
317 let result = registry
318 .execute_with_context("nonexistent", &serde_json::json!({}), &ctx)
319 .await
320 .unwrap();
321 assert_eq!(result.exit_code, 1);
322 assert!(result.output.contains("Unknown tool: nonexistent"));
323 }
324
325 struct FailingTool;
326
327 #[async_trait]
328 impl Tool for FailingTool {
329 fn name(&self) -> &str {
330 "failing"
331 }
332
333 fn description(&self) -> &str {
334 "A tool that returns failure"
335 }
336
337 fn parameters(&self) -> serde_json::Value {
338 serde_json::json!({"type": "object", "properties": {}, "required": []})
339 }
340
341 async fn execute(
342 &self,
343 _args: &serde_json::Value,
344 _ctx: &ToolContext,
345 ) -> Result<ToolOutput> {
346 Ok(ToolOutput::error("something went wrong"))
347 }
348 }
349
350 #[tokio::test]
351 async fn test_registry_execute_failing_tool() {
352 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
353 registry.register(Arc::new(FailingTool));
354
355 let result = registry
356 .execute("failing", &serde_json::json!({}))
357 .await
358 .unwrap();
359 assert_eq!(result.exit_code, 1);
360 assert_eq!(result.output, "something went wrong");
361 }
362
363 #[tokio::test]
364 async fn test_registry_execute_raw_success() {
365 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
366 registry.register(Arc::new(MockTool {
367 name: "raw_test".to_string(),
368 }));
369
370 let output = registry
371 .execute_raw("raw_test", &serde_json::json!({}))
372 .await
373 .unwrap();
374 assert!(output.is_some());
375 let output = output.unwrap();
376 assert!(output.success);
377 assert_eq!(output.content, "mock output");
378 }
379
380 #[tokio::test]
381 async fn test_registry_execute_raw_unknown() {
382 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
383
384 let output = registry
385 .execute_raw("missing", &serde_json::json!({}))
386 .await
387 .unwrap();
388 assert!(output.is_none());
389 }
390
391 #[test]
392 fn test_registry_list() {
393 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
394 registry.register(Arc::new(MockTool {
395 name: "alpha".to_string(),
396 }));
397 registry.register(Arc::new(MockTool {
398 name: "beta".to_string(),
399 }));
400
401 let names = registry.list();
402 assert_eq!(names.len(), 2);
403 assert!(names.contains(&"alpha".to_string()));
404 assert!(names.contains(&"beta".to_string()));
405 }
406
407 #[test]
408 fn test_registry_len_and_is_empty() {
409 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
410 assert!(registry.is_empty());
411 assert_eq!(registry.len(), 0);
412
413 registry.register(Arc::new(MockTool {
414 name: "t".to_string(),
415 }));
416 assert!(!registry.is_empty());
417 assert_eq!(registry.len(), 1);
418 }
419
420 #[test]
421 fn test_registry_replace_tool() {
422 let registry = ToolRegistry::new(PathBuf::from("/tmp"));
423 registry.register(Arc::new(MockTool {
424 name: "dup".to_string(),
425 }));
426 registry.register(Arc::new(MockTool {
427 name: "dup".to_string(),
428 }));
429 assert_eq!(registry.len(), 1);
431 }
432}