dioxus_provider/
injection.rs1use crate::errors::ProviderError;
9use std::any::{Any, TypeId};
10use std::collections::HashMap;
11use std::sync::{Arc, OnceLock, RwLock};
12
13static DEPENDENCY_REGISTRY: OnceLock<DependencyRegistry> = OnceLock::new();
15
16pub struct DependencyRegistry {
18 dependencies: RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>,
19}
20
21impl DependencyRegistry {
22 fn new() -> Self {
24 Self {
25 dependencies: RwLock::new(HashMap::new()),
26 }
27 }
28
29 pub fn register<T: Send + Sync + 'static>(&self, dependency: T) -> Result<(), ProviderError> {
31 let type_id = TypeId::of::<T>();
32 let mut deps = self.dependencies.write().map_err(|_| {
33 ProviderError::DependencyInjection(
34 "Failed to acquire write lock on dependencies".to_string(),
35 )
36 })?;
37
38 if deps.contains_key(&type_id) {
39 return Err(ProviderError::DependencyInjection(format!(
40 "Dependency of type {} already registered",
41 std::any::type_name::<T>()
42 )));
43 }
44
45 deps.insert(type_id, Arc::new(dependency));
46 Ok(())
47 }
48
49 pub fn get<T: Send + Sync + 'static>(&self) -> Result<Arc<T>, ProviderError> {
51 let type_id = TypeId::of::<T>();
52 let deps = self.dependencies.read().map_err(|_| {
53 ProviderError::DependencyInjection(
54 "Failed to acquire read lock on dependencies".to_string(),
55 )
56 })?;
57
58 let dependency = deps.get(&type_id).ok_or_else(|| {
59 ProviderError::DependencyInjection(format!(
60 "Dependency of type {} not found. Make sure to register it with register_dependency() first.",
61 std::any::type_name::<T>()
62 ))
63 })?;
64
65 dependency.clone().downcast::<T>().map_err(|_| {
66 ProviderError::DependencyInjection(format!(
67 "Failed to downcast dependency of type {}",
68 std::any::type_name::<T>()
69 ))
70 })
71 }
72
73 pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
75 let type_id = TypeId::of::<T>();
76 self.dependencies
77 .read()
78 .map(|deps| deps.contains_key(&type_id))
79 .unwrap_or(false)
80 }
81
82 pub fn clear(&self) -> Result<(), ProviderError> {
84 let mut deps = self.dependencies.write().map_err(|_| {
85 ProviderError::DependencyInjection(
86 "Failed to acquire write lock on dependencies".to_string(),
87 )
88 })?;
89 deps.clear();
90 Ok(())
91 }
92
93 pub fn list_types(&self) -> Result<Vec<String>, ProviderError> {
95 let deps = self.dependencies.read().map_err(|_| {
96 ProviderError::DependencyInjection(
97 "Failed to acquire read lock on dependencies".to_string(),
98 )
99 })?;
100
101 Ok(vec![format!("{} dependencies registered", deps.len())])
104 }
105}
106
107pub fn init_dependency_injection() {
109 DEPENDENCY_REGISTRY.get_or_init(DependencyRegistry::new);
110}
111
112pub fn register_dependency<T: Send + Sync + 'static>(dependency: T) -> Result<(), ProviderError> {
114 let registry = DEPENDENCY_REGISTRY.get().ok_or_else(|| {
115 ProviderError::DependencyInjection(
116 "Dependency registry not initialized. Call init_dependency_injection() first."
117 .to_string(),
118 )
119 })?;
120 registry.register(dependency)
121}
122
123pub fn inject<T: Send + Sync + 'static>() -> Result<Arc<T>, ProviderError> {
125 let registry = DEPENDENCY_REGISTRY.get().ok_or_else(|| {
126 ProviderError::DependencyInjection(
127 "Dependency registry not initialized. Call init_dependency_injection() first."
128 .to_string(),
129 )
130 })?;
131 registry.get()
132}
133
134pub fn has_dependency<T: Send + Sync + 'static>() -> bool {
136 DEPENDENCY_REGISTRY
137 .get()
138 .map(|registry| registry.contains::<T>())
139 .unwrap_or(false)
140}
141
142pub fn clear_dependencies() -> Result<(), ProviderError> {
144 let registry = DEPENDENCY_REGISTRY.get().ok_or_else(|| {
145 ProviderError::DependencyInjection("Dependency registry not initialized".to_string())
146 })?;
147 registry.clear()
148}
149
150#[macro_export]
152macro_rules! inject {
153 ($type:ty) => {
154 $crate::injection::inject::<$type>()
155 .map_err(|e| format!("Dependency injection failed: {}", e))?
156 };
157}
158
159#[macro_export]
161macro_rules! register {
162 ($dependency:expr) => {
163 $crate::injection::register_dependency($dependency)
164 .map_err(|e| format!("Dependency registration failed: {}", e))?
165 };
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171
172 struct TestService {
173 name: String,
174 }
175
176 impl TestService {
177 fn new(name: String) -> Self {
178 Self { name }
179 }
180
181 fn get_name(&self) -> &str {
182 &self.name
183 }
184 }
185
186 #[test]
187 fn test_dependency_injection() {
188 init_dependency_injection();
189
190 clear_dependencies().unwrap();
192
193 let service = TestService::new("test".to_string());
195 register_dependency(service).unwrap();
196
197 let injected: Arc<TestService> = inject().unwrap();
199 assert_eq!(injected.get_name(), "test");
200
201 assert!(has_dependency::<TestService>());
203 assert!(!has_dependency::<String>());
204 }
205
206 #[test]
207 fn test_duplicate_registration() {
208 init_dependency_injection();
209 clear_dependencies().unwrap();
210
211 let service1 = TestService::new("first".to_string());
212 let service2 = TestService::new("second".to_string());
213
214 assert!(register_dependency(service1).is_ok());
216
217 assert!(register_dependency(service2).is_err());
219 }
220
221 #[test]
222 fn test_missing_dependency() {
223 init_dependency_injection();
224 clear_dependencies().unwrap();
225
226 let result: Result<Arc<TestService>, ProviderError> = inject();
228 assert!(result.is_err());
229 }
230}