Skip to main content

agent_core_runtime/controller/tools/
registry.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use thiserror::Error;
5use tokio::sync::RwLock;
6
7use super::types::Executable;
8
9/// Error type for tool registry operations.
10#[derive(Error, Debug)]
11pub enum RegistryError {
12    /// Tool with this name already exists.
13    #[error("Tool with name {0:?} already exists")]
14    DuplicateTool(String),
15}
16
17impl From<RegistryError> for String {
18    fn from(err: RegistryError) -> Self {
19        err.to_string()
20    }
21}
22
23/// Thread-safe registry for managing available tools.
24pub struct ToolRegistry {
25    tools: RwLock<HashMap<String, Arc<dyn Executable>>>,
26}
27
28impl ToolRegistry {
29    /// Create a new empty tool registry.
30    pub fn new() -> Self {
31        Self {
32            tools: RwLock::new(HashMap::new()),
33        }
34    }
35
36    /// Register a tool in the registry.
37    /// Returns an error if a tool with the same name already exists.
38    pub async fn register(&self, tool: Arc<dyn Executable>) -> Result<(), RegistryError> {
39        let name = tool.name().to_string();
40        let mut tools = self.tools.write().await;
41
42        if tools.contains_key(&name) {
43            return Err(RegistryError::DuplicateTool(name));
44        }
45
46        tools.insert(name, tool);
47        Ok(())
48    }
49
50    /// Get a tool by name.
51    /// Returns None if the tool is not found.
52    pub async fn get(&self, name: &str) -> Option<Arc<dyn Executable>> {
53        let tools = self.tools.read().await;
54        tools.get(name).cloned()
55    }
56
57    /// Check if a tool exists in the registry.
58    pub async fn has(&self, name: &str) -> bool {
59        let tools = self.tools.read().await;
60        tools.contains_key(name)
61    }
62
63    /// Remove a tool from the registry.
64    pub async fn remove(&self, name: &str) {
65        let mut tools = self.tools.write().await;
66        tools.remove(name);
67    }
68
69    /// List all registered tool names.
70    pub async fn list(&self) -> Vec<String> {
71        let tools = self.tools.read().await;
72        tools.keys().cloned().collect()
73    }
74
75    /// Get all registered tools.
76    pub async fn get_all(&self) -> Vec<Arc<dyn Executable>> {
77        let tools = self.tools.read().await;
78        tools.values().cloned().collect()
79    }
80
81    /// Get the number of registered tools.
82    pub async fn len(&self) -> usize {
83        let tools = self.tools.read().await;
84        tools.len()
85    }
86
87    /// Check if the registry is empty.
88    pub async fn is_empty(&self) -> bool {
89        let tools = self.tools.read().await;
90        tools.is_empty()
91    }
92
93    /// Cleans up session-specific state in all registered tools.
94    ///
95    /// This should be called when a session is removed to prevent
96    /// unbounded memory growth from abandoned session state in tools.
97    pub async fn cleanup_session(&self, session_id: i64) {
98        let tools = self.tools.read().await;
99        for tool in tools.values() {
100            tool.cleanup_session(session_id).await;
101        }
102    }
103}
104
105impl Default for ToolRegistry {
106    fn default() -> Self {
107        Self::new()
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114    use crate::controller::tools::types::{ToolContext, ToolType};
115    use std::pin::Pin;
116    use std::future::Future;
117
118    struct MockTool {
119        name: String,
120    }
121
122    impl Executable for MockTool {
123        fn name(&self) -> &str {
124            &self.name
125        }
126
127        fn description(&self) -> &str {
128            "A mock tool for testing"
129        }
130
131        fn input_schema(&self) -> &str {
132            r#"{"type":"object"}"#
133        }
134
135        fn tool_type(&self) -> ToolType {
136            ToolType::Custom
137        }
138
139        fn execute(
140            &self,
141            _context: ToolContext,
142            _input: HashMap<String, serde_json::Value>,
143        ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send>> {
144            Box::pin(async { Ok("mock result".to_string()) })
145        }
146    }
147
148    #[tokio::test]
149    async fn test_register_and_get() {
150        let registry = ToolRegistry::new();
151        let tool = Arc::new(MockTool {
152            name: "test_tool".to_string(),
153        });
154
155        registry.register(tool).await.unwrap();
156
157        let retrieved = registry.get("test_tool").await;
158        assert!(retrieved.is_some());
159        assert_eq!(retrieved.unwrap().name(), "test_tool");
160    }
161
162    #[tokio::test]
163    async fn test_duplicate_registration() {
164        let registry = ToolRegistry::new();
165        let tool1 = Arc::new(MockTool {
166            name: "test_tool".to_string(),
167        });
168        let tool2 = Arc::new(MockTool {
169            name: "test_tool".to_string(),
170        });
171
172        registry.register(tool1).await.unwrap();
173        let result = registry.register(tool2).await;
174        assert!(result.is_err());
175    }
176
177    #[tokio::test]
178    async fn test_list_and_remove() {
179        let registry = ToolRegistry::new();
180        let tool = Arc::new(MockTool {
181            name: "test_tool".to_string(),
182        });
183
184        registry.register(tool).await.unwrap();
185        assert!(registry.has("test_tool").await);
186
187        let names = registry.list().await;
188        assert_eq!(names.len(), 1);
189
190        registry.remove("test_tool").await;
191        assert!(!registry.has("test_tool").await);
192    }
193}