1use std::any::{Any, TypeId};
2use std::collections::HashMap;
3use std::sync::{Arc, Mutex, OnceLock, RwLock, RwLockWriteGuard, RwLockReadGuard};
4use std::ops::{Deref, DerefMut};
5use tokio::sync::{Mutex as TokioMutex, RwLock as TokioRwLock};
6
7static STD_INSTANCE: OnceLock<Mutex<Di>> = OnceLock::new();
8static AYN_INSTANCE: OnceLock<TokioMutex<TkDi>> = OnceLock::new();
9
10type ThreadSafeAny = Arc<RwLock<dyn Any + Send + Sync + 'static>>;
11
12type AsyncSaftAny = Arc<TokioRwLock<dyn Any + Send + Sync + 'static>>;
13
14pub struct Di {
15 providers: RwLock<HashMap<TypeId, Arc<dyn Provider>>>,
16 single_map: HashMap<TypeId, ThreadSafeAny>,
17}
18
19pub struct TkDi {
20 providers: TokioRwLock<HashMap<TypeId, Arc<dyn TkProvider>>>,
21 async_map: HashMap<TypeId, AsyncSaftAny>,
22}
23
24
25pub struct SingleRef<T> {
26 value: Arc<RwLock<T>>,
27}
28
29impl<T> SingleRef<T> {
30 pub fn get(&self) -> Result<RwLockReadGuard<T>, DiError> {
31 self.value.read().map_err(|_| DiError::LockError)
32 }
33
34 pub fn get_mut(&mut self) -> Result<RwLockWriteGuard<T>, DiError> {
35 self.value.write().map_err(|_| DiError::LockError)
36 }
37}
38
39impl<T> Clone for SingleRef<T> {
40 fn clone(&self) -> Self {
41 SingleRef {
42 value: self.value.clone(),
43 }
44 }
45}
46
47
48pub struct SingleAsyncRef<T> {
49 value: Arc<TokioRwLock<T>>,
50}
51
52impl<T> SingleAsyncRef<T> {
53 pub async fn get(&self) -> tokio::sync::RwLockReadGuard<'_, T> {
54 self.value.read().await
55 }
56
57 pub async fn get_mut(&mut self) -> tokio::sync::RwLockWriteGuard<'_, T> {
58 self.value.write().await
59 }
60}
61
62impl<T> Clone for SingleAsyncRef<T> {
63 fn clone(&self) -> Self {
64 SingleAsyncRef {
65 value: self.value.clone(),
66 }
67 }
68}
69
70impl TkDi {
71 fn get_instance() -> &'static TokioMutex<TkDi> {
72 AYN_INSTANCE.get_or_init(|| TokioMutex::new(TkDi{
73 providers: TokioRwLock::new(HashMap::new()),
74 async_map: HashMap::new(),
75 }))
76 }
77
78 async fn _register<T, F>(&self, factory: F)
79 where
80 T: 'static + Send + Sync,
81 F: Fn() -> T + Send + Sync + 'static,
82 {
83 let provider = FactoryProvider {
84 factory,
85 _marker: std::marker::PhantomData,
86 };
87 let type_id = TypeId::of::<T>();
88 let mut providers = self.providers.write().await;
89 providers.insert(type_id, Arc::new(provider));
90 }
91
92 pub async fn register<T, F>(factory: F)
93 where
94 T: 'static + Send + Sync,
95 F: Fn() -> T + Send + Sync + 'static,
96 {
97 let di = TkDi::get_instance().lock().await;
98 di._register(factory).await;
100 }
101
102 pub async fn get_inner<T: 'static>(&self) -> Result<T, Box<dyn std::error::Error>> {
103 let type_id = TypeId::of::<T>();
104 let providers = self.providers.read().await;
105 let provider = providers.get(&type_id).ok_or("Provider not found")?;
106 let any = provider.provide();
107 let t = any.downcast::<T>().map_err(|_| "Downcast failed")?;
109 Ok(*t)
110 }
111 pub async fn get<T: 'static>() -> Result<T, Box<dyn std::error::Error>> {
112 let di = TkDi::get_instance().lock().await;
113 di.get_inner().await
114 }
115
116 fn _register_single<T>(&mut self, instance: T)
117 where
118 T: 'static + Send + Sync,
119 {
120 let type_id = std::any::TypeId::of::<T>();
121 let any = Arc::new(TokioRwLock::new(instance));
122 self.async_map.insert(type_id, any);
123 }
124
125 pub async fn register_single<T>(instance: T)
126 where
127 T: 'static + Send + Sync,
128 {
129 let mut di = TkDi::get_instance().lock().await;
130 di._register_single(instance);
131 }
132
133 fn _get_single<T: Any + Send + Sync + 'static>(&self) -> Option<SingleAsyncRef<T>> {
134 let type_id = std::any::TypeId::of::<T>();
135 let any = self.async_map.get(&type_id)?;
136 let value = unsafe {
137 let ptr = Arc::into_raw(any.clone());
138 Arc::from_raw(ptr as *const TokioRwLock<T>)
139 };
140 Some(SingleAsyncRef { value })
141 }
142
143 pub async fn get_single<T: Any + Send + Sync + 'static>() -> Option<SingleAsyncRef<T>> {
144 let di = TkDi::get_instance().lock().await;
145 di._get_single::<T>()
146 }
147
148}
149
150
151impl Di {
152 fn get_instance() -> &'static Mutex<Di> {
153 STD_INSTANCE.get_or_init(|| Mutex::new(Di{
154 providers: RwLock::new(HashMap::new()),
155 single_map: HashMap::new(),
156 }))
157 }
158
159 fn _register_single<T>(&mut self, instance: T)
160 where
161 T: 'static + Send + Sync,
162 {
163 let type_id = std::any::TypeId::of::<T>();
164 let any = Arc::new(RwLock::new(instance));
165 self.single_map.insert(type_id, any);
166 }
167
168
169
170 pub fn register_single<T>(instance: T)
171 where
172 T: 'static + Send + Sync,
173 {
174 let mut di = Di::get_instance().lock().unwrap();
175 di._register_single(instance);
176 }
177
178
179 fn _register<T, F>(&self, factory: F)
180 where
181 T: 'static + Send + Sync,
182 F: Fn(&Di) -> T + Send + Sync + 'static,
183 {
184 let provider = FactoryProvider {
185 factory,
186 _marker: std::marker::PhantomData,
187 };
188 let type_id = std::any::TypeId::of::<T>();
189 let mut providers = self.providers.write().unwrap();
190 providers.insert(type_id, Arc::new(provider));
191 }
192
193 pub fn register<T, F>(factory: F)
194 where
195 T: 'static + Send + Sync,
196 F: Fn(&Di) -> T + Send + Sync + 'static,
197 {
198 let di = Di::get_instance().lock().unwrap();
199 di._register(factory);
200 }
201
202 pub fn get_inner<T: 'static>(&self) -> Result<T, Box<dyn std::error::Error>> {
203 let type_id = std::any::TypeId::of::<T>();
204 let providers = self.providers.read().unwrap();
205 let provider = providers.get(&type_id).ok_or("Provider not found")?;
206
207 let any = provider.provide(self);
208 let t = any.downcast::<T>().map_err(|_| "Downcast failed")?;
210 Ok(*t)
211 }
212 pub fn get<T: 'static>() -> Result<T, Box<dyn std::error::Error>> {
213 let di = Di::get_instance().lock().unwrap();
214 di.get_inner()
215 }
216
217 fn _get_single<T: Any + Send + Sync + 'static>(&self) -> Option<SingleRef<T>> {
218 let type_id = std::any::TypeId::of::<T>();
219 let any = self.single_map.get(&type_id)?;
220 if any.type_id() != type_id {
221 return None;
222 }
223 let value = unsafe {
225 let ptr = Arc::into_raw(any.clone());
226 Arc::from_raw(ptr as *const RwLock<T>)
227 };
228 Some(SingleRef { value })
229 }
230 pub fn get_single<T: Any + Send + Sync + 'static>() -> Option<SingleRef<T>> {
231 let di = Di::get_instance().lock().unwrap();
232 di._get_single::<T>()
233 }
234
235}
236
237trait Provider: Send + Sync {
238 fn provide(&self, di: &Di) -> Box<dyn Any>;
239}
240
241trait TkProvider: Send + Sync {
242 fn provide(&self) -> Box<dyn Any>;
243}
244
245struct FactoryProvider<F, T> {
246 factory: F,
247 _marker: std::marker::PhantomData<T>,
248}
249
250impl<F, T> Provider for FactoryProvider<F, T>
251where
252 F: Fn(&Di) -> T + Send + Sync + 'static,
253 T: 'static + Send + Sync,
254{
255 fn provide(&self, di: &Di) -> Box<dyn Any> {
256 Box::new((self.factory)(di))
257 }
258}
259
260
261impl<F, T> TkProvider for FactoryProvider<F, T>
262where
263 F: Fn() -> T + Send + Sync + 'static,
264 T: 'static + Send + Sync,
265{
266 fn provide(&self) -> Box<dyn Any> {
267 Box::new((self.factory)())
268 }
269}
270
271#[derive(Debug)]
272pub enum DiError {
273 ProviderNotFound,
274 TypeMismatch,
275 LockError,
276}
277
278pub type DiResult<T> = Result<T, DiError>;
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283
284 struct Configuration {
285 port: u16,
286 }
287
288 #[derive(Clone)]
289 struct Database {
290 port: u16,
291 }
292
293 #[derive(Clone)]
294 struct AppService {
295 db: Database,
296 }
297
298 #[tokio::test]
299 async fn async_test() {
300 TkDi::register(|| {
301 Database{port: 3306}
302 }).await;
303 println!("regist database done");
304 let db = TkDi::get::<Database>().await.unwrap();
305
306 TkDi::register_single(Configuration{port: 8080}).await;
307
308 println!("regist app done");
309
310 if let Some(mut config) = TkDi::get_single::<Configuration>().await {
315 let mut config = config.get_mut().await;
316 assert_eq!(config.port, 8080);
317 config.port = 8081;
318 }
319 if let Some(mut config) = TkDi::get_single::<Configuration>().await{
320 let mut config = config.get_mut().await;
321 }
322 }
323
324 #[test]
325 fn it_works() {
326 Di::register::<Database, _>(|_| {
327 Database{port: 3306}
328 });
329 println!("regist database done");
330
331 Di::register_single(Configuration{port: 8080});
332
333 Di::register::<AppService, _>(|di| {
334 let db = di.get_inner::<Database>().unwrap();
335 AppService{ db:db.clone()}
336 });
337 println!("regist app done");
338
339 let result = Di::get::<AppService>().unwrap();
340
341 assert_eq!(result.db.port, 3306);
342
343 if let Some(mut config) = Di::get_single::<Configuration>() {
344 let mut config = config.get_mut().unwrap();
345 assert_eq!(config.port, 8080);
346 config.port = 8081;
347 }
348 if let Some(mut config) = Di::get_single::<Configuration>() {
349 let mut config = config.get_mut().unwrap();
350 assert_eq!(config.port, 8081);
351 }
352
353 ()
354 }
355}