portaldi_core/
container.rs

1//! DI container functionality.
2
3use crate::{traits::DITarget, types::DI};
4use std::{any::Any, collections::HashMap, future::Future};
5
6#[cfg(all(target_arch = "wasm32", not(feature = "multi-thread")))]
7use std::cell::RefCell;
8#[cfg(any(not(target_arch = "wasm32"), feature = "multi-thread"))]
9use std::sync::RwLock;
10
11/// DI container holds component refs.
12#[derive(Debug)]
13pub struct DIContainer {
14    /// Hold components by its type name (FQTN).
15    #[cfg(all(target_arch = "wasm32", not(feature = "multi-thread")))]
16    components: RefCell<HashMap<String, DI<dyn Any>>>,
17    #[cfg(any(not(target_arch = "wasm32"), feature = "multi-thread"))]
18    components: RwLock<HashMap<String, DI<dyn Any + Send + Sync>>>,
19}
20
21impl DIContainer {
22    /// Create new instance.
23    pub fn new() -> DIContainer {
24        DIContainer {
25            #[cfg(all(target_arch = "wasm32", not(feature = "multi-thread")))]
26            components: RefCell::new(HashMap::new()),
27            #[cfg(any(not(target_arch = "wasm32"), feature = "multi-thread"))]
28            components: RwLock::new(HashMap::new()),
29        }
30    }
31
32    /// Get a component by type.
33    pub fn get<T: DITarget>(&self) -> Option<DI<T>> {
34        #[cfg(all(target_arch = "wasm32", not(feature = "multi-thread")))]
35        let comps = self.components.borrow();
36        #[cfg(any(not(target_arch = "wasm32"), feature = "multi-thread"))]
37        let comps = self.components.read().unwrap();
38        comps
39            .get(std::any::type_name::<T>())
40            .map(|c| c.clone().downcast::<T>().unwrap())
41    }
42
43    /// Put a component into the container.
44    pub fn put_if_absent<T: DITarget>(&self, c: &DI<T>) -> DI<T> {
45        #[cfg(all(target_arch = "wasm32", not(feature = "multi-thread")))]
46        let mut components = self.components.borrow_mut();
47        #[cfg(any(not(target_arch = "wasm32"), feature = "multi-thread"))]
48        let mut components = self.components.write().unwrap();
49        let key = std::any::type_name::<T>();
50        let value = components
51            .get(key)
52            .map(|c| c.clone().downcast::<T>().unwrap());
53        if let Some(c) = value {
54            c
55        } else {
56            components.insert(key.into(), c.clone());
57            c.clone()
58        }
59    }
60
61    /// Get a component by type with a initialization.
62    /// If a target component does not exists, create and put into the container.
63    pub fn get_or_init<T, F>(&self, init: F) -> DI<T>
64    where
65        T: DITarget,
66        F: Fn() -> T,
67    {
68        if let Some(c) = self.get::<T>() {
69            c
70        } else {
71            let c = DI::new(init());
72            self.put_if_absent(&c)
73        }
74    }
75
76    /// Get a component by type with a async initialization.
77    /// If a target component does not exists, create and put into the container.
78    pub async fn get_or_init_async<T, F, Fut>(&self, init: F) -> DI<T>
79    where
80        T: DITarget,
81        F: Fn() -> Fut,
82        Fut: Future<Output = T>,
83    {
84        if let Some(c) = self.get::<T>() {
85            c
86        } else {
87            let v = init().await;
88            let c = DI::new(v);
89            self.put_if_absent(&c)
90        }
91    }
92}