use dashmap::DashMap;
use once_cell::sync::Lazy;
use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::sync::{Arc, Mutex, RwLock};
use crate::{
descriptor::ServiceProvider as DescriptorServiceProvider, DiError, DiResult, Lifetime,
ServiceDescriptor, ServiceKey,
};
static SINGLETON_SERVICES: Lazy<DashMap<ServiceKey, Arc<dyn Any + Send + Sync>>> =
Lazy::new(DashMap::new);
pub struct Container {
services: Arc<RwLock<HashMap<ServiceKey, ServiceDescriptor>>>,
resolution_stack: Arc<Mutex<Vec<ServiceKey>>>,
}
impl Container {
pub fn new() -> Self {
Self {
services: Arc::new(RwLock::new(HashMap::new())),
resolution_stack: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn register(&self, descriptor: ServiceDescriptor) -> DiResult<()> {
let mut services = self
.services
.write()
.map_err(|_| DiError::generic("Failed to acquire services write lock"))?;
if services.contains_key(&descriptor.service_key) {
return Err(DiError::Generic {
message: format!(
"Service with key {:?} is already registered",
descriptor.service_key
),
});
}
services.insert(descriptor.service_key.clone(), descriptor);
Ok(())
}
pub fn register_overwrite(&self, descriptor: ServiceDescriptor) -> DiResult<()> {
let mut services = self
.services
.write()
.map_err(|_| DiError::generic("Failed to acquire services write lock"))?;
services.insert(descriptor.service_key.clone(), descriptor);
Ok(())
}
pub fn is_registered<T: 'static>(&self) -> DiResult<bool> {
let key = ServiceKey::of_type::<T>();
self.is_registered_with_key(&key)
}
pub fn is_keyed_registered<T: 'static>(&self, name: &str) -> DiResult<bool> {
let key = ServiceKey::named::<T>(name);
self.is_registered_with_key(&key)
}
pub fn is_registered_with_key(&self, key: &ServiceKey) -> DiResult<bool> {
let services = self
.services
.read()
.map_err(|_| DiError::generic("Failed to acquire services read lock"))?;
Ok(services.contains_key(key))
}
fn get_descriptor(&self, key: &ServiceKey) -> DiResult<Option<ServiceDescriptor>> {
let services = self
.services
.read()
.map_err(|_| DiError::generic("Failed to acquire services read lock"))?;
Ok(services.get(key).cloned())
}
pub fn build_provider(self) -> ServiceProvider {
ServiceProvider::new(Arc::new(self))
}
pub fn build(self) -> ServiceProvider {
self.build_provider()
}
}
impl Default for Container {
fn default() -> Self {
Self::new()
}
}
type ScopeStorage = Arc<RwLock<HashMap<ServiceKey, Arc<dyn Any + Send + Sync>>>>;
pub struct ServiceProvider {
container: Arc<Container>,
}
impl ServiceProvider {
fn new(container: Arc<Container>) -> Self {
Self { container }
}
pub fn get_services<T: 'static + Send + Sync>(&self) -> DiResult<Vec<Arc<T>>> {
let descriptors = self.get_all_descriptors_for_type::<T>()?;
let mut services = Vec::new();
for descriptor in descriptors {
if let Some(service) = self.resolve_service::<T>(&descriptor.service_key, None)? {
services.push(service);
}
}
Ok(services)
}
pub fn create_scope(&self) -> DiResult<ServiceScope> {
ServiceScope::new(Arc::clone(&self.container))
}
fn resolve_service<T: 'static + Send + Sync>(
&self,
key: &ServiceKey,
scope_storage: Option<&ScopeStorage>,
) -> DiResult<Option<Arc<T>>> {
self.begin_resolution(key)?;
let result = self.internal_resolve_service::<T>(key, scope_storage);
self.end_resolution(key)?;
result
}
fn check_circular_dependency(&self, key: &ServiceKey) -> DiResult<()> {
let stack = self
.container
.resolution_stack
.lock()
.map_err(|_| DiError::generic("Failed to acquire resolution stack lock"))?;
if stack.contains(key) {
return Err(DiError::Generic {
message: format!("Circular dependency detected for service key: {key:?}"),
});
}
Ok(())
}
fn begin_resolution(&self, key: &ServiceKey) -> DiResult<()> {
self.check_circular_dependency(key)?;
let mut stack = self
.container
.resolution_stack
.lock()
.map_err(|_| DiError::generic("Failed to acquire resolution stack lock"))?;
stack.push(key.clone());
Ok(())
}
fn end_resolution(&self, key: &ServiceKey) -> DiResult<()> {
let mut stack = self
.container
.resolution_stack
.lock()
.map_err(|_| DiError::generic("Failed to acquire resolution stack lock"))?;
if let Some(pos) = stack.iter().position(|k| k == key) {
stack.remove(pos);
}
Ok(())
}
fn internal_resolve_service<T: 'static + Send + Sync>(
&self,
key: &ServiceKey,
scope_storage: Option<&ScopeStorage>,
) -> DiResult<Option<Arc<T>>> {
let descriptor = match self.container.get_descriptor(key)? {
Some(desc) => desc,
None => return Ok(None),
};
match descriptor.lifetime {
Lifetime::Singleton => self.resolve_singleton::<T>(&descriptor),
Lifetime::Scoped => match scope_storage {
Some(storage) => self.resolve_scoped::<T>(&descriptor, storage),
None => Err(DiError::Generic {
message: format!("Scoped service cannot be resolved without a scope: {key:?}"),
}),
},
Lifetime::Transient => self.resolve_transient::<T>(&descriptor),
}
}
fn resolve_singleton<T: 'static + Send + Sync>(
&self,
descriptor: &ServiceDescriptor,
) -> DiResult<Option<Arc<T>>> {
if let Some(cached) = SINGLETON_SERVICES.get(&descriptor.service_key) {
let any_arc = Arc::clone(&cached);
return self.cast_to_arc::<T>(any_arc);
}
let provider = ContainerServiceProvider::new(Arc::clone(&self.container), None);
let instance = descriptor.create_instance(&provider)?;
let typed_instance = self.box_to_typed_arc::<T>(instance)?;
let any_arc: Arc<dyn Any + Send + Sync> = typed_instance.clone();
SINGLETON_SERVICES.insert(descriptor.service_key.clone(), any_arc);
Ok(Some(typed_instance))
}
fn resolve_scoped<T: 'static + Send + Sync>(
&self,
descriptor: &ServiceDescriptor,
scope_storage: &ScopeStorage,
) -> DiResult<Option<Arc<T>>> {
{
let storage = scope_storage
.read()
.map_err(|_| DiError::generic("Failed to acquire scope storage read lock"))?;
if let Some(cached) = storage.get(&descriptor.service_key) {
let any_arc = Arc::clone(cached);
return self.cast_to_arc::<T>(any_arc);
}
}
let provider =
ContainerServiceProvider::new(Arc::clone(&self.container), Some(scope_storage.clone()));
let instance = descriptor.create_instance(&provider)?;
let typed_instance = self.box_to_typed_arc::<T>(instance)?;
let any_arc: Arc<dyn Any + Send + Sync> = typed_instance.clone();
{
let mut storage = scope_storage
.write()
.map_err(|_| DiError::generic("Failed to acquire scope storage write lock"))?;
storage.insert(descriptor.service_key.clone(), any_arc);
}
Ok(Some(typed_instance))
}
fn resolve_transient<T: 'static + Send + Sync>(
&self,
descriptor: &ServiceDescriptor,
) -> DiResult<Option<Arc<T>>> {
let provider = ContainerServiceProvider::new(Arc::clone(&self.container), None);
let instance = descriptor.create_instance(&provider)?;
let typed_instance = self.box_to_typed_arc::<T>(instance)?;
Ok(Some(typed_instance))
}
fn box_to_typed_arc<T: 'static + Send + Sync>(
&self,
instance: Box<dyn Any + Send + Sync>,
) -> DiResult<Arc<T>> {
match instance.downcast::<T>() {
Ok(boxed) => Ok(Arc::new(*boxed)),
Err(_) => Err(DiError::type_casting_failed::<T>()),
}
}
fn cast_to_arc<T: 'static + Send + Sync>(
&self,
any_arc: Arc<dyn Any + Send + Sync>,
) -> DiResult<Option<Arc<T>>> {
if let Ok(arc_t) = any_arc.downcast::<T>() {
return Ok(Some(arc_t));
}
Err(DiError::type_casting_failed::<T>())
}
fn get_all_descriptors_for_type<T: 'static + Send + Sync>(
&self,
) -> DiResult<Vec<ServiceDescriptor>> {
let services = self
.container
.services
.read()
.map_err(|_| DiError::generic("Failed to acquire services read lock"))?;
let target_type_id = TypeId::of::<T>();
let descriptors: Vec<ServiceDescriptor> = services
.values()
.filter(|desc| desc.service_type == target_type_id)
.cloned()
.collect();
Ok(descriptors)
}
}
struct ContainerServiceProvider {
container: Arc<Container>,
scope_storage: Option<ScopeStorage>,
}
impl ContainerServiceProvider {
fn new(container: Arc<Container>, scope_storage: Option<ScopeStorage>) -> Self {
Self {
container,
scope_storage,
}
}
}
impl DescriptorServiceProvider for ServiceProvider {
fn get_service_raw(&self, key: &ServiceKey) -> DiResult<Option<Arc<dyn Any + Send + Sync>>> {
let inner_provider = ContainerServiceProvider::new(Arc::clone(&self.container), None);
inner_provider.get_service_raw(key)
}
}
impl DescriptorServiceProvider for ContainerServiceProvider {
fn get_service_raw(&self, key: &ServiceKey) -> DiResult<Option<Arc<dyn Any + Send + Sync>>> {
let descriptor = match self.container.get_descriptor(key)? {
Some(desc) => desc,
None => return Ok(None),
};
match descriptor.lifetime {
Lifetime::Singleton => {
if let Some(cached) = SINGLETON_SERVICES.get(&descriptor.service_key) {
return Ok(Some(Arc::clone(&cached)));
}
let inner_provider =
ContainerServiceProvider::new(Arc::clone(&self.container), None);
let instance = descriptor.create_instance(&inner_provider)?;
let any_arc: Arc<dyn Any + Send + Sync> = Arc::from(instance);
SINGLETON_SERVICES.insert(descriptor.service_key.clone(), Arc::clone(&any_arc));
Ok(Some(any_arc))
}
Lifetime::Scoped => {
if let Some(storage) = &self.scope_storage {
{
let storage_guard = storage.read().map_err(|_| {
DiError::generic("Failed to acquire scope storage read lock")
})?;
if let Some(cached) = storage_guard.get(&descriptor.service_key) {
return Ok(Some(Arc::clone(cached)));
}
}
let inner_provider = ContainerServiceProvider::new(
Arc::clone(&self.container),
Some(storage.clone()),
);
let instance = descriptor.create_instance(&inner_provider)?;
let any_arc: Arc<dyn Any + Send + Sync> = Arc::from(instance);
{
let mut storage_guard = storage.write().map_err(|_| {
DiError::generic("Failed to acquire scope storage write lock")
})?;
storage_guard.insert(descriptor.service_key.clone(), Arc::clone(&any_arc));
}
Ok(Some(any_arc))
} else {
Err(DiError::Generic {
message: format!(
"Scoped service cannot be resolved without a scope: {key:?}"
),
})
}
}
Lifetime::Transient => {
let inner_provider = ContainerServiceProvider::new(
Arc::clone(&self.container),
self.scope_storage.clone(),
);
let instance = descriptor.create_instance(&inner_provider)?;
let any_arc: Arc<dyn Any + Send + Sync> = Arc::from(instance);
Ok(Some(any_arc))
}
}
}
}
pub struct ServiceScope {
container: Arc<Container>,
storage: ScopeStorage,
disposed: Arc<Mutex<bool>>,
}
impl ServiceScope {
pub fn new(container: Arc<Container>) -> DiResult<Self> {
Ok(Self {
container,
storage: Arc::new(RwLock::new(HashMap::new())),
disposed: Arc::new(Mutex::new(false)),
})
}
fn ensure_not_disposed(&self) -> DiResult<()> {
let disposed = self
.disposed
.lock()
.map_err(|_| DiError::generic("Failed to acquire disposed lock"))?;
if *disposed {
return Err(DiError::ScopeDisposed);
}
Ok(())
}
pub fn get_services<T: 'static + Send + Sync>(&self) -> DiResult<Vec<Arc<T>>> {
self.ensure_not_disposed()?;
let provider = ServiceProvider::new(Arc::clone(&self.container));
let descriptors = provider.get_all_descriptors_for_type::<T>()?;
let mut services = Vec::new();
for descriptor in descriptors {
if let Some(service) =
provider.resolve_service::<T>(&descriptor.service_key, Some(&self.storage))?
{
services.push(service);
}
}
Ok(services)
}
pub fn create_scope(&self) -> DiResult<ServiceScope> {
self.ensure_not_disposed()?;
ServiceScope::new(Arc::clone(&self.container))
}
pub fn dispose(&mut self) {
if let Ok(mut disposed) = self.disposed.lock() {
if !*disposed {
*disposed = true;
if let Ok(mut storage) = self.storage.write() {
storage.clear();
}
}
}
}
pub fn is_disposed(&self) -> bool {
self.disposed
.lock()
.map(|disposed| *disposed)
.unwrap_or(true)
}
}
impl DescriptorServiceProvider for ServiceScope {
fn get_service_raw(&self, key: &ServiceKey) -> DiResult<Option<Arc<dyn Any + Send + Sync>>> {
self.ensure_not_disposed()?;
let inner_provider =
ContainerServiceProvider::new(Arc::clone(&self.container), Some(self.storage.clone()));
inner_provider.get_service_raw(key)
}
}
impl Drop for ServiceScope {
fn drop(&mut self) {
self.dispose();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::descriptor::ServiceProviderExt;
use crate::ServiceDescriptor;
#[derive(Debug, Clone, PartialEq)]
struct TestService {
value: i32,
}
#[derive(Debug, Clone, PartialEq)]
#[allow(dead_code)]
struct DependentService {
dependency: Arc<TestService>,
}
#[test]
fn test_container_creation() {
let container = Container::new();
assert!(!container.is_registered::<TestService>().unwrap());
}
#[test]
fn test_service_registration() {
let container = Container::new();
let descriptor = ServiceDescriptor::transient::<TestService, TestService>(Box::new(|_| {
Ok(Box::new(TestService { value: 42 }))
}));
container.register(descriptor).unwrap();
assert!(container.is_registered::<TestService>().unwrap());
}
#[test]
fn test_singleton_service_resolution() {
let container = Container::new();
let descriptor = ServiceDescriptor::singleton::<TestService, TestService>(Box::new(|_| {
Ok(Box::new(TestService { value: 100 }))
}));
container.register(descriptor).unwrap();
let provider = container.build();
let service1 = provider.get_required_service::<TestService>().unwrap();
let service2 = provider.get_required_service::<TestService>().unwrap();
assert_eq!(service1.value, 100);
assert_eq!(service2.value, 100);
}
#[test]
fn test_transient_service_resolution() {
let container = Container::new();
let descriptor = ServiceDescriptor::transient::<TestService, TestService>(Box::new(|_| {
Ok(Box::new(TestService { value: 200 }))
}));
container.register(descriptor).unwrap();
let provider = container.build();
let service1 = provider.get_required_service::<TestService>().unwrap();
let service2 = provider.get_required_service::<TestService>().unwrap();
assert_eq!(service1.value, 200);
assert_eq!(service2.value, 200);
}
#[test]
fn test_keyed_service_registration_and_resolution() {
let container = Container::new();
let descriptor = ServiceDescriptor::named_singleton::<TestService, TestService>(
"primary",
Box::new(|_| Ok(Box::new(TestService { value: 300 }))),
);
container.register(descriptor).unwrap();
assert!(container
.is_keyed_registered::<TestService>("primary")
.unwrap());
assert!(!container
.is_keyed_registered::<TestService>("secondary")
.unwrap());
let provider = container.build();
let service = provider
.get_required_keyed_service::<TestService>("primary")
.unwrap();
assert_eq!(service.value, 300);
let result = provider.get_keyed_service::<TestService>("nonexistent");
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}
#[test]
fn test_scoped_service_with_scope() {
let container = Container::new();
let descriptor = ServiceDescriptor::scoped::<TestService, TestService>(Box::new(|_| {
Ok(Box::new(TestService { value: 400 }))
}));
container.register(descriptor).unwrap();
let provider = container.build();
let mut scope = provider.create_scope().unwrap();
let service1 = scope.get_required_service::<TestService>().unwrap();
let service2 = scope.get_required_service::<TestService>().unwrap();
assert_eq!(service1.value, 400);
assert_eq!(service2.value, 400);
scope.dispose();
}
#[test]
fn test_service_collection() {
let container = Container::new();
let desc1 = ServiceDescriptor::named_transient::<TestService, TestService>(
"service1",
Box::new(|_| Ok(Box::new(TestService { value: 1 }))),
);
let desc2 = ServiceDescriptor::named_transient::<TestService, TestService>(
"service2",
Box::new(|_| Ok(Box::new(TestService { value: 2 }))),
);
container.register(desc1).unwrap();
container.register(desc2).unwrap();
let provider = container.build();
let services = provider.get_services::<TestService>().unwrap();
assert_eq!(services.len(), 2);
let values: Vec<i32> = services.iter().map(|s| s.value).collect();
assert!(values.contains(&1));
assert!(values.contains(&2));
}
}