dioxus_provider/
injection.rs1use crate::errors::ProviderError;
9use std::any::{Any, TypeId};
10use std::collections::HashMap;
11use std::sync::{Arc, OnceLock, RwLock};
12
13static DEPENDENCY_REGISTRY: OnceLock<DependencyRegistry> = OnceLock::new();
15
16pub struct DependencyRegistry {
18 dependencies: RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>,
19}
20
21impl DependencyRegistry {
22 fn new() -> Self {
24 Self {
25 dependencies: RwLock::new(HashMap::new()),
26 }
27 }
28
29 pub fn register<T: Send + Sync + 'static>(&self, dependency: T) -> Result<(), ProviderError> {
31 let type_id = TypeId::of::<T>();
32 let mut deps = self.dependencies.write().map_err(|_| {
33 ProviderError::DependencyInjection(
34 "Failed to acquire write lock on dependencies".to_string(),
35 )
36 })?;
37
38 if deps.contains_key(&type_id) {
39 return Err(ProviderError::DependencyInjection(format!(
40 "Dependency of type {} already registered",
41 std::any::type_name::<T>()
42 )));
43 }
44
45 deps.insert(type_id, Arc::new(dependency));
46 Ok(())
47 }
48
49 pub fn get<T: Send + Sync + 'static>(&self) -> Result<Arc<T>, ProviderError> {
51 let type_id = TypeId::of::<T>();
52 let deps = self.dependencies.read().map_err(|_| {
53 ProviderError::DependencyInjection(
54 "Failed to acquire read lock on dependencies".to_string(),
55 )
56 })?;
57
58 let dependency = deps.get(&type_id).ok_or_else(|| {
59 ProviderError::DependencyInjection(format!(
60 "Dependency of type {} not found. Make sure to register it with register_dependency() first.",
61 std::any::type_name::<T>()
62 ))
63 })?;
64
65 dependency.clone().downcast::<T>().map_err(|_| {
66 ProviderError::DependencyInjection(format!(
67 "Failed to downcast dependency of type {}",
68 std::any::type_name::<T>()
69 ))
70 })
71 }
72
73 pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
75 let type_id = TypeId::of::<T>();
76 self.dependencies
77 .read()
78 .map(|deps| deps.contains_key(&type_id))
79 .unwrap_or(false)
80 }
81
82 pub fn clear(&self) -> Result<(), ProviderError> {
84 let mut deps = self.dependencies.write().map_err(|_| {
85 ProviderError::DependencyInjection(
86 "Failed to acquire write lock on dependencies".to_string(),
87 )
88 })?;
89 deps.clear();
90 Ok(())
91 }
92
93 pub fn list_types(&self) -> Result<Vec<String>, ProviderError> {
95 let deps = self.dependencies.read().map_err(|_| {
96 ProviderError::DependencyInjection(
97 "Failed to acquire read lock on dependencies".to_string(),
98 )
99 })?;
100
101 Ok(vec![format!("{} dependencies registered", deps.len())])
104 }
105}
106
107#[deprecated(
113 since = "0.1.0",
114 note = "Use init() or ProviderConfig::new().with_dependency_injection().init() instead"
115)]
116pub fn init_dependency_injection() {
117 DEPENDENCY_REGISTRY.get_or_init(DependencyRegistry::new);
118}
119
120pub(crate) fn ensure_dependency_injection_initialized() {
124 DEPENDENCY_REGISTRY.get_or_init(DependencyRegistry::new);
125}
126
127pub fn register_dependency<T: Send + Sync + 'static>(dependency: T) -> Result<(), ProviderError> {
129 let registry = DEPENDENCY_REGISTRY.get().ok_or_else(|| {
130 ProviderError::DependencyInjection(
131 "Dependency registry not initialized. Call init_dependency_injection() first."
132 .to_string(),
133 )
134 })?;
135 registry.register(dependency)
136}
137
138pub fn inject<T: Send + Sync + 'static>() -> Result<Arc<T>, ProviderError> {
140 let registry = DEPENDENCY_REGISTRY.get().ok_or_else(|| {
141 ProviderError::DependencyInjection(
142 "Dependency registry not initialized. Call init_dependency_injection() first."
143 .to_string(),
144 )
145 })?;
146 registry.get()
147}
148
149pub fn has_dependency<T: Send + Sync + 'static>() -> bool {
151 DEPENDENCY_REGISTRY
152 .get()
153 .map(|registry| registry.contains::<T>())
154 .unwrap_or(false)
155}
156
157pub fn clear_dependencies() -> Result<(), ProviderError> {
159 let registry = DEPENDENCY_REGISTRY.get().ok_or_else(|| {
160 ProviderError::DependencyInjection("Dependency registry not initialized".to_string())
161 })?;
162 registry.clear()
163}
164
165#[macro_export]
167macro_rules! inject {
168 ($type:ty) => {
169 $crate::injection::inject::<$type>()
170 .map_err(|e| format!("Dependency injection failed: {}", e))?
171 };
172}
173
174#[macro_export]
176macro_rules! register {
177 ($dependency:expr) => {
178 $crate::injection::register_dependency($dependency)
179 .map_err(|e| format!("Dependency registration failed: {}", e))?
180 };
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186
187 struct TestService {
188 name: String,
189 }
190
191 impl TestService {
192 fn new(name: String) -> Self {
193 Self { name }
194 }
195
196 fn get_name(&self) -> &str {
197 &self.name
198 }
199 }
200
201 #[test]
202 fn test_dependency_injection() {
203 init_dependency_injection();
204
205 clear_dependencies().unwrap();
207
208 let service = TestService::new("test".to_string());
210 register_dependency(service).unwrap();
211
212 let injected: Arc<TestService> = inject().unwrap();
214 assert_eq!(injected.get_name(), "test");
215
216 assert!(has_dependency::<TestService>());
218 assert!(!has_dependency::<String>());
219 }
220
221 #[test]
222 fn test_duplicate_registration() {
223 init_dependency_injection();
224 clear_dependencies().unwrap();
225
226 let service1 = TestService::new("first".to_string());
227 let service2 = TestService::new("second".to_string());
228
229 assert!(register_dependency(service1).is_ok());
231
232 assert!(register_dependency(service2).is_err());
234 }
235
236 #[test]
237 fn test_missing_dependency() {
238 init_dependency_injection();
239 clear_dependencies().unwrap();
240
241 let result: Result<Arc<TestService>, ProviderError> = inject();
243 assert!(result.is_err());
244 }
245}