futures_concurrency/future/race_ok/array/
mod.rs1use super::RaceOk as RaceOkTrait;
2use crate::utils::array_assume_init;
3use crate::utils::iter_pin_mut;
4use crate::utils::PollArray;
5
6use core::array;
7use core::fmt;
8use core::future::{Future, IntoFuture};
9use core::mem::{self, MaybeUninit};
10use core::pin::Pin;
11use core::task::{Context, Poll};
12
13use pin_project::{pin_project, pinned_drop};
14
15mod error;
16
17pub use error::AggregateError;
18
19#[must_use = "futures do nothing unless you `.await` or poll them"]
27#[pin_project(PinnedDrop)]
28pub struct RaceOk<Fut, T, E, const N: usize>
29where
30 Fut: Future<Output = Result<T, E>>,
31{
32 #[pin]
33 futures: [Fut; N],
34 errors: [MaybeUninit<E>; N],
35 error_states: PollArray<N>,
36 completed: usize,
37}
38
39#[pinned_drop]
40impl<Fut, T, E, const N: usize> PinnedDrop for RaceOk<Fut, T, E, N>
41where
42 Fut: Future<Output = Result<T, E>>,
43{
44 fn drop(self: Pin<&mut Self>) {
45 let this = self.project();
46 for (st, err) in this
47 .error_states
48 .iter_mut()
49 .zip(this.errors.iter_mut())
50 .filter(|(st, _err)| st.is_ready())
51 {
52 unsafe { err.assume_init_drop() };
54 st.set_none();
55 }
56 }
57}
58
59impl<Fut, T, E, const N: usize> fmt::Debug for RaceOk<Fut, T, E, N>
60where
61 Fut: Future<Output = Result<T, E>> + fmt::Debug,
62 Fut::Output: fmt::Debug,
63{
64 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65 f.debug_list().entries(self.futures.iter()).finish()
66 }
67}
68
69impl<Fut, T, E, const N: usize> Future for RaceOk<Fut, T, E, N>
70where
71 Fut: Future<Output = Result<T, E>>,
72{
73 type Output = Result<T, AggregateError<E, N>>;
74
75 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
76 let this = self.project();
77
78 let futures = iter_pin_mut(this.futures);
79
80 for ((fut, out), st) in futures
81 .zip(this.errors.iter_mut())
82 .zip(this.error_states.iter_mut())
83 {
84 if st.is_ready() {
85 continue;
86 }
87 if let Poll::Ready(output) = fut.poll(cx) {
88 match output {
89 Ok(ok) => return Poll::Ready(Ok(ok)),
90 Err(err) => {
91 *out = MaybeUninit::new(err);
92 *this.completed += 1;
93 st.set_ready();
94 }
95 }
96 }
97 }
98
99 let all_completed = *this.completed == N;
100 if all_completed {
101 let mut errors = array::from_fn(|_| MaybeUninit::uninit());
102 mem::swap(&mut errors, this.errors);
103 this.error_states.set_all_none();
104
105 let result = unsafe { array_assume_init(errors) };
107
108 Poll::Ready(Err(AggregateError::new(result)))
109 } else {
110 Poll::Pending
111 }
112 }
113}
114
115impl<Fut, T, E, const N: usize> RaceOkTrait for [Fut; N]
116where
117 Fut: IntoFuture<Output = Result<T, E>>,
118{
119 type Output = T;
120 type Error = AggregateError<E, N>;
121 type Future = RaceOk<Fut::IntoFuture, T, E, N>;
122
123 fn race_ok(self) -> Self::Future {
124 RaceOk {
125 futures: self.map(|fut| fut.into_future()),
126 errors: array::from_fn(|_| MaybeUninit::uninit()),
127 error_states: PollArray::new_pending(),
128 completed: 0,
129 }
130 }
131}
132
133#[cfg(test)]
134mod test {
135 use super::*;
136 use core::future;
137
138 #[test]
139 fn all_ok() {
140 futures_lite::future::block_on(async {
141 let res: Result<&str, AggregateError<(), 2>> =
142 [future::ready(Ok("hello")), future::ready(Ok("world"))]
143 .race_ok()
144 .await;
145 assert!(res.is_ok());
146 })
147 }
148
149 #[test]
150 fn one_err() {
151 futures_lite::future::block_on(async {
152 let res: Result<&str, AggregateError<_, 2>> =
153 [future::ready(Ok("hello")), future::ready(Err("oh no"))]
154 .race_ok()
155 .await;
156 assert_eq!(res.unwrap(), "hello");
157 });
158 }
159
160 #[test]
161 fn all_err() {
162 futures_lite::future::block_on(async {
163 let res: Result<&str, AggregateError<_, 2>> =
164 [future::ready(Err("oops")), future::ready(Err("oh no"))]
165 .race_ok()
166 .await;
167 let errs = res.unwrap_err();
168 assert_eq!(errs[0], "oops");
169 assert_eq!(errs[1], "oh no");
170 });
171 }
172
173 #[test]
174 fn resume_after_completion() {
175 use futures_lite::future::yield_now;
176 futures_lite::future::block_on(async {
177 let fut = |ok| async move {
178 if ok {
179 yield_now().await;
180 yield_now().await;
181 Ok(())
182 } else {
183 Err(())
184 }
185 };
186
187 let res = [fut(true), fut(false)].race_ok().await;
188 assert_eq!(res.ok().unwrap(), ());
189 });
190 }
191
192 #[test]
193 fn drop_errors() {
194 use futures_lite::future::yield_now;
195
196 struct Droper<'a>(&'a core::cell::Cell<usize>);
197 impl Drop for Droper<'_> {
198 fn drop(&mut self) {
199 self.0.set(self.0.get() + 1);
200 }
201 }
202
203 futures_lite::future::block_on(async {
204 let drop_count = Default::default();
205 let fut = |ok| {
206 let drop_count = &drop_count;
207 async move {
208 if ok {
209 yield_now().await;
210 yield_now().await;
211 Ok(())
212 } else {
213 Err(Droper(drop_count))
214 }
215 }
216 };
217 let res = [fut(true), fut(false)].race_ok().await;
218 assert_eq!(drop_count.get(), 1);
219 assert_eq!(res.ok().unwrap(), ());
220
221 drop_count.set(0);
222 let res = [fut(false), fut(false)].race_ok().await;
223 assert!(res.is_err());
224 assert_eq!(drop_count.get(), 0);
225 drop(res);
226 assert_eq!(drop_count.get(), 2);
227 })
228 }
229}