1use std::future::Future;
2use std::hash::Hash;
3use std::pin::Pin;
4use std::task::{Context, Poll, Waker};
5use std::time::Duration;
6use std::{future, mem};
7
8use futures_util::future::BoxFuture;
9use futures_util::stream::FuturesUnordered;
10use futures_util::{FutureExt, StreamExt};
11
12use crate::{Delay, PushError, Timeout};
13
14pub struct FuturesMap<ID, O> {
18 make_delay: Box<dyn Fn() -> Delay + Send + Sync>,
19 capacity: usize,
20 inner: FuturesUnordered<TaggedFuture<ID, TimeoutFuture<BoxFuture<'static, O>>>>,
21 empty_waker: Option<Waker>,
22 full_waker: Option<Waker>,
23}
24
25impl<ID, O> FuturesMap<ID, O> {
26 pub fn new(make_delay: impl Fn() -> Delay + Send + Sync + 'static, capacity: usize) -> Self {
27 Self {
28 make_delay: Box::new(make_delay),
29 capacity,
30 inner: Default::default(),
31 empty_waker: None,
32 full_waker: None,
33 }
34 }
35}
36
37impl<ID, O> FuturesMap<ID, O>
38where
39 ID: Clone + Hash + Eq + Send + Unpin + 'static,
40 O: 'static,
41{
42 pub fn try_push<F>(&mut self, future_id: ID, future: F) -> Result<(), PushError<BoxFuture<O>>>
50 where
51 F: Future<Output = O> + Send + 'static,
52 {
53 if self.inner.len() >= self.capacity {
54 return Err(PushError::BeyondCapacity(future.boxed()));
55 }
56
57 if let Some(waker) = self.empty_waker.take() {
58 waker.wake();
59 }
60
61 let old = self.remove(future_id.clone());
62 self.inner.push(TaggedFuture {
63 tag: future_id,
64 inner: TimeoutFuture {
65 inner: future.boxed(),
66 timeout: (self.make_delay)(),
67 cancelled: false,
68 },
69 });
70 match old {
71 None => Ok(()),
72 Some(old) => Err(PushError::Replaced(old)),
73 }
74 }
75
76 pub fn remove(&mut self, id: ID) -> Option<BoxFuture<'static, O>> {
77 let tagged = self.inner.iter_mut().find(|s| s.tag == id)?;
78
79 let inner = mem::replace(&mut tagged.inner.inner, future::pending().boxed());
80 tagged.inner.cancelled = true;
81
82 Some(inner)
83 }
84
85 pub fn contains(&self, id: ID) -> bool {
86 self.inner.iter().any(|f| f.tag == id && !f.inner.cancelled)
87 }
88
89 pub fn len(&self) -> usize {
90 self.inner.len()
91 }
92
93 pub fn is_empty(&self) -> bool {
94 self.inner.is_empty()
95 }
96
97 #[allow(unknown_lints, clippy::needless_pass_by_ref_mut)] pub fn poll_ready_unpin(&mut self, cx: &mut Context<'_>) -> Poll<()> {
99 if self.inner.len() < self.capacity {
100 return Poll::Ready(());
101 }
102
103 self.full_waker = Some(cx.waker().clone());
104
105 Poll::Pending
106 }
107
108 pub fn poll_unpin(&mut self, cx: &mut Context<'_>) -> Poll<(ID, Result<O, Timeout>)> {
109 loop {
110 let maybe_result = futures_util::ready!(self.inner.poll_next_unpin(cx));
111
112 match maybe_result {
113 None => {
114 self.empty_waker = Some(cx.waker().clone());
115 return Poll::Pending;
116 }
117 Some((id, Ok(output))) => return Poll::Ready((id, Ok(output))),
118 Some((id, Err(TimeoutError::Timeout(dur)))) => {
119 return Poll::Ready((id, Err(Timeout::new(dur))))
120 }
121 Some((_, Err(TimeoutError::Cancelled))) => continue,
122 }
123 }
124 }
125}
126
127struct TimeoutFuture<F> {
128 inner: F,
129 timeout: Delay,
130
131 cancelled: bool,
132}
133
134impl<F> Future for TimeoutFuture<F>
135where
136 F: Future + Unpin,
137{
138 type Output = Result<F::Output, TimeoutError>;
139
140 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
141 if self.cancelled {
142 return Poll::Ready(Err(TimeoutError::Cancelled));
143 }
144
145 if let Poll::Ready(duration) = self.timeout.poll_unpin(cx) {
146 return Poll::Ready(Err(TimeoutError::Timeout(duration)));
147 }
148
149 self.inner.poll_unpin(cx).map(Ok)
150 }
151}
152
153enum TimeoutError {
154 Timeout(Duration),
155 Cancelled,
156}
157
158struct TaggedFuture<T, F> {
159 tag: T,
160 inner: F,
161}
162
163impl<T, F> Future for TaggedFuture<T, F>
164where
165 T: Clone + Unpin,
166 F: Future + Unpin,
167{
168 type Output = (T, F::Output);
169
170 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
171 let output = futures_util::ready!(self.inner.poll_unpin(cx));
172
173 Poll::Ready((self.tag.clone(), output))
174 }
175}
176
177#[cfg(all(test, feature = "futures-timer"))]
178mod tests {
179 use futures::channel::oneshot;
180 use futures_util::task::noop_waker_ref;
181 use std::future::{pending, poll_fn, ready};
182 use std::pin::Pin;
183 use std::time::Instant;
184
185 use super::*;
186
187 #[test]
188 fn cannot_push_more_than_capacity_tasks() {
189 let mut futures = FuturesMap::new(|| Delay::futures_timer(Duration::from_secs(10)), 1);
190
191 assert!(futures.try_push("ID_1", ready(())).is_ok());
192 matches!(
193 futures.try_push("ID_2", ready(())),
194 Err(PushError::BeyondCapacity(_))
195 );
196 }
197
198 #[test]
199 fn cannot_push_the_same_id_few_times() {
200 let mut futures = FuturesMap::new(|| Delay::futures_timer(Duration::from_secs(10)), 5);
201
202 assert!(futures.try_push("ID", ready(())).is_ok());
203 matches!(
204 futures.try_push("ID", ready(())),
205 Err(PushError::Replaced(_))
206 );
207 }
208
209 #[tokio::test]
210 async fn futures_timeout() {
211 let mut futures = FuturesMap::new(|| Delay::futures_timer(Duration::from_millis(100)), 1);
212
213 let _ = futures.try_push("ID", pending::<()>());
214 futures_timer::Delay::new(Duration::from_millis(150)).await;
215 let (_, result) = poll_fn(|cx| futures.poll_unpin(cx)).await;
216
217 assert!(result.is_err())
218 }
219
220 #[test]
221 fn resources_of_removed_future_are_cleaned_up() {
222 let mut futures = FuturesMap::new(|| Delay::futures_timer(Duration::from_millis(100)), 1);
223
224 let _ = futures.try_push("ID", pending::<()>());
225 futures.remove("ID");
226
227 let poll = futures.poll_unpin(&mut Context::from_waker(noop_waker_ref()));
228 assert!(poll.is_pending());
229
230 assert_eq!(futures.len(), 0);
231 }
232
233 #[tokio::test]
234 async fn replaced_pending_future_is_polled() {
235 let mut streams = FuturesMap::new(|| Delay::futures_timer(Duration::from_millis(100)), 3);
236
237 let (_tx1, rx1) = oneshot::channel();
238 let (tx2, rx2) = oneshot::channel();
239
240 let _ = streams.try_push("ID1", rx1);
241 let _ = streams.try_push("ID2", rx2);
242
243 let _ = tx2.send(2);
244 let (id, res) = poll_fn(|cx| streams.poll_unpin(cx)).await;
245 assert_eq!(id, "ID2");
246 assert_eq!(res.unwrap().unwrap(), 2);
247
248 let (new_tx1, new_rx1) = oneshot::channel();
249 let replaced = streams.try_push("ID1", new_rx1);
250 assert!(matches!(replaced.unwrap_err(), PushError::Replaced(_)));
251
252 let _ = new_tx1.send(4);
253 let (id, res) = poll_fn(|cx| streams.poll_unpin(cx)).await;
254
255 assert_eq!(id, "ID1");
256 assert_eq!(res.unwrap().unwrap(), 4);
257 }
258
259 #[tokio::test]
262 async fn backpressure() {
263 const DELAY: Duration = Duration::from_millis(100);
264 const NUM_FUTURES: u32 = 10;
265
266 let start = Instant::now();
267 Task::new(DELAY, NUM_FUTURES, 1).await;
268 let duration = start.elapsed();
269
270 assert!(duration >= DELAY * NUM_FUTURES);
271 }
272
273 #[test]
274 fn contains() {
275 let mut futures = FuturesMap::new(|| Delay::futures_timer(Duration::from_secs(10)), 1);
276 _ = futures.try_push("ID", pending::<()>());
277 assert!(futures.contains("ID"));
278 _ = futures.remove("ID");
279 assert!(!futures.contains("ID"));
280 }
281
282 struct Task {
283 future: Duration,
284 num_futures: usize,
285 num_processed: usize,
286 inner: FuturesMap<u8, ()>,
287 }
288
289 impl Task {
290 fn new(future: Duration, num_futures: u32, capacity: usize) -> Self {
291 Self {
292 future,
293 num_futures: num_futures as usize,
294 num_processed: 0,
295 inner: FuturesMap::new(|| Delay::futures_timer(Duration::from_secs(60)), capacity),
296 }
297 }
298 }
299
300 impl Future for Task {
301 type Output = ();
302
303 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
304 let this = self.get_mut();
305
306 while this.num_processed < this.num_futures {
307 if let Poll::Ready((_, result)) = this.inner.poll_unpin(cx) {
308 if result.is_err() {
309 panic!("Timeout is great than future delay")
310 }
311
312 this.num_processed += 1;
313 continue;
314 }
315
316 if let Poll::Ready(()) = this.inner.poll_ready_unpin(cx) {
317 let maybe_future = this
320 .inner
321 .try_push(1u8, futures_timer::Delay::new(this.future));
322 assert!(maybe_future.is_ok(), "we polled for readiness");
323
324 continue;
325 }
326
327 return Poll::Pending;
328 }
329
330 Poll::Ready(())
331 }
332 }
333}