futures_concurrency/future/try_join/
tuple.rs1use super::TryJoin as TryJoinTrait;
2use crate::utils::{PollArray, WakerArray};
3
4use core::fmt::{self, Debug};
5use core::future::{Future, IntoFuture};
6use core::marker::PhantomData;
7use core::mem::ManuallyDrop;
8use core::mem::MaybeUninit;
9use core::ops::DerefMut;
10use core::pin::Pin;
11use core::task::{Context, Poll};
12
13use pin_project::{pin_project, pinned_drop};
14
15macro_rules! unsafe_poll {
27 (@inner $iteration:ident, $this:ident, $futures:ident, $cx:ident, $fut_name:ident $($F:ident)* | $fut_idx:tt $($rest:tt)*) => {
29 if $fut_idx == $iteration {
30
31 if let Poll::Ready(value) = unsafe {
32 $futures.$fut_name.as_mut()
33 .map_unchecked_mut(|t| t.deref_mut())
34 .poll(&mut $cx)
35 } {
36 *$this.completed += 1;
37
38 match value {
40 Ok(value) => {
41 $this.outputs.$fut_idx.write(value);
42
43 $this.state[$fut_idx].set_ready();
48 unsafe { ManuallyDrop::drop($futures.$fut_name.as_mut().get_unchecked_mut()) };
49 }
50 Err(err) => {
51 *$this.consumed = true;
53
54 $this.state[$fut_idx].set_none();
60 unsafe { ManuallyDrop::drop($futures.$fut_name.as_mut().get_unchecked_mut()) };
61
62 return Poll::Ready(Err(err));
63 }
64 }
65 }
66 }
67 unsafe_poll!(@inner $iteration, $this, $futures, $cx, $($F)* | $($rest)*);
68 };
69
70 (@inner $iteration:ident, $this:ident, $futures:ident, $cx:ident, | $($rest:tt)*) => {};
72
73 ($iteration:ident, $this:ident, $futures:ident, $cx:ident, $LEN:ident, $($F:ident,)+) => {
75 unsafe_poll!(@inner $iteration, $this, $futures, $cx, $($F)+ | 0 1 2 3 4 5 6 7 8 9 10 11);
76 };
77}
78
79macro_rules! drop_initialized_values {
81 (@drop $output:ident, $($rem_outs:ident,)* | $states:expr, $state_idx:tt, $($rem_idx:tt,)*) => {
83 if $states[$state_idx].is_ready() {
84 unsafe { $output.assume_init_drop() };
87 $states[$state_idx].set_none();
88 }
89 drop_initialized_values!(@drop $($rem_outs,)* | $states, $($rem_idx,)*);
90 };
91
92 (@drop | $states:expr, $($rem_idx:tt,)*) => {};
94
95 ($($outs:ident,)+ | $states:expr) => {
97 drop_initialized_values!(@drop $($outs,)+ | $states, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,);
98 };
99}
100
101macro_rules! drop_pending_futures {
103 (@inner $states:ident, $futures:ident, $fut_name:ident $($F:ident)* | $fut_idx:tt $($rest:tt)*) => {
105 if $states[$fut_idx].is_pending() {
106 let futures = unsafe { $futures.as_mut().get_unchecked_mut() };
108 unsafe { ManuallyDrop::drop(&mut futures.$fut_name) };
111 }
112 drop_pending_futures!(@inner $states, $futures, $($F)* | $($rest)*);
113 };
114
115 (@inner $states:ident, $futures:ident, | $($rest:tt)*) => {};
117
118 ($states:ident, $futures:ident, $($F:ident,)+) => {
120 drop_pending_futures!(@inner $states, $futures, $($F)+ | 0 1 2 3 4 5 6 7 8 9 10 11);
121 };
122}
123
124macro_rules! impl_try_join_tuple {
125 ($mod_name:ident $StructName:ident) => {
127 #[must_use = "futures do nothing unless you `.await` or poll them"]
135 #[allow(non_snake_case)]
136 pub struct $StructName {}
137
138 impl fmt::Debug for $StructName {
139 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
140 f.debug_tuple("TryJoin").finish()
141 }
142 }
143
144 impl Future for $StructName {
145 type Output = Result<(), core::convert::Infallible>;
146
147 fn poll(
148 self: Pin<&mut Self>, _cx: &mut Context<'_>
149 ) -> Poll<Self::Output> {
150 Poll::Ready(Ok(()))
151 }
152 }
153
154 impl TryJoinTrait for () {
155 type Output = ();
156 type Error = core::convert::Infallible;
157 type Future = $StructName;
158 fn try_join(self) -> Self::Future {
159 $StructName {}
160 }
161 }
162 };
163
164 ($mod_name:ident $StructName:ident $(($F:ident $T:ident))+) => {
166 mod $mod_name {
167 use core::mem::ManuallyDrop;
168
169 #[pin_project::pin_project]
170 pub(super) struct Futures<$($F,)+> {$(
171 #[pin]
172 pub(super) $F: ManuallyDrop<$F>,
173 )+}
174
175 #[repr(u8)]
176 pub(super) enum Indexes { $($F,)+ }
177
178 pub(super) const LEN: usize = [$(Indexes::$F,)+].len();
179 }
180
181 #[pin_project(PinnedDrop)]
189 #[must_use = "futures do nothing unless you `.await` or poll them"]
190 #[allow(non_snake_case)]
191 pub struct $StructName<$($F, $T,)+ Err> {
192 #[pin]
193 futures: $mod_name::Futures<$($F,)+>,
194 outputs: ($(MaybeUninit<$T>,)+),
195 state: PollArray<{$mod_name::LEN}>,
198 wakers: WakerArray<{$mod_name::LEN}>,
199 completed: usize,
200 consumed: bool,
201 _phantom: PhantomData<Err>,
202 }
203
204 impl<$($F, $T,)+ Err> Debug for $StructName<$($F, $T,)+ Err>
205 where
206 $( $F: Future + Debug, )+
207 {
208 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
209 f.debug_tuple("TryJoin")
210 $(.field(&self.futures.$F))+
211 .finish()
212 }
213 }
214
215 #[allow(unused_mut)]
216 #[allow(unused_parens)]
217 #[allow(unused_variables)]
218 impl<$($F, $T,)+ Err> Future for $StructName<$($F, $T,)+ Err>
219 where $(
220 $F: Future<Output = Result<$T, Err>>,
221 )+ {
222 type Output = Result<($($T,)+), Err>;
223
224 fn poll(
225 self: Pin<&mut Self>, cx: &mut Context<'_>
226 ) -> Poll<Self::Output> {
227 const LEN: usize = $mod_name::LEN;
228
229 let mut this = self.project();
230 assert!(!*this.consumed, "Futures must not be polled after completing");
231
232 let mut futures = this.futures.project();
233
234 let mut readiness = this.wakers.readiness();
235 readiness.set_waker(cx.waker());
236
237 for index in 0..LEN {
238 if !readiness.any_ready() {
239 return Poll::Pending;
241 }
242 if !readiness.clear_ready(index) || this.state[index].is_ready() {
243 continue;
245 }
246
247 #[allow(clippy::drop_non_drop)]
249 drop(readiness);
250
251 let mut cx = Context::from_waker(this.wakers.get(index).unwrap());
253
254 unsafe_poll!(index, this, futures, cx, LEN, $($F,)+);
257
258 if *this.completed == LEN {
259 let out = {
260 let mut out = ($(MaybeUninit::<$T>::uninit(),)+);
261 core::mem::swap(&mut out, this.outputs);
262 let ($($F,)+) = out;
263 unsafe { ($($F.assume_init(),)+) }
264 };
265
266 this.state.set_all_none();
267 *this.consumed = true;
268
269 return Poll::Ready(Ok(out));
270 }
271 readiness = this.wakers.readiness();
272 }
273
274 Poll::Pending
275 }
276 }
277
278 #[pinned_drop]
279 impl<$($F, $T,)+ Err> PinnedDrop for $StructName<$($F, $T,)+ Err> {
280 fn drop(self: Pin<&mut Self>) {
281 let this = self.project();
282
283 let ($(ref mut $F,)+) = this.outputs;
284
285 let states = this.state;
286 let mut futures = this.futures;
287 drop_initialized_values!($($F,)+ | states);
288 drop_pending_futures!(states, futures, $($F,)+);
289 }
290 }
291
292 #[allow(unused_parens)]
293 impl<$($F, $T,)+ Err> TryJoinTrait for ($($F,)+)
294 where $(
295 $F: IntoFuture<Output = Result<$T, Err>>,
296 )+ {
297 type Output = ($($T,)+);
298 type Error = Err;
299 type Future = $StructName<$($F::IntoFuture, $T,)+ Err>;
300
301 fn try_join(self) -> Self::Future {
302 let ($($F,)+): ($($F,)+) = self;
303 $StructName {
304 futures: $mod_name::Futures {$(
305 $F: ManuallyDrop::new($F.into_future()),
306 )+},
307 state: PollArray::new_pending(),
308 outputs: ($(MaybeUninit::<$T>::uninit(),)+),
309 wakers: WakerArray::new(),
310 completed: 0,
311 consumed: false,
312 _phantom: PhantomData,
313 }
314 }
315 }
316 };
317}
318
319impl_try_join_tuple! { try_join0 TryJoin0 }
320impl_try_join_tuple! { try_join_1 TryJoin1 (A ResA) }
321impl_try_join_tuple! { try_join_2 TryJoin2 (A ResA) (B ResB) }
322impl_try_join_tuple! { try_join_3 TryJoin3 (A ResA) (B ResB) (C ResC) }
323impl_try_join_tuple! { try_join_4 TryJoin4 (A ResA) (B ResB) (C ResC) (D ResD) }
324impl_try_join_tuple! { try_join_5 TryJoin5 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) }
325impl_try_join_tuple! { try_join_6 TryJoin6 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) }
326impl_try_join_tuple! { try_join_7 TryJoin7 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) }
327impl_try_join_tuple! { try_join_8 TryJoin8 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) }
328impl_try_join_tuple! { try_join_9 TryJoin9 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) (I ResI) }
329impl_try_join_tuple! { try_join_10 TryJoin10 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) (I ResI) (J ResJ) }
330impl_try_join_tuple! { try_join_11 TryJoin11 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) (I ResI) (J ResJ) (K ResK) }
331impl_try_join_tuple! { try_join_12 TryJoin12 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) (I ResI) (J ResJ) (K ResK) (L ResL) }
332
333#[cfg(test)]
334mod test {
335 use super::*;
336
337 use core::convert::Infallible;
338 use core::future;
339
340 #[test]
341 fn all_ok() {
342 futures_lite::future::block_on(async {
343 let a = async { Ok::<_, Infallible>("aaaa") };
344 let b = async { Ok::<_, Infallible>(1) };
345 let c = async { Ok::<_, Infallible>('z') };
346
347 let result = (a, b, c).try_join().await;
348 assert_eq!(result, Ok(("aaaa", 1, 'z')));
349 })
350 }
351
352 #[test]
353 fn one_err() {
354 futures_lite::future::block_on(async {
355 let res: Result<(_, char), ()> = (future::ready(Ok("hello")), future::ready(Err(())))
356 .try_join()
357 .await;
358 assert_eq!(res, Err(()));
359 })
360 }
361
362 #[test]
363 fn issue_135_resume_after_completion() {
364 use futures_lite::future::yield_now;
365 futures_lite::future::block_on(async {
366 let ok = async { Ok::<_, ()>(()) };
367 let err = async {
368 yield_now().await;
369 Ok::<_, ()>(())
370 };
371
372 let res = (ok, err).try_join().await;
373
374 assert_eq!(res.unwrap(), ((), ()));
375 });
376 }
377
378 #[test]
379 #[cfg(feature = "std")]
380 fn does_not_leak_memory() {
381 use core::cell::RefCell;
382 use futures_lite::future::pending;
383
384 thread_local! {
385 static NOT_LEAKING: RefCell<bool> = const { RefCell::new(false) };
386 };
387
388 struct FlipFlagAtDrop;
389 impl Drop for FlipFlagAtDrop {
390 fn drop(&mut self) {
391 NOT_LEAKING.with(|v| {
392 *v.borrow_mut() = true;
393 });
394 }
395 }
396
397 futures_lite::future::block_on(async {
398 let string = future::ready(Result::Ok("memory leak".to_owned()));
400
401 let flip = future::ready(Result::Ok(FlipFlagAtDrop));
403
404 let leak = (string, flip, pending::<Result<u8, ()>>()).try_join();
405
406 _ = futures_lite::future::poll_once(leak).await;
407 });
408
409 NOT_LEAKING.with(|flag| {
410 assert!(*flag.borrow());
411 })
412 }
413}