Skip to main content

nestforge_core/
container.rs

1use std::{
2    any::{Any, TypeId},
3    collections::{HashMap, HashSet},
4    sync::{Arc, RwLock},
5};
6
7use thiserror::Error;
8
9/// The Dependency Injection (DI) Container.
10///
11/// This is the core registry for all providers, services, and configuration in a NestForge application.
12/// It mimics the behavior of the NestJS container but is adapted for Rust's ownership and thread-safety models.
13///
14/// ### Core Features
15/// - **Singleton Registry:** By default, registered services are singletons (Arc<T>).
16/// - **Thread Safety:** Uses `RwLock` to allow concurrent reads (resolving) and exclusive writes (registering).
17/// - **Type-Based Resolution:** Services are stored and retrieved by their `TypeId`.
18/// - **Scoped & Transient:** Supports request-scoped and transient factories for more complex lifecycles.
19#[derive(Clone, Default)]
20pub struct Container {
21    /*
22    We use Arc<RwLock<...>> because the container itself is cloned and shared across every request.
23    The inner HashMap holds the singleton instances as `Arc<dyn Any>`.
24    */
25    inner: Arc<RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
26
27    /*
28    Overrides are checked before the main registry. This is primarily used for testing
29    or for request-scoped sub-containers that need to shadow parent providers.
30    */
31    overrides: Arc<RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
32
33    request_factories: Arc<RwLock<HashMap<TypeId, Arc<RequestFactoryFn>>>>,
34    transient_factories: Arc<RwLock<HashMap<TypeId, Arc<TransientFactoryFn>>>>,
35
36    /*
37    We keep a set of registered type names mainly for debugging and error reporting.
38    It helps us tell the user *which* provider is missing by name.
39    */
40    names: Arc<RwLock<HashSet<&'static str>>>,
41}
42
43type RequestFactoryValue = Arc<dyn Any + Send + Sync>;
44type RequestFactoryFn =
45    dyn Fn(&Container) -> anyhow::Result<RequestFactoryValue> + Send + Sync + 'static;
46type TransientFactoryFn =
47    dyn Fn(&Container) -> anyhow::Result<RequestFactoryValue> + Send + Sync + 'static;
48
49#[derive(Debug, Error)]
50pub enum ContainerError {
51    #[error("Container write lock poisoned")]
52    WriteLockPoisoned,
53    #[error("Container read lock poisoned")]
54    ReadLockPoisoned,
55    #[error("Type already registered: {type_name}")]
56    TypeAlreadyRegistered { type_name: &'static str },
57    #[error("Type not registered: {type_name}")]
58    TypeNotRegistered { type_name: &'static str },
59    #[error("Failed to downcast resolved value: {type_name}")]
60    DowncastFailed { type_name: &'static str },
61    #[error("Request-scoped factory failed for {type_name}: {message}")]
62    RequestFactoryFailed {
63        type_name: &'static str,
64        message: String,
65    },
66    #[error("Type not registered: {type_name} (required by module `{module_name}`)")]
67    TypeNotRegisteredInModule {
68        type_name: &'static str,
69        module_name: &'static str,
70    },
71}
72
73impl Container {
74    /// Creates a new, empty container.
75    ///
76    /// This is equivalent to `Container::default()`.
77    pub fn new() -> Self {
78        Self::default()
79    }
80
81    /// Creates a "scoped" child container.
82    ///
83    /// A scoped container shares the underlying singleton registry (`inner`) with its parent
84    /// but has its own empty `overrides` map.
85    ///
86    /// This is used during HTTP requests to create a context where request-scoped providers
87    /// can be cached for the duration of that single request without affecting the global state.
88    pub fn scoped(&self) -> Self {
89        Self {
90            inner: Arc::clone(&self.inner),
91            overrides: Arc::new(RwLock::new(HashMap::new())),
92            request_factories: Arc::clone(&self.request_factories),
93            transient_factories: Arc::clone(&self.transient_factories),
94            names: Arc::clone(&self.names),
95        }
96    }
97
98    /// Registers a value (singleton) into the container.
99    ///
100    /// The value must be thread-safe (`Send + Sync`) and `'static`.
101    ///
102    /// # Example
103    /// ```rust
104    /// container.register(AppConfig::default())?;
105    /// ```
106    pub fn register<T>(&self, value: T) -> Result<(), ContainerError>
107    where
108        T: Send + Sync + 'static,
109    {
110        let mut map = self
111            .inner
112            .write()
113            .map_err(|_| ContainerError::WriteLockPoisoned)?;
114
115        let type_id = TypeId::of::<T>();
116
117        if map.contains_key(&type_id) {
118            return Err(ContainerError::TypeAlreadyRegistered {
119                type_name: std::any::type_name::<T>(),
120            });
121        }
122
123        map.insert(type_id, Arc::new(value));
124        self.names
125            .write()
126            .map_err(|_| ContainerError::WriteLockPoisoned)?
127            .insert(std::any::type_name::<T>());
128        Ok(())
129    }
130
131    /// Replaces an existing registration with a new value.
132    ///
133    /// Unlike `register`, this will not error if the type is already present.
134    /// It effectively updates the singleton instance.
135    pub fn replace<T>(&self, value: T) -> Result<(), ContainerError>
136    where
137        T: Send + Sync + 'static,
138    {
139        let mut map = self
140            .inner
141            .write()
142            .map_err(|_| ContainerError::WriteLockPoisoned)?;
143
144        map.insert(TypeId::of::<T>(), Arc::new(value));
145        self.names
146            .write()
147            .map_err(|_| ContainerError::WriteLockPoisoned)?
148            .insert(std::any::type_name::<T>());
149        Ok(())
150    }
151
152    /// Overrides a value in the current scope.
153    ///
154    /// If called on a global container, it works like `replace` but stores the value
155    /// in the `overrides` map, which takes precedence over `inner`.
156    ///
157    /// If called on a `scoped()` container, the override only exists for that scope.
158    pub fn override_value<T>(&self, value: T) -> Result<(), ContainerError>
159    where
160        T: Send + Sync + 'static,
161    {
162        let mut overrides = self
163            .overrides
164            .write()
165            .map_err(|_| ContainerError::WriteLockPoisoned)?;
166
167        overrides.insert(TypeId::of::<T>(), Arc::new(value));
168        self.names
169            .write()
170            .map_err(|_| ContainerError::WriteLockPoisoned)?
171            .insert(std::any::type_name::<T>());
172        Ok(())
173    }
174
175    /// Checks if a type with the given name is registered.
176    ///
177    /// This relies on `std::any::type_name` matching what was stored.
178    pub fn is_type_registered_name(&self, type_name: &'static str) -> Result<bool, ContainerError> {
179        let names = self
180            .names
181            .read()
182            .map_err(|_| ContainerError::ReadLockPoisoned)?;
183        Ok(names.contains(type_name))
184    }
185
186    /// Registers a factory for request-scoped providers.
187    ///
188    /// A request-scoped provider is created once per `scoped()` container (i.e., once per request)
189    /// and then cached within that scope.
190    pub fn register_request_factory<T, F>(&self, factory: F) -> Result<(), ContainerError>
191    where
192        T: Send + Sync + 'static,
193        F: Fn(&Container) -> anyhow::Result<T> + Send + Sync + 'static,
194    {
195        let type_id = TypeId::of::<T>();
196        let mut factories = self
197            .request_factories
198            .write()
199            .map_err(|_| ContainerError::WriteLockPoisoned)?;
200
201        if factories.contains_key(&type_id) {
202            return Err(ContainerError::TypeAlreadyRegistered {
203                type_name: std::any::type_name::<T>(),
204            });
205        }
206
207        factories.insert(
208            type_id,
209            Arc::new(move |container| Ok(Arc::new(factory(container)?) as RequestFactoryValue)),
210        );
211        self.names
212            .write()
213            .map_err(|_| ContainerError::WriteLockPoisoned)?
214            .insert(std::any::type_name::<T>());
215        Ok(())
216    }
217
218    /// Registers a factory for transient providers.
219    ///
220    /// A transient provider is created anew every single time it is resolved.
221    /// It is never cached.
222    pub fn register_transient_factory<T, F>(&self, factory: F) -> Result<(), ContainerError>
223    where
224        T: Send + Sync + 'static,
225        F: Fn(&Container) -> anyhow::Result<T> + Send + Sync + 'static,
226    {
227        let type_id = TypeId::of::<T>();
228        let mut factories = self
229            .transient_factories
230            .write()
231            .map_err(|_| ContainerError::WriteLockPoisoned)?;
232
233        if factories.contains_key(&type_id) {
234            return Err(ContainerError::TypeAlreadyRegistered {
235                type_name: std::any::type_name::<T>(),
236            });
237        }
238
239        factories.insert(
240            type_id,
241            Arc::new(move |container| Ok(Arc::new(factory(container)?) as RequestFactoryValue)),
242        );
243        self.names
244            .write()
245            .map_err(|_| ContainerError::WriteLockPoisoned)?
246            .insert(std::any::type_name::<T>());
247        Ok(())
248    }
249
250    /// Resolves (retrieves) a registered provider.
251    ///
252    /// The search order is:
253    /// 1. Overrides (scoped instances)
254    /// 2. Singletons (global instances)
255    /// 3. Request-scoped factories (create and cache in overrides if found)
256    /// 4. Transient factories (create new instance)
257    ///
258    /// Returns an `Arc<T>` so the service can be cheaply shared.
259    pub fn resolve<T>(&self) -> Result<Arc<T>, ContainerError>
260    where
261        T: Send + Sync + 'static,
262    {
263        /*
264        Step 1: Check overrides.
265        If we are in a request scope, this is where request-scoped instances live.
266        */
267        if let Some(value) = self.resolve_from_map::<T>(&self.overrides)? {
268            return Ok(value);
269        }
270
271        /*
272        Step 2: Check global singletons.
273        This is the most common case for stateless services.
274        */
275        if let Some(value) = self.resolve_from_map::<T>(&self.inner)? {
276            return Ok(value);
277        }
278
279        /*
280        Step 3: Check request-scoped factories.
281        If found, we run the factory, cache the result in `overrides`, and return it.
282        */
283        if let Some(value) = self.resolve_from_request_factory::<T>()? {
284            return Ok(value);
285        }
286
287        /*
288        Step 4: Check transient factories.
289        If found, we run the factory and return a fresh instance.
290        */
291        if let Some(value) = self.resolve_from_transient_factory::<T>()? {
292            return Ok(value);
293        }
294
295        Err(ContainerError::TypeNotRegistered {
296            type_name: std::any::type_name::<T>(),
297        })
298    }
299
300    pub fn resolve_in_module<T>(&self, module_name: &'static str) -> Result<Arc<T>, ContainerError>
301    where
302        T: Send + Sync + 'static,
303    {
304        self.resolve::<T>().map_err(|err| match err {
305            ContainerError::TypeNotRegistered { type_name } => {
306                ContainerError::TypeNotRegisteredInModule {
307                    type_name,
308                    module_name,
309                }
310            }
311            other => other,
312        })
313    }
314
315    fn resolve_from_map<T>(
316        &self,
317        map: &Arc<RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
318    ) -> Result<Option<Arc<T>>, ContainerError>
319    where
320        T: Send + Sync + 'static,
321    {
322        let map = map.read().map_err(|_| ContainerError::ReadLockPoisoned)?;
323        let Some(value) = map.get(&TypeId::of::<T>()).cloned() else {
324            return Ok(None);
325        };
326
327        let value = value
328            .downcast::<T>()
329            .map_err(|_| ContainerError::DowncastFailed {
330                type_name: std::any::type_name::<T>(),
331            })?;
332
333        Ok(Some(value))
334    }
335
336    fn resolve_from_request_factory<T>(&self) -> Result<Option<Arc<T>>, ContainerError>
337    where
338        T: Send + Sync + 'static,
339    {
340        let factory = {
341            let factories = self
342                .request_factories
343                .read()
344                .map_err(|_| ContainerError::ReadLockPoisoned)?;
345            factories.get(&TypeId::of::<T>()).cloned()
346        };
347
348        let Some(factory) = factory else {
349            return Ok(None);
350        };
351
352        let value = factory(self).map_err(|err| ContainerError::RequestFactoryFailed {
353            type_name: std::any::type_name::<T>(),
354            message: err.to_string(),
355        })?;
356        let typed = value
357            .downcast::<T>()
358            .map_err(|_| ContainerError::DowncastFailed {
359                type_name: std::any::type_name::<T>(),
360            })?;
361
362        self.overrides
363            .write()
364            .map_err(|_| ContainerError::WriteLockPoisoned)?
365            .insert(TypeId::of::<T>(), typed.clone() as RequestFactoryValue);
366
367        Ok(Some(typed))
368    }
369
370    fn resolve_from_transient_factory<T>(&self) -> Result<Option<Arc<T>>, ContainerError>
371    where
372        T: Send + Sync + 'static,
373    {
374        let factory = {
375            let factories = self
376                .transient_factories
377                .read()
378                .map_err(|_| ContainerError::ReadLockPoisoned)?;
379            factories.get(&TypeId::of::<T>()).cloned()
380        };
381
382        let Some(factory) = factory else {
383            return Ok(None);
384        };
385
386        let value = factory(self).map_err(|err| ContainerError::RequestFactoryFailed {
387            type_name: std::any::type_name::<T>(),
388            message: err.to_string(),
389        })?;
390        let typed = value
391            .downcast::<T>()
392            .map_err(|_| ContainerError::DowncastFailed {
393                type_name: std::any::type_name::<T>(),
394            })?;
395
396        Ok(Some(typed))
397    }
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403
404    #[derive(Debug, PartialEq, Eq)]
405    struct AppConfig {
406        app_name: &'static str,
407    }
408
409    #[test]
410    fn override_value_takes_precedence_over_registered_value() {
411        let container = Container::new();
412
413        container
414            .register(AppConfig {
415                app_name: "default",
416            })
417            .expect("register should succeed");
418        container
419            .override_value(AppConfig { app_name: "test" })
420            .expect("override should succeed");
421
422        let config = container
423            .resolve::<AppConfig>()
424            .expect("config should resolve");
425        assert_eq!(config.app_name, "test");
426    }
427
428    #[derive(Clone)]
429    struct RequestId(String);
430
431    struct RequestGreeting(String);
432    struct TransientCounter(usize);
433
434    #[test]
435    fn scoped_container_resolves_request_factory_without_leaking_to_parent() {
436        let container = Container::new();
437        container
438            .register_request_factory::<RequestGreeting, _>(|scoped| {
439                let request_id = scoped.resolve::<RequestId>()?;
440                Ok(RequestGreeting(format!("hello {}", request_id.0)))
441            })
442            .expect("request factory should register");
443
444        let scoped = container.scoped();
445        scoped
446            .override_value(RequestId("req-1".to_string()))
447            .expect("request id should override");
448
449        let greeting = scoped
450            .resolve::<RequestGreeting>()
451            .expect("request greeting should resolve");
452
453        assert_eq!(greeting.0, "hello req-1");
454        assert!(container.resolve::<RequestGreeting>().is_err());
455    }
456
457    #[test]
458    fn transient_factory_creates_new_instances_per_resolve() {
459        let container = Container::new();
460        let counter = Arc::new(RwLock::new(0usize));
461        let counter_for_factory = Arc::clone(&counter);
462
463        container
464            .register_transient_factory::<TransientCounter, _>(move |_| {
465                let mut count = counter_for_factory
466                    .write()
467                    .expect("counter should be writable");
468                *count += 1;
469                Ok(TransientCounter(*count))
470            })
471            .expect("transient factory should register");
472
473        let first = container
474            .resolve::<TransientCounter>()
475            .expect("first transient should resolve");
476        let second = container
477            .resolve::<TransientCounter>()
478            .expect("second transient should resolve");
479
480        assert_eq!(first.0, 1);
481        assert_eq!(second.0, 2);
482    }
483}