futures_map/
kv.rs

1use std::{
2    borrow::Borrow,
3    collections::HashMap,
4    fmt::Debug,
5    future::Future,
6    hash::Hash,
7    sync::Mutex,
8    task::{Poll, Waker},
9};
10
11enum Event<V> {
12    Value(V),
13    Cancel,
14}
15
16struct RawMap<K, V> {
17    /// kv store.
18    kv: HashMap<K, Event<V>>,
19    /// Waiting future wakers.
20    wakers: HashMap<K, Waker>,
21}
22
23/// A future-based concurrent event map with `wait` API.
24
25pub struct KeyWaitMap<K, V> {
26    inner: Mutex<RawMap<K, V>>,
27}
28
29impl<K, V> Default for KeyWaitMap<K, V> {
30    fn default() -> Self {
31        KeyWaitMap {
32            inner: Mutex::new(RawMap {
33                kv: HashMap::new(),
34                wakers: HashMap::new(),
35            }),
36        }
37    }
38}
39
40impl<K, V> KeyWaitMap<K, V>
41where
42    K: Eq + Hash + Unpin,
43{
44    pub fn new() -> Self {
45        KeyWaitMap {
46            inner: Mutex::new(RawMap {
47                kv: HashMap::new(),
48                wakers: HashMap::new(),
49            }),
50        }
51    }
52    /// Inserts a event-value pair into the map.
53    ///
54    /// If the map did not have this key present, None is returned.
55    pub fn insert(&self, k: K, v: V) -> Option<V> {
56        let mut raw = self.inner.lock().unwrap();
57
58        let waker = raw.wakers.remove(&k);
59
60        let older = raw.kv.insert(k, Event::Value(v));
61
62        drop(raw);
63
64        if let Some(waker) = waker {
65            waker.wake();
66        }
67
68        if let Some(event) = older {
69            match event {
70                Event::Value(value) => Some(value),
71                Event::Cancel => None,
72            }
73        } else {
74            None
75        }
76    }
77
78    /// Inserts a event-value pair into the map.
79    ///
80    /// If the map did not have this key present, None is returned.
81    pub fn batch_insert<I>(&self, kv: I)
82    where
83        I: IntoIterator<Item = (K, V)>,
84        K: Debug,
85    {
86        let mut raw = self.inner.lock().unwrap();
87
88        let mut wakers = vec![];
89
90        for (k, v) in kv.into_iter() {
91            if let Some(waker) = raw.wakers.remove(&k) {
92                log::trace!("wakeup: {:?}", k);
93                wakers.push(waker);
94            } else {
95                log::trace!("wakeup: {:?}, without waiting task", k);
96            }
97
98            raw.kv.insert(k, Event::Value(v));
99        }
100
101        drop(raw);
102
103        for waker in wakers {
104            waker.wake();
105        }
106    }
107
108    /// Create a key waiting task until a value is put under the key,
109    /// only one waiting task for a key can exist at a time.
110    ///
111    /// Returns the value at the key if the key was inserted into the map,
112    /// or returns `None` if the waiting task is canceled.
113    pub async fn wait<L>(&self, k: &K, locker: L) -> Option<V>
114    where
115        K: Clone,
116        L: Unpin,
117    {
118        Wait {
119            event_map: self,
120            k,
121            locker: Some(locker),
122        }
123        .await
124    }
125
126    /// Cancel other key waiting tasks.
127    pub fn cancel<Q>(&self, k: &Q) -> bool
128    where
129        K: Borrow<Q>,
130        Q: Hash + Eq,
131    {
132        let mut raw = self.inner.lock().unwrap();
133
134        if let Some((k, waker)) = raw.wakers.remove_entry(k) {
135            raw.kv.insert(k, Event::Cancel);
136            drop(raw);
137            waker.wake();
138            true
139        } else {
140            raw.kv.remove(k);
141            false
142        }
143    }
144
145    /// Cancel all key waiting tasks.
146    pub fn cancel_all(&self) {
147        let mut raw = self.inner.lock().unwrap();
148
149        let wakers = raw.wakers.drain().collect::<Vec<_>>();
150
151        let mut droping = vec![];
152
153        for (k, waker) in wakers {
154            raw.kv.insert(k, Event::Cancel);
155            droping.push(waker);
156        }
157
158        drop(raw);
159
160        for waker in droping {
161            waker.wake();
162        }
163    }
164}
165
166struct Wait<'a, K, V, L> {
167    event_map: &'a KeyWaitMap<K, V>,
168    k: &'a K,
169    locker: Option<L>,
170}
171
172impl<'a, K, V, L> Future for Wait<'a, K, V, L>
173where
174    K: Eq + Hash + Unpin + Clone,
175    L: Unpin,
176{
177    type Output = Option<V>;
178
179    fn poll(
180        mut self: std::pin::Pin<&mut Self>,
181        cx: &mut std::task::Context<'_>,
182    ) -> std::task::Poll<Self::Output> {
183        let mut inner = self.event_map.inner.lock().unwrap();
184
185        drop(self.locker.take());
186
187        if let Some(event) = inner.kv.remove(&self.k) {
188            match event {
189                Event::Value(value) => return Poll::Ready(Some(value)),
190                Event::Cancel => return Poll::Ready(None),
191            }
192        } else {
193            inner.wakers.insert(self.k.clone(), cx.waker().clone());
194            return Poll::Pending;
195        }
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use futures::poll;
202
203    use super::*;
204
205    #[futures_test::test]
206    async fn test_event_map() {
207        let event_map = KeyWaitMap::<usize, usize>::new();
208
209        event_map.insert(1, 1);
210
211        assert_eq!(event_map.wait(&1, ()).await, Some(1));
212
213        let mut wait = Box::pin(event_map.wait(&2, ()));
214
215        assert_eq!(poll!(&mut wait), Poll::Pending);
216
217        event_map.cancel(&2);
218
219        assert_eq!(poll!(&mut wait), Poll::Ready(None));
220
221        let mut wait = Box::pin(event_map.wait(&2, ()));
222
223        assert_eq!(poll!(&mut wait), Poll::Pending);
224
225        event_map.insert(2, 2);
226
227        assert_eq!(poll!(&mut wait), Poll::Ready(Some(2)));
228    }
229}