Skip to main content

atomr_core/event/
event_stream.rs

1//! Typed pub/sub.
2
3use std::any::{Any, TypeId};
4use std::sync::Arc;
5
6use dashmap::DashMap;
7use parking_lot::Mutex;
8
9type SubFn = Arc<dyn Fn(&(dyn Any + Send + Sync)) + Send + Sync>;
10type SubMap = Arc<DashMap<TypeId, Mutex<Vec<(u64, SubFn)>>>>;
11
12#[derive(Clone)]
13pub struct Subscription {
14    pub id: u64,
15    type_id: TypeId,
16    map: SubMap,
17}
18
19impl Subscription {
20    pub fn unsubscribe(&self) {
21        if let Some(e) = self.map.get(&self.type_id) {
22            e.lock().retain(|(id, _)| *id != self.id);
23        }
24    }
25}
26
27#[derive(Default)]
28pub struct EventStream {
29    map: SubMap,
30    next_id: std::sync::atomic::AtomicU64,
31}
32
33impl EventStream {
34    pub fn new() -> Self {
35        Self::default()
36    }
37
38    pub fn subscribe<T: Any + Send + Sync>(&self, f: impl Fn(&T) + Send + Sync + 'static) -> Subscription {
39        let id = self.next_id.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
40        let type_id = TypeId::of::<T>();
41        let cb: SubFn = Arc::new(move |any: &(dyn Any + Send + Sync)| {
42            if let Some(t) = any.downcast_ref::<T>() {
43                f(t);
44            }
45        });
46        self.map.entry(type_id).or_default().lock().push((id, cb));
47        Subscription { id, type_id, map: self.map.clone() }
48    }
49
50    /// Subscribe with a predicate filter — only events matching
51    /// `pred(t)` are delivered to `f`. Phase 3.5 of
52    /// `docs/full-port-plan.md`.
53    /// `EventStream.Subscribe(IActorRef, predicate)` analog.
54    pub fn subscribe_filtered<T, P>(&self, pred: P, f: impl Fn(&T) + Send + Sync + 'static) -> Subscription
55    where
56        T: Any + Send + Sync,
57        P: Fn(&T) -> bool + Send + Sync + 'static,
58    {
59        let id = self.next_id.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
60        let type_id = TypeId::of::<T>();
61        let cb: SubFn = Arc::new(move |any: &(dyn Any + Send + Sync)| {
62            if let Some(t) = any.downcast_ref::<T>() {
63                if pred(t) {
64                    f(t);
65                }
66            }
67        });
68        self.map.entry(type_id).or_default().lock().push((id, cb));
69        Subscription { id, type_id, map: self.map.clone() }
70    }
71
72    /// Number of subscribers registered for events of type `T`.
73    pub fn subscriber_count<T: Any>(&self) -> usize {
74        self.map.get(&TypeId::of::<T>()).map(|e| e.lock().len()).unwrap_or(0)
75    }
76
77    pub fn publish<T: Any + Send + Sync>(&self, value: T) {
78        let type_id = TypeId::of::<T>();
79        let value_arc: Arc<dyn Any + Send + Sync> = Arc::new(value);
80        if let Some(entry) = self.map.get(&type_id) {
81            let callbacks: Vec<SubFn> = entry.lock().iter().map(|(_, f)| f.clone()).collect();
82            for f in callbacks {
83                f(&*value_arc);
84            }
85        }
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92    use std::sync::atomic::{AtomicU32, Ordering};
93
94    #[test]
95    fn publishes_to_typed_subscribers() {
96        let bus = EventStream::new();
97        let n = Arc::new(AtomicU32::new(0));
98        let n1 = n.clone();
99        let sub = bus.subscribe(move |v: &u32| {
100            n1.fetch_add(*v, Ordering::SeqCst);
101        });
102        bus.publish(10u32);
103        bus.publish(20u32);
104        bus.publish("ignored".to_string());
105        assert_eq!(n.load(Ordering::SeqCst), 30);
106        sub.unsubscribe();
107        bus.publish(100u32);
108        assert_eq!(n.load(Ordering::SeqCst), 30);
109    }
110
111    #[test]
112    fn subscribe_filtered_delivers_only_matches() {
113        let bus = EventStream::new();
114        let count = Arc::new(AtomicU32::new(0));
115        let c2 = count.clone();
116        let _sub = bus.subscribe_filtered(
117            |v: &u32| *v > 5,
118            move |_| {
119                c2.fetch_add(1, Ordering::SeqCst);
120            },
121        );
122        bus.publish(1u32);
123        bus.publish(7u32);
124        bus.publish(2u32);
125        bus.publish(99u32);
126        assert_eq!(count.load(Ordering::SeqCst), 2);
127    }
128
129    #[test]
130    fn subscriber_count_reflects_registered_subscribers() {
131        let bus = EventStream::new();
132        assert_eq!(bus.subscriber_count::<u32>(), 0);
133        let _s1 = bus.subscribe(|_v: &u32| {});
134        let _s2 = bus.subscribe_filtered(|_v: &u32| true, |_| {});
135        assert_eq!(bus.subscriber_count::<u32>(), 2);
136        assert_eq!(bus.subscriber_count::<String>(), 0);
137    }
138}