hala_future/
event_map.rs

1use std::{
2    borrow::Borrow,
3    fmt::Debug,
4    hash::Hash,
5    sync::{
6        atomic::{AtomicU8, Ordering},
7        Arc,
8    },
9    task::{Poll, Waker},
10};
11
12use dashmap::DashMap;
13use hala_sync::{AsyncGuardMut, AsyncLockable};
14
15#[derive(Debug, thiserror::Error, PartialEq)]
16pub enum EventMapError {
17    #[error("Waiting operation canceled by user")]
18    Cancel,
19    #[error("Waiting operation canceled by EventMap to drop `EventMap` self")]
20    Destroy,
21}
22
23/// waiter wakeup reason.
24#[derive(Debug, Clone, Copy)]
25pub enum Reason {
26    /// Wakeup reason is unset.
27    None,
28    /// Waiting event on
29    On,
30    /// Cancel by user.
31    Cancel,
32    /// EventMap is dropping.
33    Destroy,
34}
35
36impl From<Reason> for u8 {
37    fn from(value: Reason) -> Self {
38        match value {
39            Reason::None => 0,
40            Reason::On => 1,
41            Reason::Cancel => 2,
42            Reason::Destroy => 3,
43        }
44    }
45}
46
47#[derive(Debug)]
48struct WakerWithReason {
49    waker: Waker,
50    reason: Arc<AtomicU8>,
51}
52
53impl WakerWithReason {
54    fn wake(self, reason: Reason) {
55        self.reason.store(reason.into(), Ordering::Release);
56        self.waker.wake();
57    }
58
59    fn wake_by_ref(&self, reason: Reason) {
60        self.reason.store(reason.into(), Ordering::Release);
61        self.waker.wake_by_ref();
62    }
63}
64
65/// The mediator of event notify for futures-aware enviroment.
66#[derive(Debug)]
67pub struct EventMap<E>
68where
69    E: Send + Eq + Hash,
70{
71    wakers: DashMap<E, WakerWithReason>,
72}
73
74impl<E> Drop for EventMap<E>
75where
76    E: Send + Eq + Hash,
77{
78    fn drop(&mut self) {
79        for entry in self.wakers.iter() {
80            entry.value().wake_by_ref(Reason::Destroy);
81        }
82    }
83}
84
85impl<E> Default for EventMap<E>
86where
87    E: Send + Eq + Hash,
88{
89    fn default() -> Self {
90        Self {
91            wakers: DashMap::new(),
92        }
93    }
94}
95
96impl<E> EventMap<E>
97where
98    E: Send + Eq + Hash + Debug + Clone,
99{
100    /// Only remove event waker, without wakeup it.
101    pub fn wait_cancel<Q>(&self, event: Q)
102    where
103        Q: Borrow<E>,
104    {
105        self.wakers.remove(event.borrow());
106    }
107
108    /// Notify one event `E` on.
109    #[inline(always)]
110    pub fn notify_one<Q>(&self, event: Q, reason: Reason) -> bool
111    where
112        Q: Borrow<E>,
113    {
114        if let Some((_, waker)) = self.wakers.remove(event.borrow()) {
115            log::trace!("{:?} wakeup", event.borrow());
116            waker.wake(reason);
117            true
118        } else {
119            log::trace!("{:?} wakeup -- not found", event.borrow());
120            false
121        }
122    }
123
124    /// Notify all event on in the providing `events` list
125    #[inline(always)]
126    pub fn notify_all<L: AsRef<[E]>>(&self, events: L, reason: Reason) {
127        for event in events.as_ref() {
128            self.notify_one(event, reason);
129        }
130    }
131
132    /// Notify all event on in the providing `events` list
133    #[inline(always)]
134    pub fn notify_any(&self, reason: Reason) {
135        let events = self
136            .wakers
137            .iter()
138            .map(|pair| pair.key().clone())
139            .collect::<Vec<_>>();
140
141        self.notify_all(&events, reason);
142    }
143
144    #[inline(always)]
145    pub fn wait<'a, Q, G>(&'a self, event: Q, guard: G) -> Wait<'a, E, G>
146    where
147        G: AsyncGuardMut<'a> + 'a,
148        Q: Borrow<E>,
149    {
150        Wait {
151            event: event.borrow().clone(),
152            guard: Some(guard),
153            event_map: self,
154            reason: Arc::new(AtomicU8::new(Reason::None.into())),
155        }
156    }
157}
158
159pub struct Wait<'a, E, G>
160where
161    E: Send + Eq + Hash,
162    G: AsyncGuardMut<'a> + 'a,
163{
164    event: E,
165    guard: Option<G>,
166    event_map: &'a EventMap<E>,
167    reason: Arc<AtomicU8>,
168}
169
170impl<'a, E, G> std::future::Future for Wait<'a, E, G>
171where
172    E: Send + Eq + Hash + Clone + Unpin + Debug,
173    G: AsyncGuardMut<'a> + Unpin + 'a,
174{
175    type Output = Result<(), EventMapError>;
176
177    #[inline(always)]
178    fn poll(
179        mut self: std::pin::Pin<&mut Self>,
180        cx: &mut std::task::Context<'_>,
181    ) -> std::task::Poll<Self::Output> {
182        if let Some(guard) = self.guard.take() {
183            // insert waker into waiting map.
184            self.event_map.wakers.insert(
185                self.event.clone(),
186                WakerWithReason {
187                    waker: cx.waker().clone(),
188                    reason: self.reason.clone(),
189                },
190            );
191
192            G::Locker::unlock(guard);
193        }
194
195        // Check reason to avoid unexpected `poll` calling.
196        // For example, calling `wait` function in `futures::select!` block
197
198        let reason = self.reason.load(Ordering::SeqCst);
199
200        if reason == Reason::None.into() {
201            return Poll::Pending;
202        } else if reason == Reason::Cancel.into() {
203            return Poll::Ready(Err(EventMapError::Cancel));
204        } else if reason == Reason::Destroy.into() {
205            return Poll::Ready(Err(EventMapError::Destroy));
206        } else {
207            return Poll::Ready(Ok(()));
208        }
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215
216    use futures::{executor::ThreadPool, task::SpawnExt};
217    use hala_sync::AsyncSpinMutex;
218
219    #[futures_test::test]
220    async fn test_across_suspend_point() {
221        let local_pool = ThreadPool::builder().pool_size(10).create().unwrap();
222
223        let mediator = Arc::new(EventMap::<i32>::default());
224
225        let shared = Arc::new(AsyncSpinMutex::new(1));
226
227        let mediator_cloned = mediator.clone();
228
229        let handle = local_pool
230            .spawn_with_handle(async move {
231                let shared = shared.lock().await;
232
233                mediator_cloned.wait(1, shared).await.unwrap();
234            })
235            .unwrap();
236
237        while !mediator.notify_one(1, Reason::On) {}
238
239        handle.await;
240    }
241}