ru_di/
lib.rs

1use std::any::{Any, TypeId};
2use std::collections::HashMap;
3use std::sync::{Arc, Mutex, OnceLock, RwLock, RwLockWriteGuard, RwLockReadGuard};
4use std::ops::{Deref, DerefMut};
5use tokio::sync::{Mutex as TokioMutex, RwLock as TokioRwLock};
6
7static STD_INSTANCE: OnceLock<Mutex<Di>> = OnceLock::new();
8static AYN_INSTANCE: OnceLock<TokioMutex<TkDi>> = OnceLock::new();
9
10type ThreadSafeAny = Arc<RwLock<dyn Any + Send + Sync + 'static>>;
11
12type AsyncSaftAny = Arc<TokioRwLock<dyn Any + Send + Sync + 'static>>;
13
14pub struct Di {
15    providers: RwLock<HashMap<TypeId, Arc<dyn Provider>>>,
16    single_map: HashMap<TypeId, ThreadSafeAny>,
17}
18
19pub struct TkDi {
20    providers: TokioRwLock<HashMap<TypeId, Arc<dyn TkProvider>>>,
21    async_map: HashMap<TypeId, AsyncSaftAny>,
22}
23
24
25pub struct SingleRef<T> {
26    value: Arc<RwLock<T>>,
27}
28
29impl<T> SingleRef<T> {
30    pub fn get(&self) -> Result<RwLockReadGuard<T>, DiError> {
31        self.value.read().map_err(|_| DiError::LockError)
32    }
33
34    pub fn get_mut(&mut self) -> Result<RwLockWriteGuard<T>, DiError> {
35        self.value.write().map_err(|_| DiError::LockError)
36    }
37}
38
39impl<T> Clone for SingleRef<T> {
40    fn clone(&self) -> Self {
41        SingleRef {
42            value: self.value.clone(),
43        }
44    }
45}
46
47
48pub struct SingleAsyncRef<T> {
49    value: Arc<TokioRwLock<T>>,
50}
51
52impl<T> SingleAsyncRef<T> {
53    pub async fn get(&self) -> tokio::sync::RwLockReadGuard<'_, T> {
54        self.value.read().await
55    }
56
57    pub async fn get_mut(&mut self) -> tokio::sync::RwLockWriteGuard<'_, T> {
58        self.value.write().await
59    }
60}
61
62impl<T> Clone for SingleAsyncRef<T> {
63    fn clone(&self) -> Self {
64        SingleAsyncRef {
65            value: self.value.clone(),
66        }
67    }
68}
69
70impl TkDi {
71    fn get_instance() -> &'static TokioMutex<TkDi> {
72        AYN_INSTANCE.get_or_init(|| TokioMutex::new(TkDi{
73            providers: TokioRwLock::new(HashMap::new()),
74            async_map: HashMap::new(),
75        }))
76    }
77
78    async fn _register<T, F>(&self, factory: F)
79    where
80        T: 'static + Send + Sync,
81        F: Fn() -> T + Send + Sync + 'static,
82    {
83        let provider = FactoryProvider {
84            factory,
85            _marker: std::marker::PhantomData,
86        };
87        let type_id = TypeId::of::<T>();
88        let mut providers = self.providers.write().await;
89        providers.insert(type_id, Arc::new(provider));
90    }
91    
92    pub async fn register<T, F>(factory: F)
93    where
94        T: 'static + Send + Sync,
95        F: Fn() -> T + Send + Sync + 'static,
96    {
97        let di = TkDi::get_instance().lock().await;
98        //println!("reg got di instance");
99        di._register(factory).await;
100    }
101    
102    pub async fn get_inner<T: 'static>(&self) -> Result<T, Box<dyn std::error::Error>> {
103        let type_id = TypeId::of::<T>();
104        let providers = self.providers.read().await;
105        let provider = providers.get(&type_id).ok_or("Provider not found")?;
106        let any = provider.provide();
107        // 从 Box<dyn Any> 中提取 Arc<T>
108        let t = any.downcast::<T>().map_err(|_| "Downcast failed")?;
109        Ok(*t)
110    }
111    pub async fn get<T: 'static>() -> Result<T, Box<dyn std::error::Error>> {
112        let di = TkDi::get_instance().lock().await;
113        di.get_inner().await
114    }
115
116    fn _register_single<T>(&mut self, instance: T)
117    where
118        T: 'static + Send + Sync,
119    {
120        let type_id = std::any::TypeId::of::<T>();
121        let any = Arc::new(TokioRwLock::new(instance));
122        self.async_map.insert(type_id, any);
123    }
124
125    pub async fn register_single<T>(instance: T)
126    where
127        T: 'static + Send + Sync,
128    {
129        let mut di = TkDi::get_instance().lock().await;
130        di._register_single(instance);
131    }
132
133    fn _get_single<T: Any + Send + Sync + 'static>(&self) -> Option<SingleAsyncRef<T>> {
134        let type_id = std::any::TypeId::of::<T>();
135        let any = self.async_map.get(&type_id)?;
136        let value = unsafe {
137            let ptr = Arc::into_raw(any.clone());
138            Arc::from_raw(ptr as *const TokioRwLock<T>)
139        };
140        Some(SingleAsyncRef { value })
141    }
142
143    pub async fn get_single<T: Any + Send + Sync + 'static>() -> Option<SingleAsyncRef<T>> {
144        let di = TkDi::get_instance().lock().await;
145        di._get_single::<T>()
146    }
147    
148}
149
150
151impl Di {
152    fn get_instance() -> &'static Mutex<Di> {
153        STD_INSTANCE.get_or_init(|| Mutex::new(Di{
154            providers: RwLock::new(HashMap::new()),
155            single_map: HashMap::new(),
156        }))
157    }
158    
159    fn _register_single<T>(&mut self, instance: T)
160    where
161        T: 'static + Send + Sync,
162    {
163        let type_id = std::any::TypeId::of::<T>();
164        let any = Arc::new(RwLock::new(instance));
165        self.single_map.insert(type_id, any);
166    }
167
168   
169    
170    pub fn register_single<T>(instance: T)
171    where
172        T: 'static + Send + Sync,
173    {
174        let mut di = Di::get_instance().lock().unwrap();
175        di._register_single(instance);
176    }
177    
178    
179    fn _register<T, F>(&self, factory: F)
180    where
181        T: 'static + Send + Sync,
182        F: Fn(&Di) -> T + Send + Sync + 'static,
183    {
184        let provider = FactoryProvider {
185            factory,
186            _marker: std::marker::PhantomData,
187        };
188        let type_id = std::any::TypeId::of::<T>();
189        let mut providers = self.providers.write().unwrap();
190        providers.insert(type_id, Arc::new(provider));
191    }
192    
193    pub fn register<T, F>(factory: F)
194    where
195        T: 'static + Send + Sync,
196        F: Fn(&Di) -> T + Send + Sync + 'static,
197    {
198        let di = Di::get_instance().lock().unwrap();
199        di._register(factory);
200    }
201
202     pub fn get_inner<T: 'static>(&self) -> Result<T, Box<dyn std::error::Error>> {
203        let type_id = std::any::TypeId::of::<T>();
204        let providers = self.providers.read().unwrap();
205        let provider = providers.get(&type_id).ok_or("Provider not found")?;
206
207        let any = provider.provide(self);
208        // 从 Box<dyn Any> 中提取 Arc<T>
209         let t = any.downcast::<T>().map_err(|_| "Downcast failed")?;
210         Ok(*t)
211    }
212    pub fn get<T: 'static>() -> Result<T, Box<dyn std::error::Error>> {
213        let di = Di::get_instance().lock().unwrap();
214        di.get_inner()
215    }
216    
217    fn _get_single<T: Any + Send + Sync + 'static>(&self) -> Option<SingleRef<T>> {
218        let type_id = std::any::TypeId::of::<T>();
219        let any = self.single_map.get(&type_id)?;
220        if any.type_id() != type_id {
221            return None;
222        }
223        // 安全转换
224        let value = unsafe {
225            let ptr = Arc::into_raw(any.clone());
226            Arc::from_raw(ptr as *const RwLock<T>)
227        };
228        Some(SingleRef { value })
229    }
230    pub fn get_single<T: Any + Send + Sync + 'static>() -> Option<SingleRef<T>> {
231        let di = Di::get_instance().lock().unwrap();
232        di._get_single::<T>()
233    }
234    
235}
236
237trait Provider: Send + Sync {
238    fn provide(&self, di: &Di) -> Box<dyn Any>;
239}
240
241trait TkProvider: Send + Sync {
242    fn provide(&self) -> Box<dyn Any>;
243}
244
245struct FactoryProvider<F, T> {
246    factory: F,
247    _marker: std::marker::PhantomData<T>,
248}
249
250impl<F, T> Provider for FactoryProvider<F, T>
251where
252    F: Fn(&Di) -> T + Send + Sync + 'static,
253    T: 'static + Send + Sync,
254{
255    fn provide(&self, di: &Di) -> Box<dyn Any> {
256        Box::new((self.factory)(di))
257    }
258}
259
260
261impl<F, T> TkProvider for FactoryProvider<F, T>
262where
263    F: Fn() -> T + Send + Sync + 'static,
264    T: 'static + Send + Sync,
265{
266    fn provide(&self) -> Box<dyn Any> {
267        Box::new((self.factory)())
268    }
269}
270
271#[derive(Debug)]
272pub enum DiError {
273    ProviderNotFound,
274    TypeMismatch,
275    LockError,
276}
277
278pub type DiResult<T> = Result<T, DiError>;
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283    
284    struct Configuration {
285        port: u16,
286    }
287    
288    #[derive(Clone)]
289    struct Database {
290        port: u16,
291    }
292    
293    #[derive(Clone)]
294    struct  AppService {
295        db: Database,
296    }
297
298    #[tokio::test]
299    async fn async_test() {
300        TkDi::register(|| {
301            Database{port: 3306}
302        }).await;
303        println!("regist database done");
304        let db = TkDi::get::<Database>().await.unwrap();
305        
306        TkDi::register_single(Configuration{port: 8080}).await;
307        
308        println!("regist app done");
309        
310        //let result = TkDi::get::<AppService>().await.unwrap();
311        
312        //assert_eq!(result.db.port, 3306);
313        
314        if let Some(mut config) = TkDi::get_single::<Configuration>().await {
315            let mut config = config.get_mut().await;
316            assert_eq!(config.port, 8080);
317            config.port = 8081;
318        }
319        if let Some(mut config) = TkDi::get_single::<Configuration>().await{
320            let mut config = config.get_mut().await;
321        }
322    }
323
324    #[test]
325    fn it_works() {
326        Di::register::<Database, _>(|_| {
327            Database{port: 3306}
328        });
329        println!("regist database done");
330        
331        Di::register_single(Configuration{port: 8080});
332        
333        Di::register::<AppService, _>(|di| {
334            let db = di.get_inner::<Database>().unwrap();
335            AppService{ db:db.clone()}
336        });
337        println!("regist app done");
338        
339        let result = Di::get::<AppService>().unwrap();
340        
341        assert_eq!(result.db.port, 3306);
342        
343        if let Some(mut config) = Di::get_single::<Configuration>() {
344            let mut config = config.get_mut().unwrap();
345            assert_eq!(config.port, 8080);
346            config.port = 8081;
347        }
348        if let Some(mut config) = Di::get_single::<Configuration>() {
349            let mut config = config.get_mut().unwrap();
350            assert_eq!(config.port, 8081);
351        }
352        
353        ()
354    }
355}