futures_map/
unordered.rs

1use std::{
2    collections::{HashMap, VecDeque},
3    fmt::Debug,
4    future::Future,
5    hash::Hash,
6    sync::{
7        atomic::{AtomicUsize, Ordering},
8        Arc, Mutex,
9    },
10    task::{Context, Poll, Waker},
11};
12
13use cooked_waker::{IntoWaker, WakeRef};
14use futures::{future::BoxFuture, FutureExt, Stream};
15
16struct RawFutureWaitMap<K, R> {
17    futs: HashMap<K, BoxFuture<'static, R>>,
18    ready_queue: VecDeque<K>,
19    waker: Option<Waker>,
20}
21
22impl<K, R> Default for RawFutureWaitMap<K, R> {
23    fn default() -> Self {
24        Self {
25            futs: HashMap::new(),
26            ready_queue: VecDeque::new(),
27            waker: None,
28        }
29    }
30}
31
32/// A waitable map for futures.
33pub struct FuturesUnorderedMap<K, R> {
34    len: Arc<AtomicUsize>,
35    inner: Arc<Mutex<RawFutureWaitMap<K, R>>>,
36}
37
38impl<K, R> Clone for FuturesUnorderedMap<K, R> {
39    fn clone(&self) -> Self {
40        Self {
41            len: self.len.clone(),
42            inner: self.inner.clone(),
43        }
44    }
45}
46
47impl<K, R> AsRef<FuturesUnorderedMap<K, R>> for FuturesUnorderedMap<K, R> {
48    fn as_ref(&self) -> &FuturesUnorderedMap<K, R> {
49        self
50    }
51}
52
53impl<K, R> FuturesUnorderedMap<K, R> {
54    /// Create a new future `WaitMap` instance.
55    pub fn new() -> Self {
56        Self {
57            len: Default::default(),
58            inner: Default::default(),
59        }
60    }
61    /// Insert a new key / future pair.
62    pub fn insert<Fut>(&self, k: K, fut: Fut)
63    where
64        Fut: Future<Output = R> + Send + 'static,
65        K: Hash + Eq + Clone,
66    {
67        let mut inner = self.inner.lock().unwrap();
68
69        inner.ready_queue.push_back(k.clone());
70        inner.futs.insert(k, Box::pin(fut));
71        let waker = inner.waker.take();
72
73        drop(inner);
74
75        self.len.fetch_add(1, Ordering::Relaxed);
76
77        if let Some(waker) = waker {
78            waker.wake();
79        }
80    }
81
82    pub fn poll_next(&self, cx: &mut Context<'_>) -> Poll<(K, R)>
83    where
84        K: Hash + Eq + Clone + Send + Sync + 'static + Debug,
85        R: 'static,
86    {
87        let mut inner = self.inner.lock().unwrap();
88
89        inner.waker = Some(cx.waker().clone());
90
91        while let Some(key) = inner.ready_queue.pop_front() {
92            let mut fut = match inner.futs.remove(&key) {
93                Some(fut) => fut,
94                None => continue,
95            };
96
97            drop(inner);
98
99            let waker = Arc::new(FutureWaitMapWaker(key.clone(), self.inner.clone())).into_waker();
100
101            let mut proxy_context = Context::from_waker(&waker);
102
103            match fut.poll_unpin(&mut proxy_context) {
104                Poll::Ready(r) => {
105                    self.len.fetch_sub(1, Ordering::Relaxed);
106                    return Poll::Ready((key, r));
107                }
108                _ => {
109                    inner = self.inner.lock().unwrap();
110                    inner.futs.insert(key, fut);
111                }
112            }
113        }
114
115        Poll::Pending
116    }
117
118    /// Returns the map's length.
119    pub fn len(&self) -> usize {
120        self.len.load(Ordering::Acquire)
121    }
122
123    /// Returns true if this map is empty.
124    pub fn is_empty(&self) -> bool {
125        self.len() == 0
126    }
127}
128
129impl<K, R> Stream for FuturesUnorderedMap<K, R>
130where
131    K: Hash + Eq + Clone + Send + Sync + 'static + Debug,
132    R: 'static,
133{
134    type Item = (K, R);
135
136    fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
137        FuturesUnorderedMap::poll_next(&self, cx).map(Some)
138    }
139}
140
141impl<K, R> Stream for &FuturesUnorderedMap<K, R>
142where
143    K: Hash + Eq + Clone + Send + Sync + 'static + Debug,
144    R: 'static,
145{
146    type Item = (K, R);
147
148    fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
149        FuturesUnorderedMap::poll_next(&self, cx).map(Some)
150    }
151}
152
153struct FutureWaitMapWaker<K, R>(K, Arc<Mutex<RawFutureWaitMap<K, R>>>);
154
155impl<K, R> WakeRef for FutureWaitMapWaker<K, R>
156where
157    K: Hash + Eq + Clone + Debug,
158{
159    fn wake_by_ref(&self) {
160        let mut inner = self.1.lock().unwrap();
161
162        inner.ready_queue.push_back(self.0.clone());
163
164        let waker = inner.waker.take();
165
166        drop(inner);
167
168        if let Some(waker) = waker {
169            waker.wake();
170        }
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use std::task::Poll;
177
178    use futures::{
179        future::{pending, poll_fn},
180        poll, StreamExt,
181    };
182
183    use super::FuturesUnorderedMap;
184
185    #[futures_test::test]
186    async fn test_map() {
187        let map = FuturesUnorderedMap::new();
188
189        map.insert(1, pending::<i32>());
190
191        let mut map_ref = &map;
192
193        let mut next = map_ref.next();
194
195        assert_eq!(poll!(&mut next), Poll::Pending);
196
197        map.insert(1, poll_fn(|_| Poll::Ready(2)));
198
199        assert_eq!(poll!(&mut next), Poll::Ready(Some((1, 2))));
200    }
201}