ferro_events/
dispatcher.rs1use crate::{Error, Event, Listener};
4use parking_lot::RwLock;
5use std::any::TypeId;
6use std::collections::HashMap;
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10use tracing::{debug, error, info};
11
12type ListenerFn<E> =
14 Arc<dyn Fn(&E) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send>> + Send + Sync>;
15
16struct ListenerEntry {
18 handler: Box<dyn std::any::Any + Send + Sync>,
20 priority: i32,
22}
23
24pub struct EventDispatcher {
30 listeners: RwLock<HashMap<TypeId, Vec<ListenerEntry>>>,
32}
33
34impl Default for EventDispatcher {
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40impl EventDispatcher {
41 pub fn new() -> Self {
43 Self {
44 listeners: RwLock::new(HashMap::new()),
45 }
46 }
47
48 pub fn listen<E, L>(&self, listener: L)
74 where
75 E: Event,
76 L: Listener<E>,
77 {
78 self.listen_with_priority(listener, 0);
79 }
80
81 pub fn listen_with_priority<E, L>(&self, listener: L, priority: i32)
85 where
86 E: Event,
87 L: Listener<E>,
88 {
89 let listener = Arc::new(listener);
90 let handler: ListenerFn<E> = Arc::new(move |event: &E| {
91 let listener = Arc::clone(&listener);
92 let event = event.clone();
93 Box::pin(async move { listener.handle(&event).await })
94 });
95
96 let entry = ListenerEntry {
97 handler: Box::new(handler),
98 priority,
99 };
100
101 let type_id = TypeId::of::<E>();
102 let mut listeners = self.listeners.write();
103 let list = listeners.entry(type_id).or_default();
104 list.push(entry);
105 list.sort_by(|a, b| b.priority.cmp(&a.priority));
107 }
108
109 pub fn on<E, F, Fut>(&self, handler: F)
129 where
130 E: Event,
131 F: Fn(E) -> Fut + Send + Sync + 'static,
132 Fut: Future<Output = Result<(), Error>> + Send + 'static,
133 {
134 let handler = Arc::new(handler);
135 let listener_fn: ListenerFn<E> = Arc::new(move |event: &E| {
136 let handler = Arc::clone(&handler);
137 let event = event.clone();
138 Box::pin(async move { handler(event).await })
139 });
140
141 let entry = ListenerEntry {
142 handler: Box::new(listener_fn),
143 priority: 0,
144 };
145
146 let type_id = TypeId::of::<E>();
147 let mut listeners = self.listeners.write();
148 listeners.entry(type_id).or_default().push(entry);
149 }
150
151 pub async fn dispatch<E: Event>(&self, event: E) -> Result<(), Error> {
155 let type_id = TypeId::of::<E>();
156 let event_name = event.name();
157
158 debug!(event = event_name, "Dispatching event");
159
160 let handlers: Vec<ListenerFn<E>> = {
161 let listeners = self.listeners.read();
162 match listeners.get(&type_id) {
163 Some(entries) => entries
164 .iter()
165 .filter_map(|entry| entry.handler.downcast_ref::<ListenerFn<E>>().cloned())
166 .collect(),
167 None => {
168 debug!(event = event_name, "No listeners registered");
169 return Ok(());
170 }
171 }
172 };
173
174 info!(
175 event = event_name,
176 listener_count = handlers.len(),
177 "Calling listeners"
178 );
179
180 for handler in handlers {
181 if let Err(e) = handler(&event).await {
182 error!(event = event_name, error = %e, "Listener failed");
183 return Err(e);
184 }
185 }
186
187 debug!(event = event_name, "Event dispatched successfully");
188 Ok(())
189 }
190
191 pub fn dispatch_async<E: Event + 'static>(&self, event: E) {
195 let type_id = TypeId::of::<E>();
196 let event_name = event.name();
197
198 let handlers: Vec<ListenerFn<E>> = {
199 let listeners = self.listeners.read();
200 match listeners.get(&type_id) {
201 Some(entries) => entries
202 .iter()
203 .filter_map(|entry| entry.handler.downcast_ref::<ListenerFn<E>>().cloned())
204 .collect(),
205 None => return,
206 }
207 };
208
209 tokio::spawn(async move {
210 for handler in handlers {
211 if let Err(e) = handler(&event).await {
212 error!(event = event_name, error = %e, "Async listener failed");
213 }
214 }
215 });
216 }
217
218 pub fn has_listeners<E: Event>(&self) -> bool {
220 let type_id = TypeId::of::<E>();
221 let listeners = self.listeners.read();
222 listeners.get(&type_id).is_some_and(|v| !v.is_empty())
223 }
224
225 pub fn forget<E: Event>(&self) {
227 let type_id = TypeId::of::<E>();
228 let mut listeners = self.listeners.write();
229 listeners.remove(&type_id);
230 }
231
232 pub fn flush(&self) {
234 let mut listeners = self.listeners.write();
235 listeners.clear();
236 }
237}
238
239static GLOBAL_DISPATCHER: std::sync::OnceLock<EventDispatcher> = std::sync::OnceLock::new();
241
242pub fn global_dispatcher() -> &'static EventDispatcher {
244 GLOBAL_DISPATCHER.get_or_init(EventDispatcher::new)
245}
246
247pub async fn dispatch<E: Event>(event: E) -> Result<(), Error> {
266 global_dispatcher().dispatch(event).await
267}
268
269pub fn dispatch_sync<E: Event + 'static>(event: E) {
273 global_dispatcher().dispatch_async(event);
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279 use std::sync::atomic::{AtomicU32, Ordering};
280
281 #[derive(Clone)]
282 struct TestEvent {
283 value: u32,
284 }
285
286 impl Event for TestEvent {
287 fn name(&self) -> &'static str {
288 "TestEvent"
289 }
290 }
291
292 #[tokio::test]
293 async fn test_dispatch_to_closure() {
294 let dispatcher = EventDispatcher::new();
295 let counter = Arc::new(AtomicU32::new(0));
296 let counter_clone = Arc::clone(&counter);
297
298 dispatcher.on::<TestEvent, _, _>(move |event| {
299 let counter = Arc::clone(&counter_clone);
300 async move {
301 counter.fetch_add(event.value, Ordering::SeqCst);
302 Ok(())
303 }
304 });
305
306 dispatcher.dispatch(TestEvent { value: 5 }).await.unwrap();
307 assert_eq!(counter.load(Ordering::SeqCst), 5);
308 }
309
310 #[tokio::test]
311 async fn test_multiple_listeners() {
312 let dispatcher = EventDispatcher::new();
313 let counter = Arc::new(AtomicU32::new(0));
314
315 for _ in 0..3 {
316 let counter_clone = Arc::clone(&counter);
317 dispatcher.on::<TestEvent, _, _>(move |_| {
318 let counter = Arc::clone(&counter_clone);
319 async move {
320 counter.fetch_add(1, Ordering::SeqCst);
321 Ok(())
322 }
323 });
324 }
325
326 dispatcher.dispatch(TestEvent { value: 1 }).await.unwrap();
327 assert_eq!(counter.load(Ordering::SeqCst), 3);
328 }
329
330 #[tokio::test]
331 async fn test_priority_order() {
332 let dispatcher = EventDispatcher::new();
333 let order = Arc::new(RwLock::new(Vec::new()));
334
335 for priority in [1, 3, 2] {
337 let order_clone = Arc::clone(&order);
338 let handler: ListenerFn<TestEvent> = Arc::new(move |_| {
339 let order = Arc::clone(&order_clone);
340 let p = priority;
341 Box::pin(async move {
342 order.write().push(p);
343 Ok(())
344 })
345 });
346
347 let entry = ListenerEntry {
348 handler: Box::new(handler),
349 priority,
350 };
351
352 let type_id = TypeId::of::<TestEvent>();
353 let mut listeners = dispatcher.listeners.write();
354 let list = listeners.entry(type_id).or_default();
355 list.push(entry);
356 list.sort_by(|a, b| b.priority.cmp(&a.priority));
357 }
358
359 dispatcher.dispatch(TestEvent { value: 0 }).await.unwrap();
360
361 let result = order.read().clone();
362 assert_eq!(result, vec![3, 2, 1]);
363 }
364
365 #[tokio::test]
366 async fn test_has_listeners() {
367 let dispatcher = EventDispatcher::new();
368 assert!(!dispatcher.has_listeners::<TestEvent>());
369
370 dispatcher.on::<TestEvent, _, _>(|_| async { Ok(()) });
371 assert!(dispatcher.has_listeners::<TestEvent>());
372
373 dispatcher.forget::<TestEvent>();
374 assert!(!dispatcher.has_listeners::<TestEvent>());
375 }
376
377 #[tokio::test]
378 async fn test_no_listeners() {
379 let dispatcher = EventDispatcher::new();
380 let result = dispatcher.dispatch(TestEvent { value: 1 }).await;
382 assert!(result.is_ok());
383 }
384}