use crate::container::scope::ServiceScope;
use crate::errors::CoreError;
use crate::foundation::traits::Service;
use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
pub enum ServiceEntry {
Instance(Arc<dyn Any + Send + Sync>),
Factory(Box<dyn Fn() -> Box<dyn Any + Send + Sync> + Send + Sync>),
}
impl std::fmt::Debug for ServiceEntry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ServiceEntry::Instance(_) => f.debug_tuple("Instance").field(&"<instance>").finish(),
ServiceEntry::Factory(_) => f.debug_tuple("Factory").field(&"<factory>").finish(),
}
}
}
#[derive(Debug)]
pub struct ServiceRegistry {
services: Arc<RwLock<HashMap<TypeId, ServiceEntry>>>,
scopes: Arc<RwLock<HashMap<TypeId, ServiceScope>>>,
}
impl ServiceRegistry {
pub fn new() -> Self {
Self {
services: Arc::new(RwLock::new(HashMap::new())),
scopes: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn register_service<T>(&mut self, service: T) -> Result<(), CoreError>
where
T: Service + Clone + 'static,
{
let type_id = TypeId::of::<T>();
let arc_service = Arc::new(service);
let mut services = self.services.write().map_err(|_| CoreError::LockError {
resource: "service_registry".to_string(),
})?;
let mut scopes = self.scopes.write().map_err(|_| CoreError::LockError {
resource: "service_scopes".to_string(),
})?;
services.insert(type_id, ServiceEntry::Instance(arc_service));
scopes.insert(type_id, ServiceScope::Singleton);
Ok(())
}
pub fn register_singleton<T>(&mut self, service: T) -> Result<(), CoreError>
where
T: Service + Clone + 'static,
{
let type_id = TypeId::of::<T>();
let arc_service = Arc::new(service);
let mut services = self.services.write().map_err(|_| CoreError::LockError {
resource: "service_registry".to_string(),
})?;
let mut scopes = self.scopes.write().map_err(|_| CoreError::LockError {
resource: "service_scopes".to_string(),
})?;
services.insert(type_id, ServiceEntry::Instance(arc_service));
scopes.insert(type_id, ServiceScope::Singleton);
Ok(())
}
pub fn register_transient<T>(
&mut self,
factory: Box<dyn Fn() -> T + Send + Sync>,
) -> Result<(), CoreError>
where
T: Service + 'static,
{
let type_id = TypeId::of::<T>();
let wrapped_factory: Box<dyn Fn() -> Box<dyn Any + Send + Sync> + Send + Sync> =
Box::new(move || -> Box<dyn Any + Send + Sync> { Box::new(factory()) });
let mut services = self.services.write().map_err(|_| CoreError::LockError {
resource: "service_registry".to_string(),
})?;
let mut scopes = self.scopes.write().map_err(|_| CoreError::LockError {
resource: "service_scopes".to_string(),
})?;
services.insert(type_id, ServiceEntry::Factory(wrapped_factory));
scopes.insert(type_id, ServiceScope::Transient);
Ok(())
}
pub fn resolve<T>(&self) -> Result<Arc<T>, CoreError>
where
T: Service + Clone + 'static,
{
self.try_resolve::<T>()
.ok_or_else(|| CoreError::ServiceNotFound {
service_type: std::any::type_name::<T>().to_string(),
})
}
pub fn try_resolve<T>(&self) -> Option<Arc<T>>
where
T: Service + Clone + 'static,
{
let type_id = TypeId::of::<T>();
let services = self.services.read().ok()?;
match services.get(&type_id)? {
ServiceEntry::Instance(instance) => {
instance.clone().downcast::<T>().ok()
}
ServiceEntry::Factory(factory) => {
let instance = factory();
let boxed = instance.downcast::<T>().ok()?;
Some(Arc::new(*boxed))
}
}
}
pub fn contains<T>(&self) -> bool
where
T: Service + 'static,
{
let type_id = TypeId::of::<T>();
self.services
.read()
.map(|services| services.contains_key(&type_id))
.unwrap_or(false)
}
pub fn service_count(&self) -> usize {
self.services
.read()
.map(|services| services.len())
.unwrap_or(0)
}
pub fn registered_services(&self) -> Vec<TypeId> {
self.services
.read()
.map(|services| services.keys().cloned().collect())
.unwrap_or_default()
}
pub fn validate(&self) -> Result<(), CoreError> {
let _services = self.services.read().map_err(|_| CoreError::LockError {
resource: "service_registry".to_string(),
})?;
let _scopes = self.scopes.read().map_err(|_| CoreError::LockError {
resource: "service_scopes".to_string(),
})?;
Ok(())
}
pub async fn initialize_all(&self) -> Result<(), CoreError> {
self.validate()
}
}
impl Default for ServiceRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::foundation::traits::Service;
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Debug, Clone)]
struct TestService {
id: usize,
counter: Arc<AtomicUsize>,
}
impl TestService {
fn new() -> Self {
static COUNTER: AtomicUsize = AtomicUsize::new(0);
Self {
id: COUNTER.fetch_add(1, Ordering::SeqCst),
counter: Arc::new(AtomicUsize::new(0)),
}
}
fn increment(&self) -> usize {
self.counter.fetch_add(1, Ordering::SeqCst) + 1
}
fn get_count(&self) -> usize {
self.counter.load(Ordering::SeqCst)
}
}
impl crate::foundation::traits::FrameworkComponent for TestService {}
impl Service for TestService {}
#[test]
fn test_singleton_behavior() {
let mut registry = ServiceRegistry::new();
let service = TestService::new();
let original_id = service.id;
registry.register_singleton(service).unwrap();
let instance1 = registry.resolve::<TestService>().unwrap();
let instance2 = registry.resolve::<TestService>().unwrap();
let instance3 = registry.resolve::<TestService>().unwrap();
assert_eq!(instance1.id, original_id);
assert_eq!(instance2.id, original_id);
assert_eq!(instance3.id, original_id);
let count1 = instance1.increment();
assert_eq!(count1, 1);
assert_eq!(instance2.get_count(), 1);
assert_eq!(instance3.get_count(), 1);
let count2 = instance2.increment();
assert_eq!(count2, 2);
assert_eq!(instance1.get_count(), 2);
assert_eq!(instance3.get_count(), 2);
}
#[test]
fn test_singleton_arc_sharing() {
let mut registry = ServiceRegistry::new();
let service = TestService::new();
registry.register_singleton(service).unwrap();
let instance1 = registry.resolve::<TestService>().unwrap();
let instance2 = registry.resolve::<TestService>().unwrap();
assert!(Arc::ptr_eq(&instance1, &instance2));
}
#[test]
fn test_transient_behavior() {
let mut registry = ServiceRegistry::new();
registry
.register_transient::<TestService>(Box::new(|| TestService::new()))
.unwrap();
let instance1 = registry.resolve::<TestService>().unwrap();
let instance2 = registry.resolve::<TestService>().unwrap();
assert_ne!(instance1.id, instance2.id);
assert!(!Arc::ptr_eq(&instance1, &instance2));
instance1.increment();
assert_eq!(instance1.get_count(), 1);
assert_eq!(instance2.get_count(), 0);
}
#[test]
fn test_service_registry_operations() {
let mut registry = ServiceRegistry::new();
let service = TestService::new();
assert!(!registry.contains::<TestService>());
assert_eq!(registry.service_count(), 0);
registry.register_singleton(service).unwrap();
assert!(registry.contains::<TestService>());
assert_eq!(registry.service_count(), 1);
let resolved = registry.resolve::<TestService>().unwrap();
assert_eq!(
resolved.service_id(),
"elif_core::container::registry::tests::TestService"
);
}
}