Skip to main content

nestforge_core/
container.rs

1use std::{
2    any::{Any, TypeId},
3    collections::{HashMap, HashSet},
4    sync::{Arc, RwLock},
5};
6
7use thiserror::Error;
8
9/**
10* Container = our tiny dependency injection store (v1).
11*
12* What it does:
13* - stores values by type (TypeId).
14* - lets us register services/config once.
15* - lets us resolve them later.
16*
17* Why Arc?
18* - so multiple parts of the app can share the same service safely.
19*
20* Why RwLock?
21* - allows safe reads/writes across threads.
22* - write when registering.
23* - read when resolving.
24*/
25
26#[derive(Clone, Default)]
27pub struct Container {
28    inner: Arc<RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
29    overrides: Arc<RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
30    request_factories: Arc<RwLock<HashMap<TypeId, Arc<RequestFactoryFn>>>>,
31    transient_factories: Arc<RwLock<HashMap<TypeId, Arc<TransientFactoryFn>>>>,
32    names: Arc<RwLock<HashSet<&'static str>>>,
33}
34
35type RequestFactoryValue = Arc<dyn Any + Send + Sync>;
36type RequestFactoryFn =
37    dyn Fn(&Container) -> anyhow::Result<RequestFactoryValue> + Send + Sync + 'static;
38type TransientFactoryFn =
39    dyn Fn(&Container) -> anyhow::Result<RequestFactoryValue> + Send + Sync + 'static;
40
41#[derive(Debug, Error)]
42pub enum ContainerError {
43    #[error("Container write lock poisoned")]
44    WriteLockPoisoned,
45    #[error("Container read lock poisoned")]
46    ReadLockPoisoned,
47    #[error("Type already registered: {type_name}")]
48    TypeAlreadyRegistered { type_name: &'static str },
49    #[error("Type not registered: {type_name}")]
50    TypeNotRegistered { type_name: &'static str },
51    #[error("Failed to downcast resolved value: {type_name}")]
52    DowncastFailed { type_name: &'static str },
53    #[error("Request-scoped factory failed for {type_name}: {message}")]
54    RequestFactoryFailed {
55        type_name: &'static str,
56        message: String,
57    },
58    #[error("Type not registered: {type_name} (required by module `{module_name}`)")]
59    TypeNotRegisteredInModule {
60        type_name: &'static str,
61        module_name: &'static str,
62    },
63}
64
65impl Container {
66    /**
67     * Nice helper constructor.
68     * Same as Default, just cleaner to read in app code.
69     */
70    pub fn new() -> Self {
71        Self::default()
72    }
73
74    pub fn scoped(&self) -> Self {
75        Self {
76            inner: Arc::clone(&self.inner),
77            overrides: Arc::new(RwLock::new(HashMap::new())),
78            request_factories: Arc::clone(&self.request_factories),
79            transient_factories: Arc::clone(&self.transient_factories),
80            names: Arc::clone(&self.names),
81        }
82    }
83
84    /**
85     * Register a value/service into the container.
86     *
87     * Example:
88     * container.register(AppConfig { ... })?;
89     *
90     * Rules:
91     * - T must be thread-safe (Send + Sync).
92     * - T must be 'static because we store it for the app lifetime.
93     */
94    pub fn register<T>(&self, value: T) -> Result<(), ContainerError>
95    where
96        T: Send + Sync + 'static,
97    {
98        let mut map = self
99            .inner
100            .write()
101            .map_err(|_| ContainerError::WriteLockPoisoned)?;
102
103        let type_id = TypeId::of::<T>();
104
105        if map.contains_key(&type_id) {
106            return Err(ContainerError::TypeAlreadyRegistered {
107                type_name: std::any::type_name::<T>(),
108            });
109        }
110
111        map.insert(type_id, Arc::new(value));
112        self.names
113            .write()
114            .map_err(|_| ContainerError::WriteLockPoisoned)?
115            .insert(std::any::type_name::<T>());
116        Ok(())
117    }
118
119    pub fn replace<T>(&self, value: T) -> Result<(), ContainerError>
120    where
121        T: Send + Sync + 'static,
122    {
123        let mut map = self
124            .inner
125            .write()
126            .map_err(|_| ContainerError::WriteLockPoisoned)?;
127
128        map.insert(TypeId::of::<T>(), Arc::new(value));
129        self.names
130            .write()
131            .map_err(|_| ContainerError::WriteLockPoisoned)?
132            .insert(std::any::type_name::<T>());
133        Ok(())
134    }
135
136    pub fn override_value<T>(&self, value: T) -> Result<(), ContainerError>
137    where
138        T: Send + Sync + 'static,
139    {
140        let mut overrides = self
141            .overrides
142            .write()
143            .map_err(|_| ContainerError::WriteLockPoisoned)?;
144
145        overrides.insert(TypeId::of::<T>(), Arc::new(value));
146        self.names
147            .write()
148            .map_err(|_| ContainerError::WriteLockPoisoned)?
149            .insert(std::any::type_name::<T>());
150        Ok(())
151    }
152
153    pub fn is_type_registered_name(&self, type_name: &'static str) -> Result<bool, ContainerError> {
154        let names = self
155            .names
156            .read()
157            .map_err(|_| ContainerError::ReadLockPoisoned)?;
158        Ok(names.contains(type_name))
159    }
160
161    pub fn register_request_factory<T, F>(&self, factory: F) -> Result<(), ContainerError>
162    where
163        T: Send + Sync + 'static,
164        F: Fn(&Container) -> anyhow::Result<T> + Send + Sync + 'static,
165    {
166        let type_id = TypeId::of::<T>();
167        let mut factories = self
168            .request_factories
169            .write()
170            .map_err(|_| ContainerError::WriteLockPoisoned)?;
171
172        if factories.contains_key(&type_id) {
173            return Err(ContainerError::TypeAlreadyRegistered {
174                type_name: std::any::type_name::<T>(),
175            });
176        }
177
178        factories.insert(
179            type_id,
180            Arc::new(move |container| Ok(Arc::new(factory(container)?) as RequestFactoryValue)),
181        );
182        self.names
183            .write()
184            .map_err(|_| ContainerError::WriteLockPoisoned)?
185            .insert(std::any::type_name::<T>());
186        Ok(())
187    }
188
189    pub fn register_transient_factory<T, F>(&self, factory: F) -> Result<(), ContainerError>
190    where
191        T: Send + Sync + 'static,
192        F: Fn(&Container) -> anyhow::Result<T> + Send + Sync + 'static,
193    {
194        let type_id = TypeId::of::<T>();
195        let mut factories = self
196            .transient_factories
197            .write()
198            .map_err(|_| ContainerError::WriteLockPoisoned)?;
199
200        if factories.contains_key(&type_id) {
201            return Err(ContainerError::TypeAlreadyRegistered {
202                type_name: std::any::type_name::<T>(),
203            });
204        }
205
206        factories.insert(
207            type_id,
208            Arc::new(move |container| Ok(Arc::new(factory(container)?) as RequestFactoryValue)),
209        );
210        self.names
211            .write()
212            .map_err(|_| ContainerError::WriteLockPoisoned)?
213            .insert(std::any::type_name::<T>());
214        Ok(())
215    }
216
217    /**
218     * Resolve (get back) a registered value/service by type.
219     *
220     * Example:
221     * let config = container.resolve::<AppConfig>()?;
222     *
223     * Returns Arc<T> so the caller can clone/share it cheaply.
224     */
225    pub fn resolve<T>(&self) -> Result<Arc<T>, ContainerError>
226    where
227        T: Send + Sync + 'static,
228    {
229        if let Some(value) = self.resolve_from_map::<T>(&self.overrides)? {
230            return Ok(value);
231        }
232
233        if let Some(value) = self.resolve_from_map::<T>(&self.inner)? {
234            return Ok(value);
235        }
236
237        if let Some(value) = self.resolve_from_request_factory::<T>()? {
238            return Ok(value);
239        }
240
241        if let Some(value) = self.resolve_from_transient_factory::<T>()? {
242            return Ok(value);
243        }
244
245        Err(ContainerError::TypeNotRegistered {
246            type_name: std::any::type_name::<T>(),
247        })
248    }
249
250    pub fn resolve_in_module<T>(&self, module_name: &'static str) -> Result<Arc<T>, ContainerError>
251    where
252        T: Send + Sync + 'static,
253    {
254        self.resolve::<T>().map_err(|err| match err {
255            ContainerError::TypeNotRegistered { type_name } => {
256                ContainerError::TypeNotRegisteredInModule {
257                    type_name,
258                    module_name,
259                }
260            }
261            other => other,
262        })
263    }
264
265    fn resolve_from_map<T>(
266        &self,
267        map: &Arc<RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
268    ) -> Result<Option<Arc<T>>, ContainerError>
269    where
270        T: Send + Sync + 'static,
271    {
272        let map = map.read().map_err(|_| ContainerError::ReadLockPoisoned)?;
273        let Some(value) = map.get(&TypeId::of::<T>()).cloned() else {
274            return Ok(None);
275        };
276
277        let value = value
278            .downcast::<T>()
279            .map_err(|_| ContainerError::DowncastFailed {
280                type_name: std::any::type_name::<T>(),
281            })?;
282
283        Ok(Some(value))
284    }
285
286    fn resolve_from_request_factory<T>(&self) -> Result<Option<Arc<T>>, ContainerError>
287    where
288        T: Send + Sync + 'static,
289    {
290        let factory = {
291            let factories = self
292                .request_factories
293                .read()
294                .map_err(|_| ContainerError::ReadLockPoisoned)?;
295            factories.get(&TypeId::of::<T>()).cloned()
296        };
297
298        let Some(factory) = factory else {
299            return Ok(None);
300        };
301
302        let value = factory(self).map_err(|err| ContainerError::RequestFactoryFailed {
303            type_name: std::any::type_name::<T>(),
304            message: err.to_string(),
305        })?;
306        let typed = value
307            .downcast::<T>()
308            .map_err(|_| ContainerError::DowncastFailed {
309                type_name: std::any::type_name::<T>(),
310            })?;
311
312        self.overrides
313            .write()
314            .map_err(|_| ContainerError::WriteLockPoisoned)?
315            .insert(TypeId::of::<T>(), typed.clone() as RequestFactoryValue);
316
317        Ok(Some(typed))
318    }
319
320    fn resolve_from_transient_factory<T>(&self) -> Result<Option<Arc<T>>, ContainerError>
321    where
322        T: Send + Sync + 'static,
323    {
324        let factory = {
325            let factories = self
326                .transient_factories
327                .read()
328                .map_err(|_| ContainerError::ReadLockPoisoned)?;
329            factories.get(&TypeId::of::<T>()).cloned()
330        };
331
332        let Some(factory) = factory else {
333            return Ok(None);
334        };
335
336        let value = factory(self).map_err(|err| ContainerError::RequestFactoryFailed {
337            type_name: std::any::type_name::<T>(),
338            message: err.to_string(),
339        })?;
340        let typed = value
341            .downcast::<T>()
342            .map_err(|_| ContainerError::DowncastFailed {
343                type_name: std::any::type_name::<T>(),
344            })?;
345
346        Ok(Some(typed))
347    }
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353
354    #[derive(Debug, PartialEq, Eq)]
355    struct AppConfig {
356        app_name: &'static str,
357    }
358
359    #[test]
360    fn override_value_takes_precedence_over_registered_value() {
361        let container = Container::new();
362
363        container
364            .register(AppConfig {
365                app_name: "default",
366            })
367            .expect("register should succeed");
368        container
369            .override_value(AppConfig { app_name: "test" })
370            .expect("override should succeed");
371
372        let config = container
373            .resolve::<AppConfig>()
374            .expect("config should resolve");
375        assert_eq!(config.app_name, "test");
376    }
377
378    #[derive(Clone)]
379    struct RequestId(String);
380
381    struct RequestGreeting(String);
382    struct TransientCounter(usize);
383
384    #[test]
385    fn scoped_container_resolves_request_factory_without_leaking_to_parent() {
386        let container = Container::new();
387        container
388            .register_request_factory::<RequestGreeting, _>(|scoped| {
389                let request_id = scoped.resolve::<RequestId>()?;
390                Ok(RequestGreeting(format!("hello {}", request_id.0)))
391            })
392            .expect("request factory should register");
393
394        let scoped = container.scoped();
395        scoped
396            .override_value(RequestId("req-1".to_string()))
397            .expect("request id should override");
398
399        let greeting = scoped
400            .resolve::<RequestGreeting>()
401            .expect("request greeting should resolve");
402
403        assert_eq!(greeting.0, "hello req-1");
404        assert!(container.resolve::<RequestGreeting>().is_err());
405    }
406
407    #[test]
408    fn transient_factory_creates_new_instances_per_resolve() {
409        let container = Container::new();
410        let counter = Arc::new(RwLock::new(0usize));
411        let counter_for_factory = Arc::clone(&counter);
412
413        container
414            .register_transient_factory::<TransientCounter, _>(move |_| {
415                let mut count = counter_for_factory
416                    .write()
417                    .expect("counter should be writable");
418                *count += 1;
419                Ok(TransientCounter(*count))
420            })
421            .expect("transient factory should register");
422
423        let first = container
424            .resolve::<TransientCounter>()
425            .expect("first transient should resolve");
426        let second = container
427            .resolve::<TransientCounter>()
428            .expect("second transient should resolve");
429
430        assert_eq!(first.0, 1);
431        assert_eq!(second.0, 2);
432    }
433}