atomr_core/event/
event_stream.rs1use std::any::{Any, TypeId};
4use std::sync::Arc;
5
6use dashmap::DashMap;
7use parking_lot::Mutex;
8
9type SubFn = Arc<dyn Fn(&(dyn Any + Send + Sync)) + Send + Sync>;
10type SubMap = Arc<DashMap<TypeId, Mutex<Vec<(u64, SubFn)>>>>;
11
12#[derive(Clone)]
13pub struct Subscription {
14 pub id: u64,
15 type_id: TypeId,
16 map: SubMap,
17}
18
19impl Subscription {
20 pub fn unsubscribe(&self) {
21 if let Some(e) = self.map.get(&self.type_id) {
22 e.lock().retain(|(id, _)| *id != self.id);
23 }
24 }
25}
26
27#[derive(Default)]
28pub struct EventStream {
29 map: SubMap,
30 next_id: std::sync::atomic::AtomicU64,
31}
32
33impl EventStream {
34 pub fn new() -> Self {
35 Self::default()
36 }
37
38 pub fn subscribe<T: Any + Send + Sync>(&self, f: impl Fn(&T) + Send + Sync + 'static) -> Subscription {
39 let id = self.next_id.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
40 let type_id = TypeId::of::<T>();
41 let cb: SubFn = Arc::new(move |any: &(dyn Any + Send + Sync)| {
42 if let Some(t) = any.downcast_ref::<T>() {
43 f(t);
44 }
45 });
46 self.map.entry(type_id).or_default().lock().push((id, cb));
47 Subscription { id, type_id, map: self.map.clone() }
48 }
49
50 pub fn subscribe_filtered<T, P>(&self, pred: P, f: impl Fn(&T) + Send + Sync + 'static) -> Subscription
55 where
56 T: Any + Send + Sync,
57 P: Fn(&T) -> bool + Send + Sync + 'static,
58 {
59 let id = self.next_id.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
60 let type_id = TypeId::of::<T>();
61 let cb: SubFn = Arc::new(move |any: &(dyn Any + Send + Sync)| {
62 if let Some(t) = any.downcast_ref::<T>() {
63 if pred(t) {
64 f(t);
65 }
66 }
67 });
68 self.map.entry(type_id).or_default().lock().push((id, cb));
69 Subscription { id, type_id, map: self.map.clone() }
70 }
71
72 pub fn subscriber_count<T: Any>(&self) -> usize {
74 self.map.get(&TypeId::of::<T>()).map(|e| e.lock().len()).unwrap_or(0)
75 }
76
77 pub fn publish<T: Any + Send + Sync>(&self, value: T) {
78 let type_id = TypeId::of::<T>();
79 let value_arc: Arc<dyn Any + Send + Sync> = Arc::new(value);
80 if let Some(entry) = self.map.get(&type_id) {
81 let callbacks: Vec<SubFn> = entry.lock().iter().map(|(_, f)| f.clone()).collect();
82 for f in callbacks {
83 f(&*value_arc);
84 }
85 }
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use super::*;
92 use std::sync::atomic::{AtomicU32, Ordering};
93
94 #[test]
95 fn publishes_to_typed_subscribers() {
96 let bus = EventStream::new();
97 let n = Arc::new(AtomicU32::new(0));
98 let n1 = n.clone();
99 let sub = bus.subscribe(move |v: &u32| {
100 n1.fetch_add(*v, Ordering::SeqCst);
101 });
102 bus.publish(10u32);
103 bus.publish(20u32);
104 bus.publish("ignored".to_string());
105 assert_eq!(n.load(Ordering::SeqCst), 30);
106 sub.unsubscribe();
107 bus.publish(100u32);
108 assert_eq!(n.load(Ordering::SeqCst), 30);
109 }
110
111 #[test]
112 fn subscribe_filtered_delivers_only_matches() {
113 let bus = EventStream::new();
114 let count = Arc::new(AtomicU32::new(0));
115 let c2 = count.clone();
116 let _sub = bus.subscribe_filtered(
117 |v: &u32| *v > 5,
118 move |_| {
119 c2.fetch_add(1, Ordering::SeqCst);
120 },
121 );
122 bus.publish(1u32);
123 bus.publish(7u32);
124 bus.publish(2u32);
125 bus.publish(99u32);
126 assert_eq!(count.load(Ordering::SeqCst), 2);
127 }
128
129 #[test]
130 fn subscriber_count_reflects_registered_subscribers() {
131 let bus = EventStream::new();
132 assert_eq!(bus.subscriber_count::<u32>(), 0);
133 let _s1 = bus.subscribe(|_v: &u32| {});
134 let _s2 = bus.subscribe_filtered(|_v: &u32| true, |_| {});
135 assert_eq!(bus.subscriber_count::<u32>(), 2);
136 assert_eq!(bus.subscriber_count::<String>(), 0);
137 }
138}