dioxus_provider/
injection.rs

1/*!
2 * Global Dependency Injection System
3 *
4 * Provides a type-safe way to register and access shared dependencies
5 * that don't fit well as provider parameters (e.g., API clients, databases).
6 */
7
8use crate::errors::ProviderError;
9use std::any::{Any, TypeId};
10use std::collections::HashMap;
11use std::sync::{Arc, OnceLock, RwLock};
12
13/// Global registry for dependency injection
14static DEPENDENCY_REGISTRY: OnceLock<DependencyRegistry> = OnceLock::new();
15
16/// Registry that holds all injected dependencies
17pub struct DependencyRegistry {
18    dependencies: RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>,
19}
20
21impl DependencyRegistry {
22    /// Create a new dependency registry
23    fn new() -> Self {
24        Self {
25            dependencies: RwLock::new(HashMap::new()),
26        }
27    }
28
29    /// Register a dependency of type T
30    pub fn register<T: Send + Sync + 'static>(&self, dependency: T) -> Result<(), ProviderError> {
31        let type_id = TypeId::of::<T>();
32        let mut deps = self.dependencies.write().map_err(|_| {
33            ProviderError::DependencyInjection(
34                "Failed to acquire write lock on dependencies".to_string(),
35            )
36        })?;
37
38        if deps.contains_key(&type_id) {
39            return Err(ProviderError::DependencyInjection(format!(
40                "Dependency of type {} already registered",
41                std::any::type_name::<T>()
42            )));
43        }
44
45        deps.insert(type_id, Arc::new(dependency));
46        Ok(())
47    }
48
49    /// Get a dependency of type T
50    pub fn get<T: Send + Sync + 'static>(&self) -> Result<Arc<T>, ProviderError> {
51        let type_id = TypeId::of::<T>();
52        let deps = self.dependencies.read().map_err(|_| {
53            ProviderError::DependencyInjection(
54                "Failed to acquire read lock on dependencies".to_string(),
55            )
56        })?;
57
58        let dependency = deps.get(&type_id).ok_or_else(|| {
59            ProviderError::DependencyInjection(format!(
60                "Dependency of type {} not found. Make sure to register it with register_dependency() first.",
61                std::any::type_name::<T>()
62            ))
63        })?;
64
65        dependency.clone().downcast::<T>().map_err(|_| {
66            ProviderError::DependencyInjection(format!(
67                "Failed to downcast dependency of type {}",
68                std::any::type_name::<T>()
69            ))
70        })
71    }
72
73    /// Check if a dependency of type T is registered
74    pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
75        let type_id = TypeId::of::<T>();
76        self.dependencies
77            .read()
78            .map(|deps| deps.contains_key(&type_id))
79            .unwrap_or(false)
80    }
81
82    /// Clear all dependencies (mainly for testing)
83    pub fn clear(&self) -> Result<(), ProviderError> {
84        let mut deps = self.dependencies.write().map_err(|_| {
85            ProviderError::DependencyInjection(
86                "Failed to acquire write lock on dependencies".to_string(),
87            )
88        })?;
89        deps.clear();
90        Ok(())
91    }
92
93    /// Get all registered dependency type names (for debugging)
94    pub fn list_types(&self) -> Result<Vec<String>, ProviderError> {
95        let deps = self.dependencies.read().map_err(|_| {
96            ProviderError::DependencyInjection(
97                "Failed to acquire read lock on dependencies".to_string(),
98            )
99        })?;
100
101        // Note: We can't easily get type names from TypeId,
102        // so this is mainly useful for debugging count
103        Ok(vec![format!("{} dependencies registered", deps.len())])
104    }
105}
106
107/// Initialize the global dependency registry
108pub fn init_dependency_injection() {
109    DEPENDENCY_REGISTRY.get_or_init(DependencyRegistry::new);
110}
111
112/// Register a global dependency
113pub fn register_dependency<T: Send + Sync + 'static>(dependency: T) -> Result<(), ProviderError> {
114    let registry = DEPENDENCY_REGISTRY.get().ok_or_else(|| {
115        ProviderError::DependencyInjection(
116            "Dependency registry not initialized. Call init_dependency_injection() first."
117                .to_string(),
118        )
119    })?;
120    registry.register(dependency)
121}
122
123/// Get a global dependency
124pub fn inject<T: Send + Sync + 'static>() -> Result<Arc<T>, ProviderError> {
125    let registry = DEPENDENCY_REGISTRY.get().ok_or_else(|| {
126        ProviderError::DependencyInjection(
127            "Dependency registry not initialized. Call init_dependency_injection() first."
128                .to_string(),
129        )
130    })?;
131    registry.get()
132}
133
134/// Check if a dependency is registered
135pub fn has_dependency<T: Send + Sync + 'static>() -> bool {
136    DEPENDENCY_REGISTRY
137        .get()
138        .map(|registry| registry.contains::<T>())
139        .unwrap_or(false)
140}
141
142/// Clear all dependencies (mainly for testing)
143pub fn clear_dependencies() -> Result<(), ProviderError> {
144    let registry = DEPENDENCY_REGISTRY.get().ok_or_else(|| {
145        ProviderError::DependencyInjection("Dependency registry not initialized".to_string())
146    })?;
147    registry.clear()
148}
149
150/// Macro for easy dependency injection in providers
151#[macro_export]
152macro_rules! inject {
153    ($type:ty) => {
154        $crate::injection::inject::<$type>()
155            .map_err(|e| format!("Dependency injection failed: {}", e))?
156    };
157}
158
159/// Macro for registering dependencies with error handling
160#[macro_export]
161macro_rules! register {
162    ($dependency:expr) => {
163        $crate::injection::register_dependency($dependency)
164            .map_err(|e| format!("Dependency registration failed: {}", e))?
165    };
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171
172    struct TestService {
173        name: String,
174    }
175
176    impl TestService {
177        fn new(name: String) -> Self {
178            Self { name }
179        }
180
181        fn get_name(&self) -> &str {
182            &self.name
183        }
184    }
185
186    #[test]
187    fn test_dependency_injection() {
188        init_dependency_injection();
189
190        // Clear any existing dependencies
191        clear_dependencies().unwrap();
192
193        // Register a dependency
194        let service = TestService::new("test".to_string());
195        register_dependency(service).unwrap();
196
197        // Inject the dependency
198        let injected: Arc<TestService> = inject().unwrap();
199        assert_eq!(injected.get_name(), "test");
200
201        // Check if dependency exists
202        assert!(has_dependency::<TestService>());
203        assert!(!has_dependency::<String>());
204    }
205
206    #[test]
207    fn test_duplicate_registration() {
208        init_dependency_injection();
209        clear_dependencies().unwrap();
210
211        let service1 = TestService::new("first".to_string());
212        let service2 = TestService::new("second".to_string());
213
214        // First registration should succeed
215        assert!(register_dependency(service1).is_ok());
216
217        // Second registration should fail
218        assert!(register_dependency(service2).is_err());
219    }
220
221    #[test]
222    fn test_missing_dependency() {
223        init_dependency_injection();
224        clear_dependencies().unwrap();
225
226        // Try to inject non-existent dependency
227        let result: Result<Arc<TestService>, ProviderError> = inject();
228        assert!(result.is_err());
229    }
230}