agent_core/controller/tools/
registry.rs1use std::collections::HashMap;
2use std::sync::Arc;
3
4use thiserror::Error;
5use tokio::sync::RwLock;
6
7use super::types::Executable;
8
9#[derive(Error, Debug)]
11pub enum RegistryError {
12 #[error("Tool with name {0:?} already exists")]
14 DuplicateTool(String),
15}
16
17impl From<RegistryError> for String {
18 fn from(err: RegistryError) -> Self {
19 err.to_string()
20 }
21}
22
23pub struct ToolRegistry {
25 tools: RwLock<HashMap<String, Arc<dyn Executable>>>,
26}
27
28impl ToolRegistry {
29 pub fn new() -> Self {
31 Self {
32 tools: RwLock::new(HashMap::new()),
33 }
34 }
35
36 pub async fn register(&self, tool: Arc<dyn Executable>) -> Result<(), RegistryError> {
39 let name = tool.name().to_string();
40 let mut tools = self.tools.write().await;
41
42 if tools.contains_key(&name) {
43 return Err(RegistryError::DuplicateTool(name));
44 }
45
46 tools.insert(name, tool);
47 Ok(())
48 }
49
50 pub async fn get(&self, name: &str) -> Option<Arc<dyn Executable>> {
53 let tools = self.tools.read().await;
54 tools.get(name).cloned()
55 }
56
57 pub async fn has(&self, name: &str) -> bool {
59 let tools = self.tools.read().await;
60 tools.contains_key(name)
61 }
62
63 pub async fn remove(&self, name: &str) {
65 let mut tools = self.tools.write().await;
66 tools.remove(name);
67 }
68
69 pub async fn list(&self) -> Vec<String> {
71 let tools = self.tools.read().await;
72 tools.keys().cloned().collect()
73 }
74
75 pub async fn get_all(&self) -> Vec<Arc<dyn Executable>> {
77 let tools = self.tools.read().await;
78 tools.values().cloned().collect()
79 }
80
81 pub async fn len(&self) -> usize {
83 let tools = self.tools.read().await;
84 tools.len()
85 }
86
87 pub async fn is_empty(&self) -> bool {
89 let tools = self.tools.read().await;
90 tools.is_empty()
91 }
92}
93
94impl Default for ToolRegistry {
95 fn default() -> Self {
96 Self::new()
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103 use crate::controller::tools::types::{ToolContext, ToolType};
104 use std::pin::Pin;
105 use std::future::Future;
106
107 struct MockTool {
108 name: String,
109 }
110
111 impl Executable for MockTool {
112 fn name(&self) -> &str {
113 &self.name
114 }
115
116 fn description(&self) -> &str {
117 "A mock tool for testing"
118 }
119
120 fn input_schema(&self) -> &str {
121 r#"{"type":"object"}"#
122 }
123
124 fn tool_type(&self) -> ToolType {
125 ToolType::Custom
126 }
127
128 fn execute(
129 &self,
130 _context: ToolContext,
131 _input: HashMap<String, serde_json::Value>,
132 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send>> {
133 Box::pin(async { Ok("mock result".to_string()) })
134 }
135 }
136
137 #[tokio::test]
138 async fn test_register_and_get() {
139 let registry = ToolRegistry::new();
140 let tool = Arc::new(MockTool {
141 name: "test_tool".to_string(),
142 });
143
144 registry.register(tool).await.unwrap();
145
146 let retrieved = registry.get("test_tool").await;
147 assert!(retrieved.is_some());
148 assert_eq!(retrieved.unwrap().name(), "test_tool");
149 }
150
151 #[tokio::test]
152 async fn test_duplicate_registration() {
153 let registry = ToolRegistry::new();
154 let tool1 = Arc::new(MockTool {
155 name: "test_tool".to_string(),
156 });
157 let tool2 = Arc::new(MockTool {
158 name: "test_tool".to_string(),
159 });
160
161 registry.register(tool1).await.unwrap();
162 let result = registry.register(tool2).await;
163 assert!(result.is_err());
164 }
165
166 #[tokio::test]
167 async fn test_list_and_remove() {
168 let registry = ToolRegistry::new();
169 let tool = Arc::new(MockTool {
170 name: "test_tool".to_string(),
171 });
172
173 registry.register(tool).await.unwrap();
174 assert!(registry.has("test_tool").await);
175
176 let names = registry.list().await;
177 assert_eq!(names.len(), 1);
178
179 registry.remove("test_tool").await;
180 assert!(!registry.has("test_tool").await);
181 }
182}