1use std::{
2 borrow::Borrow,
3 fmt::Debug,
4 hash::Hash,
5 sync::{
6 atomic::{AtomicU8, Ordering},
7 Arc,
8 },
9 task::{Poll, Waker},
10};
11
12use dashmap::DashMap;
13use hala_sync::{AsyncGuardMut, AsyncLockable};
14
15#[derive(Debug, thiserror::Error, PartialEq)]
16pub enum EventMapError {
17 #[error("Waiting operation canceled by user")]
18 Cancel,
19 #[error("Waiting operation canceled by EventMap to drop `EventMap` self")]
20 Destroy,
21}
22
23#[derive(Debug, Clone, Copy)]
25pub enum Reason {
26 None,
28 On,
30 Cancel,
32 Destroy,
34}
35
36impl From<Reason> for u8 {
37 fn from(value: Reason) -> Self {
38 match value {
39 Reason::None => 0,
40 Reason::On => 1,
41 Reason::Cancel => 2,
42 Reason::Destroy => 3,
43 }
44 }
45}
46
47#[derive(Debug)]
48struct WakerWithReason {
49 waker: Waker,
50 reason: Arc<AtomicU8>,
51}
52
53impl WakerWithReason {
54 fn wake(self, reason: Reason) {
55 self.reason.store(reason.into(), Ordering::Release);
56 self.waker.wake();
57 }
58
59 fn wake_by_ref(&self, reason: Reason) {
60 self.reason.store(reason.into(), Ordering::Release);
61 self.waker.wake_by_ref();
62 }
63}
64
65#[derive(Debug)]
67pub struct EventMap<E>
68where
69 E: Send + Eq + Hash,
70{
71 wakers: DashMap<E, WakerWithReason>,
72}
73
74impl<E> Drop for EventMap<E>
75where
76 E: Send + Eq + Hash,
77{
78 fn drop(&mut self) {
79 for entry in self.wakers.iter() {
80 entry.value().wake_by_ref(Reason::Destroy);
81 }
82 }
83}
84
85impl<E> Default for EventMap<E>
86where
87 E: Send + Eq + Hash,
88{
89 fn default() -> Self {
90 Self {
91 wakers: DashMap::new(),
92 }
93 }
94}
95
96impl<E> EventMap<E>
97where
98 E: Send + Eq + Hash + Debug + Clone,
99{
100 pub fn wait_cancel<Q>(&self, event: Q)
102 where
103 Q: Borrow<E>,
104 {
105 self.wakers.remove(event.borrow());
106 }
107
108 #[inline(always)]
110 pub fn notify_one<Q>(&self, event: Q, reason: Reason) -> bool
111 where
112 Q: Borrow<E>,
113 {
114 if let Some((_, waker)) = self.wakers.remove(event.borrow()) {
115 log::trace!("{:?} wakeup", event.borrow());
116 waker.wake(reason);
117 true
118 } else {
119 log::trace!("{:?} wakeup -- not found", event.borrow());
120 false
121 }
122 }
123
124 #[inline(always)]
126 pub fn notify_all<L: AsRef<[E]>>(&self, events: L, reason: Reason) {
127 for event in events.as_ref() {
128 self.notify_one(event, reason);
129 }
130 }
131
132 #[inline(always)]
134 pub fn notify_any(&self, reason: Reason) {
135 let events = self
136 .wakers
137 .iter()
138 .map(|pair| pair.key().clone())
139 .collect::<Vec<_>>();
140
141 self.notify_all(&events, reason);
142 }
143
144 #[inline(always)]
145 pub fn wait<'a, Q, G>(&'a self, event: Q, guard: G) -> Wait<'a, E, G>
146 where
147 G: AsyncGuardMut<'a> + 'a,
148 Q: Borrow<E>,
149 {
150 Wait {
151 event: event.borrow().clone(),
152 guard: Some(guard),
153 event_map: self,
154 reason: Arc::new(AtomicU8::new(Reason::None.into())),
155 }
156 }
157}
158
159pub struct Wait<'a, E, G>
160where
161 E: Send + Eq + Hash,
162 G: AsyncGuardMut<'a> + 'a,
163{
164 event: E,
165 guard: Option<G>,
166 event_map: &'a EventMap<E>,
167 reason: Arc<AtomicU8>,
168}
169
170impl<'a, E, G> std::future::Future for Wait<'a, E, G>
171where
172 E: Send + Eq + Hash + Clone + Unpin + Debug,
173 G: AsyncGuardMut<'a> + Unpin + 'a,
174{
175 type Output = Result<(), EventMapError>;
176
177 #[inline(always)]
178 fn poll(
179 mut self: std::pin::Pin<&mut Self>,
180 cx: &mut std::task::Context<'_>,
181 ) -> std::task::Poll<Self::Output> {
182 if let Some(guard) = self.guard.take() {
183 self.event_map.wakers.insert(
185 self.event.clone(),
186 WakerWithReason {
187 waker: cx.waker().clone(),
188 reason: self.reason.clone(),
189 },
190 );
191
192 G::Locker::unlock(guard);
193 }
194
195 let reason = self.reason.load(Ordering::SeqCst);
199
200 if reason == Reason::None.into() {
201 return Poll::Pending;
202 } else if reason == Reason::Cancel.into() {
203 return Poll::Ready(Err(EventMapError::Cancel));
204 } else if reason == Reason::Destroy.into() {
205 return Poll::Ready(Err(EventMapError::Destroy));
206 } else {
207 return Poll::Ready(Ok(()));
208 }
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215
216 use futures::{executor::ThreadPool, task::SpawnExt};
217 use hala_sync::AsyncSpinMutex;
218
219 #[futures_test::test]
220 async fn test_across_suspend_point() {
221 let local_pool = ThreadPool::builder().pool_size(10).create().unwrap();
222
223 let mediator = Arc::new(EventMap::<i32>::default());
224
225 let shared = Arc::new(AsyncSpinMutex::new(1));
226
227 let mediator_cloned = mediator.clone();
228
229 let handle = local_pool
230 .spawn_with_handle(async move {
231 let shared = shared.lock().await;
232
233 mediator_cloned.wait(1, shared).await.unwrap();
234 })
235 .unwrap();
236
237 while !mediator.notify_one(1, Reason::On) {}
238
239 handle.await;
240 }
241}