use super::DynCallable;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
pub struct CallableRegistry {
callables: Arc<RwLock<HashMap<String, DynCallable>>>,
}
impl CallableRegistry {
pub fn new() -> Self {
Self {
callables: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn register(&self, name: String, callable: DynCallable) {
let mut callables = self.callables.write().unwrap();
callables.insert(name, callable);
}
pub fn get(&self, name: &str) -> Option<DynCallable> {
let callables = self.callables.read().unwrap();
callables.get(name).cloned()
}
pub fn list(&self) -> Vec<String> {
let callables = self.callables.read().unwrap();
let mut names: Vec<String> = callables.keys().cloned().collect();
names.sort();
names
}
pub fn contains(&self, name: &str) -> bool {
let callables = self.callables.read().unwrap();
callables.contains_key(name)
}
pub fn remove(&self, name: &str) -> bool {
let mut callables = self.callables.write().unwrap();
callables.remove(name).is_some()
}
pub fn len(&self) -> usize {
let callables = self.callables.read().unwrap();
callables.len()
}
pub fn is_empty(&self) -> bool {
let callables = self.callables.read().unwrap();
callables.is_empty()
}
}
impl Default for CallableRegistry {
fn default() -> Self {
Self::new()
}
}
impl Clone for CallableRegistry {
fn clone(&self) -> Self {
Self {
callables: Arc::clone(&self.callables),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::callable::Callable;
use async_trait::async_trait;
use std::sync::Arc;
struct MockCallable {
name: String,
output: String,
}
#[async_trait]
impl Callable for MockCallable {
fn name(&self) -> &str {
&self.name
}
async fn run(&self, _input: &str) -> anyhow::Result<String> {
Ok(self.output.clone())
}
}
#[tokio::test]
async fn test_register_and_get() {
let registry = CallableRegistry::new();
let callable = Arc::new(MockCallable {
name: "test".to_string(),
output: "test output".to_string(),
});
registry.register("test".to_string(), callable.clone());
let retrieved = registry.get("test").unwrap();
assert_eq!(retrieved.name(), "test");
}
#[tokio::test]
async fn test_get_nonexistent() {
let registry = CallableRegistry::new();
assert!(registry.get("nonexistent").is_none());
}
#[tokio::test]
async fn test_list() {
let registry = CallableRegistry::new();
let callable1 = Arc::new(MockCallable {
name: "callable1".to_string(),
output: "output1".to_string(),
});
let callable2 = Arc::new(MockCallable {
name: "callable2".to_string(),
output: "output2".to_string(),
});
registry.register("callable1".to_string(), callable1);
registry.register("callable2".to_string(), callable2);
let names = registry.list();
assert_eq!(names.len(), 2);
assert!(names.contains(&"callable1".to_string()));
assert!(names.contains(&"callable2".to_string()));
}
#[tokio::test]
async fn test_contains() {
let registry = CallableRegistry::new();
let callable = Arc::new(MockCallable {
name: "test".to_string(),
output: "output".to_string(),
});
registry.register("test".to_string(), callable);
assert!(registry.contains("test"));
assert!(!registry.contains("nonexistent"));
}
#[tokio::test]
async fn test_remove() {
let registry = CallableRegistry::new();
let callable = Arc::new(MockCallable {
name: "test".to_string(),
output: "output".to_string(),
});
registry.register("test".to_string(), callable);
assert!(registry.contains("test"));
assert!(registry.remove("test"));
assert!(!registry.contains("test"));
assert!(!registry.remove("test"));
}
}