Skip to main content

nidus_core/container/
mod.rs

1//! Typed dependency container primitives.
2
3mod dependency;
4mod request_scope;
5
6use std::{
7    any::{Any, TypeId, type_name},
8    collections::HashMap,
9    sync::Arc,
10};
11
12use crate::{NidusError, ProviderEntry, ProviderLifetime, Result};
13
14pub use dependency::{Factory, Inject, Lazy, Optional, Scoped};
15pub use request_scope::{RequestScope, SharedRequestScope};
16
17/// Type-indexed dependency container.
18#[derive(Default)]
19pub struct Container {
20    providers: HashMap<TypeId, ProviderEntry>,
21}
22
23impl Container {
24    /// Creates an empty container.
25    pub fn new() -> Self {
26        Self::default()
27    }
28
29    /// Creates a request scope for request-lifetime providers.
30    pub fn request_scope(&self) -> RequestScope<'_> {
31        RequestScope::borrowed(self)
32    }
33
34    /// Registers a concrete singleton value.
35    pub fn register_singleton<T>(&mut self, value: T) -> Result<()>
36    where
37        T: Send + Sync + 'static,
38    {
39        let value = Arc::new(value);
40        self.insert::<T>(ProviderLifetime::Singleton, move |_container| {
41            Ok(Arc::clone(&value) as Arc<dyn Any + Send + Sync>)
42        })
43    }
44
45    /// Replaces a singleton provider, intended for explicit test overrides.
46    pub fn override_singleton<T>(&mut self, value: T) -> Result<()>
47    where
48        T: Send + Sync + 'static,
49    {
50        self.providers.remove(&TypeId::of::<T>());
51        self.register_singleton(value)
52    }
53
54    /// Registers a provider factory.
55    pub fn register_factory<T, F>(&mut self, lifetime: ProviderLifetime, factory: F) -> Result<()>
56    where
57        T: Send + Sync + 'static,
58        F: Fn(&Container) -> Result<T> + Send + Sync + 'static,
59    {
60        self.insert::<T>(lifetime, move |container| {
61            factory(container).map(|value| Arc::new(value) as Arc<dyn Any + Send + Sync>)
62        })
63    }
64
65    /// Registers a singleton provider factory.
66    pub fn register_singleton_factory<T, F>(&mut self, factory: F) -> Result<()>
67    where
68        T: Send + Sync + 'static,
69        F: Fn(&Container) -> Result<T> + Send + Sync + 'static,
70    {
71        self.register_factory::<T, F>(ProviderLifetime::Singleton, factory)
72    }
73
74    /// Registers a transient provider factory.
75    pub fn register_transient<T, F>(&mut self, factory: F) -> Result<()>
76    where
77        T: Send + Sync + 'static,
78        F: Fn(&Container) -> Result<T> + Send + Sync + 'static,
79    {
80        self.register_factory::<T, F>(ProviderLifetime::Transient, factory)
81    }
82
83    /// Registers a request-lifetime provider factory.
84    pub fn register_request<T, F>(&mut self, factory: F) -> Result<()>
85    where
86        T: Send + Sync + 'static,
87        F: Fn(&Container) -> Result<T> + Send + Sync + 'static,
88    {
89        self.register_factory::<T, F>(ProviderLifetime::Request, factory)
90    }
91
92    /// Registers a request-lifetime provider factory that resolves dependencies
93    /// through the active request scope.
94    pub fn register_request_scoped<T, F>(&mut self, factory: F) -> Result<()>
95    where
96        T: Send + Sync + 'static,
97        F: for<'scope> Fn(&RequestScope<'scope>) -> Result<T> + Send + Sync + 'static,
98    {
99        self.insert_request_scoped::<T>(
100            |_container| {
101                Err(NidusError::RequestScopeRequired {
102                    type_name: type_name::<T>(),
103                })
104            },
105            move |scope| factory(scope).map(|value| Arc::new(value) as Arc<dyn Any + Send + Sync>),
106        )
107    }
108
109    /// Resolves a typed dependency reference.
110    pub fn inject<T>(&self) -> Result<Inject<T>>
111    where
112        T: Send + Sync + 'static,
113    {
114        self.resolve::<T>().map(Inject::new)
115    }
116
117    /// Resolves an optional typed dependency reference.
118    ///
119    /// Missing providers become `Optional::new(None)`, while registered providers
120    /// that fail to construct still return their original error.
121    pub fn optional<T>(&self) -> Result<Optional<T>>
122    where
123        T: Send + Sync + 'static,
124    {
125        match self.inject::<T>() {
126            Ok(value) => Ok(Optional::new(Some(value))),
127            Err(NidusError::MissingProvider { .. }) => Ok(Optional::new(None)),
128            Err(error) => Err(error),
129        }
130    }
131
132    /// Resolves a shared typed dependency.
133    pub fn resolve<T>(&self) -> Result<Arc<T>>
134    where
135        T: Send + Sync + 'static,
136    {
137        let entry = self.entry::<T>()?;
138        if entry.lifetime() == ProviderLifetime::Request {
139            return Err(NidusError::RequestScopeRequired {
140                type_name: type_name::<T>(),
141            });
142        }
143        let erased = entry.resolve_erased(self)?;
144        downcast::<T>(erased)
145    }
146
147    /// Eagerly constructs every registered singleton provider and caches it.
148    ///
149    /// Singletons are otherwise constructed lazily on first resolution, which
150    /// uses a blocking `Condvar` wait when two callers race to construct the
151    /// same provider. Calling this at startup pre-constructs each singleton so
152    /// later resolutions (including from async request handlers) hit the cached
153    /// value and never reach that wait, avoiding an async-runtime worker
154    /// stalling on first use. Transient and request-lifetime providers are
155    /// skipped.
156    ///
157    /// A singleton whose factory errors or panics will do so here, failing
158    /// startup fast instead of on first request.
159    pub fn eagerly_resolve_singletons(&self) -> Result<()> {
160        for entry in self.providers.values() {
161            if entry.lifetime() == ProviderLifetime::Singleton {
162                entry.resolve_erased(self)?;
163            }
164        }
165        Ok(())
166    }
167
168    fn insert<T>(
169        &mut self,
170        lifetime: ProviderLifetime,
171        factory: impl Fn(&Container) -> Result<Arc<dyn Any + Send + Sync>> + Send + Sync + 'static,
172    ) -> Result<()>
173    where
174        T: Send + Sync + 'static,
175    {
176        let type_id = TypeId::of::<T>();
177        if self.providers.contains_key(&type_id) {
178            return Err(NidusError::DuplicateProvider {
179                type_name: type_name::<T>(),
180            });
181        }
182
183        self.providers.insert(
184            type_id,
185            ProviderEntry::new(type_id, type_name::<T>(), lifetime, Arc::new(factory)),
186        );
187        Ok(())
188    }
189
190    fn insert_request_scoped<T>(
191        &mut self,
192        factory: impl Fn(&Container) -> Result<Arc<dyn Any + Send + Sync>> + Send + Sync + 'static,
193        request_factory: impl for<'scope> Fn(
194            &RequestScope<'scope>,
195        ) -> Result<Arc<dyn Any + Send + Sync>>
196        + Send
197        + Sync
198        + 'static,
199    ) -> Result<()>
200    where
201        T: Send + Sync + 'static,
202    {
203        let type_id = TypeId::of::<T>();
204        if self.providers.contains_key(&type_id) {
205            return Err(NidusError::DuplicateProvider {
206                type_name: type_name::<T>(),
207            });
208        }
209
210        self.providers.insert(
211            type_id,
212            ProviderEntry::new_request_scoped(
213                type_id,
214                type_name::<T>(),
215                Arc::new(factory),
216                Arc::new(request_factory),
217            ),
218        );
219        Ok(())
220    }
221
222    fn entry<T>(&self) -> Result<&ProviderEntry>
223    where
224        T: Send + Sync + 'static,
225    {
226        self.providers
227            .get(&TypeId::of::<T>())
228            .ok_or_else(|| NidusError::MissingProvider {
229                type_name: type_name::<T>(),
230            })
231    }
232}
233
234fn downcast<T>(erased: Arc<dyn Any + Send + Sync>) -> Result<Arc<T>>
235where
236    T: Send + Sync + 'static,
237{
238    erased
239        .downcast::<T>()
240        .map_err(|_| NidusError::MissingProvider {
241            type_name: type_name::<T>(),
242        })
243}