1use std::any::{Any, TypeId};
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use parking_lot::RwLock;
11
12type Factory = Arc<dyn Fn() -> Arc<dyn Any + Send + Sync> + Send + Sync>;
13
14#[derive(Default)]
15pub struct ServiceContainer {
16 providers: RwLock<HashMap<TypeId, Factory>>,
17}
18
19impl ServiceContainer {
20 pub fn new() -> Self {
21 Self::default()
22 }
23
24 pub fn register<T, F>(&self, factory: F)
25 where
26 T: Send + Sync + 'static,
27 F: Fn() -> Arc<T> + Send + Sync + 'static,
28 {
29 let wrapper: Factory = Arc::new(move || factory() as Arc<dyn Any + Send + Sync>);
30 self.providers.write().insert(TypeId::of::<T>(), wrapper);
31 }
32
33 pub fn resolve<T: Send + Sync + 'static>(&self) -> Option<Arc<T>> {
34 let providers = self.providers.read();
35 let factory = providers.get(&TypeId::of::<T>())?;
36 let any = factory();
37 any.downcast::<T>().ok()
38 }
39}
40
41#[cfg(test)]
42mod tests {
43 use super::*;
44
45 struct Hello(&'static str);
46
47 #[test]
48 fn resolves_registered_factory() {
49 let c = ServiceContainer::new();
50 c.register::<Hello, _>(|| Arc::new(Hello("world")));
51 let h = c.resolve::<Hello>().unwrap();
52 assert_eq!(h.0, "world");
53 }
54}