Skip to main content

atomr_core/actor/
extensions.rs

1//! Extensions — per-`ActorSystem` singletons keyed by type.
2//! akka.net: `Actor/Extensions.cs`.
3
4use std::any::{Any, TypeId};
5use std::sync::Arc;
6
7use dashmap::DashMap;
8
9/// Marker trait for types stored in `Extensions`.
10pub trait Extension: Any + Send + Sync {}
11
12impl<T: Any + Send + Sync> Extension for T {}
13
14/// Identifier trait mirroring akka.net's `IExtensionId<T>`.
15pub trait ExtensionId<E: Extension>: Send + Sync {
16    fn create(&self) -> E;
17}
18
19#[derive(Debug, Default)]
20pub struct Extensions {
21    inner: DashMap<TypeId, Arc<dyn Any + Send + Sync>>,
22}
23
24impl Extensions {
25    pub fn new() -> Self {
26        Self::default()
27    }
28
29    pub fn register<E: Extension>(&self, ext: E) {
30        self.inner.insert(TypeId::of::<E>(), Arc::new(ext));
31    }
32
33    pub fn get<E: Extension>(&self) -> Option<Arc<E>> {
34        self.inner.get(&TypeId::of::<E>()).and_then(|e| e.clone().downcast::<E>().ok())
35    }
36
37    pub fn get_or_create<E: Extension, I: ExtensionId<E>>(&self, id: &I) -> Arc<E> {
38        if let Some(e) = self.get::<E>() {
39            return e;
40        }
41        let ext = id.create();
42        self.register(ext);
43        self.get::<E>().expect("just inserted")
44    }
45}
46
47#[cfg(test)]
48mod tests {
49    use super::*;
50
51    struct Metrics(u32);
52    struct MetricsId;
53    impl ExtensionId<Metrics> for MetricsId {
54        fn create(&self) -> Metrics {
55            Metrics(99)
56        }
57    }
58
59    #[test]
60    fn create_and_get() {
61        let e = Extensions::new();
62        let m = e.get_or_create::<Metrics, _>(&MetricsId);
63        assert_eq!(m.0, 99);
64        assert!(e.get::<Metrics>().is_some());
65    }
66}