nidus_core/container/
mod.rs1mod dependency;
4mod request_scope;
5
6use std::{
7 any::{Any, TypeId, type_name},
8 collections::HashMap,
9 sync::Arc,
10};
11
12use crate::{NidusError, ProviderEntry, ProviderLifetime, Result};
13
14pub use dependency::{Factory, Inject, Lazy, Optional, Scoped};
15pub use request_scope::{RequestScope, SharedRequestScope};
16
17#[derive(Default)]
19pub struct Container {
20 providers: HashMap<TypeId, ProviderEntry>,
21}
22
23impl Container {
24 pub fn new() -> Self {
26 Self::default()
27 }
28
29 pub fn request_scope(&self) -> RequestScope<'_> {
31 RequestScope::borrowed(self)
32 }
33
34 pub fn register_singleton<T>(&mut self, value: T) -> Result<()>
36 where
37 T: Send + Sync + 'static,
38 {
39 let value = Arc::new(value);
40 self.insert::<T>(ProviderLifetime::Singleton, move |_container| {
41 Ok(Arc::clone(&value) as Arc<dyn Any + Send + Sync>)
42 })
43 }
44
45 pub fn override_singleton<T>(&mut self, value: T) -> Result<()>
47 where
48 T: Send + Sync + 'static,
49 {
50 self.providers.remove(&TypeId::of::<T>());
51 self.register_singleton(value)
52 }
53
54 pub fn register_factory<T, F>(&mut self, lifetime: ProviderLifetime, factory: F) -> Result<()>
56 where
57 T: Send + Sync + 'static,
58 F: Fn(&Container) -> Result<T> + Send + Sync + 'static,
59 {
60 self.insert::<T>(lifetime, move |container| {
61 factory(container).map(|value| Arc::new(value) as Arc<dyn Any + Send + Sync>)
62 })
63 }
64
65 pub fn register_singleton_factory<T, F>(&mut self, factory: F) -> Result<()>
67 where
68 T: Send + Sync + 'static,
69 F: Fn(&Container) -> Result<T> + Send + Sync + 'static,
70 {
71 self.register_factory::<T, F>(ProviderLifetime::Singleton, factory)
72 }
73
74 pub fn register_transient<T, F>(&mut self, factory: F) -> Result<()>
76 where
77 T: Send + Sync + 'static,
78 F: Fn(&Container) -> Result<T> + Send + Sync + 'static,
79 {
80 self.register_factory::<T, F>(ProviderLifetime::Transient, factory)
81 }
82
83 pub fn register_request<T, F>(&mut self, factory: F) -> Result<()>
85 where
86 T: Send + Sync + 'static,
87 F: Fn(&Container) -> Result<T> + Send + Sync + 'static,
88 {
89 self.register_factory::<T, F>(ProviderLifetime::Request, factory)
90 }
91
92 pub fn register_request_scoped<T, F>(&mut self, factory: F) -> Result<()>
95 where
96 T: Send + Sync + 'static,
97 F: for<'scope> Fn(&RequestScope<'scope>) -> Result<T> + Send + Sync + 'static,
98 {
99 self.insert_request_scoped::<T>(
100 |_container| {
101 Err(NidusError::RequestScopeRequired {
102 type_name: type_name::<T>(),
103 })
104 },
105 move |scope| factory(scope).map(|value| Arc::new(value) as Arc<dyn Any + Send + Sync>),
106 )
107 }
108
109 pub fn inject<T>(&self) -> Result<Inject<T>>
111 where
112 T: Send + Sync + 'static,
113 {
114 self.resolve::<T>().map(Inject::new)
115 }
116
117 pub fn optional<T>(&self) -> Result<Optional<T>>
122 where
123 T: Send + Sync + 'static,
124 {
125 match self.inject::<T>() {
126 Ok(value) => Ok(Optional::new(Some(value))),
127 Err(NidusError::MissingProvider { .. }) => Ok(Optional::new(None)),
128 Err(error) => Err(error),
129 }
130 }
131
132 pub fn resolve<T>(&self) -> Result<Arc<T>>
134 where
135 T: Send + Sync + 'static,
136 {
137 let entry = self.entry::<T>()?;
138 if entry.lifetime() == ProviderLifetime::Request {
139 return Err(NidusError::RequestScopeRequired {
140 type_name: type_name::<T>(),
141 });
142 }
143 let erased = entry.resolve_erased(self)?;
144 downcast::<T>(erased)
145 }
146
147 pub fn eagerly_resolve_singletons(&self) -> Result<()> {
160 for entry in self.providers.values() {
161 if entry.lifetime() == ProviderLifetime::Singleton {
162 entry.resolve_erased(self)?;
163 }
164 }
165 Ok(())
166 }
167
168 fn insert<T>(
169 &mut self,
170 lifetime: ProviderLifetime,
171 factory: impl Fn(&Container) -> Result<Arc<dyn Any + Send + Sync>> + Send + Sync + 'static,
172 ) -> Result<()>
173 where
174 T: Send + Sync + 'static,
175 {
176 let type_id = TypeId::of::<T>();
177 if self.providers.contains_key(&type_id) {
178 return Err(NidusError::DuplicateProvider {
179 type_name: type_name::<T>(),
180 });
181 }
182
183 self.providers.insert(
184 type_id,
185 ProviderEntry::new(type_id, type_name::<T>(), lifetime, Arc::new(factory)),
186 );
187 Ok(())
188 }
189
190 fn insert_request_scoped<T>(
191 &mut self,
192 factory: impl Fn(&Container) -> Result<Arc<dyn Any + Send + Sync>> + Send + Sync + 'static,
193 request_factory: impl for<'scope> Fn(
194 &RequestScope<'scope>,
195 ) -> Result<Arc<dyn Any + Send + Sync>>
196 + Send
197 + Sync
198 + 'static,
199 ) -> Result<()>
200 where
201 T: Send + Sync + 'static,
202 {
203 let type_id = TypeId::of::<T>();
204 if self.providers.contains_key(&type_id) {
205 return Err(NidusError::DuplicateProvider {
206 type_name: type_name::<T>(),
207 });
208 }
209
210 self.providers.insert(
211 type_id,
212 ProviderEntry::new_request_scoped(
213 type_id,
214 type_name::<T>(),
215 Arc::new(factory),
216 Arc::new(request_factory),
217 ),
218 );
219 Ok(())
220 }
221
222 fn entry<T>(&self) -> Result<&ProviderEntry>
223 where
224 T: Send + Sync + 'static,
225 {
226 self.providers
227 .get(&TypeId::of::<T>())
228 .ok_or_else(|| NidusError::MissingProvider {
229 type_name: type_name::<T>(),
230 })
231 }
232}
233
234fn downcast<T>(erased: Arc<dyn Any + Send + Sync>) -> Result<Arc<T>>
235where
236 T: Send + Sync + 'static,
237{
238 erased
239 .downcast::<T>()
240 .map_err(|_| NidusError::MissingProvider {
241 type_name: type_name::<T>(),
242 })
243}