1use std::collections::HashMap;
9use std::sync::Arc;
10use thulp_core::{Error, Result, ToolDefinition};
11use tokio::sync::RwLock;
12
13pub struct ToolRegistry {
21 tools: Arc<RwLock<HashMap<String, ToolDefinition>>>,
23
24 tags: Arc<RwLock<HashMap<String, Vec<String>>>>,
26}
27
28impl ToolRegistry {
29 pub fn new() -> Self {
31 Self {
32 tools: Arc::new(RwLock::new(HashMap::new())),
33 tags: Arc::new(RwLock::new(HashMap::new())),
34 }
35 }
36
37 pub async fn register(&self, tool: ToolDefinition) -> Result<()> {
39 let mut tools = self.tools.write().await;
40 tools.insert(tool.name.clone(), tool);
41 Ok(())
42 }
43
44 pub async fn register_many(&self, tools: Vec<ToolDefinition>) -> Result<()> {
46 let mut registry = self.tools.write().await;
47 for tool in tools {
48 registry.insert(tool.name.clone(), tool);
49 }
50 Ok(())
51 }
52
53 pub async fn unregister(&self, name: &str) -> Result<Option<ToolDefinition>> {
55 let mut tools = self.tools.write().await;
56 Ok(tools.remove(name))
57 }
58
59 pub async fn get(&self, name: &str) -> Result<Option<ToolDefinition>> {
61 let tools = self.tools.read().await;
62 Ok(tools.get(name).cloned())
63 }
64
65 pub async fn list(&self) -> Result<Vec<ToolDefinition>> {
67 let tools = self.tools.read().await;
68 Ok(tools.values().cloned().collect())
69 }
70
71 pub async fn count(&self) -> usize {
73 let tools = self.tools.read().await;
74 tools.len()
75 }
76
77 pub async fn clear(&self) {
79 let mut tools = self.tools.write().await;
80 let mut tags = self.tags.write().await;
81 tools.clear();
82 tags.clear();
83 }
84
85 pub async fn contains(&self, name: &str) -> bool {
87 let tools = self.tools.read().await;
88 tools.contains_key(name)
89 }
90
91 pub async fn tag(&self, tool_name: &str, tag: &str) -> Result<()> {
93 let tools = self.tools.read().await;
94 if !tools.contains_key(tool_name) {
95 return Err(Error::InvalidConfig(format!(
96 "Tool '{}' not found in registry",
97 tool_name
98 )));
99 }
100 drop(tools);
101
102 let mut tags = self.tags.write().await;
103 tags.entry(tag.to_string())
104 .or_insert_with(Vec::new)
105 .push(tool_name.to_string());
106 Ok(())
107 }
108
109 pub async fn find_by_tag(&self, tag: &str) -> Result<Vec<ToolDefinition>> {
111 let tags = self.tags.read().await;
112 let tool_names = match tags.get(tag) {
113 Some(names) => names.clone(),
114 None => return Ok(Vec::new()),
115 };
116 drop(tags);
117
118 let tools = self.tools.read().await;
119 let mut results = Vec::new();
120 for name in tool_names {
121 if let Some(tool) = tools.get(&name) {
122 results.push(tool.clone());
123 }
124 }
125 Ok(results)
126 }
127}
128
129impl Default for ToolRegistry {
130 fn default() -> Self {
131 Self::new()
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138 use thulp_core::Parameter;
139
140 fn create_test_tool(name: &str) -> ToolDefinition {
141 ToolDefinition::builder(name)
142 .description(format!("Test tool: {}", name))
143 .parameter(Parameter::required_string("test_param"))
144 .build()
145 }
146
147 #[tokio::test]
148 async fn registry_creation() {
149 let registry = ToolRegistry::new();
150 assert_eq!(registry.count().await, 0);
151 }
152
153 #[tokio::test]
154 async fn register_and_get_tool() {
155 let registry = ToolRegistry::new();
156 let tool = create_test_tool("test_tool");
157
158 registry.register(tool.clone()).await.unwrap();
159
160 let retrieved = registry.get("test_tool").await.unwrap();
161 assert!(retrieved.is_some());
162 assert_eq!(retrieved.unwrap().name, "test_tool");
163 }
164
165 #[tokio::test]
166 async fn register_many_tools() {
167 let registry = ToolRegistry::new();
168 let tools = vec![
169 create_test_tool("tool1"),
170 create_test_tool("tool2"),
171 create_test_tool("tool3"),
172 ];
173
174 registry.register_many(tools).await.unwrap();
175
176 assert_eq!(registry.count().await, 3);
177 assert!(registry.contains("tool1").await);
178 assert!(registry.contains("tool2").await);
179 assert!(registry.contains("tool3").await);
180 }
181
182 #[tokio::test]
183 async fn unregister_tool() {
184 let registry = ToolRegistry::new();
185 let tool = create_test_tool("test_tool");
186
187 registry.register(tool).await.unwrap();
188 assert_eq!(registry.count().await, 1);
189
190 let removed = registry.unregister("test_tool").await.unwrap();
191 assert!(removed.is_some());
192 assert_eq!(registry.count().await, 0);
193 }
194
195 #[tokio::test]
196 async fn list_tools() {
197 let registry = ToolRegistry::new();
198 let tools = vec![create_test_tool("tool1"), create_test_tool("tool2")];
199
200 registry.register_many(tools).await.unwrap();
201
202 let listed = registry.list().await.unwrap();
203 assert_eq!(listed.len(), 2);
204 }
205
206 #[tokio::test]
207 async fn clear_registry() {
208 let registry = ToolRegistry::new();
209 let tools = vec![create_test_tool("tool1"), create_test_tool("tool2")];
210
211 registry.register_many(tools).await.unwrap();
212 assert_eq!(registry.count().await, 2);
213
214 registry.clear().await;
215 assert_eq!(registry.count().await, 0);
216 }
217
218 #[tokio::test]
219 async fn tag_and_find_tools() {
220 let registry = ToolRegistry::new();
221 let tool1 = create_test_tool("tool1");
222 let tool2 = create_test_tool("tool2");
223 let tool3 = create_test_tool("tool3");
224
225 registry.register(tool1).await.unwrap();
226 registry.register(tool2).await.unwrap();
227 registry.register(tool3).await.unwrap();
228
229 registry.tag("tool1", "filesystem").await.unwrap();
230 registry.tag("tool2", "filesystem").await.unwrap();
231 registry.tag("tool3", "network").await.unwrap();
232
233 let filesystem_tools = registry.find_by_tag("filesystem").await.unwrap();
234 assert_eq!(filesystem_tools.len(), 2);
235
236 let network_tools = registry.find_by_tag("network").await.unwrap();
237 assert_eq!(network_tools.len(), 1);
238 }
239
240 #[tokio::test]
241 async fn tag_nonexistent_tool() {
242 let registry = ToolRegistry::new();
243 let result = registry.tag("nonexistent", "test").await;
244 assert!(result.is_err());
245 }
246}