futures_concurrency/stream/zip/
tuple.rs

1use core::fmt;
2use core::mem::MaybeUninit;
3use core::pin::Pin;
4use core::task::{Context, Poll};
5
6use futures_core::Stream;
7
8use super::Zip;
9use crate::utils::{PollArray, WakerArray};
10
11macro_rules! impl_zip_for_tuple {
12    ($mod_name: ident $StructName: ident $($F: ident)+) => {
13        mod $mod_name {
14            pub(super) struct Output<$($F,)+>
15            where
16                $($F: super::Stream,)+
17            {
18                $(pub(super) $F: core::mem::MaybeUninit<<$F as super::Stream>::Item>,)+
19            }
20
21            impl<$($F,)+> Default for Output<$($F,)+>
22            where
23                $($F: super::Stream,)+
24            {
25                fn default() -> Self {
26                    Self {
27                        $($F: core::mem::MaybeUninit::uninit(),)+
28                    }
29                }
30            }
31
32            #[repr(usize)]
33            enum Indexes {
34                $($F,)+
35            }
36
37            $(
38                pub(super) const $F: usize = Indexes::$F as usize;
39            )+
40
41            pub(super) const LEN: usize = [$(Indexes::$F,)+].len();
42        }
43
44        #[pin_project::pin_project(PinnedDrop)]
45        pub struct $StructName<$($F,)+>
46        where
47            $($F: Stream,)+
48        {
49            done: bool,
50            output: $mod_name::Output<$($F,)+>,
51            state: PollArray<{ $mod_name::LEN }>,
52            wakers: WakerArray<{ $mod_name::LEN }>,
53            $( #[pin] $F: $F,)+
54
55        }
56
57        impl<$($F,)+> fmt::Debug for $StructName<$($F,)+>
58        where
59            $($F: Stream + fmt::Debug,)+
60        {
61            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62                f.debug_tuple("Zip")
63                    $(.field(&self.$F))+
64                    .finish()
65            }
66        }
67
68        impl<$($F,)+> Stream for $StructName<$($F,)+>
69        where
70            $($F: Stream,)+
71        {
72            type Item = (
73                $(<$F as Stream>::Item,)+
74            );
75
76            fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
77                let mut this = self.project();
78
79                const LEN: usize = $mod_name::LEN;
80
81                assert!(!*this.done, "Stream should not be polled after completion");
82
83                let mut readiness = this.wakers.readiness();
84                readiness.set_waker(cx.waker());
85
86                for index in 0..LEN {
87                    if !readiness.any_ready() {
88                        // Nothing is ready yet
89                        return Poll::Pending;
90                    } else if this.state[index].is_ready() || !readiness.clear_ready(index) {
91                        // We already have data stored for this stream,
92                        // Or this waker isn't ready yet
93                        continue;
94                    }
95
96                    // unlock readiness so we don't deadlock when polling
97                    #[allow(clippy::drop_non_drop)]
98                    drop(readiness);
99
100                    // Obtain the intermediate waker.
101                    let mut cx = Context::from_waker(this.wakers.get(index).unwrap());
102
103                    let all_ready = match index {
104                        $(
105                            $mod_name::$F => {
106                                let stream = unsafe { Pin::new_unchecked(&mut this.$F) };
107
108                                match stream.poll_next(&mut cx) {
109                                    Poll::Pending => false,
110                                    Poll::Ready(None) => {
111                                        // If one stream returns `None`, we can no longer return
112                                        // pairs - meaning the stream is over.
113                                        *this.done = true;
114                                        return Poll::Ready(None);
115                                    }
116                                    Poll::Ready(Some(item)) => {
117                                        this.output.$F = MaybeUninit::new(item);
118                                        this.state[$mod_name::$F].set_ready();
119
120                                        this.state.iter().all(|state| state.is_ready())
121                                    }
122                                }
123                            },
124                        )+
125                        _ => unreachable!(),
126                    };
127
128                    if all_ready {
129                        // Reset the future's state.
130                        readiness = this.wakers.readiness();
131                        readiness.set_all_ready();
132                        this.state.set_all_pending();
133
134                        // Take the output
135                        //
136                        // SAFETY: we just validated all our data is populated, meaning
137                        // we can assume this is initialized.
138                        let mut output = $mod_name::Output::default();
139                        core::mem::swap(this.output, &mut output);
140
141                        match output {
142                            $mod_name::Output {
143                                $($F,)+
144                            } => return Poll::Ready(Some((
145                                $(unsafe { $F.assume_init() },)+
146                            )))
147                        }
148                    }
149
150                    // Lock readiness so we can use it again
151                    readiness = this.wakers.readiness();
152                }
153
154                Poll::Pending
155            }
156        }
157
158        impl<$($F,)+> Zip for ($($F,)+)
159        where
160            $($F: Stream,)+
161        {
162            type Item = (
163                $(<$F as Stream>::Item,)+
164            );
165
166            type Stream = $StructName<$($F,)+>;
167
168            fn zip(self) -> Self::Stream {
169                let ($($F,)*): ($($F,)*) = self;
170                Self::Stream {
171                    done: false,
172                    output: Default::default(),
173                    state: PollArray::new_pending(),
174                    wakers: WakerArray::new(),
175                    $($F,)+
176                }
177            }
178        }
179
180        #[pin_project::pinned_drop]
181        impl<$($F,)+> PinnedDrop for $StructName<$($F,)+>
182        where
183            $($F: Stream,)+
184        {
185            fn drop(self: Pin<&mut Self>) {
186                let this = self.project();
187
188                $(
189                    if this.state[$mod_name::$F].is_ready() {
190                        // SAFETY: we've just filtered down to *only* the initialized values.
191                        unsafe { this.output.$F.assume_init_drop() };
192                    }
193                )+
194            }
195        }
196    };
197}
198
199impl_zip_for_tuple! { zip_1 Zip1 A }
200impl_zip_for_tuple! { zip_2 Zip2 A B }
201impl_zip_for_tuple! { zip_3 Zip3 A B C }
202impl_zip_for_tuple! { zip_4 Zip4 A B C D }
203impl_zip_for_tuple! { zip_5 Zip5 A B C D E }
204impl_zip_for_tuple! { zip_6 Zip6 A B C D E F }
205impl_zip_for_tuple! { zip_7 Zip7 A B C D E F G }
206impl_zip_for_tuple! { zip_8 Zip8 A B C D E F G H }
207impl_zip_for_tuple! { zip_9 Zip9 A B C D E F G H I }
208impl_zip_for_tuple! { zip_10 Zip10 A B C D E F G H I J }
209impl_zip_for_tuple! { zip_11 Zip11 A B C D E F G H I J K }
210impl_zip_for_tuple! { zip_12 Zip12 A B C D E F G H I J K L }
211
212#[cfg(test)]
213mod tests {
214    use futures_lite::future::block_on;
215    use futures_lite::prelude::*;
216    use futures_lite::stream;
217
218    use crate::stream::Zip;
219
220    #[test]
221    fn zip_tuple_3() {
222        block_on(async {
223            let a = stream::repeat(1).take(2);
224            let b = stream::repeat("hello").take(2);
225            let c = stream::repeat(("a", "b")).take(2);
226            let mut s = Zip::zip((a, b, c));
227
228            assert_eq!(s.next().await, Some((1, "hello", ("a", "b"))));
229            assert_eq!(s.next().await, Some((1, "hello", ("a", "b"))));
230            assert_eq!(s.next().await, None);
231        })
232    }
233}