ru_di/
lib.rs

1use std::any::{Any, TypeId};
2use std::collections::HashMap;
3use std::sync::{Arc, Mutex, OnceLock, RwLock, RwLockWriteGuard, RwLockReadGuard};
4use std::ops::{Deref, DerefMut};
5use tokio::sync::{Mutex as TokioMutex, RwLock as TokioRwLock};
6
7static INSTANCE: OnceLock<Mutex<Di>> = OnceLock::new();
8
9type ThreadSafeAny = Arc<RwLock<dyn Any + Send + Sync + 'static>>;
10
11type AsyncSaftAny = Arc<TokioRwLock<dyn Any + Send + Sync + 'static>>;
12
13pub struct Di {
14    providers: RwLock<HashMap<TypeId, Arc<dyn Provider>>>,
15    single_map: HashMap<TypeId, ThreadSafeAny>,
16    async_map: HashMap<TypeId, AsyncSaftAny>,
17}
18
19pub struct SingleRef<T> {
20    value: Arc<RwLock<T>>,
21}
22
23impl<T> SingleRef<T> {
24    pub fn get(&self) -> RwLockReadGuard<T> {
25        self.value.read().unwrap()
26    }
27
28    pub fn get_mut(&mut self) -> RwLockWriteGuard<T> {
29        self.value.write().unwrap()
30    }
31}
32
33impl<T> Clone for SingleRef<T> {
34    fn clone(&self) -> Self {
35        SingleRef {
36            value: self.value.clone(),
37        }
38    }
39}
40
41
42pub struct SingleAsyncRef<T> {
43    value: Arc<TokioRwLock<T>>,
44}
45
46impl<T> SingleAsyncRef<T> {
47    pub async fn get(&self) -> tokio::sync::RwLockReadGuard<'_, T> {
48        self.value.read().await
49    }
50
51    pub async fn get_mut(&mut self) -> tokio::sync::RwLockWriteGuard<'_, T> {
52        self.value.write().await
53    }
54}
55
56impl<T> Clone for SingleAsyncRef<T> {
57    fn clone(&self) -> Self {
58        SingleAsyncRef {
59            value: self.value.clone(),
60        }
61    }
62}
63
64
65impl Di {
66    fn get_instance() -> &'static Mutex<Di> {
67        INSTANCE.get_or_init(|| Mutex::new(Di{
68            providers: RwLock::new(HashMap::new()),
69            single_map: HashMap::new(),
70            async_map: HashMap::new(),
71        }))
72    }
73    
74    fn _register_single<T>(&mut self, instance: T)
75    where
76        T: 'static + Send + Sync,
77    {
78        let type_id = std::any::TypeId::of::<T>();
79        let any = Arc::new(RwLock::new(instance));
80        self.single_map.insert(type_id, any);
81    }
82
83    fn _register_async_single<T>(&mut self, instance: T)
84    where
85        T: 'static + Send + Sync,
86    {
87        let type_id = std::any::TypeId::of::<T>();
88        let any = Arc::new(TokioRwLock::new(instance));
89        self.async_map.insert(type_id, any);
90    }
91    
92    pub fn register_single<T>(instance: T)
93    where
94        T: 'static + Send + Sync,
95    {
96        let mut di = Di::get_instance().lock().unwrap();
97        di._register_single(instance);
98    }
99
100    pub fn register_async_single<T>(instance: T)
101    where
102        T: 'static + Send + Sync,
103    {
104        let mut di = Di::get_instance().lock().unwrap();
105        di._register_async_single(instance);
106    }
107    
108    fn _register<T, F>(&self, factory: F)
109    where
110        T: 'static + Send + Sync,
111        F: Fn(&Di) -> T + Send + Sync + 'static,
112    {
113        let provider = FactoryProvider {
114            factory,
115            _marker: std::marker::PhantomData,
116        };
117        let type_id = std::any::TypeId::of::<T>();
118        let mut providers = self.providers.write().unwrap();
119        providers.insert(type_id, Arc::new(provider));
120    }
121    
122    pub fn register<T, F>(factory: F)
123    where
124        T: 'static + Send + Sync,
125        F: Fn(&Di) -> T + Send + Sync + 'static,
126    {
127        let di = Di::get_instance().lock().unwrap();
128        di._register(factory);
129    }
130
131     pub fn get_inner<T: 'static>(&self) -> Result<T, Box<dyn std::error::Error>> {
132        let type_id = std::any::TypeId::of::<T>();
133        let providers = self.providers.read().unwrap();
134        let provider = providers.get(&type_id).ok_or("Provider not found")?;
135
136        let any = provider.provide(self);
137        // 从 Box<dyn Any> 中提取 Arc<T>
138         let t = any.downcast::<T>().map_err(|_| "Downcast failed")?;
139         Ok(*t)
140    }
141    pub fn get<T: 'static>() -> Result<T, Box<dyn std::error::Error>> {
142        let di = Di::get_instance().lock().unwrap();
143        di.get_inner()
144    }
145    
146    fn _get_single<T: Any + Send + Sync + 'static>(&self) -> Option<SingleRef<T>> {
147        let type_id = std::any::TypeId::of::<T>();
148        let any = self.single_map.get(&type_id)?;
149        let value = unsafe {
150            let ptr = Arc::into_raw(any.clone());
151            Arc::from_raw(ptr as *const RwLock<T>)
152        };
153        Some(SingleRef { value })
154    }
155    pub fn get_single<T: Any + Send + Sync + 'static>() -> Option<SingleRef<T>> {
156        let di = Di::get_instance().lock().unwrap();
157        di._get_single::<T>()
158    }
159
160
161    fn _get_async_single<T: Any + Send + Sync + 'static>(&self) -> Option<SingleAsyncRef<T>> {
162        let type_id = std::any::TypeId::of::<T>();
163        let any = self.async_map.get(&type_id)?;
164        let value = unsafe {
165            let ptr = Arc::into_raw(any.clone());
166            Arc::from_raw(ptr as *const TokioRwLock<T>)
167        };
168        Some(SingleAsyncRef { value })
169    }
170    
171    pub fn get_async_single<T: Any + Send + Sync + 'static>() -> Option<SingleAsyncRef<T>> {
172        let di = Di::get_instance().lock().unwrap();
173        di._get_async_single::<T>()
174    }
175}
176
177trait Provider: Send + Sync {
178    fn provide(&self, di: &Di) -> Box<dyn Any>;
179}
180
181struct FactoryProvider<F, T> {
182    factory: F,
183    _marker: std::marker::PhantomData<T>,
184}
185
186impl<F, T> Provider for FactoryProvider<F, T>
187where
188    F: Fn(&Di) -> T + Send + Sync + 'static,
189    T: 'static + Send + Sync,
190{
191    fn provide(&self, di: &Di) -> Box<dyn Any> {
192        Box::new((self.factory)(di))
193    }
194}
195
196
197pub fn add(left: u64, right: u64) -> u64 {
198    left + right
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204    
205    struct Configuration {
206        port: u16,
207    }
208    
209    #[derive(Clone)]
210    struct Database {
211        port: u16,
212    }
213    
214    #[derive(Clone)]
215    struct  AppService {
216        db: Database,
217    }
218
219    #[test]
220    fn it_works() {
221        Di::register::<Database, _>(|_| {
222            Database{port: 3306}
223        });
224        println!("regist database done");
225        
226        Di::register_single(Configuration{port: 8080});
227        
228        Di::register::<AppService, _>(|di| {
229            let db = di.get_inner::<Database>().unwrap();
230            AppService{ db:db.clone()}
231        });
232        println!("regist app done");
233        
234        let result = Di::get::<AppService>().unwrap();
235        
236        assert_eq!(result.db.port, 3306);
237        
238        if let Some(mut config) = Di::get_single::<Configuration>() {
239            let mut config = config.get_mut();
240            assert_eq!(config.port, 8080);
241            config.port = 8081;
242        }
243        if let Some(mut config) = Di::get_single::<Configuration>() {
244            let mut config = config.get_mut();
245            assert_eq!(config.port, 8081);
246        }
247        
248        ()
249    }
250}