futures_concurrency/future/race_ok/vec/
mod.rs

1use super::RaceOk as RaceOkTrait;
2use crate::utils::iter_pin_mut;
3use crate::utils::MaybeDone;
4
5#[cfg(all(feature = "alloc", not(feature = "std")))]
6use alloc::{boxed::Box, vec::Vec};
7
8use core::fmt;
9use core::future::{Future, IntoFuture};
10use core::mem;
11use core::pin::Pin;
12use core::task::{Context, Poll};
13
14pub use error::AggregateError;
15
16mod error;
17
18/// A future which waits for the first successful future to complete.
19///
20/// This `struct` is created by the [`race_ok`] method on the [`RaceOk`] trait. See
21/// its documentation for more.
22///
23/// [`race_ok`]: crate::future::RaceOk::race_ok
24/// [`RaceOk`]: crate::future::RaceOk
25#[must_use = "futures do nothing unless you `.await` or poll them"]
26pub struct RaceOk<Fut, T, E>
27where
28    Fut: Future<Output = Result<T, E>>,
29{
30    elems: Pin<Box<[MaybeDone<Fut>]>>,
31}
32
33impl<Fut, T, E> fmt::Debug for RaceOk<Fut, T, E>
34where
35    Fut: Future<Output = Result<T, E>> + fmt::Debug,
36    Fut::Output: fmt::Debug,
37{
38    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39        f.debug_list().entries(self.elems.iter()).finish()
40    }
41}
42
43impl<Fut, T, E> Future for RaceOk<Fut, T, E>
44where
45    Fut: Future<Output = Result<T, E>>,
46{
47    type Output = Result<T, AggregateError<E>>;
48
49    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
50        let mut all_done = true;
51
52        for mut elem in iter_pin_mut(self.elems.as_mut()) {
53            if elem.as_mut().poll(cx).is_pending() {
54                all_done = false
55            } else if let Some(output) = elem.take_ok() {
56                return Poll::Ready(Ok(output));
57            }
58        }
59
60        if all_done {
61            let mut elems = mem::replace(&mut self.elems, Box::pin([]));
62            let result: Vec<E> = iter_pin_mut(elems.as_mut())
63                .map(|e| match e.take_err() {
64                    Some(err) => err,
65                    // Since all futures are done without any one of them returning `Ok`, they're
66                    // all `Err`s and so `take_err` cannot fail
67                    None => unreachable!(),
68                })
69                .collect();
70            Poll::Ready(Err(AggregateError::new(result)))
71        } else {
72            Poll::Pending
73        }
74    }
75}
76
77impl<Fut, T, E> RaceOkTrait for Vec<Fut>
78where
79    Fut: IntoFuture<Output = Result<T, E>>,
80{
81    type Output = T;
82    type Error = AggregateError<E>;
83    type Future = RaceOk<Fut::IntoFuture, T, E>;
84
85    fn race_ok(self) -> Self::Future {
86        let elems: Box<[_]> = self
87            .into_iter()
88            .map(|fut| MaybeDone::new(fut.into_future()))
89            .collect();
90        RaceOk {
91            elems: elems.into(),
92        }
93    }
94}
95
96#[cfg(test)]
97mod test {
98    use super::*;
99    use alloc::vec;
100    use core::future;
101
102    #[test]
103    fn all_ok() {
104        futures_lite::future::block_on(async {
105            let res: Result<&str, AggregateError<()>> =
106                vec![future::ready(Ok("hello")), future::ready(Ok("world"))]
107                    .race_ok()
108                    .await;
109            assert!(res.is_ok());
110        })
111    }
112
113    #[test]
114    fn one_err() {
115        futures_lite::future::block_on(async {
116            let res: Result<&str, AggregateError<_>> =
117                vec![future::ready(Ok("hello")), future::ready(Err("oh no"))]
118                    .race_ok()
119                    .await;
120            assert_eq!(res.unwrap(), "hello");
121        });
122    }
123
124    #[test]
125    fn all_err() {
126        futures_lite::future::block_on(async {
127            let res: Result<&str, AggregateError<_>> =
128                vec![future::ready(Err("oops")), future::ready(Err("oh no"))]
129                    .race_ok()
130                    .await;
131            let errs = res.unwrap_err();
132            assert_eq!(errs[0], "oops");
133            assert_eq!(errs[1], "oh no");
134        });
135    }
136}