futures_buffered/futures_unordered_bounded.rs
1use core::{
2    fmt,
3    future::Future,
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use crate::{slot_map::PinSlotMap, waker_list::WakerList};
9use futures_core::{FusedStream, Stream};
10
11/// A set of futures which may complete in any order.
12///
13/// Much like [`futures::stream::FuturesUnordered`](https://docs.rs/futures/0.3.25/futures/stream/struct.FuturesUnordered.html),
14/// this is a thread-safe, `Pin` friendly, lifetime friendly, concurrent processing stream.
15///
16/// The is different to `FuturesUnordered` in that `FuturesUnorderedBounded` has a fixed capacity for processing count.
17/// This means it's less flexible, but produces better memory efficiency.
18///
19/// ## Benchmarks
20///
21/// ### Speed
22///
23/// Running 65536 100us timers with 256 concurrent jobs in a single threaded tokio runtime:
24///
25/// ```text
26/// FuturesUnordered         time:   [420.47 ms 422.21 ms 423.99 ms]
27/// FuturesUnorderedBounded  time:   [366.02 ms 367.54 ms 369.05 ms]
28/// ```
29///
30/// ### Memory usage
31///
32/// Running 512000 `Ready<i32>` futures with 256 concurrent jobs.
33///
34/// - count: the number of times alloc/dealloc was called
35/// - alloc: the number of cumulative bytes allocated
36/// - dealloc: the number of cumulative bytes deallocated
37///
38/// ```text
39/// FuturesUnordered
40///     count:    1024002
41///     alloc:    40960144 B
42///     dealloc:  40960000 B
43///
44/// FuturesUnorderedBounded
45///     count:    2
46///     alloc:    8264 B
47///     dealloc:  0 B
48/// ```
49///
50/// ### Conclusion
51///
52/// As you can see, `FuturesUnorderedBounded` massively reduces you memory overhead while providing a significant performance gain.
53/// Perfect for if you want a fixed batch size
54///
55/// # Example
56///
57/// Making 1024 total HTTP requests, with a max concurrency of 128
58///
59/// ```
60/// use futures::future::Future;
61/// use futures::stream::StreamExt;
62/// use futures_buffered::FuturesUnorderedBounded;
63/// use hyper::client::conn::http1::{handshake, SendRequest};
64/// use hyper::body::Incoming;
65/// use hyper::{Request, Response};
66/// use hyper_util::rt::TokioIo;
67/// use tokio::net::TcpStream;
68///
69/// # #[cfg(miri)] fn main() {}
70/// # #[cfg(not(miri))] #[tokio::main]
71/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
72/// // create a tcp connection
73/// let stream = TcpStream::connect("example.com:80").await?;
74///
75/// // perform the http handshakes
76/// let (mut rs, conn) = handshake(TokioIo::new(stream)).await?;
77/// tokio::spawn(conn);
78///
79/// /// make http request to example.com and read the response
80/// fn make_req(rs: &mut SendRequest<String>) -> impl Future<Output = hyper::Result<Response<Incoming>>> {
81///     let req = Request::builder()
82///         .header("Host", "example.com")
83///         .method("GET")
84///         .body(String::new())
85///         .unwrap();
86///     rs.send_request(req)
87/// }
88///
89/// // create a queue that can hold 128 concurrent requests
90/// let mut queue = FuturesUnorderedBounded::new(128);
91///
92/// // start up 128 requests
93/// for _ in 0..128 {
94///     queue.push(make_req(&mut rs));
95/// }
96/// // wait for a request to finish and start another to fill its place - up to 1024 total requests
97/// for _ in 128..1024 {
98///     queue.next().await;
99///     queue.push(make_req(&mut rs));
100/// }
101/// // wait for the tail end to finish
102/// for _ in 0..128 {
103///     queue.next().await;
104/// }
105/// # Ok(()) }
106/// ```
107pub struct FuturesUnorderedBounded<F> {
108    pub(crate) tasks: PinSlotMap<F>,
109    pub(crate) shared: WakerList,
110}
111
112impl<F> Unpin for FuturesUnorderedBounded<F> {}
113
114impl<F> FuturesUnorderedBounded<F> {
115    /// Constructs a new, empty [`FuturesUnorderedBounded`] with the given fixed capacity.
116    ///
117    /// The returned [`FuturesUnorderedBounded`] does not contain any futures.
118    /// In this state, [`FuturesUnorderedBounded::poll_next`](Stream::poll_next) will
119    /// return [`Poll::Ready(None)`](Poll::Ready).
120    pub fn new(cap: usize) -> Self {
121        Self {
122            tasks: PinSlotMap::new(cap),
123            shared: WakerList::new(cap),
124        }
125    }
126
127    /// Push a future into the set.
128    ///
129    /// This method adds the given future to the set. This method will not
130    /// call [`poll`](core::future::Future::poll) on the submitted future. The caller must
131    /// ensure that [`FuturesUnorderedBounded::poll_next`](Stream::poll_next) is called
132    /// in order to receive wake-up notifications for the given future.
133    ///
134    /// # Panics
135    /// This method will panic if the buffer is currently full. See [`FuturesUnorderedBounded::try_push`] to get a result instead
136    #[track_caller]
137    pub fn push(&mut self, fut: F) {
138        if self.try_push(fut).is_err() {
139            panic!("attempted to push into a full `FuturesUnorderedBounded`");
140        }
141    }
142
143    /// Push a future into the set.
144    ///
145    /// This method adds the given future to the set. This method will not
146    /// call [`poll`](core::future::Future::poll) on the submitted future. The caller must
147    /// ensure that [`FuturesUnorderedBounded::poll_next`](Stream::poll_next) is called
148    /// in order to receive wake-up notifications for the given future.
149    ///
150    /// # Errors
151    /// This method will error if the buffer is currently full, returning the future back
152    pub fn try_push(&mut self, fut: F) -> Result<(), F> {
153        self.try_push_with(fut, core::convert::identity)
154    }
155
156    #[inline]
157    pub(crate) fn try_push_with<T>(&mut self, t: T, f: impl FnMut(T) -> F) -> Result<(), T> {
158        let i = self.tasks.insert_with(t, f)?;
159        // safety: i is always within capacity
160        unsafe {
161            self.shared.push(i);
162        }
163        Ok(())
164    }
165
166    /// Returns `true` if the set contains no futures.
167    pub fn is_empty(&self) -> bool {
168        self.tasks.is_empty()
169    }
170
171    /// Returns the number of futures contained in the set.
172    ///
173    /// This represents the total number of in-flight futures.
174    pub fn len(&self) -> usize {
175        self.tasks.len()
176    }
177
178    /// Returns the number of futures that can be contained in the set.
179    pub fn capacity(&self) -> usize {
180        self.tasks.capacity()
181    }
182}
183
184type PollFn<F, O> = fn(Pin<&mut F>, cx: &mut Context<'_>) -> Poll<O>;
185
186impl<F> FuturesUnorderedBounded<F> {
187    pub(crate) fn poll_inner_no_remove<O>(
188        &mut self,
189        cx: &mut Context<'_>,
190        poll_fn: PollFn<F, O>,
191    ) -> Poll<Option<(usize, O)>> {
192        const MAX: usize = 61;
193
194        if self.is_empty() {
195            return Poll::Ready(None);
196        }
197
198        self.shared.register(cx.waker());
199
200        let mut count = 0;
201        loop {
202            count += 1;
203            // if we are in a pending only loop - let's break out.
204            if count > MAX {
205                cx.waker().wake_by_ref();
206                return Poll::Pending;
207            }
208
209            match unsafe { self.shared.pop() } {
210                crate::waker_list::ReadySlot::None => return Poll::Pending,
211                crate::waker_list::ReadySlot::Inconsistent => {
212                    cx.waker().wake_by_ref();
213                    return Poll::Pending;
214                }
215                crate::waker_list::ReadySlot::Ready((i, waker)) => {
216                    if let Some(task) = self.tasks.get(i) {
217                        let mut cx = Context::from_waker(&waker);
218
219                        let res = poll_fn(task, &mut cx);
220
221                        if let Poll::Ready(x) = res {
222                            return Poll::Ready(Some((i, x)));
223                        }
224                    }
225                }
226            }
227        }
228    }
229}
230
231impl<F: Future> FuturesUnorderedBounded<F> {
232    pub(crate) fn poll_inner(&mut self, cx: &mut Context<'_>) -> Poll<Option<(usize, F::Output)>> {
233        match self.poll_inner_no_remove(cx, F::poll) {
234            Poll::Ready(Some((i, x))) => {
235                self.tasks.remove(i);
236                Poll::Ready(Some((i, x)))
237            }
238            p => p,
239        }
240    }
241}
242
243impl<F: Future> Stream for FuturesUnorderedBounded<F> {
244    type Item = F::Output;
245
246    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
247        match self.poll_inner(cx) {
248            Poll::Ready(Some((_, x))) => Poll::Ready(Some(x)),
249            Poll::Ready(None) => Poll::Ready(None),
250            Poll::Pending => Poll::Pending,
251        }
252    }
253
254    fn size_hint(&self) -> (usize, Option<usize>) {
255        let len = self.len();
256        (len, Some(len))
257    }
258}
259
260impl<F> FromIterator<F> for FuturesUnorderedBounded<F> {
261    /// Constructs a new, empty [`FuturesUnorderedBounded`] with a fixed capacity that is the length of the iterator.
262    ///
263    /// # Example
264    ///
265    /// Making 1024 total HTTP requests, with a max concurrency of 128
266    ///
267    /// ```
268    /// use futures::future::Future;
269    /// use futures::stream::StreamExt;
270    /// use futures_buffered::FuturesUnorderedBounded;
271    /// use hyper::client::conn::http1::{handshake, SendRequest};
272    /// use hyper::body::Incoming;
273    /// use hyper::{Request, Response};
274    /// use hyper_util::rt::TokioIo;
275    /// use tokio::net::TcpStream;
276    ///
277    /// # #[cfg(miri)] fn main() {}
278    /// # #[cfg(not(miri))] #[tokio::main]
279    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
280    /// // create a tcp connection
281    /// let stream = TcpStream::connect("example.com:80").await?;
282    ///
283    /// // perform the http handshakes
284    /// let (mut rs, conn) = handshake(TokioIo::new(stream)).await?;
285    /// tokio::spawn(conn);
286    ///
287    /// /// make http request to example.com and read the response
288    /// fn make_req(rs: &mut SendRequest<String>) -> impl Future<Output = hyper::Result<Response<Incoming>>> {
289    ///     let req = Request::builder()
290    ///         .header("Host", "example.com")
291    ///         .method("GET")
292    ///         .body(String::new())
293    ///         .unwrap();
294    ///     rs.send_request(req)
295    /// }
296    ///
297    /// // create a queue with an initial 128 concurrent requests
298    /// let mut queue: FuturesUnorderedBounded<_> = (0..128).map(|_| make_req(&mut rs)).collect();
299    ///
300    /// // wait for a request to finish and start another to fill its place - up to 1024 total requests
301    /// for _ in 128..1024 {
302    ///     queue.next().await;
303    ///     queue.push(make_req(&mut rs));
304    /// }
305    /// // wait for the tail end to finish
306    /// for _ in 0..128 {
307    ///     queue.next().await;
308    /// }
309    /// # Ok(()) }
310    /// ```
311    fn from_iter<T: IntoIterator<Item = F>>(iter: T) -> Self {
312        // store the futures in our task list
313        let tasks = PinSlotMap::from_iter(iter);
314
315        // determine the actual capacity and create the shared state
316        let cap = tasks.len();
317        let shared = WakerList::new(cap);
318
319        for i in 0..cap {
320            // safety: i is always within capacity
321            unsafe {
322                shared.push(i);
323            }
324        }
325
326        // create the queue
327        Self { tasks, shared }
328    }
329}
330
331impl<Fut: Future> FusedStream for FuturesUnorderedBounded<Fut> {
332    fn is_terminated(&self) -> bool {
333        self.is_empty()
334    }
335}
336
337impl<Fut> fmt::Debug for FuturesUnorderedBounded<Fut> {
338    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
339        f.debug_struct("FuturesUnorderedBounded")
340            .field("len", &self.tasks.len())
341            .finish_non_exhaustive()
342    }
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348    use core::{
349        cell::Cell,
350        future::{poll_fn, ready},
351        time::Duration,
352    };
353    use futures::{channel::oneshot, StreamExt};
354    use futures_test::task::noop_context;
355    use pin_project_lite::pin_project;
356    use std::time::Instant;
357
358    pin_project!(
359        struct PollCounter<'c, F> {
360            count: &'c Cell<usize>,
361            #[pin]
362            inner: F,
363        }
364    );
365
366    impl<F: Future> Future for PollCounter<'_, F> {
367        type Output = F::Output;
368        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
369            self.count.set(self.count.get() + 1);
370            self.project().inner.poll(cx)
371        }
372    }
373
374    struct Yield {
375        done: bool,
376    }
377    impl Unpin for Yield {}
378    impl Future for Yield {
379        type Output = ();
380
381        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
382            if self.as_mut().done {
383                Poll::Ready(())
384            } else {
385                cx.waker().wake_by_ref();
386                self.as_mut().done = true;
387                Poll::Pending
388            }
389        }
390    }
391
392    fn yield_now(count: &Cell<usize>) -> PollCounter<'_, Yield> {
393        PollCounter {
394            count,
395            inner: Yield { done: false },
396        }
397    }
398
399    #[test]
400    fn single() {
401        let c = Cell::new(0);
402
403        let mut buffer = FuturesUnorderedBounded::new(10);
404        buffer.push(yield_now(&c));
405        futures::executor::block_on(buffer.next());
406
407        drop(buffer);
408        assert_eq!(c.into_inner(), 2);
409    }
410
411    #[test]
412    #[should_panic(expected = "attempted to push into a full `FuturesUnorderedBounded`")]
413    fn full() {
414        let mut buffer = FuturesUnorderedBounded::new(1);
415        buffer.push(ready(()));
416        buffer.push(ready(()));
417    }
418
419    #[test]
420    fn len() {
421        let mut buffer = FuturesUnorderedBounded::new(1);
422
423        assert_eq!(buffer.len(), 0);
424        assert!(buffer.is_empty());
425        assert_eq!(buffer.capacity(), 1);
426        assert_eq!(buffer.size_hint(), (0, Some(0)));
427        assert!(buffer.is_terminated());
428
429        buffer.push(ready(()));
430
431        assert_eq!(buffer.len(), 1);
432        assert!(!buffer.is_empty());
433        assert_eq!(buffer.capacity(), 1);
434        assert_eq!(buffer.size_hint(), (1, Some(1)));
435        assert!(!buffer.is_terminated());
436
437        futures::executor::block_on(buffer.next());
438
439        assert_eq!(buffer.len(), 0);
440        assert!(buffer.is_empty());
441        assert_eq!(buffer.capacity(), 1);
442        assert_eq!(buffer.size_hint(), (0, Some(0)));
443        assert!(buffer.is_terminated());
444    }
445
446    #[test]
447    fn from_iter() {
448        let buffer = FuturesUnorderedBounded::from_iter((0..10).map(|_| ready(())));
449
450        assert_eq!(buffer.len(), 10);
451        assert_eq!(buffer.capacity(), 10);
452        assert_eq!(buffer.size_hint(), (10, Some(10)));
453    }
454
455    #[test]
456    fn drop_while_waiting() {
457        let mut buffer = FuturesUnorderedBounded::new(10);
458        let waker = Cell::new(None);
459        buffer.push(poll_fn(|cx| {
460            waker.set(Some(cx.waker().clone()));
461            Poll::<()>::Pending
462        }));
463
464        assert_eq!(buffer.poll_next_unpin(&mut noop_context()), Poll::Pending);
465        drop(buffer);
466
467        let cx = waker.take().unwrap();
468        drop(cx);
469    }
470
471    #[test]
472    fn multi() {
473        fn wait(count: &Cell<usize>) -> PollCounter<'_, Yield> {
474            yield_now(count)
475        }
476
477        let c = Cell::new(0);
478
479        let mut buffer = FuturesUnorderedBounded::new(10);
480        // build up
481        for _ in 0..10 {
482            buffer.push(wait(&c));
483        }
484        // poll and insert
485        for _ in 0..100 {
486            assert!(futures::executor::block_on(buffer.next()).is_some());
487            buffer.push(wait(&c));
488        }
489        // drain down
490        for _ in 0..10 {
491            assert!(futures::executor::block_on(buffer.next()).is_some());
492        }
493
494        let count = c.into_inner();
495        assert_eq!(count, 220);
496    }
497
498    #[test]
499    fn very_slow_task() {
500        let c = Cell::new(0);
501
502        let now = Instant::now();
503
504        let mut buffer = FuturesUnorderedBounded::new(10);
505        // build up
506        for _ in 0..9 {
507            buffer.push(yield_now(&c));
508        }
509        // spawn a slow future among a bunch of fast ones.
510        // the test is to make sure this doesn't block the rest getting completed
511        buffer.push(yield_now(&c));
512        // poll and insert
513        for _ in 0..100 {
514            assert!(futures::executor::block_on(buffer.next()).is_some());
515            buffer.push(yield_now(&c));
516        }
517        // drain down
518        for _ in 0..10 {
519            assert!(futures::executor::block_on(buffer.next()).is_some());
520        }
521
522        let dur = now.elapsed();
523        assert!(dur < Duration::from_millis(2050));
524
525        let count = c.into_inner();
526        assert_eq!(count, 220);
527    }
528
529    #[cfg(not(miri))]
530    #[tokio::test]
531    async fn unordered_large() {
532        for i in 0..256 {
533            let mut queue: FuturesUnorderedBounded<_> = ((0..i).map(|_| async move {
534                tokio::time::sleep(Duration::from_nanos(1)).await;
535            }))
536            .collect();
537            for _ in 0..i {
538                queue.next().await.unwrap();
539            }
540        }
541    }
542
543    #[test]
544    fn correct_fairer_order() {
545        const LEN: usize = 256;
546
547        let mut buffer = FuturesUnorderedBounded::new(LEN);
548        let mut txs = vec![];
549        for _ in 0..LEN {
550            let (tx, rx) = oneshot::channel();
551            buffer.push(rx);
552            txs.push(tx);
553        }
554
555        for _ in 0..=(LEN / 61) {
556            assert!(buffer.poll_next_unpin(&mut noop_context()).is_pending());
557        }
558
559        for (i, tx) in txs.into_iter().enumerate() {
560            let _ = tx.send(i);
561        }
562
563        for i in 0..LEN {
564            let poll = buffer.poll_next_unpin(&mut noop_context());
565            assert_eq!(poll, Poll::Ready(Some(Ok(i))));
566        }
567    }
568}