1use std::{
4 borrow::Borrow,
5 collections::HashMap,
6 fmt::Debug,
7 hash::Hash,
8 task::{Poll, Waker},
9};
10
11use crate::utils::{Lockable, SpinMutex};
12
13#[repr(u8)]
15#[derive(Debug, Clone, Copy)]
16pub enum EventStatus {
17 Pending = 0,
18 Ready = 1,
19 Cancel = 2,
20 Destroy = 3,
21}
22
23impl From<EventStatus> for u8 {
24 fn from(value: EventStatus) -> Self {
25 value as u8
26 }
27}
28
29impl From<u8> for EventStatus {
30 fn from(value: u8) -> Self {
31 match value {
32 0 => EventStatus::Pending,
33 1 => EventStatus::Ready,
34 2 => EventStatus::Cancel,
35 3 => EventStatus::Destroy,
36 _ => panic!("invalid status value: {}", value),
37 }
38 }
39}
40
41pub struct EventMap<E>
44where
45 E: Eq + Hash,
46{
47 listeners: SpinMutex<(bool, HashMap<E, Listener>)>,
48}
49
50impl<E> Default for EventMap<E>
51where
52 E: Eq + Hash + Unpin + Debug,
53{
54 fn default() -> Self {
55 Self::new()
56 }
57}
58
59impl<E> EventMap<E>
60where
61 E: Eq + Hash + Unpin + Debug + Unpin,
62{
63 pub fn new() -> Self {
65 Self {
66 listeners: Default::default(),
67 }
68 }
69
70 pub async fn once<G>(&self, event: E, guard: G) -> Result<(), EventStatus>
75 where
76 G: Unpin,
77 E: Clone,
78 {
79 WaitKey::new(self, event, guard).await
80 }
81
82 pub fn notify<Q: Borrow<E>>(&self, event: Q, status: EventStatus) -> bool {
84 let mut inner = self.listeners.lock();
85
86 if let Some(listener) = inner.1.get_mut(event.borrow()) {
87 listener.status = status;
88
89 listener.waker.wake_by_ref();
90
91 log::trace!("notify {:?}", event.borrow());
92
93 true
94 } else {
95 log::trace!("notify {:?}, not found", event.borrow());
96
97 false
98 }
99 }
100
101 pub fn notify_all<Q: Borrow<E>, L: AsRef<[Q]>>(&self, event_list: L, status: EventStatus) {
103 let mut inner = self.listeners.lock();
104
105 for event in event_list.as_ref() {
106 if let Some(listener) = inner.1.get_mut(event.borrow()) {
107 listener.status = status;
108
109 listener.waker.wake_by_ref();
110
111 log::trace!("notify {:?}", event.borrow());
112 } else {
113 log::trace!("notify {:?}, not found", event.borrow());
114 }
115 }
116 }
117
118 pub fn close(&self) {
119 let mut inner = self.listeners.lock();
120
121 if inner.0 {
122 return;
123 }
124
125 inner.0 = true;
126
127 for (_, listener) in inner.1.iter_mut() {
128 listener.status = EventStatus::Destroy;
129
130 listener.waker.wake_by_ref();
131 }
132 }
133}
134
135struct Listener {
136 waker: Waker,
137 status: EventStatus,
138}
139
140#[must_use = "if unused, the event listener will never actually register."]
142pub struct WaitKey<'a, E, G>
143where
144 E: Eq + Hash + Unpin,
145{
146 event: E,
147 event_map: &'a EventMap<E>,
148 guard: Option<G>,
149}
150
151impl<'a, E, G> WaitKey<'a, E, G>
152where
153 E: Eq + Hash + Unpin + Debug,
154{
155 fn new(event_map: &'a EventMap<E>, event: E, guard: G) -> Self {
156 Self {
157 guard: Some(guard),
158 event,
159 event_map,
160 }
161 }
162}
163
164impl<'a, E, G> futures::Future for WaitKey<'a, E, G>
165where
166 E: Eq + Hash + Unpin + Clone + Debug,
167 G: Unpin,
168{
169 type Output = Result<(), EventStatus>;
170
171 fn poll(
172 mut self: std::pin::Pin<&mut Self>,
173 cx: &mut std::task::Context<'_>,
174 ) -> std::task::Poll<Self::Output> {
175 let mut raw = self.event_map.listeners.lock();
176
177 let status = if let Some(listener) = raw.1.remove(&self.event) {
178 listener.status
179 } else {
180 EventStatus::Pending
181 };
182
183 self.guard.take();
184
185 match status {
186 EventStatus::Pending => {
187 raw.1.insert(
188 self.event.clone(),
189 Listener {
190 waker: cx.waker().clone(),
191 status,
192 },
193 );
194 Poll::Pending
196 }
197 EventStatus::Ready => Poll::Ready(Ok(())),
198 _ => Poll::Ready(Err(status)),
199 }
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use std::{sync::Arc, thread::sleep, time::Duration};
206
207 use futures::{lock, task::SpawnExt};
208
209 use super::*;
210
211 #[futures_test::test]
212 async fn test_with_future_aware_mutex() {
213 let event_map = Arc::new(EventMap::<i32>::new());
214
215 let locker = Arc::new(futures::lock::Mutex::new(()));
216
217 let guard = locker.lock().await;
218
219 let thread_pool = futures::executor::ThreadPool::new().unwrap();
220
221 let event_map_cloned = event_map.clone();
222
223 let locker_cloned = locker.clone();
224
225 thread_pool
226 .spawn(async move {
227 locker_cloned.lock().await;
228 event_map_cloned.notify(1, EventStatus::Ready);
229 })
230 .unwrap();
231
232 event_map.once(1, guard).await.unwrap();
233
234 locker.lock().await;
235 }
236
237 #[futures_test::test]
238 async fn test_with_std_mutex() {
239 let event_map = Arc::new(EventMap::<i32>::new());
240
241 let locker = Arc::new(std::sync::Mutex::new(()));
242
243 let guard = locker.lock().unwrap();
244
245 let thread_pool = futures::executor::ThreadPool::new().unwrap();
246
247 let event_map_cloned = event_map.clone();
248
249 let locker_cloned = locker.clone();
250
251 thread_pool
252 .spawn(async move {
253 let _guard = locker_cloned.lock().unwrap();
254 event_map_cloned.notify(1, EventStatus::Ready);
255 })
256 .unwrap();
257
258 event_map.once(1, guard).await.unwrap();
259
260 let _guard = locker.lock().unwrap();
261 }
262
263 #[futures_test::test]
264 async fn test_notify_all() {
265 let event_map = Arc::new(EventMap::<i32>::new());
266
267 let thread_pool = futures::executor::ThreadPool::new().unwrap();
268
269 let mut handles = vec![];
270
271 let loops = 100;
272
273 for i in 0..loops {
274 let event_map = event_map.clone();
275
276 handles.push(
277 thread_pool
278 .spawn_with_handle(async move {
279 let locker = lock::Mutex::new(());
280
281 let guard = locker.lock();
282
283 event_map.once(i, guard).await.unwrap();
284 })
285 .unwrap(),
286 );
287 }
288
289 loop {
291 sleep(Duration::from_millis(100));
292
293 if event_map.listeners.lock().1.len() == loops as usize {
294 break;
295 }
296 }
297
298 event_map.notify_all((0..loops).collect::<Vec<_>>(), EventStatus::Ready);
299
300 for (_, handle) in handles.iter_mut().enumerate() {
301 handle.await;
302 }
303 }
304}