1use crate::error::AgentError;
2use crate::tool::Tool;
3use serde_json::Value;
4use std::collections::HashMap;
5use std::sync::Arc;
6
7pub struct ToolRegistry {
9 tools: HashMap<String, Arc<dyn Tool>>,
10}
11
12impl ToolRegistry {
13 pub fn new() -> Self {
15 ToolRegistry {
16 tools: HashMap::new(),
17 }
18 }
19
20 pub fn register<T>(&mut self, tool: T) -> Result<(), AgentError>
22 where
23 T: Tool + 'static,
24 {
25 let name = tool.name().to_string();
26 self.tools.insert(name, Arc::new(tool));
27 Ok(())
28 }
29
30 pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
32 self.tools.get(name).cloned()
33 }
34
35 pub fn list(&self) -> Vec<String> {
37 let mut names: Vec<String> = self.tools.keys().cloned().collect();
38 names.sort();
39 names
40 }
41
42 pub async fn execute(&self, name: &str, args: Value) -> Result<Value, AgentError> {
44 let tool = self
45 .get(name)
46 .ok_or_else(|| AgentError::ToolError(format!("Tool '{}' not found", name)))?;
47 tool.execute(args).await
48 }
49}
50
51impl Default for ToolRegistry {
52 fn default() -> Self {
53 Self::new()
54 }
55}
56
57#[cfg(test)]
58mod tests {
59 use super::*;
60 use crate::tool::EchoTool;
61
62 #[tokio::test]
63 async fn test_registry_new() {
64 let registry = ToolRegistry::new();
65 assert_eq!(registry.list().len(), 0);
66 }
67
68 #[tokio::test]
69 async fn test_registry_default() {
70 let registry = ToolRegistry::default();
71 assert_eq!(registry.list().len(), 0);
72 }
73
74 #[tokio::test]
75 async fn test_registry_register() {
76 let mut registry = ToolRegistry::new();
77 registry.register(EchoTool::new()).unwrap();
78
79 assert_eq!(registry.list().len(), 1);
80 assert_eq!(registry.list()[0], "echo");
81 }
82
83 #[tokio::test]
84 async fn test_registry_get() {
85 let mut registry = ToolRegistry::new();
86 registry.register(EchoTool::new()).unwrap();
87
88 let tool = registry.get("echo");
89 assert!(tool.is_some());
90 assert_eq!(tool.unwrap().name(), "echo");
91 }
92
93 #[tokio::test]
94 async fn test_registry_get_nonexistent() {
95 let registry = ToolRegistry::new();
96 let tool = registry.get("nonexistent");
97 assert!(tool.is_none());
98 }
99
100 #[tokio::test]
101 async fn test_registry_list() {
102 let mut registry = ToolRegistry::new();
103 registry.register(EchoTool::new()).unwrap();
104
105 let names = registry.list();
106 assert_eq!(names.len(), 1);
107 assert_eq!(names[0], "echo");
108 }
109
110 #[tokio::test]
111 async fn test_registry_execute() {
112 let mut registry = ToolRegistry::new();
113 registry.register(EchoTool::new()).unwrap();
114
115 let args = serde_json::json!({"test": "value"});
116 let result = registry.execute("echo", args.clone()).await.unwrap();
117 assert_eq!(result, args);
118 }
119
120 #[tokio::test]
121 async fn test_registry_execute_nonexistent() {
122 let registry = ToolRegistry::new();
123
124 let args = serde_json::json!({"test": "value"});
125 let result = registry.execute("nonexistent", args).await;
126
127 assert!(result.is_err());
128 assert!(matches!(result.unwrap_err(), AgentError::ToolError(_)));
129 }
130
131 #[tokio::test]
132 async fn test_registry_multiple_tools() {
133 use async_trait::async_trait;
134
135 struct AnotherTool;
136
137 #[async_trait]
138 impl Tool for AnotherTool {
139 fn name(&self) -> &str {
140 "another"
141 }
142
143 async fn execute(&self, _args: Value) -> Result<Value, AgentError> {
144 Ok(serde_json::json!({"status": "ok"}))
145 }
146 }
147
148 let mut registry = ToolRegistry::new();
149 registry.register(EchoTool::new()).unwrap();
150 registry.register(AnotherTool).unwrap();
151
152 let names = registry.list();
153 assert_eq!(names.len(), 2);
154 assert!(names.contains(&"echo".to_string()));
155 assert!(names.contains(&"another".to_string()));
156 }
157}