futures_waitmap/
lib.rs

1use std::{
2    borrow::Borrow,
3    collections::HashMap,
4    future::Future,
5    hash::Hash,
6    task::{Poll, Waker},
7};
8
9use futures::{
10    lock::{Mutex, MutexLockFuture},
11    FutureExt,
12};
13
14enum Event<V> {
15    Value(V),
16    Cancel,
17}
18
19struct RawMap<K, V> {
20    /// kv store.
21    kv: HashMap<K, Event<V>>,
22    /// Waiting future wakers.
23    wakers: HashMap<K, Waker>,
24}
25
26/// A future-based concurrent event map with `wait` API.
27
28pub struct WaitMap<K, V> {
29    inner: Mutex<RawMap<K, V>>,
30}
31
32impl<K, V> WaitMap<K, V>
33where
34    K: Eq + Hash + Unpin,
35{
36    pub fn new() -> Self {
37        WaitMap {
38            inner: Mutex::new(RawMap {
39                kv: HashMap::new(),
40                wakers: HashMap::new(),
41            }),
42        }
43    }
44    /// Inserts a event-value pair into the map.
45    ///
46    /// If the map did not have this key present, None is returned.
47    pub async fn insert(&self, k: K, v: V) -> Option<V> {
48        let mut raw = self.inner.lock().await;
49
50        if let Some(waker) = raw.wakers.remove(&k) {
51            waker.wake();
52        }
53
54        if let Some(event) = raw.kv.insert(k, Event::Value(v)) {
55            match event {
56                Event::Value(value) => Some(value),
57                Event::Cancel => None,
58            }
59        } else {
60            None
61        }
62    }
63
64    /// Create a key waiting task until a value is put under the key,
65    /// only one waiting task for a key can exist at a time.
66    ///
67    /// Returns the value at the key if the key was inserted into the map,
68    /// or returns `None` if the waiting task is canceled.
69    pub async fn wait(&self, k: &K) -> Option<V>
70    where
71        K: Clone,
72    {
73        {
74            let mut raw = self.inner.lock().await;
75
76            if let Some(v) = raw.kv.remove(&k) {
77                match v {
78                    Event::Value(value) => return Some(value),
79                    Event::Cancel => return None,
80                }
81            }
82        }
83
84        Wait {
85            event_map: self,
86            lock: None,
87            k,
88        }
89        .await
90    }
91
92    /// Cancel other key waiting tasks.
93    pub async fn cancel<Q>(&self, k: &Q) -> bool
94    where
95        K: Borrow<Q>,
96        Q: Hash + Eq,
97    {
98        let mut raw = self.inner.lock().await;
99
100        if let Some((k, waker)) = raw.wakers.remove_entry(k) {
101            raw.kv.insert(k, Event::Cancel);
102            waker.wake();
103            true
104        } else {
105            raw.kv.remove(k);
106            false
107        }
108    }
109
110    /// Cancel all key waiting tasks.
111    pub async fn cancel_all(&self) {
112        let mut raw = self.inner.lock().await;
113
114        let wakers = raw.wakers.drain().collect::<Vec<_>>();
115
116        for (k, waker) in wakers {
117            raw.kv.insert(k, Event::Cancel);
118            waker.wake();
119        }
120    }
121}
122
123struct Wait<'a, K, V> {
124    event_map: &'a WaitMap<K, V>,
125    lock: Option<MutexLockFuture<'a, RawMap<K, V>>>,
126    k: &'a K,
127}
128
129impl<'a, K, V> Future for Wait<'a, K, V>
130where
131    K: Eq + Hash + Unpin + Clone,
132{
133    type Output = Option<V>;
134
135    fn poll(
136        mut self: std::pin::Pin<&mut Self>,
137        cx: &mut std::task::Context<'_>,
138    ) -> std::task::Poll<Self::Output> {
139        let mut lock = if let Some(lock) = self.lock.take() {
140            lock
141        } else {
142            self.event_map.inner.lock()
143        };
144
145        match lock.poll_unpin(cx) {
146            std::task::Poll::Ready(mut inner) => {
147                if let Some(event) = inner.kv.remove(&self.k) {
148                    match event {
149                        Event::Value(value) => return Poll::Ready(Some(value)),
150                        Event::Cancel => return Poll::Ready(None),
151                    }
152                } else {
153                    inner.wakers.insert(self.k.clone(), cx.waker().clone());
154                    return Poll::Pending;
155                }
156            }
157            std::task::Poll::Pending => {
158                self.lock = Some(lock);
159                return Poll::Pending;
160            }
161        }
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use futures::poll;
168
169    use super::*;
170
171    #[futures_test::test]
172    async fn test_event_map() {
173        let event_map = WaitMap::<usize, usize>::new();
174
175        event_map.insert(1, 1).await;
176
177        assert_eq!(event_map.wait(&1).await, Some(1));
178
179        let mut wait = Box::pin(event_map.wait(&2));
180
181        assert_eq!(poll!(&mut wait), Poll::Pending);
182
183        event_map.cancel(&2).await;
184
185        assert_eq!(poll!(&mut wait), Poll::Ready(None));
186
187        let mut wait = Box::pin(event_map.wait(&2));
188
189        assert_eq!(poll!(&mut wait), Poll::Pending);
190
191        event_map.insert(2, 2).await;
192
193        assert_eq!(poll!(&mut wait), Poll::Ready(Some(2)));
194    }
195}