1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
use std::sync::Arc;
use dashmap::DashMap;
use std::any::Any;
use std::any::type_name;
use std::ops::Deref;
use once_cell::sync::OnceCell;
use std::sync::Mutex;
use std::error::Error;
pub use autowired_derive::Component;

fn component_mutex() -> &'static Mutex<u64> {
    static INSTANCE: OnceCell<Mutex<u64>> = OnceCell::new();
    INSTANCE.get_or_init(Default::default)
}

fn component_dashmap() -> &'static DashMap<String, Arc<dyn Any + 'static + Send + Sync>> {
    static INSTANCE: OnceCell<DashMap<String, Arc<dyn Any + 'static + Send + Sync>>> = OnceCell::new();
    INSTANCE.get_or_init(Default::default)
}

fn get_component<T: Component>() -> Option<Arc<T>> {
    component_dashmap().get(type_name::<T>())
        .map(|x| x.value().clone())
        .map(|x| x.downcast::<T>().ok())
        .flatten()
}

/// return true if component exists
pub fn exist_component<T: Component>() -> bool {
    component_dashmap().contains_key(type_name::<T>())
}

pub trait Component: Any + 'static + Send + Sync {
    /// create a new component instance
    fn new_instance() -> Result<Arc<Self>, Box<dyn Error>>;

    /// call `new_instance` to create new component, then add it into a global map
    fn register() where Self: std::marker::Sized {
        let name = type_name::<Self>();
        // 在注册组件的时候进行加锁,防止出现多次初始化
        if let Ok(mut count) = component_mutex().lock() {
            if component_dashmap().contains_key(name) {
                return;
            }

            let component: Arc<Self> = match Self::new_instance() {
                Ok(v) => v,
                Err(e) => {
                    log::error!("[Component] register failure, {}", e);
                    return;
                }
            };
            component_dashmap().insert(name.to_string(), component.clone());
            *count += 1;

            log::debug!("[Component] register, name={}", name);
            component.after_register();
        }
    }

    /// run code after component register
    fn after_register(&self) {}
}

/// lazy autowired
pub struct Autowired<T> {
    inner: OnceCell<Arc<T>>,
}

impl<T> Autowired<T> {
    pub const fn new() -> Self {
        Autowired { inner: OnceCell::new() }
    }
}

impl<T: Component> Deref for Autowired<T> {
    type Target = Arc<T>;

    fn deref(&self) -> &Self::Target {
        self.inner.get_or_init(|| {
            if !exist_component::<T>() {
                T::register()
            }
            get_component::<T>().unwrap_or_else(||
                panic!(format!("[Autowired] not found component {}", type_name::<T>()))
            )
        })
    }
}

impl<T: Component> Default for Autowired<T> {
    fn default() -> Self {
        Autowired::new()
    }
}

#[cfg(test)]
mod tests {
    use crate::{Component, Autowired};
    use std::sync::Arc;
    use std::error::Error;
    use std::sync::atomic::{AtomicU32, Ordering};
    use once_cell::sync::OnceCell;

    const TEST_STRING: &str = "1234567890";

    fn atomic_count() -> &'static AtomicU32 {
        static INSTANCE: OnceCell<AtomicU32> = OnceCell::new();
        INSTANCE.get_or_init(Default::default)
    }

    #[derive(Default)]
    struct Foo {
        value: String,
    }

    impl Component for Foo {
        fn new_instance() -> Result<Arc<Self>, Box<dyn Error>> {
            Ok(Arc::new(Foo {
                value: TEST_STRING.to_string(),
            }))
        }
        fn after_register(&self) {
            atomic_count().fetch_add(1, Ordering::SeqCst);
        }
    }

    #[derive(Default, Component)]
    struct Bar {
        name: String,
        age: u32,
    }

    #[test]
    fn register_foo() {
        assert_eq!(0, atomic_count().load(Ordering::SeqCst));

        let foo = Autowired::<Foo>::new();

        assert_eq!(TEST_STRING, foo.value);
        assert_eq!(1, atomic_count().load(Ordering::SeqCst));
    }

    #[test]
    fn register_bar() {
        let bar: Autowired<Bar> = Autowired::new();

        assert_eq!(String::default(), bar.name);
        assert_eq!(u32::default(), bar.age);
    }
}