agent_core/controller/tools/
registry.rs1use std::collections::HashMap;
2use std::sync::Arc;
3
4use tokio::sync::RwLock;
5
6use super::types::Executable;
7
8pub struct ToolRegistry {
10 tools: RwLock<HashMap<String, Arc<dyn Executable>>>,
11}
12
13impl ToolRegistry {
14 pub fn new() -> Self {
16 Self {
17 tools: RwLock::new(HashMap::new()),
18 }
19 }
20
21 pub async fn register(&self, tool: Arc<dyn Executable>) -> Result<(), String> {
24 let name = tool.name().to_string();
25 let mut tools = self.tools.write().await;
26
27 if tools.contains_key(&name) {
28 return Err(format!("tool with name {:?} already exists", name));
29 }
30
31 tools.insert(name, tool);
32 Ok(())
33 }
34
35 pub async fn get(&self, name: &str) -> Option<Arc<dyn Executable>> {
38 let tools = self.tools.read().await;
39 tools.get(name).cloned()
40 }
41
42 pub async fn has(&self, name: &str) -> bool {
44 let tools = self.tools.read().await;
45 tools.contains_key(name)
46 }
47
48 pub async fn remove(&self, name: &str) {
50 let mut tools = self.tools.write().await;
51 tools.remove(name);
52 }
53
54 pub async fn list(&self) -> Vec<String> {
56 let tools = self.tools.read().await;
57 tools.keys().cloned().collect()
58 }
59
60 pub async fn get_all(&self) -> Vec<Arc<dyn Executable>> {
62 let tools = self.tools.read().await;
63 tools.values().cloned().collect()
64 }
65
66 pub async fn len(&self) -> usize {
68 let tools = self.tools.read().await;
69 tools.len()
70 }
71
72 pub async fn is_empty(&self) -> bool {
74 let tools = self.tools.read().await;
75 tools.is_empty()
76 }
77}
78
79impl Default for ToolRegistry {
80 fn default() -> Self {
81 Self::new()
82 }
83}
84
85#[cfg(test)]
86mod tests {
87 use super::*;
88 use crate::controller::tools::types::{ToolContext, ToolType};
89 use std::pin::Pin;
90 use std::future::Future;
91
92 struct MockTool {
93 name: String,
94 }
95
96 impl Executable for MockTool {
97 fn name(&self) -> &str {
98 &self.name
99 }
100
101 fn description(&self) -> &str {
102 "A mock tool for testing"
103 }
104
105 fn input_schema(&self) -> &str {
106 r#"{"type":"object"}"#
107 }
108
109 fn tool_type(&self) -> ToolType {
110 ToolType::Custom
111 }
112
113 fn execute(
114 &self,
115 _context: ToolContext,
116 _input: HashMap<String, serde_json::Value>,
117 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send>> {
118 Box::pin(async { Ok("mock result".to_string()) })
119 }
120 }
121
122 #[tokio::test]
123 async fn test_register_and_get() {
124 let registry = ToolRegistry::new();
125 let tool = Arc::new(MockTool {
126 name: "test_tool".to_string(),
127 });
128
129 registry.register(tool).await.unwrap();
130
131 let retrieved = registry.get("test_tool").await;
132 assert!(retrieved.is_some());
133 assert_eq!(retrieved.unwrap().name(), "test_tool");
134 }
135
136 #[tokio::test]
137 async fn test_duplicate_registration() {
138 let registry = ToolRegistry::new();
139 let tool1 = Arc::new(MockTool {
140 name: "test_tool".to_string(),
141 });
142 let tool2 = Arc::new(MockTool {
143 name: "test_tool".to_string(),
144 });
145
146 registry.register(tool1).await.unwrap();
147 let result = registry.register(tool2).await;
148 assert!(result.is_err());
149 }
150
151 #[tokio::test]
152 async fn test_list_and_remove() {
153 let registry = ToolRegistry::new();
154 let tool = Arc::new(MockTool {
155 name: "test_tool".to_string(),
156 });
157
158 registry.register(tool).await.unwrap();
159 assert!(registry.has("test_tool").await);
160
161 let names = registry.list().await;
162 assert_eq!(names.len(), 1);
163
164 registry.remove("test_tool").await;
165 assert!(!registry.has("test_tool").await);
166 }
167}