futures_concurrency/future/join/
tuple.rs

1use super::Join as JoinTrait;
2use crate::utils::{PollArray, WakerArray};
3
4use core::fmt::{self, Debug};
5use core::future::{Future, IntoFuture};
6use core::mem::{ManuallyDrop, MaybeUninit};
7use core::ops::DerefMut;
8use core::pin::Pin;
9use core::task::{Context, Poll};
10
11use pin_project::{pin_project, pinned_drop};
12
13/// Generates the `poll` call for every `Future` inside `$futures`.
14///
15/// SAFETY: pretty please only call this after having made very sure that the future you're trying
16/// to call is actually marked `ready!`. If Rust had unsafe macros, this would be one.
17//
18// This is implemented as a tt-muncher of the future name `$($F:ident)`
19// and the future index `$($rest)`, taking advantage that we only support
20// tuples up to  12 elements
21//
22// # References
23// TT Muncher: https://veykril.github.io/tlborm/decl-macros/patterns/tt-muncher.html
24macro_rules! unsafe_poll {
25    // recursively iterate
26    (@inner $iteration:ident, $this:ident, $futures:ident, $cx:ident, $fut_name:ident $($F:ident)* | $fut_idx:tt $($rest:tt)*) => {
27        if $fut_idx == $iteration {
28
29            if let Poll::Ready(value) = unsafe {
30                $futures.$fut_name.as_mut()
31                    .map_unchecked_mut(|t| t.deref_mut())
32                    .poll(&mut $cx)
33            } {
34                $this.outputs.$fut_idx.write(value);
35                *$this.completed += 1;
36                $this.state[$fut_idx].set_ready();
37                // SAFETY: the future state has been changed to "ready" which
38                // means we'll no longer poll the future, so it's safe to drop
39                unsafe { ManuallyDrop::drop($futures.$fut_name.as_mut().get_unchecked_mut()) };
40            }
41        }
42        unsafe_poll!(@inner $iteration, $this, $futures, $cx, $($F)* | $($rest)*);
43    };
44
45    // base condition
46    (@inner $iteration:ident, $this:ident, $futures:ident, $cx:ident, | $($rest:tt)*) => {};
47
48    // macro start
49    ($iteration:ident, $this:ident, $futures:ident, $cx:ident, $LEN:ident, $($F:ident,)+) => {
50        unsafe_poll!(@inner $iteration, $this, $futures, $cx, $($F)+ | 0 1 2 3 4 5 6 7 8 9 10 11);
51    };
52}
53
54/// Drop all initialized values
55macro_rules! drop_initialized_values {
56    // recursively iterate
57    (@drop $output:ident, $($rem_outs:ident,)* | $states:expr, $state_idx:tt, $($rem_idx:tt,)*) => {
58        if $states[$state_idx].is_ready() {
59            // SAFETY: we've just filtered down to *only* the initialized values.
60            // We can assume they're initialized, and this is where we drop them.
61            unsafe { $output.assume_init_drop() };
62            $states[$state_idx].set_none();
63        }
64        drop_initialized_values!(@drop $($rem_outs,)* | $states, $($rem_idx,)*);
65    };
66
67    // base condition
68    (@drop | $states:expr, $($rem_idx:tt,)*) => {};
69
70    // macro start
71    ($($outs:ident,)+ | $states:expr) => {
72        drop_initialized_values!(@drop $($outs,)+ | $states, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,);
73    };
74}
75
76/// Drop all pending futures
77macro_rules! drop_pending_futures {
78    // recursively iterate
79    (@inner $states:ident, $futures:ident, $fut_name:ident $($F:ident)* | $fut_idx:tt $($rest:tt)*) => {
80        if $states[$fut_idx].is_pending() {
81            // SAFETY: We're accessing the value behind the pinned reference to drop it exactly once.
82            let futures = unsafe { $futures.as_mut().get_unchecked_mut() };
83            // SAFETY: we've just filtered down to *only* the initialized values.
84            // We can assume they're initialized, and this is where we drop them.
85            unsafe { ManuallyDrop::drop(&mut futures.$fut_name) };
86        }
87        drop_pending_futures!(@inner $states, $futures, $($F)* | $($rest)*);
88    };
89
90    // base condition
91    (@inner $states:ident, $futures:ident, | $($rest:tt)*) => {};
92
93    // macro start
94    ($states:ident, $futures:ident, $($F:ident,)+) => {
95        drop_pending_futures!(@inner $states, $futures, $($F)+ | 0 1 2 3 4 5 6 7 8 9 10 11);
96    };
97}
98
99macro_rules! impl_join_tuple {
100    ($mod_name:ident $StructName:ident) => {
101        /// A future which waits for two similarly-typed futures to complete.
102        ///
103        /// This `struct` is created by the [`join`] method on the [`Join`] trait. See
104        /// its documentation for more.
105        ///
106        /// [`join`]: crate::future::Join::join
107        /// [`Join`]: crate::future::Join
108        #[must_use = "futures do nothing unless you `.await` or poll them"]
109        #[allow(non_snake_case)]
110        pub struct $StructName {}
111
112        impl fmt::Debug for $StructName {
113            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114                f.debug_tuple("Join").finish()
115            }
116        }
117
118        impl Future for $StructName {
119            type Output = ();
120
121            fn poll(
122                self: Pin<&mut Self>, _cx: &mut Context<'_>
123            ) -> Poll<Self::Output> {
124                Poll::Ready(())
125            }
126        }
127
128        impl JoinTrait for () {
129            type Output = ();
130            type Future = $StructName;
131            fn join(self) -> Self::Future {
132                $StructName {}
133            }
134        }
135    };
136    ($mod_name:ident $StructName:ident $($F:ident)+) => {
137        mod $mod_name {
138            use core::mem::ManuallyDrop;
139
140            #[pin_project::pin_project]
141            pub(super) struct Futures<$($F,)+> {$(
142                #[pin]
143                pub(super) $F: ManuallyDrop<$F>,
144            )+}
145
146            #[repr(u8)]
147            pub(super) enum Indexes { $($F,)+ }
148
149            pub(super) const LEN: usize = [$(Indexes::$F,)+].len();
150        }
151
152        /// Waits for many similarly-typed futures to complete.
153        ///
154        /// This `struct` is created by the [`join`] method on the [`Join`] trait. See
155        /// its documentation for more.
156        ///
157        /// [`join`]: crate::future::Join::join
158        /// [`Join`]: crate::future::Join
159        #[pin_project(PinnedDrop)]
160        #[must_use = "futures do nothing unless you `.await` or poll them"]
161        #[allow(non_snake_case)]
162        pub struct $StructName<$($F: Future),+> {
163            #[pin]
164            futures: $mod_name::Futures<$($F,)+>,
165            outputs: ($(MaybeUninit<$F::Output>,)+),
166            // trace the state of outputs, marking them as ready or consumed
167            // then, drop the non-consumed values, if any
168            state: PollArray<{$mod_name::LEN}>,
169            wakers: WakerArray<{$mod_name::LEN}>,
170            completed: usize,
171        }
172
173        impl<$($F),+> Debug for $StructName<$($F),+>
174        where
175            $( $F: Future + Debug, )+
176        {
177            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
178                f.debug_tuple("Join")
179                    $(.field(&self.futures.$F))+
180                    .finish()
181            }
182        }
183
184        #[allow(unused_mut)]
185        #[allow(unused_parens)]
186        #[allow(unused_variables)]
187        impl<$($F: Future),+> Future for $StructName<$($F),+> {
188            type Output = ($($F::Output,)+);
189
190            fn poll(
191                self: Pin<&mut Self>, cx: &mut Context<'_>
192            ) -> Poll<Self::Output> {
193                const LEN: usize = $mod_name::LEN;
194
195                let mut this = self.project();
196                let all_completed = !(*this.completed == LEN);
197                assert!(all_completed, "Futures must not be polled after completing");
198
199                let mut futures = this.futures.project();
200
201                let mut readiness = this.wakers.readiness();
202                readiness.set_waker(cx.waker());
203
204                for index in 0..LEN {
205                    if !readiness.any_ready() {
206                        // nothing ready yet
207                        return Poll::Pending;
208                    }
209                    if !readiness.clear_ready(index) || this.state[index].is_ready() {
210                        // future not ready yet or already polled to completion, skip
211                        continue;
212                    }
213
214                    // unlock readiness so we don't deadlock when polling
215                    #[allow(clippy::drop_non_drop)]
216                    drop(readiness);
217
218                    // obtain the intermediate waker
219                    let mut cx = Context::from_waker(this.wakers.get(index).unwrap());
220
221                    // generate the needed code to poll `futures.{index}`
222                    // SAFETY: the future's state should be "pending", so it's safe to poll
223                    unsafe_poll!(index, this, futures, cx, LEN, $($F,)+);
224
225                    if *this.completed == LEN {
226                        let out = {
227                            let mut out = ($(MaybeUninit::<$F::Output>::uninit(),)+);
228                            core::mem::swap(&mut out, this.outputs);
229                            let ($($F,)+) = out;
230                            unsafe { ($($F.assume_init(),)+) }
231                        };
232
233                        this.state.set_all_none();
234
235                        return Poll::Ready(out);
236                    }
237                    readiness = this.wakers.readiness();
238                }
239
240                Poll::Pending
241            }
242        }
243
244        #[pinned_drop]
245        impl<$($F: Future),+> PinnedDrop for $StructName<$($F),+> {
246            fn drop(self: Pin<&mut Self>) {
247                let this = self.project();
248
249                let ($(ref mut $F,)+) = this.outputs;
250
251                let states = this.state;
252                let mut futures = this.futures;
253                drop_initialized_values!($($F,)+ | states);
254                drop_pending_futures!(states, futures, $($F,)+);
255            }
256        }
257
258        #[allow(unused_parens)]
259        impl<$($F),+> JoinTrait for ($($F,)+)
260        where $(
261            $F: IntoFuture,
262        )+ {
263            type Output = ($($F::Output,)*);
264            type Future = $StructName<$($F::IntoFuture),*>;
265
266            fn join(self) -> Self::Future {
267                let ($($F,)+): ($($F,)+) = self;
268                $StructName {
269                    futures: $mod_name::Futures {$($F: ManuallyDrop::new($F.into_future()),)+},
270                    state: PollArray::new_pending(),
271                    outputs: ($(MaybeUninit::<$F::Output>::uninit(),)+),
272                    wakers: WakerArray::new(),
273                    completed: 0,
274                }
275            }
276        }
277    };
278}
279
280impl_join_tuple! { join0 Join0 }
281impl_join_tuple! { join1 Join1 A }
282impl_join_tuple! { join2 Join2 A B }
283impl_join_tuple! { join3 Join3 A B C }
284impl_join_tuple! { join4 Join4 A B C D }
285impl_join_tuple! { join5 Join5 A B C D E }
286impl_join_tuple! { join6 Join6 A B C D E F }
287impl_join_tuple! { join7 Join7 A B C D E F G }
288impl_join_tuple! { join8 Join8 A B C D E F G H }
289impl_join_tuple! { join9 Join9 A B C D E F G H I }
290impl_join_tuple! { join10 Join10 A B C D E F G H I J }
291impl_join_tuple! { join11 Join11 A B C D E F G H I J K }
292impl_join_tuple! { join12 Join12 A B C D E F G H I J K L }
293
294#[cfg(test)]
295mod test {
296    use super::*;
297    use core::future;
298
299    #[test]
300    #[allow(clippy::unit_cmp)]
301    fn join_0() {
302        futures_lite::future::block_on(async {
303            assert_eq!(().join().await, ());
304        });
305    }
306
307    #[test]
308    fn join_1() {
309        futures_lite::future::block_on(async {
310            let a = future::ready("hello");
311            assert_eq!((a,).join().await, ("hello",));
312        });
313    }
314
315    #[test]
316    fn join_2() {
317        futures_lite::future::block_on(async {
318            let a = future::ready("hello");
319            let b = future::ready(12);
320            assert_eq!((a, b).join().await, ("hello", 12));
321        });
322    }
323
324    #[test]
325    fn join_3() {
326        futures_lite::future::block_on(async {
327            let a = future::ready("hello");
328            let b = future::ready("world");
329            let c = future::ready(12);
330            assert_eq!((a, b, c).join().await, ("hello", "world", 12));
331        });
332    }
333
334    #[test]
335    #[cfg(feature = "std")]
336    fn does_not_leak_memory() {
337        use core::cell::RefCell;
338        use futures_lite::future::pending;
339
340        thread_local! {
341            static NOT_LEAKING: RefCell<bool> = const { RefCell::new(false) };
342        };
343
344        struct FlipFlagAtDrop;
345        impl Drop for FlipFlagAtDrop {
346            fn drop(&mut self) {
347                NOT_LEAKING.with(|v| {
348                    *v.borrow_mut() = true;
349                });
350            }
351        }
352
353        futures_lite::future::block_on(async {
354            // this will trigger Miri if we don't drop the memory
355            let string = future::ready("memory leak".to_owned());
356
357            // this will not flip the thread_local flag if we don't drop the memory
358            let flip = future::ready(FlipFlagAtDrop);
359
360            let leak = (string, flip, pending::<u8>()).join();
361
362            _ = futures_lite::future::poll_once(leak).await;
363        });
364
365        NOT_LEAKING.with(|flag| {
366            assert!(*flag.borrow());
367        })
368    }
369}