futures_concurrency/future/try_join/
array.rs

1use super::TryJoin as TryJoinTrait;
2use crate::utils::{FutureArray, OutputArray, PollArray, WakerArray};
3
4use core::fmt;
5use core::future::{Future, IntoFuture};
6use core::mem::ManuallyDrop;
7use core::ops::DerefMut;
8use core::pin::Pin;
9use core::task::{Context, Poll};
10
11use pin_project::{pin_project, pinned_drop};
12
13/// A future which waits for all futures to complete successfully, or abort early on error.
14///
15/// This `struct` is created by the [`try_join`] method on the [`TryJoin`] trait. See
16/// its documentation for more.
17///
18/// [`try_join`]: crate::future::TryJoin::try_join
19/// [`TryJoin`]: crate::future::TryJoin
20#[must_use = "futures do nothing unless you `.await` or poll them"]
21#[pin_project(PinnedDrop)]
22pub struct TryJoin<Fut, T, E, const N: usize>
23where
24    Fut: Future<Output = Result<T, E>>,
25{
26    /// A boolean which holds whether the future has completed
27    consumed: bool,
28    /// The number of futures which are currently still in-flight
29    pending: usize,
30    /// The output data, to be returned after the future completes
31    items: OutputArray<T, N>,
32    /// A structure holding the waker passed to the future, and the various
33    /// sub-wakers passed to the contained futures.
34    wakers: WakerArray<N>,
35    /// The individual poll state of each future.
36    state: PollArray<N>,
37    #[pin]
38    /// The array of futures passed to the structure.
39    futures: FutureArray<Fut, N>,
40}
41
42impl<Fut, T, E, const N: usize> TryJoin<Fut, T, E, N>
43where
44    Fut: Future<Output = Result<T, E>>,
45{
46    #[inline]
47    pub(crate) fn new(futures: [Fut; N]) -> Self {
48        Self {
49            consumed: false,
50            pending: N,
51            items: OutputArray::uninit(),
52            wakers: WakerArray::new(),
53            state: PollArray::new_pending(),
54            futures: FutureArray::new(futures),
55        }
56    }
57}
58
59impl<Fut, T, E, const N: usize> TryJoinTrait for [Fut; N]
60where
61    Fut: IntoFuture<Output = Result<T, E>>,
62{
63    type Output = [T; N];
64    type Error = E;
65    type Future = TryJoin<Fut::IntoFuture, T, E, N>;
66
67    fn try_join(self) -> Self::Future {
68        TryJoin::new(self.map(IntoFuture::into_future))
69    }
70}
71
72impl<Fut, T, E, const N: usize> fmt::Debug for TryJoin<Fut, T, E, N>
73where
74    Fut: Future<Output = Result<T, E>> + fmt::Debug,
75{
76    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77        f.debug_list().entries(self.state.iter()).finish()
78    }
79}
80
81impl<Fut, T, E, const N: usize> Future for TryJoin<Fut, T, E, N>
82where
83    Fut: Future<Output = Result<T, E>>,
84{
85    type Output = Result<[T; N], E>;
86
87    #[inline]
88    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
89        let this = self.project();
90
91        assert!(
92            !*this.consumed,
93            "Futures must not be polled after completing"
94        );
95
96        let mut readiness = this.wakers.readiness();
97        readiness.set_waker(cx.waker());
98        if *this.pending != 0 && !readiness.any_ready() {
99            // Nothing is ready yet
100            return Poll::Pending;
101        }
102
103        // Poll all ready futures
104        for (i, mut fut) in this.futures.iter().enumerate() {
105            if this.state[i].is_pending() && readiness.clear_ready(i) {
106                // unlock readiness so we don't deadlock when polling
107                #[allow(clippy::drop_non_drop)]
108                drop(readiness);
109
110                // Obtain the intermediate waker.
111                let mut cx = Context::from_waker(this.wakers.get(i).unwrap());
112
113                // Poll the future
114                // SAFETY: the future's state was "pending", so it's safe to poll
115                if let Poll::Ready(value) = unsafe {
116                    fut.as_mut()
117                        .map_unchecked_mut(|t| t.deref_mut())
118                        .poll(&mut cx)
119                } {
120                    *this.pending -= 1;
121
122                    // Check the value, short-circuit on error.
123                    match value {
124                        Ok(value) => {
125                            this.items.write(i, value);
126
127                            // SAFETY: We're marking the state as "ready", which
128                            // means the future has been consumed, and data is
129                            // now available to be consumed. The future will no
130                            // longer be used after this point so it's safe to drop.
131                            this.state[i].set_ready();
132                            unsafe { ManuallyDrop::drop(fut.get_unchecked_mut()) };
133                        }
134                        Err(err) => {
135                            // The future should no longer be polled after we're done here
136                            *this.consumed = true;
137
138                            // SAFETY: We're about to return the error value
139                            // from the future, and drop the entire future.
140                            // We're marking the future as consumed, and then
141                            // proceeding to drop all other futures and
142                            // initiatlized values in the destructor.
143                            this.state[i].set_none();
144                            unsafe { ManuallyDrop::drop(fut.get_unchecked_mut()) };
145
146                            return Poll::Ready(Err(err));
147                        }
148                    }
149                }
150
151                // Lock readiness so we can use it again
152                readiness = this.wakers.readiness();
153            }
154        }
155
156        // Check whether we're all done now or need to keep going.
157        if *this.pending == 0 {
158            // Mark all data as "consumed" before we take it
159            *this.consumed = true;
160
161            // SAFETY: we check with the state that all of our outputs have been
162            // filled, which means we're ready to take the data and assume it's initialized.
163            debug_assert!(this.state.iter().all(|entry| entry.is_ready()));
164            this.state.set_all_none();
165            Poll::Ready(Ok(unsafe { this.items.take() }))
166        } else {
167            Poll::Pending
168        }
169    }
170}
171
172/// Drop the already initialized values on cancellation.
173#[pinned_drop]
174impl<Fut, T, E, const N: usize> PinnedDrop for TryJoin<Fut, T, E, N>
175where
176    Fut: Future<Output = Result<T, E>>,
177{
178    fn drop(self: Pin<&mut Self>) {
179        let mut this = self.project();
180
181        // Drop all initialized values.
182        for i in this.state.ready_indexes() {
183            // SAFETY: we've just filtered down to *only* the initialized values.
184            // We can assume they're initialized, and this is where we drop them.
185            unsafe { this.items.drop(i) };
186        }
187
188        // Drop all pending futures.
189        for i in this.state.pending_indexes() {
190            // SAFETY: we've just filtered down to *only* the pending futures,
191            // which have not yet been dropped.
192            unsafe { this.futures.as_mut().drop(i) };
193        }
194    }
195}
196
197#[cfg(test)]
198mod test {
199    use super::*;
200    use core::future;
201
202    #[test]
203    fn all_ok() {
204        futures_lite::future::block_on(async {
205            let res: Result<_, ()> = [future::ready(Ok("hello")), future::ready(Ok("world"))]
206                .try_join()
207                .await;
208            assert_eq!(res.unwrap(), ["hello", "world"]);
209        })
210    }
211
212    #[test]
213    fn empty() {
214        futures_lite::future::block_on(async {
215            let data: [future::Ready<Result<(), ()>>; 0] = [];
216            let res = data.try_join().await;
217            assert_eq!(res.unwrap(), []);
218        });
219    }
220
221    #[test]
222    fn one_err() {
223        futures_lite::future::block_on(async {
224            let res: Result<_, _> = [future::ready(Ok("hello")), future::ready(Err("oh no"))]
225                .try_join()
226                .await;
227            assert_eq!(res.unwrap_err(), "oh no");
228        });
229    }
230}