1use std::{
2 any::{Any, TypeId},
3 collections::{HashMap, HashSet},
4 sync::{Arc, RwLock},
5};
6
7use thiserror::Error;
8
9#[derive(Clone, Default)]
27pub struct Container {
28 inner: Arc<RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
29 overrides: Arc<RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
30 request_factories: Arc<
31 RwLock<HashMap<TypeId, Arc<RequestFactoryFn>>>,
32 >,
33 transient_factories: Arc<
34 RwLock<HashMap<TypeId, Arc<TransientFactoryFn>>>,
35 >,
36 names: Arc<RwLock<HashSet<&'static str>>>,
37}
38
39type RequestFactoryValue = Arc<dyn Any + Send + Sync>;
40type RequestFactoryFn =
41 dyn Fn(&Container) -> anyhow::Result<RequestFactoryValue> + Send + Sync + 'static;
42type TransientFactoryFn =
43 dyn Fn(&Container) -> anyhow::Result<RequestFactoryValue> + Send + Sync + 'static;
44
45#[derive(Debug, Error)]
46pub enum ContainerError {
47 #[error("Container write lock poisoned")]
48 WriteLockPoisoned,
49 #[error("Container read lock poisoned")]
50 ReadLockPoisoned,
51 #[error("Type already registered: {type_name}")]
52 TypeAlreadyRegistered { type_name: &'static str },
53 #[error("Type not registered: {type_name}")]
54 TypeNotRegistered { type_name: &'static str },
55 #[error("Failed to downcast resolved value: {type_name}")]
56 DowncastFailed { type_name: &'static str },
57 #[error("Request-scoped factory failed for {type_name}: {message}")]
58 RequestFactoryFailed {
59 type_name: &'static str,
60 message: String,
61 },
62 #[error("Type not registered: {type_name} (required by module `{module_name}`)")]
63 TypeNotRegisteredInModule {
64 type_name: &'static str,
65 module_name: &'static str,
66 },
67}
68
69impl Container {
70 pub fn new() -> Self {
75 Self::default()
76 }
77
78 pub fn scoped(&self) -> Self {
79 Self {
80 inner: Arc::clone(&self.inner),
81 overrides: Arc::new(RwLock::new(HashMap::new())),
82 request_factories: Arc::clone(&self.request_factories),
83 transient_factories: Arc::clone(&self.transient_factories),
84 names: Arc::clone(&self.names),
85 }
86 }
87
88 pub fn register<T>(&self, value: T) -> Result<(), ContainerError>
99 where
100 T: Send + Sync + 'static,
101 {
102 let mut map = self
103 .inner
104 .write()
105 .map_err(|_| ContainerError::WriteLockPoisoned)?;
106
107 let type_id = TypeId::of::<T>();
108
109 if map.contains_key(&type_id) {
110 return Err(ContainerError::TypeAlreadyRegistered {
111 type_name: std::any::type_name::<T>(),
112 });
113 }
114
115 map.insert(type_id, Arc::new(value));
116 self.names
117 .write()
118 .map_err(|_| ContainerError::WriteLockPoisoned)?
119 .insert(std::any::type_name::<T>());
120 Ok(())
121 }
122
123 pub fn replace<T>(&self, value: T) -> Result<(), ContainerError>
124 where
125 T: Send + Sync + 'static,
126 {
127 let mut map = self
128 .inner
129 .write()
130 .map_err(|_| ContainerError::WriteLockPoisoned)?;
131
132 map.insert(TypeId::of::<T>(), Arc::new(value));
133 self.names
134 .write()
135 .map_err(|_| ContainerError::WriteLockPoisoned)?
136 .insert(std::any::type_name::<T>());
137 Ok(())
138 }
139
140 pub fn override_value<T>(&self, value: T) -> Result<(), ContainerError>
141 where
142 T: Send + Sync + 'static,
143 {
144 let mut overrides = self
145 .overrides
146 .write()
147 .map_err(|_| ContainerError::WriteLockPoisoned)?;
148
149 overrides.insert(TypeId::of::<T>(), Arc::new(value));
150 self.names
151 .write()
152 .map_err(|_| ContainerError::WriteLockPoisoned)?
153 .insert(std::any::type_name::<T>());
154 Ok(())
155 }
156
157 pub fn is_type_registered_name(&self, type_name: &'static str) -> Result<bool, ContainerError> {
158 let names = self
159 .names
160 .read()
161 .map_err(|_| ContainerError::ReadLockPoisoned)?;
162 Ok(names.contains(type_name))
163 }
164
165 pub fn register_request_factory<T, F>(&self, factory: F) -> Result<(), ContainerError>
166 where
167 T: Send + Sync + 'static,
168 F: Fn(&Container) -> anyhow::Result<T> + Send + Sync + 'static,
169 {
170 let type_id = TypeId::of::<T>();
171 let mut factories = self
172 .request_factories
173 .write()
174 .map_err(|_| ContainerError::WriteLockPoisoned)?;
175
176 if factories.contains_key(&type_id) {
177 return Err(ContainerError::TypeAlreadyRegistered {
178 type_name: std::any::type_name::<T>(),
179 });
180 }
181
182 factories.insert(
183 type_id,
184 Arc::new(move |container| Ok(Arc::new(factory(container)?) as RequestFactoryValue)),
185 );
186 self.names
187 .write()
188 .map_err(|_| ContainerError::WriteLockPoisoned)?
189 .insert(std::any::type_name::<T>());
190 Ok(())
191 }
192
193 pub fn register_transient_factory<T, F>(&self, factory: F) -> Result<(), ContainerError>
194 where
195 T: Send + Sync + 'static,
196 F: Fn(&Container) -> anyhow::Result<T> + Send + Sync + 'static,
197 {
198 let type_id = TypeId::of::<T>();
199 let mut factories = self
200 .transient_factories
201 .write()
202 .map_err(|_| ContainerError::WriteLockPoisoned)?;
203
204 if factories.contains_key(&type_id) {
205 return Err(ContainerError::TypeAlreadyRegistered {
206 type_name: std::any::type_name::<T>(),
207 });
208 }
209
210 factories.insert(
211 type_id,
212 Arc::new(move |container| Ok(Arc::new(factory(container)?) as RequestFactoryValue)),
213 );
214 self.names
215 .write()
216 .map_err(|_| ContainerError::WriteLockPoisoned)?
217 .insert(std::any::type_name::<T>());
218 Ok(())
219 }
220
221 pub fn resolve<T>(&self) -> Result<Arc<T>, ContainerError>
230 where
231 T: Send + Sync + 'static,
232 {
233 if let Some(value) = self.resolve_from_map::<T>(&self.overrides)? {
234 return Ok(value);
235 }
236
237 if let Some(value) = self.resolve_from_map::<T>(&self.inner)? {
238 return Ok(value);
239 }
240
241 if let Some(value) = self.resolve_from_request_factory::<T>()? {
242 return Ok(value);
243 }
244
245 if let Some(value) = self.resolve_from_transient_factory::<T>()? {
246 return Ok(value);
247 }
248
249 Err(ContainerError::TypeNotRegistered {
250 type_name: std::any::type_name::<T>(),
251 })
252 }
253
254 pub fn resolve_in_module<T>(&self, module_name: &'static str) -> Result<Arc<T>, ContainerError>
255 where
256 T: Send + Sync + 'static,
257 {
258 self.resolve::<T>().map_err(|err| match err {
259 ContainerError::TypeNotRegistered { type_name } => {
260 ContainerError::TypeNotRegisteredInModule {
261 type_name,
262 module_name,
263 }
264 }
265 other => other,
266 })
267 }
268
269 fn resolve_from_map<T>(
270 &self,
271 map: &Arc<RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
272 ) -> Result<Option<Arc<T>>, ContainerError>
273 where
274 T: Send + Sync + 'static,
275 {
276 let map = map.read().map_err(|_| ContainerError::ReadLockPoisoned)?;
277 let Some(value) = map.get(&TypeId::of::<T>()).cloned() else {
278 return Ok(None);
279 };
280
281 let value = value.downcast::<T>().map_err(|_| ContainerError::DowncastFailed {
282 type_name: std::any::type_name::<T>(),
283 })?;
284
285 Ok(Some(value))
286 }
287
288 fn resolve_from_request_factory<T>(&self) -> Result<Option<Arc<T>>, ContainerError>
289 where
290 T: Send + Sync + 'static,
291 {
292 let factory = {
293 let factories = self
294 .request_factories
295 .read()
296 .map_err(|_| ContainerError::ReadLockPoisoned)?;
297 factories.get(&TypeId::of::<T>()).cloned()
298 };
299
300 let Some(factory) = factory else {
301 return Ok(None);
302 };
303
304 let value = factory(self).map_err(|err| ContainerError::RequestFactoryFailed {
305 type_name: std::any::type_name::<T>(),
306 message: err.to_string(),
307 })?;
308 let typed = value.downcast::<T>().map_err(|_| ContainerError::DowncastFailed {
309 type_name: std::any::type_name::<T>(),
310 })?;
311
312 self.overrides
313 .write()
314 .map_err(|_| ContainerError::WriteLockPoisoned)?
315 .insert(TypeId::of::<T>(), typed.clone() as RequestFactoryValue);
316
317 Ok(Some(typed))
318 }
319
320 fn resolve_from_transient_factory<T>(&self) -> Result<Option<Arc<T>>, ContainerError>
321 where
322 T: Send + Sync + 'static,
323 {
324 let factory = {
325 let factories = self
326 .transient_factories
327 .read()
328 .map_err(|_| ContainerError::ReadLockPoisoned)?;
329 factories.get(&TypeId::of::<T>()).cloned()
330 };
331
332 let Some(factory) = factory else {
333 return Ok(None);
334 };
335
336 let value = factory(self).map_err(|err| ContainerError::RequestFactoryFailed {
337 type_name: std::any::type_name::<T>(),
338 message: err.to_string(),
339 })?;
340 let typed = value.downcast::<T>().map_err(|_| ContainerError::DowncastFailed {
341 type_name: std::any::type_name::<T>(),
342 })?;
343
344 Ok(Some(typed))
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351
352 #[derive(Debug, PartialEq, Eq)]
353 struct AppConfig {
354 app_name: &'static str,
355 }
356
357 #[test]
358 fn override_value_takes_precedence_over_registered_value() {
359 let container = Container::new();
360
361 container
362 .register(AppConfig {
363 app_name: "default",
364 })
365 .expect("register should succeed");
366 container
367 .override_value(AppConfig { app_name: "test" })
368 .expect("override should succeed");
369
370 let config = container
371 .resolve::<AppConfig>()
372 .expect("config should resolve");
373 assert_eq!(config.app_name, "test");
374 }
375
376 #[derive(Clone)]
377 struct RequestId(String);
378
379 struct RequestGreeting(String);
380 struct TransientCounter(usize);
381
382 #[test]
383 fn scoped_container_resolves_request_factory_without_leaking_to_parent() {
384 let container = Container::new();
385 container
386 .register_request_factory::<RequestGreeting, _>(|scoped| {
387 let request_id = scoped.resolve::<RequestId>()?;
388 Ok(RequestGreeting(format!("hello {}", request_id.0)))
389 })
390 .expect("request factory should register");
391
392 let scoped = container.scoped();
393 scoped
394 .override_value(RequestId("req-1".to_string()))
395 .expect("request id should override");
396
397 let greeting = scoped
398 .resolve::<RequestGreeting>()
399 .expect("request greeting should resolve");
400
401 assert_eq!(greeting.0, "hello req-1");
402 assert!(container.resolve::<RequestGreeting>().is_err());
403 }
404
405 #[test]
406 fn transient_factory_creates_new_instances_per_resolve() {
407 let container = Container::new();
408 let counter = Arc::new(RwLock::new(0usize));
409 let counter_for_factory = Arc::clone(&counter);
410
411 container
412 .register_transient_factory::<TransientCounter, _>(move |_| {
413 let mut count = counter_for_factory
414 .write()
415 .expect("counter should be writable");
416 *count += 1;
417 Ok(TransientCounter(*count))
418 })
419 .expect("transient factory should register");
420
421 let first = container
422 .resolve::<TransientCounter>()
423 .expect("first transient should resolve");
424 let second = container
425 .resolve::<TransientCounter>()
426 .expect("second transient should resolve");
427
428 assert_eq!(first.0, 1);
429 assert_eq!(second.0, 2);
430 }
431}