agent_core/controller/tools/
registry.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use tokio::sync::RwLock;
5
6use super::types::Executable;
7
8/// Thread-safe registry for managing available tools.
9pub struct ToolRegistry {
10    tools: RwLock<HashMap<String, Arc<dyn Executable>>>,
11}
12
13impl ToolRegistry {
14    /// Create a new empty tool registry.
15    pub fn new() -> Self {
16        Self {
17            tools: RwLock::new(HashMap::new()),
18        }
19    }
20
21    /// Register a tool in the registry.
22    /// Returns an error if a tool with the same name already exists.
23    pub async fn register(&self, tool: Arc<dyn Executable>) -> Result<(), String> {
24        let name = tool.name().to_string();
25        let mut tools = self.tools.write().await;
26
27        if tools.contains_key(&name) {
28            return Err(format!("tool with name {:?} already exists", name));
29        }
30
31        tools.insert(name, tool);
32        Ok(())
33    }
34
35    /// Get a tool by name.
36    /// Returns None if the tool is not found.
37    pub async fn get(&self, name: &str) -> Option<Arc<dyn Executable>> {
38        let tools = self.tools.read().await;
39        tools.get(name).cloned()
40    }
41
42    /// Check if a tool exists in the registry.
43    pub async fn has(&self, name: &str) -> bool {
44        let tools = self.tools.read().await;
45        tools.contains_key(name)
46    }
47
48    /// Remove a tool from the registry.
49    pub async fn remove(&self, name: &str) {
50        let mut tools = self.tools.write().await;
51        tools.remove(name);
52    }
53
54    /// List all registered tool names.
55    pub async fn list(&self) -> Vec<String> {
56        let tools = self.tools.read().await;
57        tools.keys().cloned().collect()
58    }
59
60    /// Get all registered tools.
61    pub async fn get_all(&self) -> Vec<Arc<dyn Executable>> {
62        let tools = self.tools.read().await;
63        tools.values().cloned().collect()
64    }
65
66    /// Get the number of registered tools.
67    pub async fn len(&self) -> usize {
68        let tools = self.tools.read().await;
69        tools.len()
70    }
71
72    /// Check if the registry is empty.
73    pub async fn is_empty(&self) -> bool {
74        let tools = self.tools.read().await;
75        tools.is_empty()
76    }
77}
78
79impl Default for ToolRegistry {
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85#[cfg(test)]
86mod tests {
87    use super::*;
88    use crate::controller::tools::types::{ToolContext, ToolType};
89    use std::pin::Pin;
90    use std::future::Future;
91
92    struct MockTool {
93        name: String,
94    }
95
96    impl Executable for MockTool {
97        fn name(&self) -> &str {
98            &self.name
99        }
100
101        fn description(&self) -> &str {
102            "A mock tool for testing"
103        }
104
105        fn input_schema(&self) -> &str {
106            r#"{"type":"object"}"#
107        }
108
109        fn tool_type(&self) -> ToolType {
110            ToolType::Custom
111        }
112
113        fn execute(
114            &self,
115            _context: ToolContext,
116            _input: HashMap<String, serde_json::Value>,
117        ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send>> {
118            Box::pin(async { Ok("mock result".to_string()) })
119        }
120    }
121
122    #[tokio::test]
123    async fn test_register_and_get() {
124        let registry = ToolRegistry::new();
125        let tool = Arc::new(MockTool {
126            name: "test_tool".to_string(),
127        });
128
129        registry.register(tool).await.unwrap();
130
131        let retrieved = registry.get("test_tool").await;
132        assert!(retrieved.is_some());
133        assert_eq!(retrieved.unwrap().name(), "test_tool");
134    }
135
136    #[tokio::test]
137    async fn test_duplicate_registration() {
138        let registry = ToolRegistry::new();
139        let tool1 = Arc::new(MockTool {
140            name: "test_tool".to_string(),
141        });
142        let tool2 = Arc::new(MockTool {
143            name: "test_tool".to_string(),
144        });
145
146        registry.register(tool1).await.unwrap();
147        let result = registry.register(tool2).await;
148        assert!(result.is_err());
149    }
150
151    #[tokio::test]
152    async fn test_list_and_remove() {
153        let registry = ToolRegistry::new();
154        let tool = Arc::new(MockTool {
155            name: "test_tool".to_string(),
156        });
157
158        registry.register(tool).await.unwrap();
159        assert!(registry.has("test_tool").await);
160
161        let names = registry.list().await;
162        assert_eq!(names.len(), 1);
163
164        registry.remove("test_tool").await;
165        assert!(!registry.has("test_tool").await);
166    }
167}