futures_concurrency/future/race_ok/tuple/
mod.rs

1use super::RaceOk;
2use crate::utils::{self, PollArray};
3
4use core::array;
5use core::fmt;
6use core::future::{Future, IntoFuture};
7use core::mem::{self, MaybeUninit};
8use core::pin::Pin;
9use core::task::{Context, Poll};
10
11use pin_project::{pin_project, pinned_drop};
12
13mod error;
14pub(crate) use error::AggregateError;
15
16macro_rules! impl_race_ok_tuple {
17    ($StructName:ident $($F:ident)+) => {
18        /// A workaround to avoid calling the recursive macro several times. Since it's for private
19        /// use only, we don't case about capitalization so we reuse `$StructName` for simplicity
20        /// (renaming it as `const LEN: usize = ...`) when in a function for clarity.
21        #[allow(non_upper_case_globals)]
22        const $StructName: usize = utils::tuple_len!($($F,)*);
23
24        /// A future which waits for the first successful future to complete.
25        ///
26        /// This `struct` is created by the [`race_ok`] method on the [`RaceOk`] trait. See
27        /// its documentation for more.
28        ///
29        /// [`race_ok`]: crate::future::RaceOk::race_ok
30        /// [`RaceOk`]: crate::future::RaceOk
31        #[must_use = "futures do nothing unless you `.await` or poll them"]
32        #[allow(non_snake_case)]
33        #[pin_project(PinnedDrop)]
34        pub struct $StructName<T, ERR, $($F),*>
35        where
36            $( $F: Future<Output = Result<T, ERR>>, )*
37            ERR: fmt::Debug,
38        {
39            completed: usize,
40            done: bool,
41            indexer: utils::Indexer,
42            errors: [MaybeUninit<ERR>; $StructName],
43            errors_states: PollArray<{ $StructName }>,
44            $( #[pin] $F: $F, )*
45        }
46
47        impl<T, ERR, $($F),*> fmt::Debug for $StructName<T, ERR, $($F),*>
48        where
49            $( $F: Future<Output = Result<T, ERR>> + fmt::Debug, )*
50            ERR: fmt::Debug,
51        {
52            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53                f.debug_tuple("Race")
54                    $(.field(&self.$F))*
55                    .finish()
56            }
57        }
58
59        impl<T, ERR, $($F),*> RaceOk for ($($F,)*)
60        where
61            $( $F: IntoFuture<Output = Result<T, ERR>>, )*
62            ERR: fmt::Debug,
63        {
64            type Output = T;
65            type Error = AggregateError<ERR, { $StructName }>;
66            type Future = $StructName<T, ERR, $($F::IntoFuture),*>;
67
68            fn race_ok(self) -> Self::Future {
69                let ($($F,)*): ($($F,)*) = self;
70                $StructName {
71                    completed: 0,
72                    done: false,
73                    indexer: utils::Indexer::new($StructName),
74                    errors: array::from_fn(|_| MaybeUninit::uninit()),
75                    errors_states: PollArray::new_pending(),
76                    $($F: $F.into_future()),*
77                }
78            }
79        }
80
81        impl<T, ERR, $($F),*> Future for $StructName<T, ERR, $($F),*>
82        where
83            $( $F: Future<Output = Result<T, ERR>>, )*
84            ERR: fmt::Debug,
85        {
86            type Output = Result<T, AggregateError<ERR, { $StructName }>>;
87
88            fn poll(
89                self: Pin<&mut Self>, cx: &mut Context<'_>
90            ) -> Poll<Self::Output> {
91                const LEN: usize = $StructName;
92
93                let mut this = self.project();
94
95                let can_poll = !*this.done;
96                assert!(can_poll, "Futures must not be polled after completing");
97
98                #[repr(usize)]
99                enum Indexes {
100                    $($F),*
101                }
102
103                for i in this.indexer.iter() {
104                    if this.errors_states[i].is_ready() {
105                        continue;
106                    }
107                    utils::gen_conditions!(i, this, cx, poll, $((Indexes::$F as usize; $F, {
108                        Poll::Ready(output) => match output {
109                            Ok(output) => {
110                                *this.done = true;
111                                *this.completed += 1;
112                                return Poll::Ready(Ok(output));
113                            },
114                            Err(err) => {
115                                this.errors[i] = MaybeUninit::new(err);
116                                this.errors_states[i].set_ready();
117                                *this.completed += 1;
118                                continue;
119                            },
120                        },
121                        _ => continue,
122                    }))*);
123                }
124
125                let all_completed = *this.completed == LEN;
126                if all_completed {
127                    // mark all error states as consumed before we return it
128                    this.errors_states.set_all_none();
129
130                    let mut errors = array::from_fn(|_| MaybeUninit::uninit());
131                    mem::swap(&mut errors, this.errors);
132
133                    let result = unsafe { utils::array_assume_init(errors) };
134
135                    *this.done = true;
136                    return Poll::Ready(Err(AggregateError::new(result)));
137                }
138
139                Poll::Pending
140            }
141        }
142
143        #[pinned_drop]
144        impl<T, ERR, $($F,)*> PinnedDrop for $StructName<T, ERR, $($F,)*>
145        where
146            $( $F: Future<Output = Result<T, ERR>>, )*
147            ERR: fmt::Debug,
148        {
149            fn drop(self: Pin<&mut Self>) {
150                let this = self.project();
151
152                this
153                    .errors_states
154                    .iter_mut()
155                    .zip(this.errors.iter_mut())
156                    .filter(|(st, _err)| st.is_ready())
157                    .for_each(|(st, err)| {
158                        // SAFETY: we've filtered down to only the `ready`/initialized data
159                        unsafe { err.assume_init_drop() };
160                        st.set_none();
161                    });
162            }
163        }
164    };
165}
166
167impl_race_ok_tuple! { RaceOk1 A }
168impl_race_ok_tuple! { RaceOk2 A B }
169impl_race_ok_tuple! { RaceOk3 A B C }
170impl_race_ok_tuple! { RaceOk4 A B C D }
171impl_race_ok_tuple! { RaceOk5 A B C D E }
172impl_race_ok_tuple! { RaceOk6 A B C D E F }
173impl_race_ok_tuple! { RaceOk7 A B C D E F G }
174impl_race_ok_tuple! { RaceOk8 A B C D E F G H }
175impl_race_ok_tuple! { RaceOk9 A B C D E F G H I }
176impl_race_ok_tuple! { RaceOk10 A B C D E F G H I J }
177impl_race_ok_tuple! { RaceOk11 A B C D E F G H I J K }
178impl_race_ok_tuple! { RaceOk12 A B C D E F G H I J K L }
179
180#[cfg(test)]
181mod test {
182    use super::*;
183    use core::future;
184
185    #[test]
186    fn race_ok_1() {
187        futures_lite::future::block_on(async {
188            let a = async { Ok::<_, ()>("world") };
189            let res = (a,).race_ok().await;
190            assert!(matches!(res, Ok("world")));
191        });
192    }
193
194    #[test]
195    fn race_ok_2() {
196        futures_lite::future::block_on(async {
197            let a = future::pending();
198            let b = async { Ok::<_, ()>("world") };
199            let res = (a, b).race_ok().await;
200            assert!(matches!(res, Ok("world")));
201        });
202    }
203
204    #[test]
205    fn race_ok_3() {
206        futures_lite::future::block_on(async {
207            let a = future::pending();
208            let b = async { Ok::<_, ()>("hello") };
209            let c = async { Ok::<_, ()>("world") };
210            let result = (a, b, c).race_ok().await;
211            assert!(matches!(result, Ok("hello") | Ok("world")));
212        });
213    }
214
215    #[test]
216    fn race_ok_err() {
217        futures_lite::future::block_on(async {
218            let a = async { Err::<(), _>("hello") };
219            let b = async { Err::<(), _>("world") };
220            let errors = (a, b).race_ok().await.unwrap_err();
221            assert_eq!(errors[0], "hello");
222            assert_eq!(errors[1], "world");
223        });
224    }
225
226    #[test]
227    fn race_ok_resume_after_completion() {
228        use futures_lite::future::yield_now;
229        futures_lite::future::block_on(async {
230            let ok = async {
231                yield_now().await;
232                yield_now().await;
233                Ok::<_, ()>(())
234            };
235            let err = async { Err::<(), _>(()) };
236
237            let res = (ok, err).race_ok().await;
238
239            assert_eq!(res.ok().unwrap(), ());
240        });
241    }
242}