futures_concurrency/future/try_join/
vec.rs

1use super::TryJoin as TryJoinTrait;
2use crate::utils::{FutureVec, OutputVec, PollVec, WakerVec};
3
4#[cfg(all(feature = "alloc", not(feature = "std")))]
5use alloc::vec::Vec;
6
7use core::fmt;
8use core::future::{Future, IntoFuture};
9use core::mem::ManuallyDrop;
10use core::ops::DerefMut;
11use core::pin::Pin;
12use core::task::{Context, Poll};
13
14use pin_project::{pin_project, pinned_drop};
15
16/// A future which waits for all futures to complete successfully, or abort early on error.
17///
18/// This `struct` is created by the [`try_join`] method on the [`TryJoin`] trait. See
19/// its documentation for more.
20///
21/// [`try_join`]: crate::future::TryJoin::try_join
22/// [`TryJoin`]: crate::future::TryJoin
23#[must_use = "futures do nothing unless you `.await` or poll them"]
24#[pin_project(PinnedDrop)]
25pub struct TryJoin<Fut, T, E>
26where
27    Fut: Future<Output = Result<T, E>>,
28{
29    /// A boolean which holds whether the future has completed
30    consumed: bool,
31    /// The number of futures which are currently still in-flight
32    pending: usize,
33    /// The output data, to be returned after the future completes
34    items: OutputVec<T>,
35    /// A structure holding the waker passed to the future, and the various
36    /// sub-wakers passed to the contained futures.
37    wakers: WakerVec,
38    /// The individual poll state of each future.
39    state: PollVec,
40    #[pin]
41    /// The array of futures passed to the structure.
42    futures: FutureVec<Fut>,
43}
44
45impl<Fut, T, E> TryJoin<Fut, T, E>
46where
47    Fut: Future<Output = Result<T, E>>,
48{
49    #[inline]
50    pub(crate) fn new(futures: Vec<Fut>) -> Self {
51        let len = futures.len();
52        Self {
53            consumed: false,
54            pending: len,
55            items: OutputVec::uninit(len),
56            wakers: WakerVec::new(len),
57            state: PollVec::new_pending(len),
58            futures: FutureVec::new(futures),
59        }
60    }
61}
62
63impl<Fut, T, E> TryJoinTrait for Vec<Fut>
64where
65    Fut: IntoFuture<Output = Result<T, E>>,
66{
67    type Output = Vec<T>;
68    type Error = E;
69    type Future = TryJoin<Fut::IntoFuture, T, E>;
70
71    fn try_join(self) -> Self::Future {
72        TryJoin::new(self.into_iter().map(IntoFuture::into_future).collect())
73    }
74}
75
76impl<Fut, T, E> fmt::Debug for TryJoin<Fut, T, E>
77where
78    Fut: Future<Output = Result<T, E>> + fmt::Debug,
79{
80    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81        f.debug_list().entries(self.state.iter()).finish()
82    }
83}
84
85impl<Fut, T, E> Future for TryJoin<Fut, T, E>
86where
87    Fut: Future<Output = Result<T, E>>,
88{
89    type Output = Result<Vec<T>, E>;
90
91    #[inline]
92    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
93        let this = self.project();
94
95        assert!(
96            !*this.consumed,
97            "Futures must not be polled after completing"
98        );
99
100        let mut readiness = this.wakers.readiness();
101        readiness.set_waker(cx.waker());
102        if *this.pending != 0 && !readiness.any_ready() {
103            // Nothing is ready yet
104            return Poll::Pending;
105        }
106
107        // Poll all ready futures
108        for (i, mut fut) in this.futures.iter().enumerate() {
109            if this.state[i].is_pending() && readiness.clear_ready(i) {
110                // unlock readiness so we don't deadlock when polling
111                #[allow(clippy::drop_non_drop)]
112                drop(readiness);
113
114                // Obtain the intermediate waker.
115                let mut cx = Context::from_waker(this.wakers.get(i).unwrap());
116
117                // Poll the future
118                // SAFETY: the future's state was "pending", so it's safe to poll
119                if let Poll::Ready(value) = unsafe {
120                    fut.as_mut()
121                        .map_unchecked_mut(|t| t.deref_mut())
122                        .poll(&mut cx)
123                } {
124                    *this.pending -= 1;
125
126                    // Check the value, short-circuit on error.
127                    match value {
128                        Ok(value) => {
129                            this.items.write(i, value);
130
131                            // SAFETY: We're marking the state as "ready", which
132                            // means the future has been consumed, and data is
133                            // now available to be consumed. The future will no
134                            // longer be used after this point so it's safe to drop.
135                            this.state[i].set_ready();
136                            unsafe { ManuallyDrop::drop(fut.get_unchecked_mut()) };
137                        }
138                        Err(err) => {
139                            // The future should no longer be polled after we're done here
140                            *this.consumed = true;
141
142                            // SAFETY: We're about to return the error value
143                            // from the future, and drop the entire future.
144                            // We're marking the future as consumed, and then
145                            // proceeding to drop all other futures and
146                            // initiatlized values in the destructor.
147                            this.state[i].set_none();
148                            unsafe { ManuallyDrop::drop(fut.get_unchecked_mut()) };
149
150                            return Poll::Ready(Err(err));
151                        }
152                    }
153                }
154
155                // Lock readiness so we can use it again
156                readiness = this.wakers.readiness();
157            }
158        }
159
160        // Check whether we're all done now or need to keep going.
161        if *this.pending == 0 {
162            // Mark all data as "consumed" before we take it
163            *this.consumed = true;
164            for state in this.state.iter_mut() {
165                debug_assert!(
166                    state.is_ready(),
167                    "Future should have reached a `Ready` state"
168                );
169                state.set_none();
170            }
171
172            // SAFETY: we've checked with the state that all of our outputs have been
173            // filled, which means we're ready to take the data and assume it's initialized.
174            Poll::Ready(Ok(unsafe { this.items.take() }))
175        } else {
176            Poll::Pending
177        }
178    }
179}
180
181/// Drop the already initialized values on cancellation.
182#[pinned_drop]
183impl<Fut, T, E> PinnedDrop for TryJoin<Fut, T, E>
184where
185    Fut: Future<Output = Result<T, E>>,
186{
187    fn drop(self: Pin<&mut Self>) {
188        let mut this = self.project();
189
190        // Drop all initialized values.
191        for i in this.state.ready_indexes() {
192            // SAFETY: we've just filtered down to *only* the initialized values.
193            // We can assume they're initialized, and this is where we drop them.
194            unsafe { this.items.drop(i) };
195        }
196
197        // Drop all pending futures.
198        for i in this.state.pending_indexes() {
199            // SAFETY: we've just filtered down to *only* the pending futures,
200            // which have not yet been dropped.
201            unsafe { this.futures.as_mut().drop(i) };
202        }
203    }
204}
205
206#[cfg(test)]
207mod test {
208    use super::*;
209    use alloc::vec;
210    use core::future;
211
212    #[test]
213    fn all_ok() {
214        futures_lite::future::block_on(async {
215            let res: Result<_, ()> = vec![future::ready(Ok("hello")), future::ready(Ok("world"))]
216                .try_join()
217                .await;
218            assert_eq!(res.unwrap(), ["hello", "world"]);
219        })
220    }
221
222    #[test]
223    fn empty() {
224        futures_lite::future::block_on(async {
225            let data: Vec<future::Ready<Result<(), ()>>> = vec![];
226            let res = data.try_join().await;
227            assert_eq!(res.unwrap(), vec![]);
228        });
229    }
230
231    #[test]
232    fn one_err() {
233        futures_lite::future::block_on(async {
234            let res: Result<_, _> = vec![future::ready(Ok("hello")), future::ready(Err("oh no"))]
235                .try_join()
236                .await;
237            assert_eq!(res.unwrap_err(), "oh no");
238        });
239    }
240}