futures_bounded/
futures_map.rs

1use std::future::Future;
2use std::hash::Hash;
3use std::pin::Pin;
4use std::task::{Context, Poll, Waker};
5use std::time::Duration;
6use std::{future, mem};
7
8use futures_util::future::BoxFuture;
9use futures_util::stream::FuturesUnordered;
10use futures_util::{FutureExt, StreamExt};
11
12use crate::{Delay, PushError, Timeout};
13
14/// Represents a map of [`Future`]s.
15///
16/// Each future must finish within the specified time and the map never outgrows its capacity.
17pub struct FuturesMap<ID, O> {
18    make_delay: Box<dyn Fn() -> Delay + Send + Sync>,
19    capacity: usize,
20    inner: FuturesUnordered<TaggedFuture<ID, TimeoutFuture<BoxFuture<'static, O>>>>,
21    empty_waker: Option<Waker>,
22    full_waker: Option<Waker>,
23}
24
25impl<ID, O> FuturesMap<ID, O> {
26    pub fn new(make_delay: impl Fn() -> Delay + Send + Sync + 'static, capacity: usize) -> Self {
27        Self {
28            make_delay: Box::new(make_delay),
29            capacity,
30            inner: Default::default(),
31            empty_waker: None,
32            full_waker: None,
33        }
34    }
35}
36
37impl<ID, O> FuturesMap<ID, O>
38where
39    ID: Clone + Hash + Eq + Send + Unpin + 'static,
40    O: 'static,
41{
42    /// Push a future into the map.
43    ///
44    /// This method inserts the given future with defined `future_id` to the set.
45    /// If the length of the map is equal to the capacity, this method returns [PushError::BeyondCapacity],
46    /// that contains the passed future. In that case, the future is not inserted to the map.
47    /// If a future with the given `future_id` already exists, then the old future will be replaced by a new one.
48    /// In that case, the returned error [PushError::Replaced] contains the old future.
49    pub fn try_push<F>(&mut self, future_id: ID, future: F) -> Result<(), PushError<BoxFuture<O>>>
50    where
51        F: Future<Output = O> + Send + 'static,
52    {
53        if self.inner.len() >= self.capacity {
54            return Err(PushError::BeyondCapacity(future.boxed()));
55        }
56
57        if let Some(waker) = self.empty_waker.take() {
58            waker.wake();
59        }
60
61        let old = self.remove(future_id.clone());
62        self.inner.push(TaggedFuture {
63            tag: future_id,
64            inner: TimeoutFuture {
65                inner: future.boxed(),
66                timeout: (self.make_delay)(),
67                cancelled: false,
68            },
69        });
70        match old {
71            None => Ok(()),
72            Some(old) => Err(PushError::Replaced(old)),
73        }
74    }
75
76    pub fn remove(&mut self, id: ID) -> Option<BoxFuture<'static, O>> {
77        let tagged = self.inner.iter_mut().find(|s| s.tag == id)?;
78
79        let inner = mem::replace(&mut tagged.inner.inner, future::pending().boxed());
80        tagged.inner.cancelled = true;
81
82        Some(inner)
83    }
84
85    pub fn contains(&self, id: ID) -> bool {
86        self.inner.iter().any(|f| f.tag == id && !f.inner.cancelled)
87    }
88
89    pub fn len(&self) -> usize {
90        self.inner.len()
91    }
92
93    pub fn is_empty(&self) -> bool {
94        self.inner.is_empty()
95    }
96
97    #[allow(unknown_lints, clippy::needless_pass_by_ref_mut)] // &mut Context is idiomatic.
98    pub fn poll_ready_unpin(&mut self, cx: &mut Context<'_>) -> Poll<()> {
99        if self.inner.len() < self.capacity {
100            return Poll::Ready(());
101        }
102
103        self.full_waker = Some(cx.waker().clone());
104
105        Poll::Pending
106    }
107
108    pub fn poll_unpin(&mut self, cx: &mut Context<'_>) -> Poll<(ID, Result<O, Timeout>)> {
109        loop {
110            let maybe_result = futures_util::ready!(self.inner.poll_next_unpin(cx));
111
112            match maybe_result {
113                None => {
114                    self.empty_waker = Some(cx.waker().clone());
115                    return Poll::Pending;
116                }
117                Some((id, Ok(output))) => return Poll::Ready((id, Ok(output))),
118                Some((id, Err(TimeoutError::Timeout(dur)))) => {
119                    return Poll::Ready((id, Err(Timeout::new(dur))))
120                }
121                Some((_, Err(TimeoutError::Cancelled))) => continue,
122            }
123        }
124    }
125}
126
127struct TimeoutFuture<F> {
128    inner: F,
129    timeout: Delay,
130
131    cancelled: bool,
132}
133
134impl<F> Future for TimeoutFuture<F>
135where
136    F: Future + Unpin,
137{
138    type Output = Result<F::Output, TimeoutError>;
139
140    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
141        if self.cancelled {
142            return Poll::Ready(Err(TimeoutError::Cancelled));
143        }
144
145        if let Poll::Ready(duration) = self.timeout.poll_unpin(cx) {
146            return Poll::Ready(Err(TimeoutError::Timeout(duration)));
147        }
148
149        self.inner.poll_unpin(cx).map(Ok)
150    }
151}
152
153enum TimeoutError {
154    Timeout(Duration),
155    Cancelled,
156}
157
158struct TaggedFuture<T, F> {
159    tag: T,
160    inner: F,
161}
162
163impl<T, F> Future for TaggedFuture<T, F>
164where
165    T: Clone + Unpin,
166    F: Future + Unpin,
167{
168    type Output = (T, F::Output);
169
170    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
171        let output = futures_util::ready!(self.inner.poll_unpin(cx));
172
173        Poll::Ready((self.tag.clone(), output))
174    }
175}
176
177#[cfg(all(test, feature = "futures-timer"))]
178mod tests {
179    use futures::channel::oneshot;
180    use futures_util::task::noop_waker_ref;
181    use std::future::{pending, poll_fn, ready};
182    use std::pin::Pin;
183    use std::time::Instant;
184
185    use super::*;
186
187    #[test]
188    fn cannot_push_more_than_capacity_tasks() {
189        let mut futures = FuturesMap::new(|| Delay::futures_timer(Duration::from_secs(10)), 1);
190
191        assert!(futures.try_push("ID_1", ready(())).is_ok());
192        matches!(
193            futures.try_push("ID_2", ready(())),
194            Err(PushError::BeyondCapacity(_))
195        );
196    }
197
198    #[test]
199    fn cannot_push_the_same_id_few_times() {
200        let mut futures = FuturesMap::new(|| Delay::futures_timer(Duration::from_secs(10)), 5);
201
202        assert!(futures.try_push("ID", ready(())).is_ok());
203        matches!(
204            futures.try_push("ID", ready(())),
205            Err(PushError::Replaced(_))
206        );
207    }
208
209    #[tokio::test]
210    async fn futures_timeout() {
211        let mut futures = FuturesMap::new(|| Delay::futures_timer(Duration::from_millis(100)), 1);
212
213        let _ = futures.try_push("ID", pending::<()>());
214        futures_timer::Delay::new(Duration::from_millis(150)).await;
215        let (_, result) = poll_fn(|cx| futures.poll_unpin(cx)).await;
216
217        assert!(result.is_err())
218    }
219
220    #[test]
221    fn resources_of_removed_future_are_cleaned_up() {
222        let mut futures = FuturesMap::new(|| Delay::futures_timer(Duration::from_millis(100)), 1);
223
224        let _ = futures.try_push("ID", pending::<()>());
225        futures.remove("ID");
226
227        let poll = futures.poll_unpin(&mut Context::from_waker(noop_waker_ref()));
228        assert!(poll.is_pending());
229
230        assert_eq!(futures.len(), 0);
231    }
232
233    #[tokio::test]
234    async fn replaced_pending_future_is_polled() {
235        let mut streams = FuturesMap::new(|| Delay::futures_timer(Duration::from_millis(100)), 3);
236
237        let (_tx1, rx1) = oneshot::channel();
238        let (tx2, rx2) = oneshot::channel();
239
240        let _ = streams.try_push("ID1", rx1);
241        let _ = streams.try_push("ID2", rx2);
242
243        let _ = tx2.send(2);
244        let (id, res) = poll_fn(|cx| streams.poll_unpin(cx)).await;
245        assert_eq!(id, "ID2");
246        assert_eq!(res.unwrap().unwrap(), 2);
247
248        let (new_tx1, new_rx1) = oneshot::channel();
249        let replaced = streams.try_push("ID1", new_rx1);
250        assert!(matches!(replaced.unwrap_err(), PushError::Replaced(_)));
251
252        let _ = new_tx1.send(4);
253        let (id, res) = poll_fn(|cx| streams.poll_unpin(cx)).await;
254
255        assert_eq!(id, "ID1");
256        assert_eq!(res.unwrap().unwrap(), 4);
257    }
258
259    // Each future causes a delay, `Task` only has a capacity of 1, meaning they must be processed in sequence.
260    // We stop after NUM_FUTURES tasks, meaning the overall execution must at least take DELAY * NUM_FUTURES.
261    #[tokio::test]
262    async fn backpressure() {
263        const DELAY: Duration = Duration::from_millis(100);
264        const NUM_FUTURES: u32 = 10;
265
266        let start = Instant::now();
267        Task::new(DELAY, NUM_FUTURES, 1).await;
268        let duration = start.elapsed();
269
270        assert!(duration >= DELAY * NUM_FUTURES);
271    }
272
273    #[test]
274    fn contains() {
275        let mut futures = FuturesMap::new(|| Delay::futures_timer(Duration::from_secs(10)), 1);
276        _ = futures.try_push("ID", pending::<()>());
277        assert!(futures.contains("ID"));
278        _ = futures.remove("ID");
279        assert!(!futures.contains("ID"));
280    }
281
282    struct Task {
283        future: Duration,
284        num_futures: usize,
285        num_processed: usize,
286        inner: FuturesMap<u8, ()>,
287    }
288
289    impl Task {
290        fn new(future: Duration, num_futures: u32, capacity: usize) -> Self {
291            Self {
292                future,
293                num_futures: num_futures as usize,
294                num_processed: 0,
295                inner: FuturesMap::new(|| Delay::futures_timer(Duration::from_secs(60)), capacity),
296            }
297        }
298    }
299
300    impl Future for Task {
301        type Output = ();
302
303        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
304            let this = self.get_mut();
305
306            while this.num_processed < this.num_futures {
307                if let Poll::Ready((_, result)) = this.inner.poll_unpin(cx) {
308                    if result.is_err() {
309                        panic!("Timeout is great than future delay")
310                    }
311
312                    this.num_processed += 1;
313                    continue;
314                }
315
316                if let Poll::Ready(()) = this.inner.poll_ready_unpin(cx) {
317                    // We push the constant future's ID to prove that user can use the same ID
318                    // if the future was finished
319                    let maybe_future = this
320                        .inner
321                        .try_push(1u8, futures_timer::Delay::new(this.future));
322                    assert!(maybe_future.is_ok(), "we polled for readiness");
323
324                    continue;
325                }
326
327                return Poll::Pending;
328            }
329
330            Poll::Ready(())
331        }
332    }
333}