Skip to main content

nestforge_core/
container.rs

1use std::{
2    any::{Any, TypeId},
3    collections::{HashMap, HashSet},
4    sync::{Arc, RwLock},
5};
6
7use thiserror::Error;
8
9/**
10* Container = our tiny dependency injection store (v1).
11*
12* What it does:
13* - stores values by type (TypeId).
14* - lets us register services/config once.
15* - lets us resolve them later.
16*
17* Why Arc?
18* - so multiple parts of the app can share the same service safely.
19*
20* Why RwLock?
21* - allows safe reads/writes across threads.
22* - write when registering.
23* - read when resolving.
24*/
25
26#[derive(Clone, Default)]
27pub struct Container {
28    inner: Arc<RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
29    names: Arc<RwLock<HashSet<&'static str>>>,
30}
31
32#[derive(Debug, Error)]
33pub enum ContainerError {
34    #[error("Container write lock poisoned")]
35    WriteLockPoisoned,
36    #[error("Container read lock poisoned")]
37    ReadLockPoisoned,
38    #[error("Type already registered: {type_name}")]
39    TypeAlreadyRegistered { type_name: &'static str },
40    #[error("Type not registered: {type_name}")]
41    TypeNotRegistered { type_name: &'static str },
42    #[error("Failed to downcast resolved value: {type_name}")]
43    DowncastFailed { type_name: &'static str },
44    #[error("Type not registered: {type_name} (required by module `{module_name}`)")]
45    TypeNotRegisteredInModule {
46        type_name: &'static str,
47        module_name: &'static str,
48    },
49}
50
51impl Container {
52    /**
53     * Nice helper constructor.
54     * Same as Default, just cleaner to read in app code.
55     */
56    pub fn new() -> Self {
57        Self::default()
58    }
59
60    /**
61     * Register a value/service into the container.
62     *
63     * Example:
64     * container.register(AppConfig { ... })?;
65     *
66     * Rules:
67     * - T must be thread-safe (Send + Sync).
68     * - T must be 'static because we store it for the app lifetime.
69     */
70    pub fn register<T>(&self, value: T) -> Result<(), ContainerError>
71    where
72        T: Send + Sync + 'static,
73    {
74        let mut map = self
75            .inner
76            .write()
77            .map_err(|_| ContainerError::WriteLockPoisoned)?;
78
79        let type_id = TypeId::of::<T>();
80
81        /* Prevent accidental duplicate registration of the same type. */
82        if map.contains_key(&type_id) {
83            return Err(ContainerError::TypeAlreadyRegistered {
84                type_name: std::any::type_name::<T>(),
85            });
86        }
87
88        /* Store as Arc<dyn Any> so we can keep different types in one map. */
89        map.insert(type_id, Arc::new(value));
90        self.names
91            .write()
92            .map_err(|_| ContainerError::WriteLockPoisoned)?
93            .insert(std::any::type_name::<T>());
94        Ok(())
95    }
96
97    pub fn replace<T>(&self, value: T) -> Result<(), ContainerError>
98    where
99        T: Send + Sync + 'static,
100    {
101        let mut map = self
102            .inner
103            .write()
104            .map_err(|_| ContainerError::WriteLockPoisoned)?;
105
106        map.insert(TypeId::of::<T>(), Arc::new(value));
107        self.names
108            .write()
109            .map_err(|_| ContainerError::WriteLockPoisoned)?
110            .insert(std::any::type_name::<T>());
111        Ok(())
112    }
113
114    pub fn is_type_registered_name(&self, type_name: &'static str) -> Result<bool, ContainerError> {
115        let names = self
116            .names
117            .read()
118            .map_err(|_| ContainerError::ReadLockPoisoned)?;
119        Ok(names.contains(type_name))
120    }
121
122    /**
123     * Resolve (get back) a registered value/service by type.
124     *
125     * Example:
126     * let config = container.resolve::<AppConfig>()?;
127     *
128     * Returns Arc<T> so the caller can clone/share it cheaply.
129     */
130    pub fn resolve<T>(&self) -> Result<Arc<T>, ContainerError>
131    where
132        T: Send + Sync + 'static,
133    {
134        let map = self
135            .inner
136            .read()
137            .map_err(|_| ContainerError::ReadLockPoisoned)?;
138
139        let value = map
140            .get(&TypeId::of::<T>())
141            .ok_or_else(|| ContainerError::TypeNotRegistered {
142                type_name: std::any::type_name::<T>(),
143            })?
144            .clone();
145
146        /*
147         * We stored the value as dyn Any, so now we downcast it back to the real type T.
148         * If downcast fails, the type in the map doesn’t match what we asked for.
149         */
150        value
151            .downcast::<T>()
152            .map_err(|_| ContainerError::DowncastFailed {
153                type_name: std::any::type_name::<T>(),
154            })
155    }
156
157    pub fn resolve_in_module<T>(&self, module_name: &'static str) -> Result<Arc<T>, ContainerError>
158    where
159        T: Send + Sync + 'static,
160    {
161        self.resolve::<T>().map_err(|err| match err {
162            ContainerError::TypeNotRegistered { type_name } => {
163                ContainerError::TypeNotRegisteredInModule {
164                    type_name,
165                    module_name,
166                }
167            }
168            other => other,
169        })
170    }
171}