Skip to main content

enact_core/callable/
registry.rs

1//! Callable Registry - Central registry for all callable types
2//!
3//! The registry allows registering and looking up any callable type (graphs, agents, workflows, LLM functions, etc.)
4//! by name. This enables the unified agent/stream endpoint to execute any callable type through a single interface.
5
6use super::DynCallable;
7use std::collections::HashMap;
8use std::sync::{Arc, RwLock};
9
10/// Thread-safe registry for callables
11///
12/// Allows registering any callable type (LlmCallable, GraphCallable, FnCallable, etc.)
13/// and looking them up by name. Used by the unified agent/stream endpoint.
14pub struct CallableRegistry {
15    callables: Arc<RwLock<HashMap<String, DynCallable>>>,
16}
17
18impl CallableRegistry {
19    /// Create a new empty registry
20    pub fn new() -> Self {
21        Self {
22            callables: Arc::new(RwLock::new(HashMap::new())),
23        }
24    }
25
26    /// Register a callable with a given name
27    ///
28    /// # Arguments
29    /// * `name` - The name to register the callable under
30    /// * `callable` - The callable to register
31    ///
32    /// # Panics
33    /// This will overwrite any existing callable with the same name.
34    pub fn register(&self, name: String, callable: DynCallable) {
35        let mut callables = self.callables.write().unwrap();
36        callables.insert(name, callable);
37    }
38
39    /// Get a callable by name
40    ///
41    /// # Arguments
42    /// * `name` - The name of the callable to retrieve
43    ///
44    /// # Returns
45    /// * `Some(DynCallable)` if found
46    /// * `None` if not found
47    pub fn get(&self, name: &str) -> Option<DynCallable> {
48        let callables = self.callables.read().unwrap();
49        callables.get(name).cloned()
50    }
51
52    /// List all registered callable names
53    ///
54    /// # Returns
55    /// A vector of all registered callable names, sorted alphabetically
56    pub fn list(&self) -> Vec<String> {
57        let callables = self.callables.read().unwrap();
58        let mut names: Vec<String> = callables.keys().cloned().collect();
59        names.sort();
60        names
61    }
62
63    /// Check if a callable is registered
64    ///
65    /// # Arguments
66    /// * `name` - The name to check
67    ///
68    /// # Returns
69    /// `true` if registered, `false` otherwise
70    pub fn contains(&self, name: &str) -> bool {
71        let callables = self.callables.read().unwrap();
72        callables.contains_key(name)
73    }
74
75    /// Remove a callable from the registry
76    ///
77    /// # Arguments
78    /// * `name` - The name of the callable to remove
79    ///
80    /// # Returns
81    /// `true` if the callable was removed, `false` if it wasn't found
82    pub fn remove(&self, name: &str) -> bool {
83        let mut callables = self.callables.write().unwrap();
84        callables.remove(name).is_some()
85    }
86
87    /// Get the number of registered callables
88    pub fn len(&self) -> usize {
89        let callables = self.callables.read().unwrap();
90        callables.len()
91    }
92
93    /// Check if the registry is empty
94    pub fn is_empty(&self) -> bool {
95        let callables = self.callables.read().unwrap();
96        callables.is_empty()
97    }
98}
99
100impl Default for CallableRegistry {
101    fn default() -> Self {
102        Self::new()
103    }
104}
105
106impl Clone for CallableRegistry {
107    fn clone(&self) -> Self {
108        Self {
109            callables: Arc::clone(&self.callables),
110        }
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117    use crate::callable::Callable;
118    use async_trait::async_trait;
119    use std::sync::Arc;
120
121    struct MockCallable {
122        name: String,
123        output: String,
124    }
125
126    #[async_trait]
127    impl Callable for MockCallable {
128        fn name(&self) -> &str {
129            &self.name
130        }
131
132        async fn run(&self, _input: &str) -> anyhow::Result<String> {
133            Ok(self.output.clone())
134        }
135    }
136
137    #[tokio::test]
138    async fn test_register_and_get() {
139        let registry = CallableRegistry::new();
140        let callable = Arc::new(MockCallable {
141            name: "test".to_string(),
142            output: "test output".to_string(),
143        });
144
145        registry.register("test".to_string(), callable.clone());
146        let retrieved = registry.get("test").unwrap();
147        assert_eq!(retrieved.name(), "test");
148    }
149
150    #[tokio::test]
151    async fn test_get_nonexistent() {
152        let registry = CallableRegistry::new();
153        assert!(registry.get("nonexistent").is_none());
154    }
155
156    #[tokio::test]
157    async fn test_list() {
158        let registry = CallableRegistry::new();
159        let callable1 = Arc::new(MockCallable {
160            name: "callable1".to_string(),
161            output: "output1".to_string(),
162        });
163        let callable2 = Arc::new(MockCallable {
164            name: "callable2".to_string(),
165            output: "output2".to_string(),
166        });
167
168        registry.register("callable1".to_string(), callable1);
169        registry.register("callable2".to_string(), callable2);
170
171        let names = registry.list();
172        assert_eq!(names.len(), 2);
173        assert!(names.contains(&"callable1".to_string()));
174        assert!(names.contains(&"callable2".to_string()));
175    }
176
177    #[tokio::test]
178    async fn test_contains() {
179        let registry = CallableRegistry::new();
180        let callable = Arc::new(MockCallable {
181            name: "test".to_string(),
182            output: "output".to_string(),
183        });
184
185        registry.register("test".to_string(), callable);
186        assert!(registry.contains("test"));
187        assert!(!registry.contains("nonexistent"));
188    }
189
190    #[tokio::test]
191    async fn test_remove() {
192        let registry = CallableRegistry::new();
193        let callable = Arc::new(MockCallable {
194            name: "test".to_string(),
195            output: "output".to_string(),
196        });
197
198        registry.register("test".to_string(), callable);
199        assert!(registry.contains("test"));
200
201        assert!(registry.remove("test"));
202        assert!(!registry.contains("test"));
203        assert!(!registry.remove("test"));
204    }
205}