use crate::factory::AnyFactory;
use crate::storage::{ServiceStorage, downcast_arc_unchecked};
use crate::{DiError, Injectable, Result};
use std::any::{Any, TypeId};
use std::cell::UnsafeCell;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
#[cfg(feature = "logging")]
use tracing::{debug, trace};
const HOT_CACHE_SLOTS: usize = 4;
struct CacheEntry {
type_hash: u64,
storage_ptr: usize,
service: Arc<dyn Any + Send + Sync>,
}
struct HotCache {
entries: [Option<CacheEntry>; HOT_CACHE_SLOTS],
}
impl HotCache {
const fn new() -> Self {
Self {
entries: [const { None }; HOT_CACHE_SLOTS],
}
}
#[inline(always)]
fn get<T: Send + Sync + 'static>(&self, storage_ptr: usize) -> Option<Arc<T>> {
let type_hash = Self::type_hash::<T>();
let slot = Self::slot_for_hash(type_hash, storage_ptr);
if let Some(entry) = &self.entries[slot] {
if entry.type_hash == type_hash && entry.storage_ptr == storage_ptr {
let arc = entry.service.clone();
return Some(unsafe { downcast_arc_unchecked(arc) });
}
}
None
}
#[inline]
fn insert<T: Injectable>(&mut self, storage_ptr: usize, service: Arc<T>) {
let type_hash = Self::type_hash::<T>();
let slot = Self::slot_for_hash(type_hash, storage_ptr);
self.entries[slot] = Some(CacheEntry {
type_hash,
storage_ptr,
service: service as Arc<dyn Any + Send + Sync>,
});
}
#[inline]
fn clear(&mut self) {
self.entries = [const { None }; HOT_CACHE_SLOTS];
}
#[inline(always)]
fn type_hash<T: 'static>() -> u64 {
let type_id = TypeId::of::<T>();
unsafe { std::mem::transmute_copy(&type_id) }
}
#[inline(always)]
fn slot_for_hash(type_hash: u64, storage_ptr: usize) -> usize {
let mixed = type_hash ^ (storage_ptr as u64).rotate_left(32);
let slot = mixed.wrapping_mul(0x9e3779b97f4a7c15);
(slot as usize) & (HOT_CACHE_SLOTS - 1)
}
}
thread_local! {
static HOT_CACHE: UnsafeCell<HotCache> = const { UnsafeCell::new(HotCache::new()) };
}
#[inline(always)]
fn with_hot_cache<F, R>(f: F) -> R
where
F: FnOnce(&HotCache) -> R,
{
HOT_CACHE.with(|cell| {
let cache = unsafe { &*cell.get() };
f(cache)
})
}
#[inline(always)]
fn with_hot_cache_mut<F, R>(f: F) -> R
where
F: FnOnce(&mut HotCache) -> R,
{
HOT_CACHE.with(|cell| {
let cache = unsafe { &mut *cell.get() };
f(cache)
})
}
#[derive(Clone)]
pub struct Container {
storage: Arc<ServiceStorage>,
parent_storage: Option<Arc<ServiceStorage>>,
locked: Arc<AtomicBool>,
depth: u32,
}
impl Container {
#[inline]
pub fn new() -> Self {
#[cfg(feature = "logging")]
debug!(
target: "dependency_injector",
depth = 0,
"Creating new root DI container"
);
Self {
storage: Arc::new(ServiceStorage::new()),
parent_storage: None,
locked: Arc::new(AtomicBool::new(false)),
depth: 0,
}
}
#[inline]
pub fn with_capacity(capacity: usize) -> Self {
Self {
storage: Arc::new(ServiceStorage::with_capacity(capacity)),
parent_storage: None,
locked: Arc::new(AtomicBool::new(false)),
depth: 0,
}
}
#[inline]
pub fn scope(&self) -> Self {
let child_depth = self.depth + 1;
#[cfg(feature = "logging")]
debug!(
target: "dependency_injector",
parent_depth = self.depth,
child_depth = child_depth,
parent_services = self.storage.len(),
"Creating child scope from parent container"
);
Self {
storage: Arc::new(ServiceStorage::with_parent(Arc::clone(&self.storage))),
parent_storage: Some(Arc::clone(&self.storage)), locked: Arc::new(AtomicBool::new(false)),
depth: child_depth,
}
}
#[inline]
pub fn create_scope(&self) -> Self {
self.scope()
}
#[inline]
pub fn singleton<T: Injectable>(&self, instance: T) {
self.check_not_locked();
let type_id = TypeId::of::<T>();
let type_name = std::any::type_name::<T>();
#[cfg(feature = "logging")]
debug!(
target: "dependency_injector",
service = type_name,
lifetime = "singleton",
depth = self.depth,
service_count = self.storage.len() + 1,
"Registering singleton service"
);
self.storage
.insert(type_id, AnyFactory::singleton(instance));
}
#[inline]
pub fn lazy<T: Injectable, F>(&self, factory: F)
where
F: Fn() -> T + Send + Sync + 'static,
{
self.check_not_locked();
let type_id = TypeId::of::<T>();
let type_name = std::any::type_name::<T>();
#[cfg(feature = "logging")]
debug!(
target: "dependency_injector",
service = type_name,
lifetime = "lazy_singleton",
depth = self.depth,
service_count = self.storage.len() + 1,
"Registering lazy singleton service (will be created on first access)"
);
self.storage.insert(type_id, AnyFactory::lazy(factory));
}
#[inline]
pub fn transient<T: Injectable, F>(&self, factory: F)
where
F: Fn() -> T + Send + Sync + 'static,
{
self.check_not_locked();
let type_id = TypeId::of::<T>();
let type_name = std::any::type_name::<T>();
#[cfg(feature = "logging")]
debug!(
target: "dependency_injector",
service = type_name,
lifetime = "transient",
depth = self.depth,
service_count = self.storage.len() + 1,
"Registering transient service (new instance on every resolve)"
);
self.storage.insert(type_id, AnyFactory::transient(factory));
}
#[inline]
pub fn register_factory<T: Injectable, F>(&self, factory: F)
where
F: Fn() -> T + Send + Sync + 'static,
{
self.lazy(factory);
}
#[inline]
pub fn register<T: Injectable>(&self, instance: T) {
self.singleton(instance);
}
#[inline]
#[allow(clippy::boxed_local)]
pub fn register_boxed<T: Injectable>(&self, instance: Box<T>) {
self.singleton(*instance);
}
#[inline]
pub fn register_by_id(&self, type_id: TypeId, instance: Arc<dyn Any + Send + Sync>) {
self.check_not_locked();
self.storage.insert(
type_id,
AnyFactory::Singleton(crate::factory::SingletonFactory { instance }),
);
}
#[inline]
pub fn get<T: Injectable>(&self) -> Result<Arc<T>> {
let storage_ptr = Arc::as_ptr(&self.storage) as usize;
if let Some(cached) = with_hot_cache(|cache| cache.get::<T>(storage_ptr)) {
#[cfg(feature = "logging")]
trace!(
target: "dependency_injector",
service = std::any::type_name::<T>(),
depth = self.depth,
location = "hot_cache",
"Service resolved from thread-local cache"
);
return Ok(cached);
}
self.get_and_cache::<T>(storage_ptr)
}
#[inline]
fn get_and_cache<T: Injectable>(&self, storage_ptr: usize) -> Result<Arc<T>> {
let type_id = TypeId::of::<T>();
#[cfg(feature = "logging")]
let type_name = std::any::type_name::<T>();
#[cfg(feature = "logging")]
trace!(
target: "dependency_injector",
service = type_name,
depth = self.depth,
"Resolving service (cache miss)"
);
if let Some((service, is_transient)) = self.storage.get_with_transient_flag::<T>() {
#[cfg(feature = "logging")]
trace!(
target: "dependency_injector",
service = type_name,
depth = self.depth,
location = "local",
"Service resolved from current scope"
);
if !is_transient {
with_hot_cache_mut(|cache| cache.insert(storage_ptr, Arc::clone(&service)));
}
return Ok(service);
}
if self.depth == 0 {
#[cfg(feature = "logging")]
debug!(
target: "dependency_injector",
service = std::any::type_name::<T>(),
"Service not found in root container"
);
return Err(DiError::not_found::<T>());
}
self.resolve_from_parents::<T>(&type_id, storage_ptr)
}
#[cold]
fn resolve_from_parents<T: Injectable>(
&self,
type_id: &TypeId,
storage_ptr: usize,
) -> Result<Arc<T>> {
let type_name = std::any::type_name::<T>();
#[cfg(feature = "logging")]
trace!(
target: "dependency_injector",
service = type_name,
depth = self.depth,
"Service not in local scope, walking parent chain"
);
let mut current = self.storage.parent();
let mut ancestor_depth = self.depth.saturating_sub(1);
while let Some(storage) = current {
if let Some(arc) = storage.resolve(type_id) {
let typed: Arc<T> = unsafe { downcast_arc_unchecked(arc) };
#[cfg(feature = "logging")]
trace!(
target: "dependency_injector",
service = type_name,
depth = self.depth,
ancestor_depth = ancestor_depth,
location = "ancestor",
"Service resolved from ancestor scope"
);
if !storage.is_transient(type_id) {
with_hot_cache_mut(|cache| cache.insert(storage_ptr, Arc::clone(&typed)));
}
return Ok(typed);
}
current = storage.parent();
ancestor_depth = ancestor_depth.saturating_sub(1);
}
#[cfg(feature = "logging")]
debug!(
target: "dependency_injector",
service = type_name,
depth = self.depth,
"Service not found in container or parent chain"
);
Err(DiError::not_found::<T>())
}
#[inline]
pub fn clear_cache(&self) {
with_hot_cache_mut(|cache| cache.clear());
}
#[inline]
pub fn warm_cache<T: Injectable>(&self) {
let _ = self.get::<T>();
}
#[inline]
pub fn resolve<T: Injectable>(&self) -> Result<Arc<T>> {
self.get::<T>()
}
#[inline]
pub fn try_get<T: Injectable>(&self) -> Option<Arc<T>> {
self.get::<T>().ok()
}
#[inline]
pub fn try_resolve<T: Injectable>(&self) -> Option<Arc<T>> {
self.try_get::<T>()
}
#[inline]
pub fn contains<T: Injectable>(&self) -> bool {
let type_id = TypeId::of::<T>();
self.contains_type_id(&type_id)
}
#[inline]
pub fn has<T: Injectable>(&self) -> bool {
self.contains::<T>()
}
fn contains_type_id(&self, type_id: &TypeId) -> bool {
self.storage.contains_in_chain(type_id)
}
#[inline]
pub fn len(&self) -> usize {
self.storage.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.storage.is_empty()
}
pub fn registered_types(&self) -> Vec<TypeId> {
self.storage.type_ids()
}
#[inline]
pub fn depth(&self) -> u32 {
self.depth
}
#[inline]
pub fn lock(&self) {
self.locked.store(true, Ordering::Release);
#[cfg(feature = "logging")]
debug!(
target: "dependency_injector",
depth = self.depth,
service_count = self.storage.len(),
"Container locked - no further registrations allowed"
);
}
#[inline]
pub fn is_locked(&self) -> bool {
self.locked.load(Ordering::Acquire)
}
#[cfg(feature = "perfect-hash")]
#[inline]
pub fn freeze(&self) -> crate::storage::FrozenStorage {
self.lock();
crate::storage::FrozenStorage::from_storage(&self.storage)
}
#[inline]
pub fn clear(&self) {
let count = self.storage.len();
self.storage.clear();
#[cfg(feature = "logging")]
debug!(
target: "dependency_injector",
depth = self.depth,
services_removed = count,
"Container cleared - all services removed from this scope"
);
}
#[inline]
fn check_not_locked(&self) {
if self.locked.load(Ordering::Relaxed) {
panic!("Cannot register services: container is locked");
}
}
#[inline]
pub fn batch<F>(&self, f: F)
where
F: FnOnce(BatchRegistrar<'_>),
{
self.check_not_locked();
#[cfg(feature = "logging")]
let start_count = self.storage.len();
f(BatchRegistrar {
storage: &self.storage,
});
#[cfg(feature = "logging")]
{
let end_count = self.storage.len();
debug!(
target: "dependency_injector",
depth = self.depth,
services_registered = end_count - start_count,
"Batch registration completed"
);
}
}
#[inline]
pub fn register_batch(&self) -> BatchBuilder<'_> {
self.check_not_locked();
BatchBuilder {
storage: &self.storage,
#[cfg(feature = "logging")]
count: 0,
}
}
}
pub struct BatchBuilder<'a> {
storage: &'a ServiceStorage,
#[cfg(feature = "logging")]
count: usize,
}
impl<'a> BatchBuilder<'a> {
#[inline]
pub fn singleton<T: Injectable>(self, instance: T) -> Self {
self.storage
.insert(TypeId::of::<T>(), AnyFactory::singleton(instance));
Self {
storage: self.storage,
#[cfg(feature = "logging")]
count: self.count + 1,
}
}
#[inline]
pub fn lazy<T: Injectable, F>(self, factory: F) -> Self
where
F: Fn() -> T + Send + Sync + 'static,
{
self.storage
.insert(TypeId::of::<T>(), AnyFactory::lazy(factory));
Self {
storage: self.storage,
#[cfg(feature = "logging")]
count: self.count + 1,
}
}
#[inline]
pub fn transient<T: Injectable, F>(self, factory: F) -> Self
where
F: Fn() -> T + Send + Sync + 'static,
{
self.storage
.insert(TypeId::of::<T>(), AnyFactory::transient(factory));
Self {
storage: self.storage,
#[cfg(feature = "logging")]
count: self.count + 1,
}
}
#[inline]
pub fn done(self) {
#[cfg(feature = "logging")]
debug!(
target: "dependency_injector",
services_registered = self.count,
"Batch registration completed"
);
}
}
#[repr(transparent)]
pub struct BatchRegistrar<'a> {
storage: &'a ServiceStorage,
}
impl<'a> BatchRegistrar<'a> {
#[inline]
pub fn singleton<T: Injectable>(&self, instance: T) {
self.storage
.insert(TypeId::of::<T>(), AnyFactory::singleton(instance));
}
#[inline]
pub fn lazy<T: Injectable, F>(&self, factory: F)
where
F: Fn() -> T + Send + Sync + 'static,
{
self.storage
.insert(TypeId::of::<T>(), AnyFactory::lazy(factory));
}
#[inline]
pub fn transient<T: Injectable, F>(&self, factory: F)
where
F: Fn() -> T + Send + Sync + 'static,
{
self.storage
.insert(TypeId::of::<T>(), AnyFactory::transient(factory));
}
}
use std::sync::Mutex;
pub struct ScopePool {
parent_storage: Arc<ServiceStorage>,
available: Mutex<Vec<ScopeSlot>>,
parent_depth: u32,
}
struct ScopeSlot {
storage: Arc<ServiceStorage>,
locked: Arc<AtomicBool>,
}
impl ScopePool {
pub fn new(parent: &Container, capacity: usize) -> Self {
let mut available = Vec::with_capacity(capacity);
for _ in 0..capacity {
available.push(ScopeSlot {
storage: Arc::new(ServiceStorage::with_parent(Arc::clone(&parent.storage))),
locked: Arc::new(AtomicBool::new(false)),
});
}
#[cfg(feature = "logging")]
debug!(
target: "dependency_injector",
capacity = capacity,
parent_depth = parent.depth,
"Created scope pool with pre-allocated scopes"
);
Self {
parent_storage: Arc::clone(&parent.storage),
available: Mutex::new(available),
parent_depth: parent.depth,
}
}
#[inline]
pub fn acquire(&self) -> PooledScope<'_> {
let slot = self.available.lock().unwrap().pop();
let (storage, locked) = match slot {
Some(slot) => {
#[cfg(feature = "logging")]
trace!(
target: "dependency_injector",
"Acquired scope from pool (reusing storage)"
);
(slot.storage, slot.locked)
}
None => {
#[cfg(feature = "logging")]
trace!(
target: "dependency_injector",
"Pool empty, creating new scope"
);
(
Arc::new(ServiceStorage::with_parent(Arc::clone(
&self.parent_storage,
))),
Arc::new(AtomicBool::new(false)),
)
}
};
let container = Container {
storage,
parent_storage: Some(Arc::clone(&self.parent_storage)),
locked,
depth: self.parent_depth + 1,
};
PooledScope {
container: Some(container),
pool: self,
}
}
#[inline]
fn release(&self, container: Container) {
container.storage.clear();
container.locked.store(false, Ordering::Relaxed);
self.available.lock().unwrap().push(ScopeSlot {
storage: container.storage,
locked: container.locked,
});
#[cfg(feature = "logging")]
trace!(
target: "dependency_injector",
"Released scope back to pool"
);
}
#[inline]
pub fn available_count(&self) -> usize {
self.available.lock().unwrap().len()
}
}
pub struct PooledScope<'a> {
container: Option<Container>,
pool: &'a ScopePool,
}
impl PooledScope<'_> {
#[inline]
pub fn container(&self) -> &Container {
self.container.as_ref().unwrap()
}
}
impl std::ops::Deref for PooledScope<'_> {
type Target = Container;
#[inline]
fn deref(&self) -> &Self::Target {
self.container.as_ref().unwrap()
}
}
impl Drop for PooledScope<'_> {
fn drop(&mut self) {
if let Some(container) = self.container.take() {
self.pool.release(container);
}
}
}
impl Default for Container {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for Container {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Container")
.field("service_count", &self.len())
.field("depth", &self.depth)
.field("has_parent", &self.parent_storage.is_some())
.field("locked", &self.is_locked())
.finish_non_exhaustive()
}
}
unsafe impl Send for Container {}
unsafe impl Sync for Container {}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Clone)]
struct TestService {
value: String,
}
#[allow(dead_code)]
#[derive(Clone)]
struct AnotherService {
name: String,
}
#[test]
fn test_singleton() {
let container = Container::new();
container.singleton(TestService {
value: "test".into(),
});
let s1 = container.get::<TestService>().unwrap();
let s2 = container.get::<TestService>().unwrap();
assert_eq!(s1.value, "test");
assert!(Arc::ptr_eq(&s1, &s2));
}
#[test]
fn test_lazy() {
use std::sync::atomic::{AtomicBool, Ordering};
static CREATED: AtomicBool = AtomicBool::new(false);
let container = Container::new();
container.lazy(|| {
CREATED.store(true, Ordering::SeqCst);
TestService {
value: "lazy".into(),
}
});
assert!(!CREATED.load(Ordering::SeqCst));
let s = container.get::<TestService>().unwrap();
assert!(CREATED.load(Ordering::SeqCst));
assert_eq!(s.value, "lazy");
}
#[test]
fn test_transient() {
use std::sync::atomic::{AtomicU32, Ordering};
static COUNTER: AtomicU32 = AtomicU32::new(0);
#[derive(Clone)]
struct Counter(u32);
let container = Container::new();
container.transient(|| Counter(COUNTER.fetch_add(1, Ordering::SeqCst)));
let c1 = container.get::<Counter>().unwrap();
let c2 = container.get::<Counter>().unwrap();
assert_ne!(c1.0, c2.0);
}
#[test]
fn test_scope_inheritance() {
let root = Container::new();
root.singleton(TestService {
value: "root".into(),
});
let child = root.scope();
child.singleton(AnotherService {
name: "child".into(),
});
assert!(child.contains::<TestService>());
assert!(child.contains::<AnotherService>());
assert!(root.contains::<TestService>());
assert!(!root.contains::<AnotherService>());
}
#[test]
fn test_scope_override() {
let root = Container::new();
root.singleton(TestService {
value: "root".into(),
});
let child = root.scope();
child.singleton(TestService {
value: "child".into(),
});
let root_service = root.get::<TestService>().unwrap();
let child_service = child.get::<TestService>().unwrap();
assert_eq!(root_service.value, "root");
assert_eq!(child_service.value, "child");
}
#[test]
fn test_not_found() {
let container = Container::new();
let result = container.get::<TestService>();
assert!(result.is_err());
}
#[test]
fn test_lock() {
let container = Container::new();
assert!(!container.is_locked());
container.lock();
assert!(container.is_locked());
}
#[test]
#[should_panic(expected = "Cannot register services: container is locked")]
fn test_register_after_lock() {
let container = Container::new();
container.lock();
container.singleton(TestService {
value: "fail".into(),
});
}
#[test]
fn test_batch_registration() {
#[derive(Clone)]
struct ServiceA(i32);
#[allow(dead_code)]
#[derive(Clone)]
struct ServiceB(String);
let container = Container::new();
container.batch(|batch| {
batch.singleton(ServiceA(42));
batch.singleton(ServiceB("test".into()));
batch.lazy(|| TestService {
value: "lazy".into(),
});
});
assert!(container.contains::<ServiceA>());
assert!(container.contains::<ServiceB>());
assert!(container.contains::<TestService>());
let a = container.get::<ServiceA>().unwrap();
assert_eq!(a.0, 42);
}
#[test]
fn test_scope_pool_basic() {
#[derive(Clone)]
struct RequestId(u64);
let root = Container::new();
root.singleton(TestService {
value: "root".into(),
});
let pool = ScopePool::new(&root, 2);
assert_eq!(pool.available_count(), 2);
{
let scope = pool.acquire();
assert_eq!(pool.available_count(), 1);
assert!(scope.contains::<TestService>());
scope.singleton(RequestId(123));
assert!(scope.contains::<RequestId>());
let id = scope.get::<RequestId>().unwrap();
assert_eq!(id.0, 123);
}
assert_eq!(pool.available_count(), 2);
}
#[test]
fn test_scope_pool_reuse() {
#[derive(Clone)]
struct RequestId(u64);
let root = Container::new();
let pool = ScopePool::new(&root, 1);
{
let scope = pool.acquire();
scope.singleton(RequestId(1));
assert!(scope.contains::<RequestId>());
}
{
let scope = pool.acquire();
assert!(!scope.contains::<RequestId>());
scope.singleton(RequestId(2));
let id = scope.get::<RequestId>().unwrap();
assert_eq!(id.0, 2);
}
}
#[test]
fn test_scope_pool_expansion() {
let root = Container::new();
let pool = ScopePool::new(&root, 1);
let _s1 = pool.acquire();
let _s2 = pool.acquire();
assert_eq!(pool.available_count(), 0);
drop(_s1);
drop(_s2);
assert_eq!(pool.available_count(), 2);
}
#[test]
fn test_deep_parent_chain() {
#[derive(Clone)]
struct RootService(i32);
#[derive(Clone)]
struct MiddleService(i32);
#[derive(Clone)]
struct LeafService(i32);
let root = Container::new();
root.singleton(RootService(1));
let middle1 = root.scope();
middle1.singleton(MiddleService(2));
let middle2 = middle1.scope();
let leaf = middle2.scope();
leaf.singleton(LeafService(4));
assert!(
leaf.contains::<RootService>(),
"Should find root service in leaf"
);
assert!(
leaf.contains::<MiddleService>(),
"Should find middle service in leaf"
);
assert!(
leaf.contains::<LeafService>(),
"Should find leaf service in leaf"
);
let root_svc = leaf.get::<RootService>().unwrap();
assert_eq!(root_svc.0, 1);
let middle_svc = leaf.get::<MiddleService>().unwrap();
assert_eq!(middle_svc.0, 2);
let leaf_svc = leaf.get::<LeafService>().unwrap();
assert_eq!(leaf_svc.0, 4);
assert!(middle2.contains::<RootService>());
assert!(middle2.contains::<MiddleService>());
assert!(!middle2.contains::<LeafService>()); }
}