Skip to main content

rust_api/
di.rs

1//! Dependency Injection Container (optional utility)
2//!
3//! **Primary DI in RustAPI:** pass `Arc<Service>` directly to
4//! `RouterPipeline::mount::<Controller>(Arc::new(MyService::new()))`.
5//! No container is needed for the standard use case.
6//!
7//! This module provides an optional [`Container`] — a type-safe service
8//! registry useful for larger applications with dynamic or plugin-driven
9//! service graphs, where you want late-binding resolution by type.
10//!
11//! # When to Use
12//!
13//! - You have a plugin system that registers services dynamically
14//! - You want to swap implementations at runtime (e.g. test doubles via trait
15//!   objects)
16//! - You have many services and prefer centralized registration over explicit
17//!   wiring
18//!
19//! # When NOT to Use
20//!
21//! For most APIs: construct `Arc<MyService>` directly in `main` and pass it to
22//! `pipeline.mount()`. This is explicit, compile-time-checked, and requires no
23//! type-map machinery.
24//!
25//! # Immutability Discipline
26//!
27//! Services registered in the container are stored as `Arc<T>` and shared
28//! read-only across the application. Service types should be immutable after
29//! construction — model state changes via `Atomic*` primitives or channels,
30//! not `Mutex<T>` fields. The framework enforces this by only providing
31//! shared (`Arc`) references, never exclusive (`Arc<Mutex<T>>`) ones.
32//!
33//! # Async Initialization
34//!
35//! Use [`Container::register_async_factory`] for services that require async
36//! initialization (database connections, config loading, etc.). The async
37//! effect is contained at container-construction time; all subsequent
38//! framework operations remain synchronous.
39//!
40//! Alternatively, without a container:
41//! ```ignore
42//! let svc = Arc::new(MyService::connect("postgres://...").await?);
43//! RouterPipeline::new().mount::<MyController>(svc).build()?
44//! ```
45
46use std::{
47    any::{Any, TypeId},
48    collections::HashMap,
49    future::Future,
50    sync::Arc,
51};
52
53use crate::error::Result;
54
55/// Trait that all injectable services must implement
56pub trait Injectable: Send + Sync + 'static {}
57
58/// Type-erased service storage using Any
59type ServiceBox = Arc<dyn Any + Send + Sync>;
60
61/// Dependency injection container
62///
63/// Stores services as Arc-wrapped values and provides type-safe retrieval.
64/// Services are singletons - only one instance exists per type.
65///
66/// # Example
67///
68/// ```ignore
69/// let mut container = Container::new();
70/// container.register(Arc::new(DatabaseService::new()));
71///
72/// let db: Arc<DatabaseService> = container.resolve().unwrap();
73/// ```
74#[derive(Clone, Default)]
75pub struct Container {
76    services: HashMap<TypeId, ServiceBox>,
77}
78
79impl Container {
80    /// Create a new empty container
81    pub fn new() -> Self {
82        Self {
83            services: HashMap::new(),
84        }
85    }
86
87    /// Register a service in the container
88    ///
89    /// The service must be wrapped in an Arc. If a service of this type
90    /// already exists, it will be replaced.
91    ///
92    /// # Example
93    ///
94    /// ```ignore
95    /// container.register(Arc::new(MyService::new()));
96    /// ```
97    pub fn register<T: Injectable>(&mut self, service: Arc<T>) {
98        let type_id = self.get_type_id::<T>();
99        self.insert_service(type_id, service);
100    }
101
102    // get the TypeId for a given type T
103    fn get_type_id<T: Injectable>(&self) -> TypeId {
104        TypeId::of::<T>()
105    }
106
107    // insert a service into the storage map
108    fn insert_service<T: Injectable>(&mut self, type_id: TypeId, service: Arc<T>) {
109        self.services.insert(type_id, service as ServiceBox);
110    }
111
112    /// Register a service from a constructor function
113    ///
114    /// This is a convenience method that creates the Arc for you.
115    ///
116    /// # Example
117    ///
118    /// ```ignore
119    /// container.register_factory(|| MyService::new());
120    /// ```
121    pub fn register_factory<T: Injectable, F>(&mut self, factory: F)
122    where
123        F: FnOnce() -> T,
124    {
125        let service = self.create_service(factory);
126        self.register(service);
127    }
128
129    /// Register a service from an **async** constructor function.
130    ///
131    /// Use this for services that require async initialization — database
132    /// connections, config loading from remote sources, etc. The async effect
133    /// is contained here; once registered, the service is a plain `Arc<T>`
134    /// and all subsequent framework operations remain synchronous.
135    ///
136    /// Returns `Err` if the factory future resolves to an error.
137    ///
138    /// # Example
139    ///
140    /// ```ignore
141    /// container
142    ///     .register_async_factory(|| async {
143    ///         let db = Database::connect("postgres://localhost/mydb").await?;
144    ///         Ok(DbService::new(db))
145    ///     })
146    ///     .await?;
147    /// ```
148    pub async fn register_async_factory<T, F, Fut>(&mut self, factory: F) -> Result<()>
149    where
150        T: Injectable,
151        F: FnOnce() -> Fut,
152        Fut: Future<Output = Result<T>>,
153    {
154        let service = factory().await?;
155        self.register(Arc::new(service));
156        Ok(())
157    }
158
159    // create a service instance from a factory function
160    fn create_service<T: Injectable, F>(&self, factory: F) -> Arc<T>
161    where
162        F: FnOnce() -> T,
163    {
164        Arc::new(factory())
165    }
166
167    /// Resolve a service from the container
168    ///
169    /// Returns None if the service hasn't been registered.
170    ///
171    /// # Example
172    ///
173    /// ```ignore
174    /// let service: Arc<MyService> = container.resolve().unwrap();
175    /// ```
176    pub fn resolve<T: Injectable>(&self) -> Option<Arc<T>> {
177        let type_id = self.get_type_id::<T>();
178        self.lookup_service(type_id)
179    }
180
181    // lookup a service by TypeId and downcast it
182    fn lookup_service<T: Injectable>(&self, type_id: TypeId) -> Option<Arc<T>> {
183        self.services
184            .get(&type_id)
185            .and_then(|boxed| self.downcast_service(boxed))
186    }
187
188    // downcast a type-erased service to the concrete type
189    fn downcast_service<T: Injectable>(&self, boxed: &ServiceBox) -> Option<Arc<T>> {
190        boxed.clone().downcast::<T>().ok()
191    }
192
193    /// Resolve a service or panic if not found
194    ///
195    /// # Panics
196    ///
197    /// Panics if the service hasn't been registered.
198    pub fn resolve_or_panic<T: Injectable>(&self) -> Arc<T> {
199        self.resolve()
200            .unwrap_or_else(|| panic!("Service {} not registered", std::any::type_name::<T>()))
201    }
202
203    /// Check if a service is registered
204    pub fn contains<T: Injectable>(&self) -> bool {
205        let type_id = TypeId::of::<T>();
206        self.services.contains_key(&type_id)
207    }
208
209    /// Get the number of registered services
210    pub fn len(&self) -> usize {
211        self.services.len()
212    }
213
214    /// Check if the container is empty
215    pub fn is_empty(&self) -> bool {
216        self.services.is_empty()
217    }
218
219    /// Clear all services from the container
220    pub fn clear(&mut self) {
221        self.services.clear();
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228
229    struct MockDatabase {
230        connection_string: String,
231    }
232
233    impl Injectable for MockDatabase {}
234
235    impl MockDatabase {
236        fn new(conn: &str) -> Self {
237            Self {
238                connection_string: conn.to_string(),
239            }
240        }
241    }
242
243    struct MockUserService {
244        db: Arc<MockDatabase>,
245    }
246
247    impl Injectable for MockUserService {}
248
249    impl MockUserService {
250        fn new(db: Arc<MockDatabase>) -> Self {
251            Self { db }
252        }
253    }
254
255    #[test]
256    fn test_register_and_resolve() {
257        let mut container = Container::new();
258        let db = Arc::new(MockDatabase::new("postgres://localhost"));
259
260        container.register(db.clone());
261
262        let resolved: Arc<MockDatabase> = container.resolve().unwrap();
263        assert_eq!(resolved.connection_string, "postgres://localhost");
264    }
265
266    #[test]
267    fn test_register_factory() {
268        let mut container = Container::new();
269
270        container.register_factory(|| MockDatabase::new("sqlite::memory"));
271
272        let resolved: Arc<MockDatabase> = container.resolve().unwrap();
273        assert_eq!(resolved.connection_string, "sqlite::memory");
274    }
275
276    #[test]
277    fn test_resolve_missing_service() {
278        let container = Container::new();
279        let result: Option<Arc<MockDatabase>> = container.resolve();
280        assert!(result.is_none());
281    }
282
283    #[test]
284    #[should_panic(expected = "Service")]
285    fn test_resolve_or_panic() {
286        let container = Container::new();
287        let _: Arc<MockDatabase> = container.resolve_or_panic();
288    }
289
290    #[test]
291    fn test_dependency_chain() {
292        let mut container = Container::new();
293
294        // Register database first
295        let db = Arc::new(MockDatabase::new("postgres://localhost"));
296        container.register(db.clone());
297
298        // Then register service that depends on it
299        let user_service = Arc::new(MockUserService::new(db));
300        container.register(user_service);
301
302        // Resolve both
303        let resolved_db: Arc<MockDatabase> = container.resolve().unwrap();
304        let resolved_service: Arc<MockUserService> = container.resolve().unwrap();
305
306        assert_eq!(resolved_db.connection_string, "postgres://localhost");
307        assert_eq!(
308            resolved_service.db.connection_string,
309            "postgres://localhost"
310        );
311    }
312
313    #[test]
314    fn test_contains() {
315        let mut container = Container::new();
316        assert!(!container.contains::<MockDatabase>());
317
318        container.register_factory(|| MockDatabase::new("test"));
319        assert!(container.contains::<MockDatabase>());
320    }
321
322    #[test]
323    fn test_len_and_clear() {
324        let mut container = Container::new();
325        assert_eq!(container.len(), 0);
326        assert!(container.is_empty());
327
328        container.register_factory(|| MockDatabase::new("test"));
329        assert_eq!(container.len(), 1);
330        assert!(!container.is_empty());
331
332        container.clear();
333        assert_eq!(container.len(), 0);
334        assert!(container.is_empty());
335    }
336}