futures_concurrency/future/race_ok/array/
mod.rs

1use super::RaceOk as RaceOkTrait;
2use crate::utils::array_assume_init;
3use crate::utils::iter_pin_mut;
4use crate::utils::PollArray;
5
6use core::array;
7use core::fmt;
8use core::future::{Future, IntoFuture};
9use core::mem::{self, MaybeUninit};
10use core::pin::Pin;
11use core::task::{Context, Poll};
12
13use pin_project::{pin_project, pinned_drop};
14
15mod error;
16
17pub use error::AggregateError;
18
19/// A future which waits for the first successful future to complete.
20///
21/// This `struct` is created by the [`race_ok`] method on the [`RaceOk`] trait. See
22/// its documentation for more.
23///
24/// [`race_ok`]: crate::future::RaceOk::race_ok
25/// [`RaceOk`]: crate::future::RaceOk
26#[must_use = "futures do nothing unless you `.await` or poll them"]
27#[pin_project(PinnedDrop)]
28pub struct RaceOk<Fut, T, E, const N: usize>
29where
30    Fut: Future<Output = Result<T, E>>,
31{
32    #[pin]
33    futures: [Fut; N],
34    errors: [MaybeUninit<E>; N],
35    error_states: PollArray<N>,
36    completed: usize,
37}
38
39#[pinned_drop]
40impl<Fut, T, E, const N: usize> PinnedDrop for RaceOk<Fut, T, E, N>
41where
42    Fut: Future<Output = Result<T, E>>,
43{
44    fn drop(self: Pin<&mut Self>) {
45        let this = self.project();
46        for (st, err) in this
47            .error_states
48            .iter_mut()
49            .zip(this.errors.iter_mut())
50            .filter(|(st, _err)| st.is_ready())
51        {
52            // SAFETY: we've filtered down to only the `ready`/initialized data
53            unsafe { err.assume_init_drop() };
54            st.set_none();
55        }
56    }
57}
58
59impl<Fut, T, E, const N: usize> fmt::Debug for RaceOk<Fut, T, E, N>
60where
61    Fut: Future<Output = Result<T, E>> + fmt::Debug,
62    Fut::Output: fmt::Debug,
63{
64    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65        f.debug_list().entries(self.futures.iter()).finish()
66    }
67}
68
69impl<Fut, T, E, const N: usize> Future for RaceOk<Fut, T, E, N>
70where
71    Fut: Future<Output = Result<T, E>>,
72{
73    type Output = Result<T, AggregateError<E, N>>;
74
75    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
76        let this = self.project();
77
78        let futures = iter_pin_mut(this.futures);
79
80        for ((fut, out), st) in futures
81            .zip(this.errors.iter_mut())
82            .zip(this.error_states.iter_mut())
83        {
84            if st.is_ready() {
85                continue;
86            }
87            if let Poll::Ready(output) = fut.poll(cx) {
88                match output {
89                    Ok(ok) => return Poll::Ready(Ok(ok)),
90                    Err(err) => {
91                        *out = MaybeUninit::new(err);
92                        *this.completed += 1;
93                        st.set_ready();
94                    }
95                }
96            }
97        }
98
99        let all_completed = *this.completed == N;
100        if all_completed {
101            let mut errors = array::from_fn(|_| MaybeUninit::uninit());
102            mem::swap(&mut errors, this.errors);
103            this.error_states.set_all_none();
104
105            // SAFETY: we know that all futures are properly initialized because they're all completed
106            let result = unsafe { array_assume_init(errors) };
107
108            Poll::Ready(Err(AggregateError::new(result)))
109        } else {
110            Poll::Pending
111        }
112    }
113}
114
115impl<Fut, T, E, const N: usize> RaceOkTrait for [Fut; N]
116where
117    Fut: IntoFuture<Output = Result<T, E>>,
118{
119    type Output = T;
120    type Error = AggregateError<E, N>;
121    type Future = RaceOk<Fut::IntoFuture, T, E, N>;
122
123    fn race_ok(self) -> Self::Future {
124        RaceOk {
125            futures: self.map(|fut| fut.into_future()),
126            errors: array::from_fn(|_| MaybeUninit::uninit()),
127            error_states: PollArray::new_pending(),
128            completed: 0,
129        }
130    }
131}
132
133#[cfg(test)]
134mod test {
135    use super::*;
136    use core::future;
137
138    #[test]
139    fn all_ok() {
140        futures_lite::future::block_on(async {
141            let res: Result<&str, AggregateError<(), 2>> =
142                [future::ready(Ok("hello")), future::ready(Ok("world"))]
143                    .race_ok()
144                    .await;
145            assert!(res.is_ok());
146        })
147    }
148
149    #[test]
150    fn one_err() {
151        futures_lite::future::block_on(async {
152            let res: Result<&str, AggregateError<_, 2>> =
153                [future::ready(Ok("hello")), future::ready(Err("oh no"))]
154                    .race_ok()
155                    .await;
156            assert_eq!(res.unwrap(), "hello");
157        });
158    }
159
160    #[test]
161    fn all_err() {
162        futures_lite::future::block_on(async {
163            let res: Result<&str, AggregateError<_, 2>> =
164                [future::ready(Err("oops")), future::ready(Err("oh no"))]
165                    .race_ok()
166                    .await;
167            let errs = res.unwrap_err();
168            assert_eq!(errs[0], "oops");
169            assert_eq!(errs[1], "oh no");
170        });
171    }
172
173    #[test]
174    fn resume_after_completion() {
175        use futures_lite::future::yield_now;
176        futures_lite::future::block_on(async {
177            let fut = |ok| async move {
178                if ok {
179                    yield_now().await;
180                    yield_now().await;
181                    Ok(())
182                } else {
183                    Err(())
184                }
185            };
186
187            let res = [fut(true), fut(false)].race_ok().await;
188            assert_eq!(res.ok().unwrap(), ());
189        });
190    }
191
192    #[test]
193    fn drop_errors() {
194        use futures_lite::future::yield_now;
195
196        struct Droper<'a>(&'a core::cell::Cell<usize>);
197        impl Drop for Droper<'_> {
198            fn drop(&mut self) {
199                self.0.set(self.0.get() + 1);
200            }
201        }
202
203        futures_lite::future::block_on(async {
204            let drop_count = Default::default();
205            let fut = |ok| {
206                let drop_count = &drop_count;
207                async move {
208                    if ok {
209                        yield_now().await;
210                        yield_now().await;
211                        Ok(())
212                    } else {
213                        Err(Droper(drop_count))
214                    }
215                }
216            };
217            let res = [fut(true), fut(false)].race_ok().await;
218            assert_eq!(drop_count.get(), 1);
219            assert_eq!(res.ok().unwrap(), ());
220
221            drop_count.set(0);
222            let res = [fut(false), fut(false)].race_ok().await;
223            assert!(res.is_err());
224            assert_eq!(drop_count.get(), 0);
225            drop(res);
226            assert_eq!(drop_count.get(), 2);
227        })
228    }
229}