dioxus_provider/
injection.rs1use std::any::{Any, TypeId};
9use std::collections::HashMap;
10use std::sync::{Arc, OnceLock, RwLock};
11
12static DEPENDENCY_REGISTRY: OnceLock<DependencyRegistry> = OnceLock::new();
14
15pub struct DependencyRegistry {
17 dependencies: RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>,
18}
19
20impl DependencyRegistry {
21 fn new() -> Self {
23 Self {
24 dependencies: RwLock::new(HashMap::new()),
25 }
26 }
27
28 pub fn register<T: Send + Sync + 'static>(&self, dependency: T) -> Result<(), String> {
30 let type_id = TypeId::of::<T>();
31 let mut deps = self
32 .dependencies
33 .write()
34 .map_err(|_| "Failed to acquire write lock on dependencies")?;
35
36 if deps.contains_key(&type_id) {
37 return Err(format!(
38 "Dependency of type {} already registered",
39 std::any::type_name::<T>()
40 ));
41 }
42
43 deps.insert(type_id, Arc::new(dependency));
44 Ok(())
45 }
46
47 pub fn get<T: Send + Sync + 'static>(&self) -> Result<Arc<T>, String> {
49 let type_id = TypeId::of::<T>();
50 let deps = self
51 .dependencies
52 .read()
53 .map_err(|_| "Failed to acquire read lock on dependencies")?;
54
55 let dependency = deps.get(&type_id).ok_or_else(|| {
56 format!(
57 "Dependency of type {} not found",
58 std::any::type_name::<T>()
59 )
60 })?;
61
62 dependency.clone().downcast::<T>().map_err(|_| {
63 format!(
64 "Failed to downcast dependency of type {}",
65 std::any::type_name::<T>()
66 )
67 })
68 }
69
70 pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
72 let type_id = TypeId::of::<T>();
73 self.dependencies
74 .read()
75 .map(|deps| deps.contains_key(&type_id))
76 .unwrap_or(false)
77 }
78
79 pub fn clear(&self) -> Result<(), String> {
81 let mut deps = self
82 .dependencies
83 .write()
84 .map_err(|_| "Failed to acquire write lock on dependencies")?;
85 deps.clear();
86 Ok(())
87 }
88
89 pub fn list_types(&self) -> Result<Vec<String>, String> {
91 let deps = self
92 .dependencies
93 .read()
94 .map_err(|_| "Failed to acquire read lock on dependencies")?;
95
96 Ok(vec![format!("{} dependencies registered", deps.len())])
99 }
100}
101
102pub fn init_dependency_injection() {
104 DEPENDENCY_REGISTRY.get_or_init(DependencyRegistry::new);
105}
106
107pub fn register_dependency<T: Send + Sync + 'static>(dependency: T) -> Result<(), String> {
109 let registry = DEPENDENCY_REGISTRY
110 .get()
111 .ok_or("Dependency registry not initialized. Call init_dependency_injection() first.")?;
112 registry.register(dependency)
113}
114
115pub fn inject<T: Send + Sync + 'static>() -> Result<Arc<T>, String> {
117 let registry = DEPENDENCY_REGISTRY
118 .get()
119 .ok_or("Dependency registry not initialized. Call init_dependency_injection() first.")?;
120 registry.get()
121}
122
123pub fn has_dependency<T: Send + Sync + 'static>() -> bool {
125 DEPENDENCY_REGISTRY
126 .get()
127 .map(|registry| registry.contains::<T>())
128 .unwrap_or(false)
129}
130
131pub fn clear_dependencies() -> Result<(), String> {
133 let registry = DEPENDENCY_REGISTRY
134 .get()
135 .ok_or("Dependency registry not initialized")?;
136 registry.clear()
137}
138
139#[macro_export]
141macro_rules! inject {
142 ($type:ty) => {
143 $crate::injection::inject::<$type>()
144 .map_err(|e| format!("Dependency injection failed: {}", e))?
145 };
146}
147
148#[macro_export]
150macro_rules! register {
151 ($dependency:expr) => {
152 $crate::injection::register_dependency($dependency)
153 .map_err(|e| format!("Dependency registration failed: {}", e))?
154 };
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160
161 struct TestService {
162 name: String,
163 }
164
165 impl TestService {
166 fn new(name: String) -> Self {
167 Self { name }
168 }
169
170 fn get_name(&self) -> &str {
171 &self.name
172 }
173 }
174
175 #[test]
176 fn test_dependency_injection() {
177 init_dependency_injection();
178
179 clear_dependencies().unwrap();
181
182 let service = TestService::new("test".to_string());
184 register_dependency(service).unwrap();
185
186 let injected: Arc<TestService> = inject().unwrap();
188 assert_eq!(injected.get_name(), "test");
189
190 assert!(has_dependency::<TestService>());
192 assert!(!has_dependency::<String>());
193 }
194
195 #[test]
196 fn test_duplicate_registration() {
197 init_dependency_injection();
198 clear_dependencies().unwrap();
199
200 let service1 = TestService::new("first".to_string());
201 let service2 = TestService::new("second".to_string());
202
203 assert!(register_dependency(service1).is_ok());
205
206 assert!(register_dependency(service2).is_err());
208 }
209
210 #[test]
211 fn test_missing_dependency() {
212 init_dependency_injection();
213 clear_dependencies().unwrap();
214
215 let result: Result<Arc<TestService>, String> = inject();
217 assert!(result.is_err());
218 }
219}