agent_core/controller/tools/
registry.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use thiserror::Error;
5use tokio::sync::RwLock;
6
7use super::types::Executable;
8
9/// Error type for tool registry operations.
10#[derive(Error, Debug)]
11pub enum RegistryError {
12    /// Tool with this name already exists.
13    #[error("Tool with name {0:?} already exists")]
14    DuplicateTool(String),
15}
16
17impl From<RegistryError> for String {
18    fn from(err: RegistryError) -> Self {
19        err.to_string()
20    }
21}
22
23/// Thread-safe registry for managing available tools.
24pub struct ToolRegistry {
25    tools: RwLock<HashMap<String, Arc<dyn Executable>>>,
26}
27
28impl ToolRegistry {
29    /// Create a new empty tool registry.
30    pub fn new() -> Self {
31        Self {
32            tools: RwLock::new(HashMap::new()),
33        }
34    }
35
36    /// Register a tool in the registry.
37    /// Returns an error if a tool with the same name already exists.
38    pub async fn register(&self, tool: Arc<dyn Executable>) -> Result<(), RegistryError> {
39        let name = tool.name().to_string();
40        let mut tools = self.tools.write().await;
41
42        if tools.contains_key(&name) {
43            return Err(RegistryError::DuplicateTool(name));
44        }
45
46        tools.insert(name, tool);
47        Ok(())
48    }
49
50    /// Get a tool by name.
51    /// Returns None if the tool is not found.
52    pub async fn get(&self, name: &str) -> Option<Arc<dyn Executable>> {
53        let tools = self.tools.read().await;
54        tools.get(name).cloned()
55    }
56
57    /// Check if a tool exists in the registry.
58    pub async fn has(&self, name: &str) -> bool {
59        let tools = self.tools.read().await;
60        tools.contains_key(name)
61    }
62
63    /// Remove a tool from the registry.
64    pub async fn remove(&self, name: &str) {
65        let mut tools = self.tools.write().await;
66        tools.remove(name);
67    }
68
69    /// List all registered tool names.
70    pub async fn list(&self) -> Vec<String> {
71        let tools = self.tools.read().await;
72        tools.keys().cloned().collect()
73    }
74
75    /// Get all registered tools.
76    pub async fn get_all(&self) -> Vec<Arc<dyn Executable>> {
77        let tools = self.tools.read().await;
78        tools.values().cloned().collect()
79    }
80
81    /// Get the number of registered tools.
82    pub async fn len(&self) -> usize {
83        let tools = self.tools.read().await;
84        tools.len()
85    }
86
87    /// Check if the registry is empty.
88    pub async fn is_empty(&self) -> bool {
89        let tools = self.tools.read().await;
90        tools.is_empty()
91    }
92}
93
94impl Default for ToolRegistry {
95    fn default() -> Self {
96        Self::new()
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103    use crate::controller::tools::types::{ToolContext, ToolType};
104    use std::pin::Pin;
105    use std::future::Future;
106
107    struct MockTool {
108        name: String,
109    }
110
111    impl Executable for MockTool {
112        fn name(&self) -> &str {
113            &self.name
114        }
115
116        fn description(&self) -> &str {
117            "A mock tool for testing"
118        }
119
120        fn input_schema(&self) -> &str {
121            r#"{"type":"object"}"#
122        }
123
124        fn tool_type(&self) -> ToolType {
125            ToolType::Custom
126        }
127
128        fn execute(
129            &self,
130            _context: ToolContext,
131            _input: HashMap<String, serde_json::Value>,
132        ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send>> {
133            Box::pin(async { Ok("mock result".to_string()) })
134        }
135    }
136
137    #[tokio::test]
138    async fn test_register_and_get() {
139        let registry = ToolRegistry::new();
140        let tool = Arc::new(MockTool {
141            name: "test_tool".to_string(),
142        });
143
144        registry.register(tool).await.unwrap();
145
146        let retrieved = registry.get("test_tool").await;
147        assert!(retrieved.is_some());
148        assert_eq!(retrieved.unwrap().name(), "test_tool");
149    }
150
151    #[tokio::test]
152    async fn test_duplicate_registration() {
153        let registry = ToolRegistry::new();
154        let tool1 = Arc::new(MockTool {
155            name: "test_tool".to_string(),
156        });
157        let tool2 = Arc::new(MockTool {
158            name: "test_tool".to_string(),
159        });
160
161        registry.register(tool1).await.unwrap();
162        let result = registry.register(tool2).await;
163        assert!(result.is_err());
164    }
165
166    #[tokio::test]
167    async fn test_list_and_remove() {
168        let registry = ToolRegistry::new();
169        let tool = Arc::new(MockTool {
170            name: "test_tool".to_string(),
171        });
172
173        registry.register(tool).await.unwrap();
174        assert!(registry.has("test_tool").await);
175
176        let names = registry.list().await;
177        assert_eq!(names.len(), 1);
178
179        registry.remove("test_tool").await;
180        assert!(!registry.has("test_tool").await);
181    }
182}