1use std::{
2 borrow::Borrow,
3 collections::HashMap,
4 future::Future,
5 hash::Hash,
6 task::{Poll, Waker},
7};
8
9use futures::{
10 lock::{Mutex, MutexLockFuture},
11 FutureExt,
12};
13
14enum Event<V> {
15 Value(V),
16 Cancel,
17}
18
19struct RawMap<K, V> {
20 kv: HashMap<K, Event<V>>,
22 wakers: HashMap<K, Waker>,
24}
25
26pub struct WaitMap<K, V> {
29 inner: Mutex<RawMap<K, V>>,
30}
31
32impl<K, V> WaitMap<K, V>
33where
34 K: Eq + Hash + Unpin,
35{
36 pub fn new() -> Self {
37 WaitMap {
38 inner: Mutex::new(RawMap {
39 kv: HashMap::new(),
40 wakers: HashMap::new(),
41 }),
42 }
43 }
44 pub async fn insert(&self, k: K, v: V) -> Option<V> {
48 let mut raw = self.inner.lock().await;
49
50 if let Some(waker) = raw.wakers.remove(&k) {
51 waker.wake();
52 }
53
54 if let Some(event) = raw.kv.insert(k, Event::Value(v)) {
55 match event {
56 Event::Value(value) => Some(value),
57 Event::Cancel => None,
58 }
59 } else {
60 None
61 }
62 }
63
64 pub async fn wait(&self, k: &K) -> Option<V>
70 where
71 K: Clone,
72 {
73 {
74 let mut raw = self.inner.lock().await;
75
76 if let Some(v) = raw.kv.remove(&k) {
77 match v {
78 Event::Value(value) => return Some(value),
79 Event::Cancel => return None,
80 }
81 }
82 }
83
84 Wait {
85 event_map: self,
86 lock: None,
87 k,
88 }
89 .await
90 }
91
92 pub async fn cancel<Q>(&self, k: &Q) -> bool
94 where
95 K: Borrow<Q>,
96 Q: Hash + Eq,
97 {
98 let mut raw = self.inner.lock().await;
99
100 if let Some((k, waker)) = raw.wakers.remove_entry(k) {
101 raw.kv.insert(k, Event::Cancel);
102 waker.wake();
103 true
104 } else {
105 raw.kv.remove(k);
106 false
107 }
108 }
109
110 pub async fn cancel_all(&self) {
112 let mut raw = self.inner.lock().await;
113
114 let wakers = raw.wakers.drain().collect::<Vec<_>>();
115
116 for (k, waker) in wakers {
117 raw.kv.insert(k, Event::Cancel);
118 waker.wake();
119 }
120 }
121}
122
123struct Wait<'a, K, V> {
124 event_map: &'a WaitMap<K, V>,
125 lock: Option<MutexLockFuture<'a, RawMap<K, V>>>,
126 k: &'a K,
127}
128
129impl<'a, K, V> Future for Wait<'a, K, V>
130where
131 K: Eq + Hash + Unpin + Clone,
132{
133 type Output = Option<V>;
134
135 fn poll(
136 mut self: std::pin::Pin<&mut Self>,
137 cx: &mut std::task::Context<'_>,
138 ) -> std::task::Poll<Self::Output> {
139 let mut lock = if let Some(lock) = self.lock.take() {
140 lock
141 } else {
142 self.event_map.inner.lock()
143 };
144
145 match lock.poll_unpin(cx) {
146 std::task::Poll::Ready(mut inner) => {
147 if let Some(event) = inner.kv.remove(&self.k) {
148 match event {
149 Event::Value(value) => return Poll::Ready(Some(value)),
150 Event::Cancel => return Poll::Ready(None),
151 }
152 } else {
153 inner.wakers.insert(self.k.clone(), cx.waker().clone());
154 return Poll::Pending;
155 }
156 }
157 std::task::Poll::Pending => {
158 self.lock = Some(lock);
159 return Poll::Pending;
160 }
161 }
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use futures::poll;
168
169 use super::*;
170
171 #[futures_test::test]
172 async fn test_event_map() {
173 let event_map = WaitMap::<usize, usize>::new();
174
175 event_map.insert(1, 1).await;
176
177 assert_eq!(event_map.wait(&1).await, Some(1));
178
179 let mut wait = Box::pin(event_map.wait(&2));
180
181 assert_eq!(poll!(&mut wait), Poll::Pending);
182
183 event_map.cancel(&2).await;
184
185 assert_eq!(poll!(&mut wait), Poll::Ready(None));
186
187 let mut wait = Box::pin(event_map.wait(&2));
188
189 assert_eq!(poll!(&mut wait), Poll::Pending);
190
191 event_map.insert(2, 2).await;
192
193 assert_eq!(poll!(&mut wait), Poll::Ready(Some(2)));
194 }
195}