1use std::{
2 any::{Any, TypeId},
3 collections::{HashMap, HashSet},
4 sync::{Arc, RwLock},
5};
6
7use thiserror::Error;
8
9#[derive(Clone, Default)]
20pub struct Container {
21 inner: Arc<RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
26
27 overrides: Arc<RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
32
33 request_factories: Arc<RwLock<HashMap<TypeId, Arc<RequestFactoryFn>>>>,
34 transient_factories: Arc<RwLock<HashMap<TypeId, Arc<TransientFactoryFn>>>>,
35
36 names: Arc<RwLock<HashSet<&'static str>>>,
41}
42
43type RequestFactoryValue = Arc<dyn Any + Send + Sync>;
44type RequestFactoryFn =
45 dyn Fn(&Container) -> anyhow::Result<RequestFactoryValue> + Send + Sync + 'static;
46type TransientFactoryFn =
47 dyn Fn(&Container) -> anyhow::Result<RequestFactoryValue> + Send + Sync + 'static;
48
49#[derive(Debug, Error)]
50pub enum ContainerError {
51 #[error("Container write lock poisoned")]
52 WriteLockPoisoned,
53 #[error("Container read lock poisoned")]
54 ReadLockPoisoned,
55 #[error("Type already registered: {type_name}")]
56 TypeAlreadyRegistered { type_name: &'static str },
57 #[error("Type not registered: {type_name}")]
58 TypeNotRegistered { type_name: &'static str },
59 #[error("Failed to downcast resolved value: {type_name}")]
60 DowncastFailed { type_name: &'static str },
61 #[error("Request-scoped factory failed for {type_name}: {message}")]
62 RequestFactoryFailed {
63 type_name: &'static str,
64 message: String,
65 },
66 #[error("Type not registered: {type_name} (required by module `{module_name}`)")]
67 TypeNotRegisteredInModule {
68 type_name: &'static str,
69 module_name: &'static str,
70 },
71}
72
73impl Container {
74 pub fn new() -> Self {
78 Self::default()
79 }
80
81 pub fn scoped(&self) -> Self {
89 Self {
90 inner: Arc::clone(&self.inner),
91 overrides: Arc::new(RwLock::new(HashMap::new())),
92 request_factories: Arc::clone(&self.request_factories),
93 transient_factories: Arc::clone(&self.transient_factories),
94 names: Arc::clone(&self.names),
95 }
96 }
97
98 pub fn register<T>(&self, value: T) -> Result<(), ContainerError>
107 where
108 T: Send + Sync + 'static,
109 {
110 let mut map = self
111 .inner
112 .write()
113 .map_err(|_| ContainerError::WriteLockPoisoned)?;
114
115 let type_id = TypeId::of::<T>();
116
117 if map.contains_key(&type_id) {
118 return Err(ContainerError::TypeAlreadyRegistered {
119 type_name: std::any::type_name::<T>(),
120 });
121 }
122
123 map.insert(type_id, Arc::new(value));
124 self.names
125 .write()
126 .map_err(|_| ContainerError::WriteLockPoisoned)?
127 .insert(std::any::type_name::<T>());
128 Ok(())
129 }
130
131 pub fn replace<T>(&self, value: T) -> Result<(), ContainerError>
136 where
137 T: Send + Sync + 'static,
138 {
139 let mut map = self
140 .inner
141 .write()
142 .map_err(|_| ContainerError::WriteLockPoisoned)?;
143
144 map.insert(TypeId::of::<T>(), Arc::new(value));
145 self.names
146 .write()
147 .map_err(|_| ContainerError::WriteLockPoisoned)?
148 .insert(std::any::type_name::<T>());
149 Ok(())
150 }
151
152 pub fn override_value<T>(&self, value: T) -> Result<(), ContainerError>
159 where
160 T: Send + Sync + 'static,
161 {
162 let mut overrides = self
163 .overrides
164 .write()
165 .map_err(|_| ContainerError::WriteLockPoisoned)?;
166
167 overrides.insert(TypeId::of::<T>(), Arc::new(value));
168 self.names
169 .write()
170 .map_err(|_| ContainerError::WriteLockPoisoned)?
171 .insert(std::any::type_name::<T>());
172 Ok(())
173 }
174
175 pub fn is_type_registered_name(&self, type_name: &'static str) -> Result<bool, ContainerError> {
179 let names = self
180 .names
181 .read()
182 .map_err(|_| ContainerError::ReadLockPoisoned)?;
183 Ok(names.contains(type_name))
184 }
185
186 pub fn register_request_factory<T, F>(&self, factory: F) -> Result<(), ContainerError>
191 where
192 T: Send + Sync + 'static,
193 F: Fn(&Container) -> anyhow::Result<T> + Send + Sync + 'static,
194 {
195 let type_id = TypeId::of::<T>();
196 let mut factories = self
197 .request_factories
198 .write()
199 .map_err(|_| ContainerError::WriteLockPoisoned)?;
200
201 if factories.contains_key(&type_id) {
202 return Err(ContainerError::TypeAlreadyRegistered {
203 type_name: std::any::type_name::<T>(),
204 });
205 }
206
207 factories.insert(
208 type_id,
209 Arc::new(move |container| Ok(Arc::new(factory(container)?) as RequestFactoryValue)),
210 );
211 self.names
212 .write()
213 .map_err(|_| ContainerError::WriteLockPoisoned)?
214 .insert(std::any::type_name::<T>());
215 Ok(())
216 }
217
218 pub fn register_transient_factory<T, F>(&self, factory: F) -> Result<(), ContainerError>
223 where
224 T: Send + Sync + 'static,
225 F: Fn(&Container) -> anyhow::Result<T> + Send + Sync + 'static,
226 {
227 let type_id = TypeId::of::<T>();
228 let mut factories = self
229 .transient_factories
230 .write()
231 .map_err(|_| ContainerError::WriteLockPoisoned)?;
232
233 if factories.contains_key(&type_id) {
234 return Err(ContainerError::TypeAlreadyRegistered {
235 type_name: std::any::type_name::<T>(),
236 });
237 }
238
239 factories.insert(
240 type_id,
241 Arc::new(move |container| Ok(Arc::new(factory(container)?) as RequestFactoryValue)),
242 );
243 self.names
244 .write()
245 .map_err(|_| ContainerError::WriteLockPoisoned)?
246 .insert(std::any::type_name::<T>());
247 Ok(())
248 }
249
250 pub fn resolve<T>(&self) -> Result<Arc<T>, ContainerError>
260 where
261 T: Send + Sync + 'static,
262 {
263 if let Some(value) = self.resolve_from_map::<T>(&self.overrides)? {
268 return Ok(value);
269 }
270
271 if let Some(value) = self.resolve_from_map::<T>(&self.inner)? {
276 return Ok(value);
277 }
278
279 if let Some(value) = self.resolve_from_request_factory::<T>()? {
284 return Ok(value);
285 }
286
287 if let Some(value) = self.resolve_from_transient_factory::<T>()? {
292 return Ok(value);
293 }
294
295 Err(ContainerError::TypeNotRegistered {
296 type_name: std::any::type_name::<T>(),
297 })
298 }
299
300 pub fn resolve_in_module<T>(&self, module_name: &'static str) -> Result<Arc<T>, ContainerError>
301 where
302 T: Send + Sync + 'static,
303 {
304 self.resolve::<T>().map_err(|err| match err {
305 ContainerError::TypeNotRegistered { type_name } => {
306 ContainerError::TypeNotRegisteredInModule {
307 type_name,
308 module_name,
309 }
310 }
311 other => other,
312 })
313 }
314
315 fn resolve_from_map<T>(
316 &self,
317 map: &Arc<RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
318 ) -> Result<Option<Arc<T>>, ContainerError>
319 where
320 T: Send + Sync + 'static,
321 {
322 let map = map.read().map_err(|_| ContainerError::ReadLockPoisoned)?;
323 let Some(value) = map.get(&TypeId::of::<T>()).cloned() else {
324 return Ok(None);
325 };
326
327 let value = value
328 .downcast::<T>()
329 .map_err(|_| ContainerError::DowncastFailed {
330 type_name: std::any::type_name::<T>(),
331 })?;
332
333 Ok(Some(value))
334 }
335
336 fn resolve_from_request_factory<T>(&self) -> Result<Option<Arc<T>>, ContainerError>
337 where
338 T: Send + Sync + 'static,
339 {
340 let factory = {
341 let factories = self
342 .request_factories
343 .read()
344 .map_err(|_| ContainerError::ReadLockPoisoned)?;
345 factories.get(&TypeId::of::<T>()).cloned()
346 };
347
348 let Some(factory) = factory else {
349 return Ok(None);
350 };
351
352 let value = factory(self).map_err(|err| ContainerError::RequestFactoryFailed {
353 type_name: std::any::type_name::<T>(),
354 message: err.to_string(),
355 })?;
356 let typed = value
357 .downcast::<T>()
358 .map_err(|_| ContainerError::DowncastFailed {
359 type_name: std::any::type_name::<T>(),
360 })?;
361
362 self.overrides
363 .write()
364 .map_err(|_| ContainerError::WriteLockPoisoned)?
365 .insert(TypeId::of::<T>(), typed.clone() as RequestFactoryValue);
366
367 Ok(Some(typed))
368 }
369
370 fn resolve_from_transient_factory<T>(&self) -> Result<Option<Arc<T>>, ContainerError>
371 where
372 T: Send + Sync + 'static,
373 {
374 let factory = {
375 let factories = self
376 .transient_factories
377 .read()
378 .map_err(|_| ContainerError::ReadLockPoisoned)?;
379 factories.get(&TypeId::of::<T>()).cloned()
380 };
381
382 let Some(factory) = factory else {
383 return Ok(None);
384 };
385
386 let value = factory(self).map_err(|err| ContainerError::RequestFactoryFailed {
387 type_name: std::any::type_name::<T>(),
388 message: err.to_string(),
389 })?;
390 let typed = value
391 .downcast::<T>()
392 .map_err(|_| ContainerError::DowncastFailed {
393 type_name: std::any::type_name::<T>(),
394 })?;
395
396 Ok(Some(typed))
397 }
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403
404 #[derive(Debug, PartialEq, Eq)]
405 struct AppConfig {
406 app_name: &'static str,
407 }
408
409 #[test]
410 fn override_value_takes_precedence_over_registered_value() {
411 let container = Container::new();
412
413 container
414 .register(AppConfig {
415 app_name: "default",
416 })
417 .expect("register should succeed");
418 container
419 .override_value(AppConfig { app_name: "test" })
420 .expect("override should succeed");
421
422 let config = container
423 .resolve::<AppConfig>()
424 .expect("config should resolve");
425 assert_eq!(config.app_name, "test");
426 }
427
428 #[derive(Clone)]
429 struct RequestId(String);
430
431 struct RequestGreeting(String);
432 struct TransientCounter(usize);
433
434 #[test]
435 fn scoped_container_resolves_request_factory_without_leaking_to_parent() {
436 let container = Container::new();
437 container
438 .register_request_factory::<RequestGreeting, _>(|scoped| {
439 let request_id = scoped.resolve::<RequestId>()?;
440 Ok(RequestGreeting(format!("hello {}", request_id.0)))
441 })
442 .expect("request factory should register");
443
444 let scoped = container.scoped();
445 scoped
446 .override_value(RequestId("req-1".to_string()))
447 .expect("request id should override");
448
449 let greeting = scoped
450 .resolve::<RequestGreeting>()
451 .expect("request greeting should resolve");
452
453 assert_eq!(greeting.0, "hello req-1");
454 assert!(container.resolve::<RequestGreeting>().is_err());
455 }
456
457 #[test]
458 fn transient_factory_creates_new_instances_per_resolve() {
459 let container = Container::new();
460 let counter = Arc::new(RwLock::new(0usize));
461 let counter_for_factory = Arc::clone(&counter);
462
463 container
464 .register_transient_factory::<TransientCounter, _>(move |_| {
465 let mut count = counter_for_factory
466 .write()
467 .expect("counter should be writable");
468 *count += 1;
469 Ok(TransientCounter(*count))
470 })
471 .expect("transient factory should register");
472
473 let first = container
474 .resolve::<TransientCounter>()
475 .expect("first transient should resolve");
476 let second = container
477 .resolve::<TransientCounter>()
478 .expect("second transient should resolve");
479
480 assert_eq!(first.0, 1);
481 assert_eq!(second.0, 2);
482 }
483}