enact-core 0.0.2

Core agent runtime for Enact - Graph-Native AI agents
Documentation
//! Callable Registry - Central registry for all callable types
//!
//! The registry allows registering and looking up any callable type (graphs, agents, workflows, LLM functions, etc.)
//! by name. This enables the unified agent/stream endpoint to execute any callable type through a single interface.

use super::DynCallable;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};

/// Thread-safe registry for callables
///
/// Allows registering any callable type (LlmCallable, GraphCallable, FnCallable, etc.)
/// and looking them up by name. Used by the unified agent/stream endpoint.
pub struct CallableRegistry {
    callables: Arc<RwLock<HashMap<String, DynCallable>>>,
}

impl CallableRegistry {
    /// Create a new empty registry
    pub fn new() -> Self {
        Self {
            callables: Arc::new(RwLock::new(HashMap::new())),
        }
    }

    /// Register a callable with a given name
    ///
    /// # Arguments
    /// * `name` - The name to register the callable under
    /// * `callable` - The callable to register
    ///
    /// # Panics
    /// This will overwrite any existing callable with the same name.
    pub fn register(&self, name: String, callable: DynCallable) {
        let mut callables = self.callables.write().unwrap();
        callables.insert(name, callable);
    }

    /// Get a callable by name
    ///
    /// # Arguments
    /// * `name` - The name of the callable to retrieve
    ///
    /// # Returns
    /// * `Some(DynCallable)` if found
    /// * `None` if not found
    pub fn get(&self, name: &str) -> Option<DynCallable> {
        let callables = self.callables.read().unwrap();
        callables.get(name).cloned()
    }

    /// List all registered callable names
    ///
    /// # Returns
    /// A vector of all registered callable names, sorted alphabetically
    pub fn list(&self) -> Vec<String> {
        let callables = self.callables.read().unwrap();
        let mut names: Vec<String> = callables.keys().cloned().collect();
        names.sort();
        names
    }

    /// Check if a callable is registered
    ///
    /// # Arguments
    /// * `name` - The name to check
    ///
    /// # Returns
    /// `true` if registered, `false` otherwise
    pub fn contains(&self, name: &str) -> bool {
        let callables = self.callables.read().unwrap();
        callables.contains_key(name)
    }

    /// Remove a callable from the registry
    ///
    /// # Arguments
    /// * `name` - The name of the callable to remove
    ///
    /// # Returns
    /// `true` if the callable was removed, `false` if it wasn't found
    pub fn remove(&self, name: &str) -> bool {
        let mut callables = self.callables.write().unwrap();
        callables.remove(name).is_some()
    }

    /// Get the number of registered callables
    pub fn len(&self) -> usize {
        let callables = self.callables.read().unwrap();
        callables.len()
    }

    /// Check if the registry is empty
    pub fn is_empty(&self) -> bool {
        let callables = self.callables.read().unwrap();
        callables.is_empty()
    }
}

impl Default for CallableRegistry {
    fn default() -> Self {
        Self::new()
    }
}

impl Clone for CallableRegistry {
    fn clone(&self) -> Self {
        Self {
            callables: Arc::clone(&self.callables),
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::callable::Callable;
    use async_trait::async_trait;
    use std::sync::Arc;

    struct MockCallable {
        name: String,
        output: String,
    }

    #[async_trait]
    impl Callable for MockCallable {
        fn name(&self) -> &str {
            &self.name
        }

        async fn run(&self, _input: &str) -> anyhow::Result<String> {
            Ok(self.output.clone())
        }
    }

    #[tokio::test]
    async fn test_register_and_get() {
        let registry = CallableRegistry::new();
        let callable = Arc::new(MockCallable {
            name: "test".to_string(),
            output: "test output".to_string(),
        });

        registry.register("test".to_string(), callable.clone());
        let retrieved = registry.get("test").unwrap();
        assert_eq!(retrieved.name(), "test");
    }

    #[tokio::test]
    async fn test_get_nonexistent() {
        let registry = CallableRegistry::new();
        assert!(registry.get("nonexistent").is_none());
    }

    #[tokio::test]
    async fn test_list() {
        let registry = CallableRegistry::new();
        let callable1 = Arc::new(MockCallable {
            name: "callable1".to_string(),
            output: "output1".to_string(),
        });
        let callable2 = Arc::new(MockCallable {
            name: "callable2".to_string(),
            output: "output2".to_string(),
        });

        registry.register("callable1".to_string(), callable1);
        registry.register("callable2".to_string(), callable2);

        let names = registry.list();
        assert_eq!(names.len(), 2);
        assert!(names.contains(&"callable1".to_string()));
        assert!(names.contains(&"callable2".to_string()));
    }

    #[tokio::test]
    async fn test_contains() {
        let registry = CallableRegistry::new();
        let callable = Arc::new(MockCallable {
            name: "test".to_string(),
            output: "output".to_string(),
        });

        registry.register("test".to_string(), callable);
        assert!(registry.contains("test"));
        assert!(!registry.contains("nonexistent"));
    }

    #[tokio::test]
    async fn test_remove() {
        let registry = CallableRegistry::new();
        let callable = Arc::new(MockCallable {
            name: "test".to_string(),
            output: "output".to_string(),
        });

        registry.register("test".to_string(), callable);
        assert!(registry.contains("test"));

        assert!(registry.remove("test"));
        assert!(!registry.contains("test"));
        assert!(!registry.remove("test"));
    }
}