dioxus_provider/
injection.rs

1/*!
2 * Global Dependency Injection System
3 *
4 * Provides a type-safe way to register and access shared dependencies
5 * that don't fit well as provider parameters (e.g., API clients, databases).
6 */
7
8use crate::errors::ProviderError;
9use std::any::{Any, TypeId};
10use std::collections::HashMap;
11use std::sync::{Arc, OnceLock, RwLock};
12
13/// Global registry for dependency injection
14static DEPENDENCY_REGISTRY: OnceLock<DependencyRegistry> = OnceLock::new();
15
16/// Registry that holds all injected dependencies
17pub struct DependencyRegistry {
18    dependencies: RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>,
19}
20
21impl DependencyRegistry {
22    /// Create a new dependency registry
23    fn new() -> Self {
24        Self {
25            dependencies: RwLock::new(HashMap::new()),
26        }
27    }
28
29    /// Register a dependency of type T
30    pub fn register<T: Send + Sync + 'static>(&self, dependency: T) -> Result<(), ProviderError> {
31        let type_id = TypeId::of::<T>();
32        let mut deps = self.dependencies.write().map_err(|_| {
33            ProviderError::DependencyInjection(
34                "Failed to acquire write lock on dependencies".to_string(),
35            )
36        })?;
37
38        if deps.contains_key(&type_id) {
39            return Err(ProviderError::DependencyInjection(format!(
40                "Dependency of type {} already registered",
41                std::any::type_name::<T>()
42            )));
43        }
44
45        deps.insert(type_id, Arc::new(dependency));
46        Ok(())
47    }
48
49    /// Get a dependency of type T
50    pub fn get<T: Send + Sync + 'static>(&self) -> Result<Arc<T>, ProviderError> {
51        let type_id = TypeId::of::<T>();
52        let deps = self.dependencies.read().map_err(|_| {
53            ProviderError::DependencyInjection(
54                "Failed to acquire read lock on dependencies".to_string(),
55            )
56        })?;
57
58        let dependency = deps.get(&type_id).ok_or_else(|| {
59            ProviderError::DependencyInjection(format!(
60                "Dependency of type {} not found. Make sure to register it with register_dependency() first.",
61                std::any::type_name::<T>()
62            ))
63        })?;
64
65        dependency.clone().downcast::<T>().map_err(|_| {
66            ProviderError::DependencyInjection(format!(
67                "Failed to downcast dependency of type {}",
68                std::any::type_name::<T>()
69            ))
70        })
71    }
72
73    /// Check if a dependency of type T is registered
74    pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
75        let type_id = TypeId::of::<T>();
76        self.dependencies
77            .read()
78            .map(|deps| deps.contains_key(&type_id))
79            .unwrap_or(false)
80    }
81
82    /// Clear all dependencies (mainly for testing)
83    pub fn clear(&self) -> Result<(), ProviderError> {
84        let mut deps = self.dependencies.write().map_err(|_| {
85            ProviderError::DependencyInjection(
86                "Failed to acquire write lock on dependencies".to_string(),
87            )
88        })?;
89        deps.clear();
90        Ok(())
91    }
92
93    /// Get all registered dependency type names (for debugging)
94    pub fn list_types(&self) -> Result<Vec<String>, ProviderError> {
95        let deps = self.dependencies.read().map_err(|_| {
96            ProviderError::DependencyInjection(
97                "Failed to acquire read lock on dependencies".to_string(),
98            )
99        })?;
100
101        // Note: We can't easily get type names from TypeId,
102        // so this is mainly useful for debugging count
103        Ok(vec![format!("{} dependencies registered", deps.len())])
104    }
105}
106
107/// Initialize the global dependency registry
108///
109/// # Deprecated
110/// Use `init()` or `ProviderConfig::new().with_dependency_injection().init()` instead.
111/// The new initialization system automatically handles dependency injection setup.
112#[deprecated(
113    since = "0.1.0",
114    note = "Use init() or ProviderConfig::new().with_dependency_injection().init() instead"
115)]
116pub fn init_dependency_injection() {
117    DEPENDENCY_REGISTRY.get_or_init(DependencyRegistry::new);
118}
119
120/// Ensure the dependency injection registry is initialized (non-deprecated helper)
121///
122/// This is used internally by the new unified initialization path.
123pub(crate) fn ensure_dependency_injection_initialized() {
124    DEPENDENCY_REGISTRY.get_or_init(DependencyRegistry::new);
125}
126
127/// Register a global dependency
128pub fn register_dependency<T: Send + Sync + 'static>(dependency: T) -> Result<(), ProviderError> {
129    let registry = DEPENDENCY_REGISTRY.get().ok_or_else(|| {
130        ProviderError::DependencyInjection(
131            "Dependency registry not initialized. Call init_dependency_injection() first."
132                .to_string(),
133        )
134    })?;
135    registry.register(dependency)
136}
137
138/// Get a global dependency
139pub fn inject<T: Send + Sync + 'static>() -> Result<Arc<T>, ProviderError> {
140    let registry = DEPENDENCY_REGISTRY.get().ok_or_else(|| {
141        ProviderError::DependencyInjection(
142            "Dependency registry not initialized. Call init_dependency_injection() first."
143                .to_string(),
144        )
145    })?;
146    registry.get()
147}
148
149/// Check if a dependency is registered
150pub fn has_dependency<T: Send + Sync + 'static>() -> bool {
151    DEPENDENCY_REGISTRY
152        .get()
153        .map(|registry| registry.contains::<T>())
154        .unwrap_or(false)
155}
156
157/// Clear all dependencies (mainly for testing)
158pub fn clear_dependencies() -> Result<(), ProviderError> {
159    let registry = DEPENDENCY_REGISTRY.get().ok_or_else(|| {
160        ProviderError::DependencyInjection("Dependency registry not initialized".to_string())
161    })?;
162    registry.clear()
163}
164
165/// Macro for easy dependency injection in providers
166#[macro_export]
167macro_rules! inject {
168    ($type:ty) => {
169        $crate::injection::inject::<$type>()
170            .map_err(|e| format!("Dependency injection failed: {}", e))?
171    };
172}
173
174/// Macro for registering dependencies with error handling
175#[macro_export]
176macro_rules! register {
177    ($dependency:expr) => {
178        $crate::injection::register_dependency($dependency)
179            .map_err(|e| format!("Dependency registration failed: {}", e))?
180    };
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186
187    struct TestService {
188        name: String,
189    }
190
191    impl TestService {
192        fn new(name: String) -> Self {
193            Self { name }
194        }
195
196        fn get_name(&self) -> &str {
197            &self.name
198        }
199    }
200
201    #[test]
202    fn test_dependency_injection() {
203        init_dependency_injection();
204
205        // Clear any existing dependencies
206        clear_dependencies().unwrap();
207
208        // Register a dependency
209        let service = TestService::new("test".to_string());
210        register_dependency(service).unwrap();
211
212        // Inject the dependency
213        let injected: Arc<TestService> = inject().unwrap();
214        assert_eq!(injected.get_name(), "test");
215
216        // Check if dependency exists
217        assert!(has_dependency::<TestService>());
218        assert!(!has_dependency::<String>());
219    }
220
221    #[test]
222    fn test_duplicate_registration() {
223        init_dependency_injection();
224        clear_dependencies().unwrap();
225
226        let service1 = TestService::new("first".to_string());
227        let service2 = TestService::new("second".to_string());
228
229        // First registration should succeed
230        assert!(register_dependency(service1).is_ok());
231
232        // Second registration should fail
233        assert!(register_dependency(service2).is_err());
234    }
235
236    #[test]
237    fn test_missing_dependency() {
238        init_dependency_injection();
239        clear_dependencies().unwrap();
240
241        // Try to inject non-existent dependency
242        let result: Result<Arc<TestService>, ProviderError> = inject();
243        assert!(result.is_err());
244    }
245}