1use std::any::{Any, TypeId};
2use std::collections::HashMap;
3use std::sync::{Arc, Mutex, OnceLock, RwLock, RwLockWriteGuard, RwLockReadGuard};
4use std::ops::{Deref, DerefMut};
5use tokio::sync::{Mutex as TokioMutex, RwLock as TokioRwLock};
6
7static INSTANCE: OnceLock<Mutex<Di>> = OnceLock::new();
8
9type ThreadSafeAny = Arc<RwLock<dyn Any + Send + Sync + 'static>>;
10
11type AsyncSaftAny = Arc<TokioRwLock<dyn Any + Send + Sync + 'static>>;
12
13pub struct Di {
14 providers: RwLock<HashMap<TypeId, Arc<dyn Provider>>>,
15 single_map: HashMap<TypeId, ThreadSafeAny>,
16 async_map: HashMap<TypeId, AsyncSaftAny>,
17}
18
19pub struct SingleRef<T> {
20 value: Arc<RwLock<T>>,
21}
22
23impl<T> SingleRef<T> {
24 pub fn get(&self) -> RwLockReadGuard<T> {
25 self.value.read().unwrap()
26 }
27
28 pub fn get_mut(&mut self) -> RwLockWriteGuard<T> {
29 self.value.write().unwrap()
30 }
31}
32
33impl<T> Clone for SingleRef<T> {
34 fn clone(&self) -> Self {
35 SingleRef {
36 value: self.value.clone(),
37 }
38 }
39}
40
41
42pub struct SingleAsyncRef<T> {
43 value: Arc<TokioRwLock<T>>,
44}
45
46impl<T> SingleAsyncRef<T> {
47 pub async fn get(&self) -> tokio::sync::RwLockReadGuard<'_, T> {
48 self.value.read().await
49 }
50
51 pub async fn get_mut(&mut self) -> tokio::sync::RwLockWriteGuard<'_, T> {
52 self.value.write().await
53 }
54}
55
56impl<T> Clone for SingleAsyncRef<T> {
57 fn clone(&self) -> Self {
58 SingleAsyncRef {
59 value: self.value.clone(),
60 }
61 }
62}
63
64
65impl Di {
66 fn get_instance() -> &'static Mutex<Di> {
67 INSTANCE.get_or_init(|| Mutex::new(Di{
68 providers: RwLock::new(HashMap::new()),
69 single_map: HashMap::new(),
70 async_map: HashMap::new(),
71 }))
72 }
73
74 fn _register_single<T>(&mut self, instance: T)
75 where
76 T: 'static + Send + Sync,
77 {
78 let type_id = std::any::TypeId::of::<T>();
79 let any = Arc::new(RwLock::new(instance));
80 self.single_map.insert(type_id, any);
81 }
82
83 fn _register_async_single<T>(&mut self, instance: T)
84 where
85 T: 'static + Send + Sync,
86 {
87 let type_id = std::any::TypeId::of::<T>();
88 let any = Arc::new(TokioRwLock::new(instance));
89 self.async_map.insert(type_id, any);
90 }
91
92 pub fn register_single<T>(instance: T)
93 where
94 T: 'static + Send + Sync,
95 {
96 let mut di = Di::get_instance().lock().unwrap();
97 di._register_single(instance);
98 }
99
100 pub fn register_async_single<T>(instance: T)
101 where
102 T: 'static + Send + Sync,
103 {
104 let mut di = Di::get_instance().lock().unwrap();
105 di._register_async_single(instance);
106 }
107
108 fn _register<T, F>(&self, factory: F)
109 where
110 T: 'static + Send + Sync,
111 F: Fn(&Di) -> T + Send + Sync + 'static,
112 {
113 let provider = FactoryProvider {
114 factory,
115 _marker: std::marker::PhantomData,
116 };
117 let type_id = std::any::TypeId::of::<T>();
118 let mut providers = self.providers.write().unwrap();
119 providers.insert(type_id, Arc::new(provider));
120 }
121
122 pub fn register<T, F>(factory: F)
123 where
124 T: 'static + Send + Sync,
125 F: Fn(&Di) -> T + Send + Sync + 'static,
126 {
127 let di = Di::get_instance().lock().unwrap();
128 di._register(factory);
129 }
130
131 pub fn get_inner<T: 'static>(&self) -> Result<T, Box<dyn std::error::Error>> {
132 let type_id = std::any::TypeId::of::<T>();
133 let providers = self.providers.read().unwrap();
134 let provider = providers.get(&type_id).ok_or("Provider not found")?;
135
136 let any = provider.provide(self);
137 let t = any.downcast::<T>().map_err(|_| "Downcast failed")?;
139 Ok(*t)
140 }
141 pub fn get<T: 'static>() -> Result<T, Box<dyn std::error::Error>> {
142 let di = Di::get_instance().lock().unwrap();
143 di.get_inner()
144 }
145
146 fn _get_single<T: Any + Send + Sync + 'static>(&self) -> Option<SingleRef<T>> {
147 let type_id = std::any::TypeId::of::<T>();
148 let any = self.single_map.get(&type_id)?;
149 let value = unsafe {
150 let ptr = Arc::into_raw(any.clone());
151 Arc::from_raw(ptr as *const RwLock<T>)
152 };
153 Some(SingleRef { value })
154 }
155 pub fn get_single<T: Any + Send + Sync + 'static>() -> Option<SingleRef<T>> {
156 let di = Di::get_instance().lock().unwrap();
157 di._get_single::<T>()
158 }
159
160
161 fn _get_async_single<T: Any + Send + Sync + 'static>(&self) -> Option<SingleAsyncRef<T>> {
162 let type_id = std::any::TypeId::of::<T>();
163 let any = self.async_map.get(&type_id)?;
164 let value = unsafe {
165 let ptr = Arc::into_raw(any.clone());
166 Arc::from_raw(ptr as *const TokioRwLock<T>)
167 };
168 Some(SingleAsyncRef { value })
169 }
170
171 pub fn get_async_single<T: Any + Send + Sync + 'static>() -> Option<SingleAsyncRef<T>> {
172 let di = Di::get_instance().lock().unwrap();
173 di._get_async_single::<T>()
174 }
175}
176
177trait Provider: Send + Sync {
178 fn provide(&self, di: &Di) -> Box<dyn Any>;
179}
180
181struct FactoryProvider<F, T> {
182 factory: F,
183 _marker: std::marker::PhantomData<T>,
184}
185
186impl<F, T> Provider for FactoryProvider<F, T>
187where
188 F: Fn(&Di) -> T + Send + Sync + 'static,
189 T: 'static + Send + Sync,
190{
191 fn provide(&self, di: &Di) -> Box<dyn Any> {
192 Box::new((self.factory)(di))
193 }
194}
195
196
197pub fn add(left: u64, right: u64) -> u64 {
198 left + right
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204
205 struct Configuration {
206 port: u16,
207 }
208
209 #[derive(Clone)]
210 struct Database {
211 port: u16,
212 }
213
214 #[derive(Clone)]
215 struct AppService {
216 db: Database,
217 }
218
219 #[test]
220 fn it_works() {
221 Di::register::<Database, _>(|_| {
222 Database{port: 3306}
223 });
224 println!("regist database done");
225
226 Di::register_single(Configuration{port: 8080});
227
228 Di::register::<AppService, _>(|di| {
229 let db = di.get_inner::<Database>().unwrap();
230 AppService{ db:db.clone()}
231 });
232 println!("regist app done");
233
234 let result = Di::get::<AppService>().unwrap();
235
236 assert_eq!(result.db.port, 3306);
237
238 if let Some(mut config) = Di::get_single::<Configuration>() {
239 let mut config = config.get_mut();
240 assert_eq!(config.port, 8080);
241 config.port = 8081;
242 }
243 if let Some(mut config) = Di::get_single::<Configuration>() {
244 let mut config = config.get_mut();
245 assert_eq!(config.port, 8081);
246 }
247
248 ()
249 }
250}