1use std::{
2 any::{Any, TypeId},
3 collections::{HashMap, HashSet},
4 sync::{Arc, RwLock},
5};
6
7use thiserror::Error;
8
9#[derive(Clone, Default)]
27pub struct Container {
28 inner: Arc<RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
29 overrides: Arc<RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
30 request_factories: Arc<RwLock<HashMap<TypeId, Arc<RequestFactoryFn>>>>,
31 transient_factories: Arc<RwLock<HashMap<TypeId, Arc<TransientFactoryFn>>>>,
32 names: Arc<RwLock<HashSet<&'static str>>>,
33}
34
35type RequestFactoryValue = Arc<dyn Any + Send + Sync>;
36type RequestFactoryFn =
37 dyn Fn(&Container) -> anyhow::Result<RequestFactoryValue> + Send + Sync + 'static;
38type TransientFactoryFn =
39 dyn Fn(&Container) -> anyhow::Result<RequestFactoryValue> + Send + Sync + 'static;
40
41#[derive(Debug, Error)]
42pub enum ContainerError {
43 #[error("Container write lock poisoned")]
44 WriteLockPoisoned,
45 #[error("Container read lock poisoned")]
46 ReadLockPoisoned,
47 #[error("Type already registered: {type_name}")]
48 TypeAlreadyRegistered { type_name: &'static str },
49 #[error("Type not registered: {type_name}")]
50 TypeNotRegistered { type_name: &'static str },
51 #[error("Failed to downcast resolved value: {type_name}")]
52 DowncastFailed { type_name: &'static str },
53 #[error("Request-scoped factory failed for {type_name}: {message}")]
54 RequestFactoryFailed {
55 type_name: &'static str,
56 message: String,
57 },
58 #[error("Type not registered: {type_name} (required by module `{module_name}`)")]
59 TypeNotRegisteredInModule {
60 type_name: &'static str,
61 module_name: &'static str,
62 },
63}
64
65impl Container {
66 pub fn new() -> Self {
71 Self::default()
72 }
73
74 pub fn scoped(&self) -> Self {
75 Self {
76 inner: Arc::clone(&self.inner),
77 overrides: Arc::new(RwLock::new(HashMap::new())),
78 request_factories: Arc::clone(&self.request_factories),
79 transient_factories: Arc::clone(&self.transient_factories),
80 names: Arc::clone(&self.names),
81 }
82 }
83
84 pub fn register<T>(&self, value: T) -> Result<(), ContainerError>
95 where
96 T: Send + Sync + 'static,
97 {
98 let mut map = self
99 .inner
100 .write()
101 .map_err(|_| ContainerError::WriteLockPoisoned)?;
102
103 let type_id = TypeId::of::<T>();
104
105 if map.contains_key(&type_id) {
106 return Err(ContainerError::TypeAlreadyRegistered {
107 type_name: std::any::type_name::<T>(),
108 });
109 }
110
111 map.insert(type_id, Arc::new(value));
112 self.names
113 .write()
114 .map_err(|_| ContainerError::WriteLockPoisoned)?
115 .insert(std::any::type_name::<T>());
116 Ok(())
117 }
118
119 pub fn replace<T>(&self, value: T) -> Result<(), ContainerError>
120 where
121 T: Send + Sync + 'static,
122 {
123 let mut map = self
124 .inner
125 .write()
126 .map_err(|_| ContainerError::WriteLockPoisoned)?;
127
128 map.insert(TypeId::of::<T>(), Arc::new(value));
129 self.names
130 .write()
131 .map_err(|_| ContainerError::WriteLockPoisoned)?
132 .insert(std::any::type_name::<T>());
133 Ok(())
134 }
135
136 pub fn override_value<T>(&self, value: T) -> Result<(), ContainerError>
137 where
138 T: Send + Sync + 'static,
139 {
140 let mut overrides = self
141 .overrides
142 .write()
143 .map_err(|_| ContainerError::WriteLockPoisoned)?;
144
145 overrides.insert(TypeId::of::<T>(), Arc::new(value));
146 self.names
147 .write()
148 .map_err(|_| ContainerError::WriteLockPoisoned)?
149 .insert(std::any::type_name::<T>());
150 Ok(())
151 }
152
153 pub fn is_type_registered_name(&self, type_name: &'static str) -> Result<bool, ContainerError> {
154 let names = self
155 .names
156 .read()
157 .map_err(|_| ContainerError::ReadLockPoisoned)?;
158 Ok(names.contains(type_name))
159 }
160
161 pub fn register_request_factory<T, F>(&self, factory: F) -> Result<(), ContainerError>
162 where
163 T: Send + Sync + 'static,
164 F: Fn(&Container) -> anyhow::Result<T> + Send + Sync + 'static,
165 {
166 let type_id = TypeId::of::<T>();
167 let mut factories = self
168 .request_factories
169 .write()
170 .map_err(|_| ContainerError::WriteLockPoisoned)?;
171
172 if factories.contains_key(&type_id) {
173 return Err(ContainerError::TypeAlreadyRegistered {
174 type_name: std::any::type_name::<T>(),
175 });
176 }
177
178 factories.insert(
179 type_id,
180 Arc::new(move |container| Ok(Arc::new(factory(container)?) as RequestFactoryValue)),
181 );
182 self.names
183 .write()
184 .map_err(|_| ContainerError::WriteLockPoisoned)?
185 .insert(std::any::type_name::<T>());
186 Ok(())
187 }
188
189 pub fn register_transient_factory<T, F>(&self, factory: F) -> Result<(), ContainerError>
190 where
191 T: Send + Sync + 'static,
192 F: Fn(&Container) -> anyhow::Result<T> + Send + Sync + 'static,
193 {
194 let type_id = TypeId::of::<T>();
195 let mut factories = self
196 .transient_factories
197 .write()
198 .map_err(|_| ContainerError::WriteLockPoisoned)?;
199
200 if factories.contains_key(&type_id) {
201 return Err(ContainerError::TypeAlreadyRegistered {
202 type_name: std::any::type_name::<T>(),
203 });
204 }
205
206 factories.insert(
207 type_id,
208 Arc::new(move |container| Ok(Arc::new(factory(container)?) as RequestFactoryValue)),
209 );
210 self.names
211 .write()
212 .map_err(|_| ContainerError::WriteLockPoisoned)?
213 .insert(std::any::type_name::<T>());
214 Ok(())
215 }
216
217 pub fn resolve<T>(&self) -> Result<Arc<T>, ContainerError>
226 where
227 T: Send + Sync + 'static,
228 {
229 if let Some(value) = self.resolve_from_map::<T>(&self.overrides)? {
230 return Ok(value);
231 }
232
233 if let Some(value) = self.resolve_from_map::<T>(&self.inner)? {
234 return Ok(value);
235 }
236
237 if let Some(value) = self.resolve_from_request_factory::<T>()? {
238 return Ok(value);
239 }
240
241 if let Some(value) = self.resolve_from_transient_factory::<T>()? {
242 return Ok(value);
243 }
244
245 Err(ContainerError::TypeNotRegistered {
246 type_name: std::any::type_name::<T>(),
247 })
248 }
249
250 pub fn resolve_in_module<T>(&self, module_name: &'static str) -> Result<Arc<T>, ContainerError>
251 where
252 T: Send + Sync + 'static,
253 {
254 self.resolve::<T>().map_err(|err| match err {
255 ContainerError::TypeNotRegistered { type_name } => {
256 ContainerError::TypeNotRegisteredInModule {
257 type_name,
258 module_name,
259 }
260 }
261 other => other,
262 })
263 }
264
265 fn resolve_from_map<T>(
266 &self,
267 map: &Arc<RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
268 ) -> Result<Option<Arc<T>>, ContainerError>
269 where
270 T: Send + Sync + 'static,
271 {
272 let map = map.read().map_err(|_| ContainerError::ReadLockPoisoned)?;
273 let Some(value) = map.get(&TypeId::of::<T>()).cloned() else {
274 return Ok(None);
275 };
276
277 let value = value
278 .downcast::<T>()
279 .map_err(|_| ContainerError::DowncastFailed {
280 type_name: std::any::type_name::<T>(),
281 })?;
282
283 Ok(Some(value))
284 }
285
286 fn resolve_from_request_factory<T>(&self) -> Result<Option<Arc<T>>, ContainerError>
287 where
288 T: Send + Sync + 'static,
289 {
290 let factory = {
291 let factories = self
292 .request_factories
293 .read()
294 .map_err(|_| ContainerError::ReadLockPoisoned)?;
295 factories.get(&TypeId::of::<T>()).cloned()
296 };
297
298 let Some(factory) = factory else {
299 return Ok(None);
300 };
301
302 let value = factory(self).map_err(|err| ContainerError::RequestFactoryFailed {
303 type_name: std::any::type_name::<T>(),
304 message: err.to_string(),
305 })?;
306 let typed = value
307 .downcast::<T>()
308 .map_err(|_| ContainerError::DowncastFailed {
309 type_name: std::any::type_name::<T>(),
310 })?;
311
312 self.overrides
313 .write()
314 .map_err(|_| ContainerError::WriteLockPoisoned)?
315 .insert(TypeId::of::<T>(), typed.clone() as RequestFactoryValue);
316
317 Ok(Some(typed))
318 }
319
320 fn resolve_from_transient_factory<T>(&self) -> Result<Option<Arc<T>>, ContainerError>
321 where
322 T: Send + Sync + 'static,
323 {
324 let factory = {
325 let factories = self
326 .transient_factories
327 .read()
328 .map_err(|_| ContainerError::ReadLockPoisoned)?;
329 factories.get(&TypeId::of::<T>()).cloned()
330 };
331
332 let Some(factory) = factory else {
333 return Ok(None);
334 };
335
336 let value = factory(self).map_err(|err| ContainerError::RequestFactoryFailed {
337 type_name: std::any::type_name::<T>(),
338 message: err.to_string(),
339 })?;
340 let typed = value
341 .downcast::<T>()
342 .map_err(|_| ContainerError::DowncastFailed {
343 type_name: std::any::type_name::<T>(),
344 })?;
345
346 Ok(Some(typed))
347 }
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353
354 #[derive(Debug, PartialEq, Eq)]
355 struct AppConfig {
356 app_name: &'static str,
357 }
358
359 #[test]
360 fn override_value_takes_precedence_over_registered_value() {
361 let container = Container::new();
362
363 container
364 .register(AppConfig {
365 app_name: "default",
366 })
367 .expect("register should succeed");
368 container
369 .override_value(AppConfig { app_name: "test" })
370 .expect("override should succeed");
371
372 let config = container
373 .resolve::<AppConfig>()
374 .expect("config should resolve");
375 assert_eq!(config.app_name, "test");
376 }
377
378 #[derive(Clone)]
379 struct RequestId(String);
380
381 struct RequestGreeting(String);
382 struct TransientCounter(usize);
383
384 #[test]
385 fn scoped_container_resolves_request_factory_without_leaking_to_parent() {
386 let container = Container::new();
387 container
388 .register_request_factory::<RequestGreeting, _>(|scoped| {
389 let request_id = scoped.resolve::<RequestId>()?;
390 Ok(RequestGreeting(format!("hello {}", request_id.0)))
391 })
392 .expect("request factory should register");
393
394 let scoped = container.scoped();
395 scoped
396 .override_value(RequestId("req-1".to_string()))
397 .expect("request id should override");
398
399 let greeting = scoped
400 .resolve::<RequestGreeting>()
401 .expect("request greeting should resolve");
402
403 assert_eq!(greeting.0, "hello req-1");
404 assert!(container.resolve::<RequestGreeting>().is_err());
405 }
406
407 #[test]
408 fn transient_factory_creates_new_instances_per_resolve() {
409 let container = Container::new();
410 let counter = Arc::new(RwLock::new(0usize));
411 let counter_for_factory = Arc::clone(&counter);
412
413 container
414 .register_transient_factory::<TransientCounter, _>(move |_| {
415 let mut count = counter_for_factory
416 .write()
417 .expect("counter should be writable");
418 *count += 1;
419 Ok(TransientCounter(*count))
420 })
421 .expect("transient factory should register");
422
423 let first = container
424 .resolve::<TransientCounter>()
425 .expect("first transient should resolve");
426 let second = container
427 .resolve::<TransientCounter>()
428 .expect("second transient should resolve");
429
430 assert_eq!(first.0, 1);
431 assert_eq!(second.0, 2);
432 }
433}