rok_container/
container.rs1use std::any::{Any, TypeId};
2use std::collections::HashMap;
3use std::sync::{Arc, RwLock};
4
5use crate::error::ContainerError;
6
7type AnyArc = Arc<dyn Any + Send + Sync>;
8
9enum Binding {
10 Singleton(AnyArc),
11 Factory(Box<dyn Fn() -> AnyArc + Send + Sync>),
12}
13
14pub struct Container {
36 bindings: RwLock<HashMap<TypeId, Binding>>,
37}
38
39impl Default for Container {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45impl Container {
46 pub fn new() -> Self {
48 Self {
49 bindings: RwLock::new(HashMap::new()),
50 }
51 }
52
53 pub fn bind<T, F>(&self, factory: F)
58 where
59 T: Any + Send + Sync + 'static,
60 F: Fn() -> T + Send + Sync + 'static,
61 {
62 let mut map = self.bindings.write().expect("container lock poisoned");
63 map.insert(
64 TypeId::of::<T>(),
65 Binding::Factory(Box::new(move || Arc::new(factory()))),
66 );
67 }
68
69 pub fn singleton<T>(&self, instance: T)
73 where
74 T: Any + Send + Sync + 'static,
75 {
76 let mut map = self.bindings.write().expect("container lock poisoned");
77 map.insert(TypeId::of::<T>(), Binding::Singleton(Arc::new(instance)));
78 }
79
80 pub fn make<T>(&self) -> Result<Arc<T>, ContainerError>
86 where
87 T: Any + Send + Sync + 'static,
88 {
89 let map = self.bindings.read().expect("container lock poisoned");
90 match map.get(&TypeId::of::<T>()) {
91 Some(Binding::Singleton(arc)) => arc
92 .clone()
93 .downcast::<T>()
94 .map_err(|_| ContainerError::TypeMismatch(std::any::type_name::<T>())),
95 Some(Binding::Factory(f)) => f()
96 .downcast::<T>()
97 .map_err(|_| ContainerError::TypeMismatch(std::any::type_name::<T>())),
98 None => Err(ContainerError::NotRegistered(std::any::type_name::<T>())),
99 }
100 }
101
102 pub fn extend<T, F>(&self, extender: F) -> Result<(), ContainerError>
110 where
111 T: Any + Send + Sync + 'static,
112 F: FnOnce(Arc<T>) -> T,
113 {
114 let existing = self.make::<T>()?;
115 let new_instance = extender(existing);
116 self.singleton(new_instance);
117 Ok(())
118 }
119
120 pub fn swap<T>(&self, instance: T)
124 where
125 T: Any + Send + Sync + 'static,
126 {
127 self.singleton(instance);
128 }
129}