futures_buffered/
try_join_all.rs

1use alloc::{boxed::Box, vec::Vec};
2use core::{
3    future::Future,
4    mem::MaybeUninit,
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9use crate::{FuturesUnorderedBounded, TryFuture};
10
11#[must_use = "futures do nothing unless you `.await` or poll them"]
12/// Future for the [`try_join_all`] function.
13pub struct TryJoinAll<F: TryFuture> {
14    queue: FuturesUnorderedBounded<F>,
15    output: Box<[MaybeUninit<F::Ok>]>,
16}
17
18impl<F: TryFuture> Unpin for TryJoinAll<F> {}
19
20/// Creates a future which represents a collection of the outputs of the futures
21/// given.
22///
23/// The returned future will drive execution for all of its underlying futures,
24/// collecting the results into a destination `Vec<T>` in the same order as they
25/// were provided.
26///
27/// If any future returns an error then all other futures will be canceled and
28/// an error will be returned immediately. If all futures complete successfully,
29/// however, then the returned future will succeed with a `Vec` of all the
30/// successful results.
31///
32/// # Examples
33///
34/// ```
35/// # futures::executor::block_on(async {
36/// use futures_buffered::try_join_all;
37///
38/// async fn foo(i: u32) -> Result<u32, u32> {
39///     if i < 4 { Ok(i) } else { Err(i) }
40/// }
41///
42/// let futures = vec![foo(1), foo(2), foo(3)];
43/// assert_eq!(try_join_all(futures).await, Ok(vec![1, 2, 3]));
44///
45/// let futures = vec![foo(1), foo(2), foo(3), foo(4)];
46/// assert_eq!(try_join_all(futures).await, Err(4));
47/// # });
48/// ```
49///
50/// See [`join_all`](crate::join_all()) for benchmark results
51pub fn try_join_all<I>(iter: I) -> TryJoinAll<<I as IntoIterator>::Item>
52where
53    I: IntoIterator,
54    <I as IntoIterator>::Item: TryFuture,
55{
56    // create the queue
57    let queue = FuturesUnorderedBounded::from_iter(iter);
58
59    // create the output buffer
60    let mut output = Vec::with_capacity(queue.capacity());
61    output.resize_with(queue.capacity(), MaybeUninit::uninit);
62
63    TryJoinAll {
64        queue,
65        output: output.into_boxed_slice(),
66    }
67}
68
69impl<F: TryFuture> Future for TryJoinAll<F> {
70    type Output = Result<Vec<F::Ok>, F::Err>;
71
72    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
73        loop {
74            match self.as_mut().queue.poll_inner(cx) {
75                Poll::Ready(Some((i, Ok(t)))) => {
76                    self.output[i].write(t);
77                }
78                Poll::Ready(Some((_, Err(e)))) => {
79                    break Poll::Ready(Err(e));
80                }
81                Poll::Ready(None) => {
82                    // SAFETY: for Ready(None) to be returned, we know that every future in the queue
83                    // must be consumed. Since we have a 1:1 mapping in the queue to our output, we
84                    // know that every output entry is init.
85                    let boxed = unsafe {
86                        // take the boxed slice
87                        let boxed =
88                            core::mem::replace(&mut self.output, Vec::new().into_boxed_slice());
89
90                        // Box::assume_init
91                        let raw = Box::into_raw(boxed);
92                        Box::from_raw(raw as *mut [F::Ok])
93                    };
94
95                    break Poll::Ready(Ok(boxed.into_vec()));
96                }
97                Poll::Pending => break Poll::Pending,
98            }
99        }
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use core::future::ready;
106
107    #[test]
108    fn try_join_all() {
109        let x = futures::executor::block_on(crate::try_join_all(
110            (0..10).map(|i| ready(Result::<_, ()>::Ok(i))),
111        ))
112        .unwrap();
113
114        assert_eq!(x, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
115        assert_eq!(x.capacity(), 10);
116
117        futures::executor::block_on(crate::try_join_all(
118            (0..10).map(|i| ready(if i == 9 { Err(()) } else { Ok(i) })),
119        ))
120        .unwrap_err();
121    }
122}