1use std::{
2 borrow::Borrow,
3 collections::HashMap,
4 fmt::Debug,
5 future::Future,
6 hash::Hash,
7 sync::Mutex,
8 task::{Poll, Waker},
9};
10
11enum Event<V> {
12 Value(V),
13 Cancel,
14}
15
16struct RawMap<K, V> {
17 kv: HashMap<K, Event<V>>,
19 wakers: HashMap<K, Waker>,
21}
22
23pub struct KeyWaitMap<K, V> {
26 inner: Mutex<RawMap<K, V>>,
27}
28
29impl<K, V> Default for KeyWaitMap<K, V> {
30 fn default() -> Self {
31 KeyWaitMap {
32 inner: Mutex::new(RawMap {
33 kv: HashMap::new(),
34 wakers: HashMap::new(),
35 }),
36 }
37 }
38}
39
40impl<K, V> KeyWaitMap<K, V>
41where
42 K: Eq + Hash + Unpin,
43{
44 pub fn new() -> Self {
45 KeyWaitMap {
46 inner: Mutex::new(RawMap {
47 kv: HashMap::new(),
48 wakers: HashMap::new(),
49 }),
50 }
51 }
52 pub fn insert(&self, k: K, v: V) -> Option<V> {
56 let mut raw = self.inner.lock().unwrap();
57
58 let waker = raw.wakers.remove(&k);
59
60 let older = raw.kv.insert(k, Event::Value(v));
61
62 drop(raw);
63
64 if let Some(waker) = waker {
65 waker.wake();
66 }
67
68 if let Some(event) = older {
69 match event {
70 Event::Value(value) => Some(value),
71 Event::Cancel => None,
72 }
73 } else {
74 None
75 }
76 }
77
78 pub fn batch_insert<I>(&self, kv: I)
82 where
83 I: IntoIterator<Item = (K, V)>,
84 K: Debug,
85 {
86 let mut raw = self.inner.lock().unwrap();
87
88 let mut wakers = vec![];
89
90 for (k, v) in kv.into_iter() {
91 if let Some(waker) = raw.wakers.remove(&k) {
92 log::trace!("wakeup: {:?}", k);
93 wakers.push(waker);
94 } else {
95 log::trace!("wakeup: {:?}, without waiting task", k);
96 }
97
98 raw.kv.insert(k, Event::Value(v));
99 }
100
101 drop(raw);
102
103 for waker in wakers {
104 waker.wake();
105 }
106 }
107
108 pub async fn wait<L>(&self, k: &K, locker: L) -> Option<V>
114 where
115 K: Clone,
116 L: Unpin,
117 {
118 Wait {
119 event_map: self,
120 k,
121 locker: Some(locker),
122 }
123 .await
124 }
125
126 pub fn cancel<Q>(&self, k: &Q) -> bool
128 where
129 K: Borrow<Q>,
130 Q: Hash + Eq,
131 {
132 let mut raw = self.inner.lock().unwrap();
133
134 if let Some((k, waker)) = raw.wakers.remove_entry(k) {
135 raw.kv.insert(k, Event::Cancel);
136 drop(raw);
137 waker.wake();
138 true
139 } else {
140 raw.kv.remove(k);
141 false
142 }
143 }
144
145 pub fn cancel_all(&self) {
147 let mut raw = self.inner.lock().unwrap();
148
149 let wakers = raw.wakers.drain().collect::<Vec<_>>();
150
151 let mut droping = vec![];
152
153 for (k, waker) in wakers {
154 raw.kv.insert(k, Event::Cancel);
155 droping.push(waker);
156 }
157
158 drop(raw);
159
160 for waker in droping {
161 waker.wake();
162 }
163 }
164}
165
166struct Wait<'a, K, V, L> {
167 event_map: &'a KeyWaitMap<K, V>,
168 k: &'a K,
169 locker: Option<L>,
170}
171
172impl<'a, K, V, L> Future for Wait<'a, K, V, L>
173where
174 K: Eq + Hash + Unpin + Clone,
175 L: Unpin,
176{
177 type Output = Option<V>;
178
179 fn poll(
180 mut self: std::pin::Pin<&mut Self>,
181 cx: &mut std::task::Context<'_>,
182 ) -> std::task::Poll<Self::Output> {
183 let mut inner = self.event_map.inner.lock().unwrap();
184
185 drop(self.locker.take());
186
187 if let Some(event) = inner.kv.remove(&self.k) {
188 match event {
189 Event::Value(value) => return Poll::Ready(Some(value)),
190 Event::Cancel => return Poll::Ready(None),
191 }
192 } else {
193 inner.wakers.insert(self.k.clone(), cx.waker().clone());
194 return Poll::Pending;
195 }
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use futures::poll;
202
203 use super::*;
204
205 #[futures_test::test]
206 async fn test_event_map() {
207 let event_map = KeyWaitMap::<usize, usize>::new();
208
209 event_map.insert(1, 1);
210
211 assert_eq!(event_map.wait(&1, ()).await, Some(1));
212
213 let mut wait = Box::pin(event_map.wait(&2, ()));
214
215 assert_eq!(poll!(&mut wait), Poll::Pending);
216
217 event_map.cancel(&2);
218
219 assert_eq!(poll!(&mut wait), Poll::Ready(None));
220
221 let mut wait = Box::pin(event_map.wait(&2, ()));
222
223 assert_eq!(poll!(&mut wait), Poll::Pending);
224
225 event_map.insert(2, 2);
226
227 assert_eq!(poll!(&mut wait), Poll::Ready(Some(2)));
228 }
229}