futures_concurrency/future/try_join/
tuple.rs

1use super::TryJoin as TryJoinTrait;
2use crate::utils::{PollArray, WakerArray};
3
4use core::fmt::{self, Debug};
5use core::future::{Future, IntoFuture};
6use core::marker::PhantomData;
7use core::mem::ManuallyDrop;
8use core::mem::MaybeUninit;
9use core::ops::DerefMut;
10use core::pin::Pin;
11use core::task::{Context, Poll};
12
13use pin_project::{pin_project, pinned_drop};
14
15/// Generates the `poll` call for every `Future` inside `$futures`.
16///
17/// SAFETY: pretty please only call this after having made very sure that the future you're trying
18/// to call is actually marked `ready!`. If Rust had unsafe macros, this would be one.
19//
20// This is implemented as a tt-muncher of the future name `$($F:ident)`
21// and the future index `$($rest)`, taking advantage that we only support
22// tuples up to  12 elements
23//
24// # References
25// TT Muncher: https://veykril.github.io/tlborm/decl-macros/patterns/tt-muncher.html
26macro_rules! unsafe_poll {
27    // recursively iterate
28    (@inner $iteration:ident, $this:ident, $futures:ident, $cx:ident, $fut_name:ident $($F:ident)* | $fut_idx:tt $($rest:tt)*) => {
29        if $fut_idx == $iteration {
30
31            if let Poll::Ready(value) = unsafe {
32                $futures.$fut_name.as_mut()
33                    .map_unchecked_mut(|t| t.deref_mut())
34                    .poll(&mut $cx)
35            } {
36                *$this.completed += 1;
37
38                // Check the value, short-circuit on error.
39                match value {
40                    Ok(value) => {
41                        $this.outputs.$fut_idx.write(value);
42
43                        // SAFETY: We're marking the state as "ready", which
44                        // means the future has been consumed, and data is
45                        // now available to be consumed. The future will no
46                        // longer be used after this point so it's safe to drop.
47                        $this.state[$fut_idx].set_ready();
48                        unsafe { ManuallyDrop::drop($futures.$fut_name.as_mut().get_unchecked_mut()) };
49                    }
50                    Err(err) => {
51                        // The future should no longer be polled after we're done here
52                        *$this.consumed = true;
53
54                        // SAFETY: We're about to return the error value
55                        // from the future, and drop the entire future.
56                        // We're marking the future as consumed, and then
57                        // proceeding to drop all other futures and
58                        // initiatlized values in the destructor.
59                        $this.state[$fut_idx].set_none();
60                        unsafe { ManuallyDrop::drop($futures.$fut_name.as_mut().get_unchecked_mut()) };
61
62                        return Poll::Ready(Err(err));
63                    }
64                }
65            }
66        }
67        unsafe_poll!(@inner $iteration, $this, $futures, $cx, $($F)* | $($rest)*);
68    };
69
70    // base condition
71    (@inner $iteration:ident, $this:ident, $futures:ident, $cx:ident, | $($rest:tt)*) => {};
72
73    // macro start
74    ($iteration:ident, $this:ident, $futures:ident, $cx:ident, $LEN:ident, $($F:ident,)+) => {
75        unsafe_poll!(@inner $iteration, $this, $futures, $cx, $($F)+ | 0 1 2 3 4 5 6 7 8 9 10 11);
76    };
77}
78
79/// Drop all initialized values
80macro_rules! drop_initialized_values {
81    // recursively iterate
82    (@drop $output:ident, $($rem_outs:ident,)* | $states:expr, $state_idx:tt, $($rem_idx:tt,)*) => {
83        if $states[$state_idx].is_ready() {
84            // SAFETY: we've just filtered down to *only* the initialized values.
85            // We can assume they're initialized, and this is where we drop them.
86            unsafe { $output.assume_init_drop() };
87            $states[$state_idx].set_none();
88        }
89        drop_initialized_values!(@drop $($rem_outs,)* | $states, $($rem_idx,)*);
90    };
91
92    // base condition
93    (@drop | $states:expr, $($rem_idx:tt,)*) => {};
94
95    // macro start
96    ($($outs:ident,)+ | $states:expr) => {
97        drop_initialized_values!(@drop $($outs,)+ | $states, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,);
98    };
99}
100
101/// Drop all pending futures
102macro_rules! drop_pending_futures {
103    // recursively iterate
104    (@inner $states:ident, $futures:ident, $fut_name:ident $($F:ident)* | $fut_idx:tt $($rest:tt)*) => {
105        if $states[$fut_idx].is_pending() {
106            // SAFETY: We're accessing the value behind the pinned reference to drop it exactly once.
107            let futures = unsafe { $futures.as_mut().get_unchecked_mut() };
108            // SAFETY: we've just filtered down to *only* the initialized values.
109            // We can assume they're initialized, and this is where we drop them.
110            unsafe { ManuallyDrop::drop(&mut futures.$fut_name) };
111        }
112        drop_pending_futures!(@inner $states, $futures, $($F)* | $($rest)*);
113    };
114
115    // base condition
116    (@inner $states:ident, $futures:ident, | $($rest:tt)*) => {};
117
118    // macro start
119    ($states:ident, $futures:ident, $($F:ident,)+) => {
120        drop_pending_futures!(@inner $states, $futures, $($F)+ | 0 1 2 3 4 5 6 7 8 9 10 11);
121    };
122}
123
124macro_rules! impl_try_join_tuple {
125    // `Impl TryJoin for ()`
126    ($mod_name:ident $StructName:ident) => {
127        /// A future which waits for similarly-typed futures to complete, or aborts early on error.
128        ///
129        /// This `struct` is created by the [`try_join`] method on the [`TryJoin`] trait. See
130        /// its documentation for more.
131        ///
132        /// [`try_join`]: crate::future::TryJoin::try_join
133        /// [`TryJoin`]: crate::future::Join
134        #[must_use = "futures do nothing unless you `.await` or poll them"]
135        #[allow(non_snake_case)]
136        pub struct $StructName {}
137
138        impl fmt::Debug for $StructName {
139            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
140                f.debug_tuple("TryJoin").finish()
141            }
142        }
143
144        impl Future for $StructName {
145            type Output = Result<(), core::convert::Infallible>;
146
147            fn poll(
148                self: Pin<&mut Self>, _cx: &mut Context<'_>
149            ) -> Poll<Self::Output> {
150                Poll::Ready(Ok(()))
151            }
152        }
153
154        impl TryJoinTrait for () {
155            type Output = ();
156            type Error = core::convert::Infallible;
157            type Future = $StructName;
158            fn try_join(self) -> Self::Future {
159                $StructName {}
160            }
161        }
162    };
163
164    // `Impl TryJoin for (F..)`
165    ($mod_name:ident $StructName:ident $(($F:ident $T:ident))+) => {
166        mod $mod_name {
167            use core::mem::ManuallyDrop;
168
169            #[pin_project::pin_project]
170            pub(super) struct Futures<$($F,)+> {$(
171                #[pin]
172                pub(super) $F: ManuallyDrop<$F>,
173            )+}
174
175            #[repr(u8)]
176            pub(super) enum Indexes { $($F,)+ }
177
178            pub(super) const LEN: usize = [$(Indexes::$F,)+].len();
179        }
180
181        /// Waits for many similarly-typed futures to complete or abort early on error.
182        ///
183        /// This `struct` is created by the [`try_join`] method on the [`TryJoin`] trait. See
184        /// its documentation for more.
185        ///
186        /// [`try_join`]: crate::future::TryJoin::try_join
187        /// [`TryJoin`]: crate::future::Join
188        #[pin_project(PinnedDrop)]
189        #[must_use = "futures do nothing unless you `.await` or poll them"]
190        #[allow(non_snake_case)]
191        pub struct $StructName<$($F, $T,)+ Err> {
192            #[pin]
193            futures: $mod_name::Futures<$($F,)+>,
194            outputs: ($(MaybeUninit<$T>,)+),
195            // trace the state of outputs, marking them as ready or consumed
196            // then, drop the non-consumed values, if any
197            state: PollArray<{$mod_name::LEN}>,
198            wakers: WakerArray<{$mod_name::LEN}>,
199            completed: usize,
200            consumed: bool,
201            _phantom: PhantomData<Err>,
202        }
203
204        impl<$($F, $T,)+ Err> Debug for $StructName<$($F, $T,)+ Err>
205        where
206            $( $F: Future + Debug, )+
207        {
208            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
209                f.debug_tuple("TryJoin")
210                    $(.field(&self.futures.$F))+
211                    .finish()
212            }
213        }
214
215        #[allow(unused_mut)]
216        #[allow(unused_parens)]
217        #[allow(unused_variables)]
218        impl<$($F, $T,)+ Err> Future for $StructName<$($F, $T,)+ Err>
219        where $(
220            $F: Future<Output = Result<$T, Err>>,
221        )+ {
222            type Output = Result<($($T,)+), Err>;
223
224            fn poll(
225                self: Pin<&mut Self>, cx: &mut Context<'_>
226            ) -> Poll<Self::Output> {
227                const LEN: usize = $mod_name::LEN;
228
229                let mut this = self.project();
230                assert!(!*this.consumed, "Futures must not be polled after completing");
231
232                let mut futures = this.futures.project();
233
234                let mut readiness = this.wakers.readiness();
235                readiness.set_waker(cx.waker());
236
237                for index in 0..LEN {
238                    if !readiness.any_ready() {
239                        // nothing ready yet
240                        return Poll::Pending;
241                    }
242                    if !readiness.clear_ready(index) || this.state[index].is_ready() {
243                        // future not ready yet or already polled to completion, skip
244                        continue;
245                    }
246
247                    // unlock readiness so we don't deadlock when polling
248                    #[allow(clippy::drop_non_drop)]
249                    drop(readiness);
250
251                    // obtain the intermediate waker
252                    let mut cx = Context::from_waker(this.wakers.get(index).unwrap());
253
254                    // generate the needed code to poll `futures.{index}`
255                    // SAFETY: the future's state should be "pending", so it's safe to poll
256                    unsafe_poll!(index, this, futures, cx, LEN, $($F,)+);
257
258                    if *this.completed == LEN {
259                        let out = {
260                            let mut out = ($(MaybeUninit::<$T>::uninit(),)+);
261                            core::mem::swap(&mut out, this.outputs);
262                            let ($($F,)+) = out;
263                            unsafe { ($($F.assume_init(),)+) }
264                        };
265
266                        this.state.set_all_none();
267                        *this.consumed = true;
268
269                        return Poll::Ready(Ok(out));
270                    }
271                    readiness = this.wakers.readiness();
272                }
273
274                Poll::Pending
275            }
276        }
277
278        #[pinned_drop]
279        impl<$($F, $T,)+ Err> PinnedDrop for $StructName<$($F, $T,)+ Err> {
280            fn drop(self: Pin<&mut Self>) {
281                let this = self.project();
282
283                let ($(ref mut $F,)+) = this.outputs;
284
285                let states = this.state;
286                let mut futures = this.futures;
287                drop_initialized_values!($($F,)+ | states);
288                drop_pending_futures!(states, futures, $($F,)+);
289            }
290        }
291
292        #[allow(unused_parens)]
293        impl<$($F, $T,)+ Err> TryJoinTrait for ($($F,)+)
294        where $(
295            $F: IntoFuture<Output = Result<$T, Err>>,
296        )+ {
297            type Output = ($($T,)+);
298            type Error = Err;
299            type Future = $StructName<$($F::IntoFuture, $T,)+ Err>;
300
301            fn try_join(self) -> Self::Future {
302                let ($($F,)+): ($($F,)+) = self;
303                $StructName {
304                    futures: $mod_name::Futures {$(
305                        $F: ManuallyDrop::new($F.into_future()),
306                    )+},
307                    state: PollArray::new_pending(),
308                    outputs: ($(MaybeUninit::<$T>::uninit(),)+),
309                    wakers: WakerArray::new(),
310                    completed: 0,
311                    consumed: false,
312                    _phantom: PhantomData,
313                }
314            }
315        }
316    };
317}
318
319impl_try_join_tuple! { try_join0 TryJoin0 }
320impl_try_join_tuple! { try_join_1 TryJoin1 (A ResA) }
321impl_try_join_tuple! { try_join_2 TryJoin2 (A ResA) (B ResB) }
322impl_try_join_tuple! { try_join_3 TryJoin3 (A ResA) (B ResB) (C ResC) }
323impl_try_join_tuple! { try_join_4 TryJoin4 (A ResA) (B ResB) (C ResC) (D ResD) }
324impl_try_join_tuple! { try_join_5 TryJoin5 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) }
325impl_try_join_tuple! { try_join_6 TryJoin6 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) }
326impl_try_join_tuple! { try_join_7 TryJoin7 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) }
327impl_try_join_tuple! { try_join_8 TryJoin8 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) }
328impl_try_join_tuple! { try_join_9 TryJoin9 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) (I ResI) }
329impl_try_join_tuple! { try_join_10 TryJoin10 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) (I ResI) (J ResJ) }
330impl_try_join_tuple! { try_join_11 TryJoin11 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) (I ResI) (J ResJ) (K ResK) }
331impl_try_join_tuple! { try_join_12 TryJoin12 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) (I ResI) (J ResJ) (K ResK) (L ResL) }
332
333#[cfg(test)]
334mod test {
335    use super::*;
336
337    use core::convert::Infallible;
338    use core::future;
339
340    #[test]
341    fn all_ok() {
342        futures_lite::future::block_on(async {
343            let a = async { Ok::<_, Infallible>("aaaa") };
344            let b = async { Ok::<_, Infallible>(1) };
345            let c = async { Ok::<_, Infallible>('z') };
346
347            let result = (a, b, c).try_join().await;
348            assert_eq!(result, Ok(("aaaa", 1, 'z')));
349        })
350    }
351
352    #[test]
353    fn one_err() {
354        futures_lite::future::block_on(async {
355            let res: Result<(_, char), ()> = (future::ready(Ok("hello")), future::ready(Err(())))
356                .try_join()
357                .await;
358            assert_eq!(res, Err(()));
359        })
360    }
361
362    #[test]
363    fn issue_135_resume_after_completion() {
364        use futures_lite::future::yield_now;
365        futures_lite::future::block_on(async {
366            let ok = async { Ok::<_, ()>(()) };
367            let err = async {
368                yield_now().await;
369                Ok::<_, ()>(())
370            };
371
372            let res = (ok, err).try_join().await;
373
374            assert_eq!(res.unwrap(), ((), ()));
375        });
376    }
377
378    #[test]
379    #[cfg(feature = "std")]
380    fn does_not_leak_memory() {
381        use core::cell::RefCell;
382        use futures_lite::future::pending;
383
384        thread_local! {
385            static NOT_LEAKING: RefCell<bool> = const { RefCell::new(false) };
386        };
387
388        struct FlipFlagAtDrop;
389        impl Drop for FlipFlagAtDrop {
390            fn drop(&mut self) {
391                NOT_LEAKING.with(|v| {
392                    *v.borrow_mut() = true;
393                });
394            }
395        }
396
397        futures_lite::future::block_on(async {
398            // this will trigger Miri if we don't drop the memory
399            let string = future::ready(Result::Ok("memory leak".to_owned()));
400
401            // this will not flip the thread_local flag if we don't drop the memory
402            let flip = future::ready(Result::Ok(FlipFlagAtDrop));
403
404            let leak = (string, flip, pending::<Result<u8, ()>>()).try_join();
405
406            _ = futures_lite::future::poll_once(leak).await;
407        });
408
409        NOT_LEAKING.with(|flag| {
410            assert!(*flag.borrow());
411        })
412    }
413}