flow_di/
container.rs

1use dashmap::DashMap;
2use once_cell::sync::Lazy;
3use std::any::{Any, TypeId};
4use std::collections::HashMap;
5use std::sync::{Arc, Mutex, RwLock};
6
7use crate::{
8    descriptor::ServiceProvider as DescriptorServiceProvider, DiError, DiResult, Lifetime,
9    ServiceDescriptor, ServiceKey,
10};
11
12/// Global singleton service storage
13static SINGLETON_SERVICES: Lazy<DashMap<ServiceKey, Arc<dyn Any + Send + Sync>>> =
14    Lazy::new(DashMap::new);
15
16/// Dependency injection container
17pub struct Container {
18    /// Service descriptor storage
19    services: Arc<RwLock<HashMap<ServiceKey, ServiceDescriptor>>>,
20    /// Circular dependency detection stack
21    resolution_stack: Arc<Mutex<Vec<ServiceKey>>>,
22}
23
24impl Container {
25    /// Create a new container instance
26    pub fn new() -> Self {
27        Self {
28            services: Arc::new(RwLock::new(HashMap::new())),
29            resolution_stack: Arc::new(Mutex::new(Vec::new())),
30        }
31    }
32
33    /// Register a service descriptor
34    pub fn register(&self, descriptor: ServiceDescriptor) -> DiResult<()> {
35        let mut services = self
36            .services
37            .write()
38            .map_err(|_| DiError::generic("Failed to acquire services write lock"))?;
39
40        // Check if the same service key is already registered
41        if services.contains_key(&descriptor.service_key) {
42            return Err(DiError::Generic {
43                message: format!(
44                    "Service with key {:?} is already registered",
45                    descriptor.service_key
46                ),
47            });
48        }
49
50        services.insert(descriptor.service_key.clone(), descriptor);
51        Ok(())
52    }
53
54    /// Register a service descriptor (allow overwrite)
55    pub fn register_overwrite(&self, descriptor: ServiceDescriptor) -> DiResult<()> {
56        let mut services = self
57            .services
58            .write()
59            .map_err(|_| DiError::generic("Failed to acquire services write lock"))?;
60
61        services.insert(descriptor.service_key.clone(), descriptor);
62        Ok(())
63    }
64
65    /// Check if a service is registered
66    pub fn is_registered<T: 'static>(&self) -> DiResult<bool> {
67        let key = ServiceKey::of_type::<T>();
68        self.is_registered_with_key(&key)
69    }
70
71    /// 检查带键服务是否已注册
72    pub fn is_keyed_registered<T: 'static>(&self, name: &str) -> DiResult<bool> {
73        let key = ServiceKey::named::<T>(name);
74        self.is_registered_with_key(&key)
75    }
76
77    /// 检查指定键的服务是否已注册
78    pub fn is_registered_with_key(&self, key: &ServiceKey) -> DiResult<bool> {
79        let services = self
80            .services
81            .read()
82            .map_err(|_| DiError::generic("Failed to acquire services read lock"))?;
83
84        Ok(services.contains_key(key))
85    }
86
87    /// 获取服务描述符
88    fn get_descriptor(&self, key: &ServiceKey) -> DiResult<Option<ServiceDescriptor>> {
89        let services = self
90            .services
91            .read()
92            .map_err(|_| DiError::generic("Failed to acquire services read lock"))?;
93
94        Ok(services.get(key).cloned())
95    }
96
97    /// 构建服务提供者
98    pub fn build_provider(self) -> ServiceProvider {
99        ServiceProvider::new(Arc::new(self))
100    }
101
102    /// 构建默认方法,返回提供者
103    pub fn build(self) -> ServiceProvider {
104        self.build_provider()
105    }
106}
107
108impl Default for Container {
109    fn default() -> Self {
110        Self::new()
111    }
112}
113
114/// 作用域存储类型
115type ScopeStorage = Arc<RwLock<HashMap<ServiceKey, Arc<dyn Any + Send + Sync>>>>;
116
117/// 服务提供者
118pub struct ServiceProvider {
119    container: Arc<Container>,
120}
121
122impl ServiceProvider {
123    fn new(container: Arc<Container>) -> Self {
124        Self { container }
125    }
126
127    /// 获取指定类型的所有服务
128    pub fn get_services<T: 'static + Send + Sync>(&self) -> DiResult<Vec<Arc<T>>> {
129        let descriptors = self.get_all_descriptors_for_type::<T>()?;
130        let mut services = Vec::new();
131
132        for descriptor in descriptors {
133            if let Some(service) = self.resolve_service::<T>(&descriptor.service_key, None)? {
134                services.push(service);
135            }
136        }
137
138        Ok(services)
139    }
140
141    /// 创建服务作用域
142    pub fn create_scope(&self) -> DiResult<ServiceScope> {
143        ServiceScope::new(Arc::clone(&self.container))
144    }
145
146    /// 解析服务实例
147    fn resolve_service<T: 'static + Send + Sync>(
148        &self,
149        key: &ServiceKey,
150        scope_storage: Option<&ScopeStorage>,
151    ) -> DiResult<Option<Arc<T>>> {
152        // 开始解析
153        self.begin_resolution(key)?;
154
155        let result = self.internal_resolve_service::<T>(key, scope_storage);
156
157        // 结束解析
158        self.end_resolution(key)?;
159
160        result
161    }
162
163    /// 检测循环依赖
164    fn check_circular_dependency(&self, key: &ServiceKey) -> DiResult<()> {
165        let stack = self
166            .container
167            .resolution_stack
168            .lock()
169            .map_err(|_| DiError::generic("Failed to acquire resolution stack lock"))?;
170
171        if stack.contains(key) {
172            return Err(DiError::Generic {
173                message: format!("Circular dependency detected for service key: {key:?}"),
174            });
175        }
176
177        Ok(())
178    }
179
180    /// 开始解析服务(添加到循环依赖检测栈)
181    fn begin_resolution(&self, key: &ServiceKey) -> DiResult<()> {
182        self.check_circular_dependency(key)?;
183
184        let mut stack = self
185            .container
186            .resolution_stack
187            .lock()
188            .map_err(|_| DiError::generic("Failed to acquire resolution stack lock"))?;
189
190        stack.push(key.clone());
191        Ok(())
192    }
193
194    /// 结束解析服务(从循环依赖检测栈中移除)
195    fn end_resolution(&self, key: &ServiceKey) -> DiResult<()> {
196        let mut stack = self
197            .container
198            .resolution_stack
199            .lock()
200            .map_err(|_| DiError::generic("Failed to acquire resolution stack lock"))?;
201
202        if let Some(pos) = stack.iter().position(|k| k == key) {
203            stack.remove(pos);
204        }
205
206        Ok(())
207    }
208
209    /// 内部解析服务实例
210    fn internal_resolve_service<T: 'static + Send + Sync>(
211        &self,
212        key: &ServiceKey,
213        scope_storage: Option<&ScopeStorage>,
214    ) -> DiResult<Option<Arc<T>>> {
215        let descriptor = match self.container.get_descriptor(key)? {
216            Some(desc) => desc,
217            None => return Ok(None),
218        };
219
220        match descriptor.lifetime {
221            Lifetime::Singleton => self.resolve_singleton::<T>(&descriptor),
222            Lifetime::Scoped => match scope_storage {
223                Some(storage) => self.resolve_scoped::<T>(&descriptor, storage),
224                None => Err(DiError::Generic {
225                    message: format!("Scoped service cannot be resolved without a scope: {key:?}"),
226                }),
227            },
228            Lifetime::Transient => self.resolve_transient::<T>(&descriptor),
229        }
230    }
231
232    /// 解析单例服务
233    fn resolve_singleton<T: 'static + Send + Sync>(
234        &self,
235        descriptor: &ServiceDescriptor,
236    ) -> DiResult<Option<Arc<T>>> {
237        // 检查全局单例缓存
238        if let Some(cached) = SINGLETON_SERVICES.get(&descriptor.service_key) {
239            let any_arc = Arc::clone(&cached);
240            return self.cast_to_arc::<T>(any_arc);
241        }
242
243        // 创建新实例
244        let provider = ContainerServiceProvider::new(Arc::clone(&self.container), None);
245        let instance = descriptor.create_instance(&provider)?;
246
247        // 转换并缓存
248        let typed_instance = self.box_to_typed_arc::<T>(instance)?;
249        let any_arc: Arc<dyn Any + Send + Sync> = typed_instance.clone();
250        SINGLETON_SERVICES.insert(descriptor.service_key.clone(), any_arc);
251
252        Ok(Some(typed_instance))
253    }
254
255    /// 解析作用域服务
256    fn resolve_scoped<T: 'static + Send + Sync>(
257        &self,
258        descriptor: &ServiceDescriptor,
259        scope_storage: &ScopeStorage,
260    ) -> DiResult<Option<Arc<T>>> {
261        // 首先检查作用域缓存
262        {
263            let storage = scope_storage
264                .read()
265                .map_err(|_| DiError::generic("Failed to acquire scope storage read lock"))?;
266
267            if let Some(cached) = storage.get(&descriptor.service_key) {
268                let any_arc = Arc::clone(cached);
269                return self.cast_to_arc::<T>(any_arc);
270            }
271        }
272
273        // 创建新实例
274        let provider =
275            ContainerServiceProvider::new(Arc::clone(&self.container), Some(scope_storage.clone()));
276        let instance = descriptor.create_instance(&provider)?;
277
278        // 转换并缓存到作用域
279        let typed_instance = self.box_to_typed_arc::<T>(instance)?;
280        let any_arc: Arc<dyn Any + Send + Sync> = typed_instance.clone();
281
282        {
283            let mut storage = scope_storage
284                .write()
285                .map_err(|_| DiError::generic("Failed to acquire scope storage write lock"))?;
286            storage.insert(descriptor.service_key.clone(), any_arc);
287        }
288
289        Ok(Some(typed_instance))
290    }
291
292    /// 解析瞬时服务
293    fn resolve_transient<T: 'static + Send + Sync>(
294        &self,
295        descriptor: &ServiceDescriptor,
296    ) -> DiResult<Option<Arc<T>>> {
297        let provider = ContainerServiceProvider::new(Arc::clone(&self.container), None);
298        let instance = descriptor.create_instance(&provider)?;
299        let typed_instance = self.box_to_typed_arc::<T>(instance)?;
300        Ok(Some(typed_instance))
301    }
302
303    /// 将Box转换为Arc<T>
304    fn box_to_typed_arc<T: 'static + Send + Sync>(
305        &self,
306        instance: Box<dyn Any + Send + Sync>,
307    ) -> DiResult<Arc<T>> {
308        match instance.downcast::<T>() {
309            Ok(boxed) => Ok(Arc::new(*boxed)),
310            Err(_) => Err(DiError::type_casting_failed::<T>()),
311        }
312    }
313
314    /// 类型转换辅助方法
315    fn cast_to_arc<T: 'static + Send + Sync>(
316        &self,
317        any_arc: Arc<dyn Any + Send + Sync>,
318    ) -> DiResult<Option<Arc<T>>> {
319        // 尝试从Arc中提取
320        if let Ok(arc_t) = any_arc.downcast::<T>() {
321            return Ok(Some(arc_t));
322        }
323
324        Err(DiError::type_casting_failed::<T>())
325    }
326
327    /// 获取指定类型的所有服务描述符
328    fn get_all_descriptors_for_type<T: 'static + Send + Sync>(
329        &self,
330    ) -> DiResult<Vec<ServiceDescriptor>> {
331        let services = self
332            .container
333            .services
334            .read()
335            .map_err(|_| DiError::generic("Failed to acquire services read lock"))?;
336
337        let target_type_id = TypeId::of::<T>();
338        let descriptors: Vec<ServiceDescriptor> = services
339            .values()
340            .filter(|desc| desc.service_type == target_type_id)
341            .cloned()
342            .collect();
343
344        Ok(descriptors)
345    }
346}
347
348/// 容器内部服务提供者
349struct ContainerServiceProvider {
350    container: Arc<Container>,
351    scope_storage: Option<ScopeStorage>,
352}
353
354impl ContainerServiceProvider {
355    fn new(container: Arc<Container>, scope_storage: Option<ScopeStorage>) -> Self {
356        Self {
357            container,
358            scope_storage,
359        }
360    }
361}
362
363impl DescriptorServiceProvider for ServiceProvider {
364    fn get_service_raw(&self, key: &ServiceKey) -> DiResult<Option<Arc<dyn Any + Send + Sync>>> {
365        let inner_provider = ContainerServiceProvider::new(Arc::clone(&self.container), None);
366        inner_provider.get_service_raw(key)
367    }
368}
369
370impl DescriptorServiceProvider for ContainerServiceProvider {
371    fn get_service_raw(&self, key: &ServiceKey) -> DiResult<Option<Arc<dyn Any + Send + Sync>>> {
372        // 获取服务描述符
373        let descriptor = match self.container.get_descriptor(key)? {
374            Some(desc) => desc,
375            None => return Ok(None),
376        };
377
378        // 根据生命周期解析服务
379        match descriptor.lifetime {
380            Lifetime::Singleton => {
381                // 检查全局单例缓存
382                if let Some(cached) = SINGLETON_SERVICES.get(&descriptor.service_key) {
383                    return Ok(Some(Arc::clone(&cached)));
384                }
385
386                // 创建新实例
387                let inner_provider =
388                    ContainerServiceProvider::new(Arc::clone(&self.container), None);
389                let instance = descriptor.create_instance(&inner_provider)?;
390                let any_arc: Arc<dyn Any + Send + Sync> = Arc::from(instance);
391                SINGLETON_SERVICES.insert(descriptor.service_key.clone(), Arc::clone(&any_arc));
392                Ok(Some(any_arc))
393            }
394            Lifetime::Scoped => {
395                if let Some(storage) = &self.scope_storage {
396                    // 首先检查作用域缓存
397                    {
398                        let storage_guard = storage.read().map_err(|_| {
399                            DiError::generic("Failed to acquire scope storage read lock")
400                        })?;
401
402                        if let Some(cached) = storage_guard.get(&descriptor.service_key) {
403                            return Ok(Some(Arc::clone(cached)));
404                        }
405                    }
406
407                    // 创建新实例
408                    let inner_provider = ContainerServiceProvider::new(
409                        Arc::clone(&self.container),
410                        Some(storage.clone()),
411                    );
412                    let instance = descriptor.create_instance(&inner_provider)?;
413                    let any_arc: Arc<dyn Any + Send + Sync> = Arc::from(instance);
414
415                    {
416                        let mut storage_guard = storage.write().map_err(|_| {
417                            DiError::generic("Failed to acquire scope storage write lock")
418                        })?;
419                        storage_guard.insert(descriptor.service_key.clone(), Arc::clone(&any_arc));
420                    }
421
422                    Ok(Some(any_arc))
423                } else {
424                    Err(DiError::Generic {
425                        message: format!(
426                            "Scoped service cannot be resolved without a scope: {key:?}"
427                        ),
428                    })
429                }
430            }
431            Lifetime::Transient => {
432                let inner_provider = ContainerServiceProvider::new(
433                    Arc::clone(&self.container),
434                    self.scope_storage.clone(),
435                );
436                let instance = descriptor.create_instance(&inner_provider)?;
437                let any_arc: Arc<dyn Any + Send + Sync> = Arc::from(instance);
438                Ok(Some(any_arc))
439            }
440        }
441    }
442}
443
444/// 服务作用域
445pub struct ServiceScope {
446    container: Arc<Container>,
447    storage: ScopeStorage,
448    disposed: Arc<Mutex<bool>>,
449}
450
451impl ServiceScope {
452    /// 创建新的服务作用域
453    pub fn new(container: Arc<Container>) -> DiResult<Self> {
454        Ok(Self {
455            container,
456            storage: Arc::new(RwLock::new(HashMap::new())),
457            disposed: Arc::new(Mutex::new(false)),
458        })
459    }
460
461    /// 检查作用域是否可用
462    fn ensure_not_disposed(&self) -> DiResult<()> {
463        let disposed = self
464            .disposed
465            .lock()
466            .map_err(|_| DiError::generic("Failed to acquire disposed lock"))?;
467
468        if *disposed {
469            return Err(DiError::ScopeDisposed);
470        }
471
472        Ok(())
473    }
474
475    /// 获取指定类型的所有服务
476    pub fn get_services<T: 'static + Send + Sync>(&self) -> DiResult<Vec<Arc<T>>> {
477        self.ensure_not_disposed()?;
478        let provider = ServiceProvider::new(Arc::clone(&self.container));
479
480        let descriptors = provider.get_all_descriptors_for_type::<T>()?;
481        let mut services = Vec::new();
482
483        for descriptor in descriptors {
484            if let Some(service) =
485                provider.resolve_service::<T>(&descriptor.service_key, Some(&self.storage))?
486            {
487                services.push(service);
488            }
489        }
490
491        Ok(services)
492    }
493
494    /// 创建嵌套作用域
495    pub fn create_scope(&self) -> DiResult<ServiceScope> {
496        self.ensure_not_disposed()?;
497        ServiceScope::new(Arc::clone(&self.container))
498    }
499
500    /// 释放作用域资源
501    pub fn dispose(&mut self) {
502        if let Ok(mut disposed) = self.disposed.lock() {
503            if !*disposed {
504                *disposed = true;
505
506                // 清理作用域中的服务
507                if let Ok(mut storage) = self.storage.write() {
508                    storage.clear();
509                }
510            }
511        }
512    }
513
514    /// 检查作用域是否已释放
515    pub fn is_disposed(&self) -> bool {
516        self.disposed
517            .lock()
518            .map(|disposed| *disposed)
519            .unwrap_or(true)
520    }
521}
522
523impl DescriptorServiceProvider for ServiceScope {
524    fn get_service_raw(&self, key: &ServiceKey) -> DiResult<Option<Arc<dyn Any + Send + Sync>>> {
525        self.ensure_not_disposed()?;
526        let inner_provider =
527            ContainerServiceProvider::new(Arc::clone(&self.container), Some(self.storage.clone()));
528        inner_provider.get_service_raw(key)
529    }
530}
531
532impl Drop for ServiceScope {
533    fn drop(&mut self) {
534        self.dispose();
535    }
536}
537
538#[cfg(test)]
539mod tests {
540    use super::*;
541    use crate::descriptor::ServiceProviderExt;
542    use crate::ServiceDescriptor;
543
544    #[derive(Debug, Clone, PartialEq)]
545    struct TestService {
546        value: i32,
547    }
548
549    #[derive(Debug, Clone, PartialEq)]
550    #[allow(dead_code)]
551    struct DependentService {
552        dependency: Arc<TestService>,
553    }
554
555    #[test]
556    fn test_container_creation() {
557        let container = Container::new();
558        assert!(!container.is_registered::<TestService>().unwrap());
559    }
560
561    #[test]
562    fn test_service_registration() {
563        let container = Container::new();
564
565        let descriptor = ServiceDescriptor::transient::<TestService, TestService>(Box::new(|_| {
566            Ok(Box::new(TestService { value: 42 }))
567        }));
568
569        container.register(descriptor).unwrap();
570        assert!(container.is_registered::<TestService>().unwrap());
571    }
572
573    #[test]
574    fn test_singleton_service_resolution() {
575        let container = Container::new();
576
577        let descriptor = ServiceDescriptor::singleton::<TestService, TestService>(Box::new(|_| {
578            Ok(Box::new(TestService { value: 100 }))
579        }));
580
581        container.register(descriptor).unwrap();
582
583        let provider = container.build();
584        let service1 = provider.get_required_service::<TestService>().unwrap();
585        let service2 = provider.get_required_service::<TestService>().unwrap();
586
587        assert_eq!(service1.value, 100);
588        assert_eq!(service2.value, 100);
589    }
590
591    #[test]
592    fn test_transient_service_resolution() {
593        let container = Container::new();
594
595        let descriptor = ServiceDescriptor::transient::<TestService, TestService>(Box::new(|_| {
596            Ok(Box::new(TestService { value: 200 }))
597        }));
598
599        container.register(descriptor).unwrap();
600
601        let provider = container.build();
602        let service1 = provider.get_required_service::<TestService>().unwrap();
603        let service2 = provider.get_required_service::<TestService>().unwrap();
604
605        assert_eq!(service1.value, 200);
606        assert_eq!(service2.value, 200);
607    }
608
609    #[test]
610    fn test_keyed_service_registration_and_resolution() {
611        let container = Container::new();
612
613        let descriptor = ServiceDescriptor::named_singleton::<TestService, TestService>(
614            "primary",
615            Box::new(|_| Ok(Box::new(TestService { value: 300 }))),
616        );
617
618        container.register(descriptor).unwrap();
619        assert!(container
620            .is_keyed_registered::<TestService>("primary")
621            .unwrap());
622        assert!(!container
623            .is_keyed_registered::<TestService>("secondary")
624            .unwrap());
625
626        let provider = container.build();
627        let service = provider
628            .get_required_keyed_service::<TestService>("primary")
629            .unwrap();
630        assert_eq!(service.value, 300);
631
632        let result = provider.get_keyed_service::<TestService>("nonexistent");
633        assert!(result.is_ok());
634        assert!(result.unwrap().is_none());
635    }
636
637    #[test]
638    fn test_scoped_service_with_scope() {
639        let container = Container::new();
640
641        let descriptor = ServiceDescriptor::scoped::<TestService, TestService>(Box::new(|_| {
642            Ok(Box::new(TestService { value: 400 }))
643        }));
644
645        container.register(descriptor).unwrap();
646
647        let provider = container.build();
648        let mut scope = provider.create_scope().unwrap();
649
650        let service1 = scope.get_required_service::<TestService>().unwrap();
651        let service2 = scope.get_required_service::<TestService>().unwrap();
652
653        assert_eq!(service1.value, 400);
654        assert_eq!(service2.value, 400);
655
656        scope.dispose();
657    }
658
659    #[test]
660    fn test_service_collection() {
661        let container = Container::new();
662
663        let desc1 = ServiceDescriptor::named_transient::<TestService, TestService>(
664            "service1",
665            Box::new(|_| Ok(Box::new(TestService { value: 1 }))),
666        );
667        let desc2 = ServiceDescriptor::named_transient::<TestService, TestService>(
668            "service2",
669            Box::new(|_| Ok(Box::new(TestService { value: 2 }))),
670        );
671
672        container.register(desc1).unwrap();
673        container.register(desc2).unwrap();
674
675        let provider = container.build();
676        let services = provider.get_services::<TestService>().unwrap();
677
678        assert_eq!(services.len(), 2);
679        let values: Vec<i32> = services.iter().map(|s| s.value).collect();
680        assert!(values.contains(&1));
681        assert!(values.contains(&2));
682    }
683}