rasi_ext/future/
event_map.rs

1//! A mediator pattern implementation for rust async rt.
2
3use std::{
4    borrow::Borrow,
5    collections::HashMap,
6    fmt::Debug,
7    hash::Hash,
8    task::{Poll, Waker},
9};
10
11use crate::utils::{Lockable, SpinMutex};
12
13/// The variant for event listener waiting status .
14#[repr(u8)]
15#[derive(Debug, Clone, Copy)]
16pub enum EventStatus {
17    Pending = 0,
18    Ready = 1,
19    Cancel = 2,
20    Destroy = 3,
21}
22
23impl From<EventStatus> for u8 {
24    fn from(value: EventStatus) -> Self {
25        value as u8
26    }
27}
28
29impl From<u8> for EventStatus {
30    fn from(value: u8) -> Self {
31        match value {
32            0 => EventStatus::Pending,
33            1 => EventStatus::Ready,
34            2 => EventStatus::Cancel,
35            3 => EventStatus::Destroy,
36            _ => panic!("invalid status value: {}", value),
37        }
38    }
39}
40
41/// A mediator pattern implementation for rust async rt.
42///
43pub struct EventMap<E>
44where
45    E: Eq + Hash,
46{
47    listeners: SpinMutex<(bool, HashMap<E, Listener>)>,
48}
49
50impl<E> Default for EventMap<E>
51where
52    E: Eq + Hash + Unpin + Debug,
53{
54    fn default() -> Self {
55        Self::new()
56    }
57}
58
59impl<E> EventMap<E>
60where
61    E: Eq + Hash + Unpin + Debug + Unpin,
62{
63    /// Create new [`EventMap<E>`](EventMap) instance with default config.
64    pub fn new() -> Self {
65        Self {
66            listeners: Default::default(),
67        }
68    }
69
70    /// Listens for the `event` to be triggered once.
71    ///
72    /// # Parameters
73    /// - guard: An RAII guard returned by some lock primitives.
74    pub async fn once<G>(&self, event: E, guard: G) -> Result<(), EventStatus>
75    where
76        G: Unpin,
77        E: Clone,
78    {
79        WaitKey::new(self, event, guard).await
80    }
81
82    /// Notify `event` listener, and set the listener status to `status`.
83    pub fn notify<Q: Borrow<E>>(&self, event: Q, status: EventStatus) -> bool {
84        let mut inner = self.listeners.lock();
85
86        if let Some(listener) = inner.1.get_mut(event.borrow()) {
87            listener.status = status;
88
89            listener.waker.wake_by_ref();
90
91            log::trace!("notify {:?}", event.borrow());
92
93            true
94        } else {
95            log::trace!("notify {:?}, not found", event.borrow());
96
97            false
98        }
99    }
100
101    /// Notify all provided event listeners on `event_list`, and set the listener status to `status`.
102    pub fn notify_all<Q: Borrow<E>, L: AsRef<[Q]>>(&self, event_list: L, status: EventStatus) {
103        let mut inner = self.listeners.lock();
104
105        for event in event_list.as_ref() {
106            if let Some(listener) = inner.1.get_mut(event.borrow()) {
107                listener.status = status;
108
109                listener.waker.wake_by_ref();
110
111                log::trace!("notify {:?}", event.borrow());
112            } else {
113                log::trace!("notify {:?}, not found", event.borrow());
114            }
115        }
116    }
117
118    pub fn close(&self) {
119        let mut inner = self.listeners.lock();
120
121        if inner.0 {
122            return;
123        }
124
125        inner.0 = true;
126
127        for (_, listener) in inner.1.iter_mut() {
128            listener.status = EventStatus::Destroy;
129
130            listener.waker.wake_by_ref();
131        }
132    }
133}
134
135struct Listener {
136    waker: Waker,
137    status: EventStatus,
138}
139
140/// The type of future that is waiting for a specific event to be notified.
141#[must_use = "if unused, the event listener will never actually register."]
142pub struct WaitKey<'a, E, G>
143where
144    E: Eq + Hash + Unpin,
145{
146    event: E,
147    event_map: &'a EventMap<E>,
148    guard: Option<G>,
149}
150
151impl<'a, E, G> WaitKey<'a, E, G>
152where
153    E: Eq + Hash + Unpin + Debug,
154{
155    fn new(event_map: &'a EventMap<E>, event: E, guard: G) -> Self {
156        Self {
157            guard: Some(guard),
158            event,
159            event_map,
160        }
161    }
162}
163
164impl<'a, E, G> futures::Future for WaitKey<'a, E, G>
165where
166    E: Eq + Hash + Unpin + Clone + Debug,
167    G: Unpin,
168{
169    type Output = Result<(), EventStatus>;
170
171    fn poll(
172        mut self: std::pin::Pin<&mut Self>,
173        cx: &mut std::task::Context<'_>,
174    ) -> std::task::Poll<Self::Output> {
175        let mut raw = self.event_map.listeners.lock();
176
177        let status = if let Some(listener) = raw.1.remove(&self.event) {
178            listener.status
179        } else {
180            EventStatus::Pending
181        };
182
183        self.guard.take();
184
185        match status {
186            EventStatus::Pending => {
187                raw.1.insert(
188                    self.event.clone(),
189                    Listener {
190                        waker: cx.waker().clone(),
191                        status,
192                    },
193                );
194                // This future may wrapped by select!, so runtime may call this future's poll without really call `Waker::wake()`.
195                Poll::Pending
196            }
197            EventStatus::Ready => Poll::Ready(Ok(())),
198            _ => Poll::Ready(Err(status)),
199        }
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use std::{sync::Arc, thread::sleep, time::Duration};
206
207    use futures::{lock, task::SpawnExt};
208
209    use super::*;
210
211    #[futures_test::test]
212    async fn test_with_future_aware_mutex() {
213        let event_map = Arc::new(EventMap::<i32>::new());
214
215        let locker = Arc::new(futures::lock::Mutex::new(()));
216
217        let guard = locker.lock().await;
218
219        let thread_pool = futures::executor::ThreadPool::new().unwrap();
220
221        let event_map_cloned = event_map.clone();
222
223        let locker_cloned = locker.clone();
224
225        thread_pool
226            .spawn(async move {
227                locker_cloned.lock().await;
228                event_map_cloned.notify(1, EventStatus::Ready);
229            })
230            .unwrap();
231
232        event_map.once(1, guard).await.unwrap();
233
234        locker.lock().await;
235    }
236
237    #[futures_test::test]
238    async fn test_with_std_mutex() {
239        let event_map = Arc::new(EventMap::<i32>::new());
240
241        let locker = Arc::new(std::sync::Mutex::new(()));
242
243        let guard = locker.lock().unwrap();
244
245        let thread_pool = futures::executor::ThreadPool::new().unwrap();
246
247        let event_map_cloned = event_map.clone();
248
249        let locker_cloned = locker.clone();
250
251        thread_pool
252            .spawn(async move {
253                let _guard = locker_cloned.lock().unwrap();
254                event_map_cloned.notify(1, EventStatus::Ready);
255            })
256            .unwrap();
257
258        event_map.once(1, guard).await.unwrap();
259
260        let _guard = locker.lock().unwrap();
261    }
262
263    #[futures_test::test]
264    async fn test_notify_all() {
265        let event_map = Arc::new(EventMap::<i32>::new());
266
267        let thread_pool = futures::executor::ThreadPool::new().unwrap();
268
269        let mut handles = vec![];
270
271        let loops = 100;
272
273        for i in 0..loops {
274            let event_map = event_map.clone();
275
276            handles.push(
277                thread_pool
278                    .spawn_with_handle(async move {
279                        let locker = lock::Mutex::new(());
280
281                        let guard = locker.lock();
282
283                        event_map.once(i, guard).await.unwrap();
284                    })
285                    .unwrap(),
286            );
287        }
288
289        // Waiting for `loop` function `event_map.once` calls to finish
290        loop {
291            sleep(Duration::from_millis(100));
292
293            if event_map.listeners.lock().1.len() == loops as usize {
294                break;
295            }
296        }
297
298        event_map.notify_all((0..loops).collect::<Vec<_>>(), EventStatus::Ready);
299
300        for (_, handle) in handles.iter_mut().enumerate() {
301            handle.await;
302        }
303    }
304}