1use std::{
2 collections::{HashMap, VecDeque},
3 fmt::Debug,
4 future::Future,
5 hash::Hash,
6 sync::{
7 atomic::{AtomicUsize, Ordering},
8 Arc, Mutex,
9 },
10 task::{Context, Poll, Waker},
11};
12
13use cooked_waker::{IntoWaker, WakeRef};
14use futures::{future::BoxFuture, FutureExt, Stream};
15
16struct RawFutureWaitMap<K, R> {
17 futs: HashMap<K, BoxFuture<'static, R>>,
18 ready_queue: VecDeque<K>,
19 waker: Option<Waker>,
20}
21
22impl<K, R> Default for RawFutureWaitMap<K, R> {
23 fn default() -> Self {
24 Self {
25 futs: HashMap::new(),
26 ready_queue: VecDeque::new(),
27 waker: None,
28 }
29 }
30}
31
32pub struct FuturesUnorderedMap<K, R> {
34 len: Arc<AtomicUsize>,
35 inner: Arc<Mutex<RawFutureWaitMap<K, R>>>,
36}
37
38impl<K, R> Clone for FuturesUnorderedMap<K, R> {
39 fn clone(&self) -> Self {
40 Self {
41 len: self.len.clone(),
42 inner: self.inner.clone(),
43 }
44 }
45}
46
47impl<K, R> AsRef<FuturesUnorderedMap<K, R>> for FuturesUnorderedMap<K, R> {
48 fn as_ref(&self) -> &FuturesUnorderedMap<K, R> {
49 self
50 }
51}
52
53impl<K, R> FuturesUnorderedMap<K, R> {
54 pub fn new() -> Self {
56 Self {
57 len: Default::default(),
58 inner: Default::default(),
59 }
60 }
61 pub fn insert<Fut>(&self, k: K, fut: Fut)
63 where
64 Fut: Future<Output = R> + Send + 'static,
65 K: Hash + Eq + Clone,
66 {
67 let mut inner = self.inner.lock().unwrap();
68
69 inner.ready_queue.push_back(k.clone());
70 inner.futs.insert(k, Box::pin(fut));
71 let waker = inner.waker.take();
72
73 drop(inner);
74
75 self.len.fetch_add(1, Ordering::Relaxed);
76
77 if let Some(waker) = waker {
78 waker.wake();
79 }
80 }
81
82 pub fn poll_next(&self, cx: &mut Context<'_>) -> Poll<(K, R)>
83 where
84 K: Hash + Eq + Clone + Send + Sync + 'static + Debug,
85 R: 'static,
86 {
87 let mut inner = self.inner.lock().unwrap();
88
89 inner.waker = Some(cx.waker().clone());
90
91 while let Some(key) = inner.ready_queue.pop_front() {
92 let mut fut = match inner.futs.remove(&key) {
93 Some(fut) => fut,
94 None => continue,
95 };
96
97 drop(inner);
98
99 let waker = Arc::new(FutureWaitMapWaker(key.clone(), self.inner.clone())).into_waker();
100
101 let mut proxy_context = Context::from_waker(&waker);
102
103 match fut.poll_unpin(&mut proxy_context) {
104 Poll::Ready(r) => {
105 self.len.fetch_sub(1, Ordering::Relaxed);
106 return Poll::Ready((key, r));
107 }
108 _ => {
109 inner = self.inner.lock().unwrap();
110 inner.futs.insert(key, fut);
111 }
112 }
113 }
114
115 Poll::Pending
116 }
117
118 pub fn len(&self) -> usize {
120 self.len.load(Ordering::Acquire)
121 }
122
123 pub fn is_empty(&self) -> bool {
125 self.len() == 0
126 }
127}
128
129impl<K, R> Stream for FuturesUnorderedMap<K, R>
130where
131 K: Hash + Eq + Clone + Send + Sync + 'static + Debug,
132 R: 'static,
133{
134 type Item = (K, R);
135
136 fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
137 FuturesUnorderedMap::poll_next(&self, cx).map(Some)
138 }
139}
140
141impl<K, R> Stream for &FuturesUnorderedMap<K, R>
142where
143 K: Hash + Eq + Clone + Send + Sync + 'static + Debug,
144 R: 'static,
145{
146 type Item = (K, R);
147
148 fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
149 FuturesUnorderedMap::poll_next(&self, cx).map(Some)
150 }
151}
152
153struct FutureWaitMapWaker<K, R>(K, Arc<Mutex<RawFutureWaitMap<K, R>>>);
154
155impl<K, R> WakeRef for FutureWaitMapWaker<K, R>
156where
157 K: Hash + Eq + Clone + Debug,
158{
159 fn wake_by_ref(&self) {
160 let mut inner = self.1.lock().unwrap();
161
162 inner.ready_queue.push_back(self.0.clone());
163
164 let waker = inner.waker.take();
165
166 drop(inner);
167
168 if let Some(waker) = waker {
169 waker.wake();
170 }
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use std::task::Poll;
177
178 use futures::{
179 future::{pending, poll_fn},
180 poll, StreamExt,
181 };
182
183 use super::FuturesUnorderedMap;
184
185 #[futures_test::test]
186 async fn test_map() {
187 let map = FuturesUnorderedMap::new();
188
189 map.insert(1, pending::<i32>());
190
191 let mut map_ref = ↦
192
193 let mut next = map_ref.next();
194
195 assert_eq!(poll!(&mut next), Poll::Pending);
196
197 map.insert(1, poll_fn(|_| Poll::Ready(2)));
198
199 assert_eq!(poll!(&mut next), Poll::Ready(Some((1, 2))));
200 }
201}