use crate::core::platform::container::paladin::Paladin;
use paladin_ports::output::paladin_registry::{PaladinRegistry, RegistryError};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
#[derive(Debug, Default)]
pub struct HashMapPaladinRegistry {
paladins: RwLock<HashMap<String, Arc<Paladin>>>,
}
impl HashMapPaladinRegistry {
pub fn new() -> Self {
Self {
paladins: RwLock::new(HashMap::new()),
}
}
pub fn from_map(paladins: HashMap<String, Arc<Paladin>>) -> Self {
Self {
paladins: RwLock::new(paladins),
}
}
pub fn count(&self) -> usize {
self.paladins.read().map(|map| map.len()).unwrap_or(0)
}
pub fn clear(&self) {
if let Ok(mut map) = self.paladins.write() {
map.clear();
}
}
}
impl PaladinRegistry for HashMapPaladinRegistry {
fn register(&self, id: String, paladin: Arc<Paladin>) -> Result<(), RegistryError> {
if id.is_empty() {
return Err(RegistryError::InvalidId("ID cannot be empty".to_string()));
}
let mut map = self
.paladins
.write()
.map_err(|e| RegistryError::AccessFailed(format!("Write lock poisoned: {}", e)))?;
if map.contains_key(&id) {
return Err(RegistryError::DuplicateId(id));
}
map.insert(id, paladin);
Ok(())
}
fn get(&self, id: &str) -> Option<Arc<Paladin>> {
let map = self.paladins.read().ok()?;
map.get(id).cloned()
}
fn contains(&self, id: &str) -> bool {
let Ok(map) = self.paladins.read() else {
return false;
};
map.contains_key(id)
}
fn list_ids(&self) -> Vec<String> {
let Ok(map) = self.paladins.read() else {
return Vec::new();
};
map.keys().cloned().collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::base::entity::node::Node;
use crate::core::platform::container::paladin::PaladinData;
use std::thread;
fn create_test_paladin(id: &str) -> Paladin {
let data = PaladinData {
system_prompt: format!("Test paladin {}", id),
name: id.to_string(),
..Default::default()
};
Node::new(data, Some(id.to_string()))
}
#[test]
fn test_registry_register_and_get() {
let registry = HashMapPaladinRegistry::new();
let paladin = create_test_paladin("test_paladin");
let result = registry.register("test_paladin".to_string(), Arc::new(paladin));
assert!(result.is_ok(), "Registration should succeed");
let retrieved = registry.get("test_paladin");
assert!(retrieved.is_some(), "Paladin should be retrievable");
assert_eq!(retrieved.unwrap().name, Some("test_paladin".to_string()));
}
#[test]
fn test_registry_duplicate_id_error() {
let registry = HashMapPaladinRegistry::new();
let paladin1 = create_test_paladin("duplicate");
let paladin2 = create_test_paladin("duplicate");
assert!(
registry
.register("duplicate".to_string(), Arc::new(paladin1))
.is_ok()
);
let result = registry.register("duplicate".to_string(), Arc::new(paladin2));
assert!(result.is_err(), "Duplicate registration should fail");
match result.unwrap_err() {
RegistryError::DuplicateId(id) => assert_eq!(id, "duplicate"),
_ => panic!("Expected DuplicateId error"),
}
}
#[test]
fn test_registry_contains() {
let registry = HashMapPaladinRegistry::new();
let paladin = create_test_paladin("exists");
assert!(!registry.contains("exists"), "Should not exist initially");
registry
.register("exists".to_string(), Arc::new(paladin))
.unwrap();
assert!(
registry.contains("exists"),
"Should exist after registration"
);
assert!(
!registry.contains("nonexistent"),
"Non-existent should return false"
);
}
#[test]
fn test_registry_list_ids() {
let registry = HashMapPaladinRegistry::new();
assert_eq!(registry.list_ids().len(), 0);
registry
.register(
"paladin1".to_string(),
Arc::new(create_test_paladin("paladin1")),
)
.unwrap();
registry
.register(
"paladin2".to_string(),
Arc::new(create_test_paladin("paladin2")),
)
.unwrap();
registry
.register(
"paladin3".to_string(),
Arc::new(create_test_paladin("paladin3")),
)
.unwrap();
let ids = registry.list_ids();
assert_eq!(ids.len(), 3);
assert!(ids.contains(&"paladin1".to_string()));
assert!(ids.contains(&"paladin2".to_string()));
assert!(ids.contains(&"paladin3".to_string()));
}
#[test]
fn test_registry_thread_safety() {
let registry = Arc::new(HashMapPaladinRegistry::new());
let mut handles = vec![];
for i in 0..10 {
let registry_clone = Arc::clone(®istry);
let handle = thread::spawn(move || {
let id = format!("paladin_{}", i);
let paladin = create_test_paladin(&id);
registry_clone.register(id, Arc::new(paladin)).unwrap();
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
assert_eq!(registry.count(), 10);
for i in 0..10 {
let id = format!("paladin_{}", i);
assert!(registry.contains(&id), "Paladin {} should exist", id);
}
}
#[test]
fn test_registry_invalid_id() {
let registry = HashMapPaladinRegistry::new();
let paladin = create_test_paladin("test");
let result = registry.register("".to_string(), Arc::new(paladin));
assert!(result.is_err(), "Empty ID should fail");
match result.unwrap_err() {
RegistryError::InvalidId(_) => {}
_ => panic!("Expected InvalidId error"),
}
}
#[test]
fn test_registry_count() {
let registry = HashMapPaladinRegistry::new();
assert_eq!(registry.count(), 0);
registry
.register("p1".to_string(), Arc::new(create_test_paladin("p1")))
.unwrap();
assert_eq!(registry.count(), 1);
registry
.register("p2".to_string(), Arc::new(create_test_paladin("p2")))
.unwrap();
assert_eq!(registry.count(), 2);
}
#[test]
fn test_registry_clear() {
let registry = HashMapPaladinRegistry::new();
registry
.register("p1".to_string(), Arc::new(create_test_paladin("p1")))
.unwrap();
registry
.register("p2".to_string(), Arc::new(create_test_paladin("p2")))
.unwrap();
assert_eq!(registry.count(), 2);
registry.clear();
assert_eq!(registry.count(), 0);
assert!(!registry.contains("p1"));
assert!(!registry.contains("p2"));
}
#[test]
fn test_registry_from_map() {
let mut initial = HashMap::new();
initial.insert("p1".to_string(), Arc::new(create_test_paladin("p1")));
initial.insert("p2".to_string(), Arc::new(create_test_paladin("p2")));
let registry = HashMapPaladinRegistry::from_map(initial);
assert_eq!(registry.count(), 2);
assert!(registry.contains("p1"));
assert!(registry.contains("p2"));
}
}