Skip to main content

nestforge_core/
container.rs

1use std::{
2    any::{Any, TypeId},
3    collections::{HashMap, HashSet},
4    sync::{Arc, RwLock},
5};
6
7use thiserror::Error;
8
9/**
10* Container = our tiny dependency injection store (v1).
11*
12* What it does:
13* - stores values by type (TypeId).
14* - lets us register services/config once.
15* - lets us resolve them later.
16*
17* Why Arc?
18* - so multiple parts of the app can share the same service safely.
19*
20* Why RwLock?
21* - allows safe reads/writes across threads.
22* - write when registering.
23* - read when resolving.
24*/
25
26#[derive(Clone, Default)]
27pub struct Container {
28    inner: Arc<RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
29    overrides: Arc<RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
30    request_factories: Arc<
31        RwLock<HashMap<TypeId, Arc<RequestFactoryFn>>>,
32    >,
33    transient_factories: Arc<
34        RwLock<HashMap<TypeId, Arc<TransientFactoryFn>>>,
35    >,
36    names: Arc<RwLock<HashSet<&'static str>>>,
37}
38
39type RequestFactoryValue = Arc<dyn Any + Send + Sync>;
40type RequestFactoryFn =
41    dyn Fn(&Container) -> anyhow::Result<RequestFactoryValue> + Send + Sync + 'static;
42type TransientFactoryFn =
43    dyn Fn(&Container) -> anyhow::Result<RequestFactoryValue> + Send + Sync + 'static;
44
45#[derive(Debug, Error)]
46pub enum ContainerError {
47    #[error("Container write lock poisoned")]
48    WriteLockPoisoned,
49    #[error("Container read lock poisoned")]
50    ReadLockPoisoned,
51    #[error("Type already registered: {type_name}")]
52    TypeAlreadyRegistered { type_name: &'static str },
53    #[error("Type not registered: {type_name}")]
54    TypeNotRegistered { type_name: &'static str },
55    #[error("Failed to downcast resolved value: {type_name}")]
56    DowncastFailed { type_name: &'static str },
57    #[error("Request-scoped factory failed for {type_name}: {message}")]
58    RequestFactoryFailed {
59        type_name: &'static str,
60        message: String,
61    },
62    #[error("Type not registered: {type_name} (required by module `{module_name}`)")]
63    TypeNotRegisteredInModule {
64        type_name: &'static str,
65        module_name: &'static str,
66    },
67}
68
69impl Container {
70    /**
71     * Nice helper constructor.
72     * Same as Default, just cleaner to read in app code.
73     */
74    pub fn new() -> Self {
75        Self::default()
76    }
77
78    pub fn scoped(&self) -> Self {
79        Self {
80            inner: Arc::clone(&self.inner),
81            overrides: Arc::new(RwLock::new(HashMap::new())),
82            request_factories: Arc::clone(&self.request_factories),
83            transient_factories: Arc::clone(&self.transient_factories),
84            names: Arc::clone(&self.names),
85        }
86    }
87
88    /**
89     * Register a value/service into the container.
90     *
91     * Example:
92     * container.register(AppConfig { ... })?;
93     *
94     * Rules:
95     * - T must be thread-safe (Send + Sync).
96     * - T must be 'static because we store it for the app lifetime.
97     */
98    pub fn register<T>(&self, value: T) -> Result<(), ContainerError>
99    where
100        T: Send + Sync + 'static,
101    {
102        let mut map = self
103            .inner
104            .write()
105            .map_err(|_| ContainerError::WriteLockPoisoned)?;
106
107        let type_id = TypeId::of::<T>();
108
109        if map.contains_key(&type_id) {
110            return Err(ContainerError::TypeAlreadyRegistered {
111                type_name: std::any::type_name::<T>(),
112            });
113        }
114
115        map.insert(type_id, Arc::new(value));
116        self.names
117            .write()
118            .map_err(|_| ContainerError::WriteLockPoisoned)?
119            .insert(std::any::type_name::<T>());
120        Ok(())
121    }
122
123    pub fn replace<T>(&self, value: T) -> Result<(), ContainerError>
124    where
125        T: Send + Sync + 'static,
126    {
127        let mut map = self
128            .inner
129            .write()
130            .map_err(|_| ContainerError::WriteLockPoisoned)?;
131
132        map.insert(TypeId::of::<T>(), Arc::new(value));
133        self.names
134            .write()
135            .map_err(|_| ContainerError::WriteLockPoisoned)?
136            .insert(std::any::type_name::<T>());
137        Ok(())
138    }
139
140    pub fn override_value<T>(&self, value: T) -> Result<(), ContainerError>
141    where
142        T: Send + Sync + 'static,
143    {
144        let mut overrides = self
145            .overrides
146            .write()
147            .map_err(|_| ContainerError::WriteLockPoisoned)?;
148
149        overrides.insert(TypeId::of::<T>(), Arc::new(value));
150        self.names
151            .write()
152            .map_err(|_| ContainerError::WriteLockPoisoned)?
153            .insert(std::any::type_name::<T>());
154        Ok(())
155    }
156
157    pub fn is_type_registered_name(&self, type_name: &'static str) -> Result<bool, ContainerError> {
158        let names = self
159            .names
160            .read()
161            .map_err(|_| ContainerError::ReadLockPoisoned)?;
162        Ok(names.contains(type_name))
163    }
164
165    pub fn register_request_factory<T, F>(&self, factory: F) -> Result<(), ContainerError>
166    where
167        T: Send + Sync + 'static,
168        F: Fn(&Container) -> anyhow::Result<T> + Send + Sync + 'static,
169    {
170        let type_id = TypeId::of::<T>();
171        let mut factories = self
172            .request_factories
173            .write()
174            .map_err(|_| ContainerError::WriteLockPoisoned)?;
175
176        if factories.contains_key(&type_id) {
177            return Err(ContainerError::TypeAlreadyRegistered {
178                type_name: std::any::type_name::<T>(),
179            });
180        }
181
182        factories.insert(
183            type_id,
184            Arc::new(move |container| Ok(Arc::new(factory(container)?) as RequestFactoryValue)),
185        );
186        self.names
187            .write()
188            .map_err(|_| ContainerError::WriteLockPoisoned)?
189            .insert(std::any::type_name::<T>());
190        Ok(())
191    }
192
193    pub fn register_transient_factory<T, F>(&self, factory: F) -> Result<(), ContainerError>
194    where
195        T: Send + Sync + 'static,
196        F: Fn(&Container) -> anyhow::Result<T> + Send + Sync + 'static,
197    {
198        let type_id = TypeId::of::<T>();
199        let mut factories = self
200            .transient_factories
201            .write()
202            .map_err(|_| ContainerError::WriteLockPoisoned)?;
203
204        if factories.contains_key(&type_id) {
205            return Err(ContainerError::TypeAlreadyRegistered {
206                type_name: std::any::type_name::<T>(),
207            });
208        }
209
210        factories.insert(
211            type_id,
212            Arc::new(move |container| Ok(Arc::new(factory(container)?) as RequestFactoryValue)),
213        );
214        self.names
215            .write()
216            .map_err(|_| ContainerError::WriteLockPoisoned)?
217            .insert(std::any::type_name::<T>());
218        Ok(())
219    }
220
221    /**
222     * Resolve (get back) a registered value/service by type.
223     *
224     * Example:
225     * let config = container.resolve::<AppConfig>()?;
226     *
227     * Returns Arc<T> so the caller can clone/share it cheaply.
228     */
229    pub fn resolve<T>(&self) -> Result<Arc<T>, ContainerError>
230    where
231        T: Send + Sync + 'static,
232    {
233        if let Some(value) = self.resolve_from_map::<T>(&self.overrides)? {
234            return Ok(value);
235        }
236
237        if let Some(value) = self.resolve_from_map::<T>(&self.inner)? {
238            return Ok(value);
239        }
240
241        if let Some(value) = self.resolve_from_request_factory::<T>()? {
242            return Ok(value);
243        }
244
245        if let Some(value) = self.resolve_from_transient_factory::<T>()? {
246            return Ok(value);
247        }
248
249        Err(ContainerError::TypeNotRegistered {
250            type_name: std::any::type_name::<T>(),
251        })
252    }
253
254    pub fn resolve_in_module<T>(&self, module_name: &'static str) -> Result<Arc<T>, ContainerError>
255    where
256        T: Send + Sync + 'static,
257    {
258        self.resolve::<T>().map_err(|err| match err {
259            ContainerError::TypeNotRegistered { type_name } => {
260                ContainerError::TypeNotRegisteredInModule {
261                    type_name,
262                    module_name,
263                }
264            }
265            other => other,
266        })
267    }
268
269    fn resolve_from_map<T>(
270        &self,
271        map: &Arc<RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
272    ) -> Result<Option<Arc<T>>, ContainerError>
273    where
274        T: Send + Sync + 'static,
275    {
276        let map = map.read().map_err(|_| ContainerError::ReadLockPoisoned)?;
277        let Some(value) = map.get(&TypeId::of::<T>()).cloned() else {
278            return Ok(None);
279        };
280
281        let value = value.downcast::<T>().map_err(|_| ContainerError::DowncastFailed {
282            type_name: std::any::type_name::<T>(),
283        })?;
284
285        Ok(Some(value))
286    }
287
288    fn resolve_from_request_factory<T>(&self) -> Result<Option<Arc<T>>, ContainerError>
289    where
290        T: Send + Sync + 'static,
291    {
292        let factory = {
293            let factories = self
294                .request_factories
295                .read()
296                .map_err(|_| ContainerError::ReadLockPoisoned)?;
297            factories.get(&TypeId::of::<T>()).cloned()
298        };
299
300        let Some(factory) = factory else {
301            return Ok(None);
302        };
303
304        let value = factory(self).map_err(|err| ContainerError::RequestFactoryFailed {
305            type_name: std::any::type_name::<T>(),
306            message: err.to_string(),
307        })?;
308        let typed = value.downcast::<T>().map_err(|_| ContainerError::DowncastFailed {
309            type_name: std::any::type_name::<T>(),
310        })?;
311
312        self.overrides
313            .write()
314            .map_err(|_| ContainerError::WriteLockPoisoned)?
315            .insert(TypeId::of::<T>(), typed.clone() as RequestFactoryValue);
316
317        Ok(Some(typed))
318    }
319
320    fn resolve_from_transient_factory<T>(&self) -> Result<Option<Arc<T>>, ContainerError>
321    where
322        T: Send + Sync + 'static,
323    {
324        let factory = {
325            let factories = self
326                .transient_factories
327                .read()
328                .map_err(|_| ContainerError::ReadLockPoisoned)?;
329            factories.get(&TypeId::of::<T>()).cloned()
330        };
331
332        let Some(factory) = factory else {
333            return Ok(None);
334        };
335
336        let value = factory(self).map_err(|err| ContainerError::RequestFactoryFailed {
337            type_name: std::any::type_name::<T>(),
338            message: err.to_string(),
339        })?;
340        let typed = value.downcast::<T>().map_err(|_| ContainerError::DowncastFailed {
341            type_name: std::any::type_name::<T>(),
342        })?;
343
344        Ok(Some(typed))
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351
352    #[derive(Debug, PartialEq, Eq)]
353    struct AppConfig {
354        app_name: &'static str,
355    }
356
357    #[test]
358    fn override_value_takes_precedence_over_registered_value() {
359        let container = Container::new();
360
361        container
362            .register(AppConfig {
363                app_name: "default",
364            })
365            .expect("register should succeed");
366        container
367            .override_value(AppConfig { app_name: "test" })
368            .expect("override should succeed");
369
370        let config = container
371            .resolve::<AppConfig>()
372            .expect("config should resolve");
373        assert_eq!(config.app_name, "test");
374    }
375
376    #[derive(Clone)]
377    struct RequestId(String);
378
379    struct RequestGreeting(String);
380    struct TransientCounter(usize);
381
382    #[test]
383    fn scoped_container_resolves_request_factory_without_leaking_to_parent() {
384        let container = Container::new();
385        container
386            .register_request_factory::<RequestGreeting, _>(|scoped| {
387                let request_id = scoped.resolve::<RequestId>()?;
388                Ok(RequestGreeting(format!("hello {}", request_id.0)))
389            })
390            .expect("request factory should register");
391
392        let scoped = container.scoped();
393        scoped
394            .override_value(RequestId("req-1".to_string()))
395            .expect("request id should override");
396
397        let greeting = scoped
398            .resolve::<RequestGreeting>()
399            .expect("request greeting should resolve");
400
401        assert_eq!(greeting.0, "hello req-1");
402        assert!(container.resolve::<RequestGreeting>().is_err());
403    }
404
405    #[test]
406    fn transient_factory_creates_new_instances_per_resolve() {
407        let container = Container::new();
408        let counter = Arc::new(RwLock::new(0usize));
409        let counter_for_factory = Arc::clone(&counter);
410
411        container
412            .register_transient_factory::<TransientCounter, _>(move |_| {
413                let mut count = counter_for_factory
414                    .write()
415                    .expect("counter should be writable");
416                *count += 1;
417                Ok(TransientCounter(*count))
418            })
419            .expect("transient factory should register");
420
421        let first = container
422            .resolve::<TransientCounter>()
423            .expect("first transient should resolve");
424        let second = container
425            .resolve::<TransientCounter>()
426            .expect("second transient should resolve");
427
428        assert_eq!(first.0, 1);
429        assert_eq!(second.0, 2);
430    }
431}