Skip to main content

rok_container/
container.rs

1use std::any::{Any, TypeId};
2use std::collections::HashMap;
3use std::sync::{Arc, RwLock};
4
5use crate::error::ContainerError;
6
7type AnyArc = Arc<dyn Any + Send + Sync>;
8
9enum Binding {
10    Singleton(AnyArc),
11    Factory(Box<dyn Fn() -> AnyArc + Send + Sync>),
12}
13
14/// Type-map IoC service container.
15///
16/// Intended to be shared as `Arc<Container>` and mounted as an Axum `Extension`.
17///
18/// # Example
19///
20/// ```rust,ignore
21/// use std::sync::Arc;
22/// use rok_container::Container;
23///
24/// let container = Arc::new(Container::new());
25///
26/// // singleton — same Arc<T> returned every call
27/// container.singleton(MyService::new());
28///
29/// // factory — fresh instance on every make()
30/// container.bind::<Config>(|| Config::from_env());
31///
32/// // resolve
33/// let svc: Arc<MyService> = container.make::<MyService>().unwrap();
34/// ```
35pub struct Container {
36    bindings: RwLock<HashMap<TypeId, Binding>>,
37}
38
39impl Default for Container {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45impl Container {
46    /// Create an empty container.
47    pub fn new() -> Self {
48        Self {
49            bindings: RwLock::new(HashMap::new()),
50        }
51    }
52
53    /// Register a factory for `T`.
54    ///
55    /// Every call to [`make::<T>`](Container::make) invokes `factory` and
56    /// returns a new `Arc<T>`.
57    pub fn bind<T, F>(&self, factory: F)
58    where
59        T: Any + Send + Sync + 'static,
60        F: Fn() -> T + Send + Sync + 'static,
61    {
62        let mut map = self.bindings.write().expect("container lock poisoned");
63        map.insert(
64            TypeId::of::<T>(),
65            Binding::Factory(Box::new(move || Arc::new(factory()))),
66        );
67    }
68
69    /// Register a singleton instance for `T`.
70    ///
71    /// Every call to [`make::<T>`](Container::make) clones the same `Arc<T>`.
72    pub fn singleton<T>(&self, instance: T)
73    where
74        T: Any + Send + Sync + 'static,
75    {
76        let mut map = self.bindings.write().expect("container lock poisoned");
77        map.insert(TypeId::of::<T>(), Binding::Singleton(Arc::new(instance)));
78    }
79
80    /// Resolve `T`, returning `Arc<T>`.
81    ///
82    /// # Errors
83    ///
84    /// Returns [`ContainerError::NotRegistered`] if `T` was never bound.
85    pub fn make<T>(&self) -> Result<Arc<T>, ContainerError>
86    where
87        T: Any + Send + Sync + 'static,
88    {
89        let map = self.bindings.read().expect("container lock poisoned");
90        match map.get(&TypeId::of::<T>()) {
91            Some(Binding::Singleton(arc)) => arc
92                .clone()
93                .downcast::<T>()
94                .map_err(|_| ContainerError::TypeMismatch(std::any::type_name::<T>())),
95            Some(Binding::Factory(f)) => f()
96                .downcast::<T>()
97                .map_err(|_| ContainerError::TypeMismatch(std::any::type_name::<T>())),
98            None => Err(ContainerError::NotRegistered(std::any::type_name::<T>())),
99        }
100    }
101
102    /// Decorate an existing singleton by transforming it.
103    ///
104    /// Useful for wrapping a service with a decorator or logging layer.
105    ///
106    /// # Errors
107    ///
108    /// Returns an error if `T` is not registered.
109    pub fn extend<T, F>(&self, extender: F) -> Result<(), ContainerError>
110    where
111        T: Any + Send + Sync + 'static,
112        F: FnOnce(Arc<T>) -> T,
113    {
114        let existing = self.make::<T>()?;
115        let new_instance = extender(existing);
116        self.singleton(new_instance);
117        Ok(())
118    }
119
120    /// Replace an existing binding with a new singleton.
121    ///
122    /// Useful in tests to swap a real service for a test double.
123    pub fn swap<T>(&self, instance: T)
124    where
125        T: Any + Send + Sync + 'static,
126    {
127        self.singleton(instance);
128    }
129}