dioxus_provider/
injection.rs

1/*!
2 * Global Dependency Injection System
3 *
4 * Provides a type-safe way to register and access shared dependencies
5 * that don't fit well as provider parameters (e.g., API clients, databases).
6 */
7
8use std::any::{Any, TypeId};
9use std::collections::HashMap;
10use std::sync::{Arc, OnceLock, RwLock};
11
12/// Global registry for dependency injection
13static DEPENDENCY_REGISTRY: OnceLock<DependencyRegistry> = OnceLock::new();
14
15/// Registry that holds all injected dependencies
16pub struct DependencyRegistry {
17    dependencies: RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>,
18}
19
20impl DependencyRegistry {
21    /// Create a new dependency registry
22    fn new() -> Self {
23        Self {
24            dependencies: RwLock::new(HashMap::new()),
25        }
26    }
27
28    /// Register a dependency of type T
29    pub fn register<T: Send + Sync + 'static>(&self, dependency: T) -> Result<(), String> {
30        let type_id = TypeId::of::<T>();
31        let mut deps = self
32            .dependencies
33            .write()
34            .map_err(|_| "Failed to acquire write lock on dependencies")?;
35
36        if deps.contains_key(&type_id) {
37            return Err(format!(
38                "Dependency of type {} already registered",
39                std::any::type_name::<T>()
40            ));
41        }
42
43        deps.insert(type_id, Arc::new(dependency));
44        Ok(())
45    }
46
47    /// Get a dependency of type T
48    pub fn get<T: Send + Sync + 'static>(&self) -> Result<Arc<T>, String> {
49        let type_id = TypeId::of::<T>();
50        let deps = self
51            .dependencies
52            .read()
53            .map_err(|_| "Failed to acquire read lock on dependencies")?;
54
55        let dependency = deps.get(&type_id).ok_or_else(|| {
56            format!(
57                "Dependency of type {} not found",
58                std::any::type_name::<T>()
59            )
60        })?;
61
62        dependency.clone().downcast::<T>().map_err(|_| {
63            format!(
64                "Failed to downcast dependency of type {}",
65                std::any::type_name::<T>()
66            )
67        })
68    }
69
70    /// Check if a dependency of type T is registered
71    pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
72        let type_id = TypeId::of::<T>();
73        self.dependencies
74            .read()
75            .map(|deps| deps.contains_key(&type_id))
76            .unwrap_or(false)
77    }
78
79    /// Clear all dependencies (mainly for testing)
80    pub fn clear(&self) -> Result<(), String> {
81        let mut deps = self
82            .dependencies
83            .write()
84            .map_err(|_| "Failed to acquire write lock on dependencies")?;
85        deps.clear();
86        Ok(())
87    }
88
89    /// Get all registered dependency type names (for debugging)
90    pub fn list_types(&self) -> Result<Vec<String>, String> {
91        let deps = self
92            .dependencies
93            .read()
94            .map_err(|_| "Failed to acquire read lock on dependencies")?;
95
96        // Note: We can't easily get type names from TypeId,
97        // so this is mainly useful for debugging count
98        Ok(vec![format!("{} dependencies registered", deps.len())])
99    }
100}
101
102/// Initialize the global dependency registry
103pub fn init_dependency_injection() {
104    DEPENDENCY_REGISTRY.get_or_init(DependencyRegistry::new);
105}
106
107/// Register a global dependency
108pub fn register_dependency<T: Send + Sync + 'static>(dependency: T) -> Result<(), String> {
109    let registry = DEPENDENCY_REGISTRY
110        .get()
111        .ok_or("Dependency registry not initialized. Call init_dependency_injection() first.")?;
112    registry.register(dependency)
113}
114
115/// Get a global dependency
116pub fn inject<T: Send + Sync + 'static>() -> Result<Arc<T>, String> {
117    let registry = DEPENDENCY_REGISTRY
118        .get()
119        .ok_or("Dependency registry not initialized. Call init_dependency_injection() first.")?;
120    registry.get()
121}
122
123/// Check if a dependency is registered
124pub fn has_dependency<T: Send + Sync + 'static>() -> bool {
125    DEPENDENCY_REGISTRY
126        .get()
127        .map(|registry| registry.contains::<T>())
128        .unwrap_or(false)
129}
130
131/// Clear all dependencies (mainly for testing)
132pub fn clear_dependencies() -> Result<(), String> {
133    let registry = DEPENDENCY_REGISTRY
134        .get()
135        .ok_or("Dependency registry not initialized")?;
136    registry.clear()
137}
138
139/// Macro for easy dependency injection in providers
140#[macro_export]
141macro_rules! inject {
142    ($type:ty) => {
143        $crate::injection::inject::<$type>()
144            .map_err(|e| format!("Dependency injection failed: {}", e))?
145    };
146}
147
148/// Macro for registering dependencies with error handling
149#[macro_export]
150macro_rules! register {
151    ($dependency:expr) => {
152        $crate::injection::register_dependency($dependency)
153            .map_err(|e| format!("Dependency registration failed: {}", e))?
154    };
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160
161    struct TestService {
162        name: String,
163    }
164
165    impl TestService {
166        fn new(name: String) -> Self {
167            Self { name }
168        }
169
170        fn get_name(&self) -> &str {
171            &self.name
172        }
173    }
174
175    #[test]
176    fn test_dependency_injection() {
177        init_dependency_injection();
178
179        // Clear any existing dependencies
180        clear_dependencies().unwrap();
181
182        // Register a dependency
183        let service = TestService::new("test".to_string());
184        register_dependency(service).unwrap();
185
186        // Inject the dependency
187        let injected: Arc<TestService> = inject().unwrap();
188        assert_eq!(injected.get_name(), "test");
189
190        // Check if dependency exists
191        assert!(has_dependency::<TestService>());
192        assert!(!has_dependency::<String>());
193    }
194
195    #[test]
196    fn test_duplicate_registration() {
197        init_dependency_injection();
198        clear_dependencies().unwrap();
199
200        let service1 = TestService::new("first".to_string());
201        let service2 = TestService::new("second".to_string());
202
203        // First registration should succeed
204        assert!(register_dependency(service1).is_ok());
205
206        // Second registration should fail
207        assert!(register_dependency(service2).is_err());
208    }
209
210    #[test]
211    fn test_missing_dependency() {
212        init_dependency_injection();
213        clear_dependencies().unwrap();
214
215        // Try to inject non-existent dependency
216        let result: Result<Arc<TestService>, String> = inject();
217        assert!(result.is_err());
218    }
219}