futures_concurrency/stream/merge/
tuple.rs

1use super::Merge as MergeTrait;
2use crate::stream::IntoStream;
3use crate::utils::{self, PollArray, WakerArray};
4
5use core::fmt;
6use core::pin::Pin;
7use core::task::{Context, Poll};
8use futures_core::Stream;
9
10macro_rules! poll_stream {
11    ($stream_idx:tt, $iteration:ident, $this:ident, $streams:ident . $stream_member:ident, $cx:ident, $len_streams:ident) => {
12        if $stream_idx == $iteration {
13            match unsafe { Pin::new_unchecked(&mut $streams.$stream_member) }.poll_next(&mut $cx) {
14                Poll::Ready(Some(item)) => {
15                    // Mark ourselves as ready again because we need to poll for the next item.
16                    $this.wakers.readiness().set_ready($stream_idx);
17                    return Poll::Ready(Some(item));
18                }
19                Poll::Ready(None) => {
20                    *$this.completed += 1;
21                    $this.state[$stream_idx].set_none();
22                    if *$this.completed == $len_streams {
23                        return Poll::Ready(None);
24                    }
25                }
26                Poll::Pending => {}
27            }
28        }
29    };
30}
31
32macro_rules! impl_merge_tuple {
33    ($ignore:ident $StructName:ident) => {
34        /// A stream that merges multiple streams into a single stream.
35        ///
36        /// This `struct` is created by the [`merge`] method on the [`Merge`] trait. See its
37        /// documentation for more.
38        ///
39        /// [`merge`]: trait.Merge.html#method.merge
40        /// [`Merge`]: trait.Merge.html
41        pub struct $StructName {}
42
43        impl fmt::Debug for $StructName {
44            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45                f.debug_tuple("Merge").finish()
46            }
47        }
48
49        impl Stream for $StructName {
50            type Item = core::convert::Infallible; // TODO: convert to `never` type in the stdlib
51
52            fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
53                Poll::Ready(None)
54            }
55        }
56
57        impl MergeTrait for () {
58            type Item = core::convert::Infallible; // TODO: convert to `never` type in the stdlib
59            type Stream = $StructName;
60
61            fn merge(self) -> Self::Stream {
62                $StructName { }
63            }
64        }
65    };
66    ($mod_name:ident $StructName:ident $($F:ident)+) => {
67        mod $mod_name {
68            #[pin_project::pin_project]
69            pub(super) struct Streams<$($F,)+> { $(#[pin] pub(super) $F: $F),+ }
70
71            #[repr(usize)]
72            pub(super) enum Indexes { $($F),+ }
73
74            pub(super) const LEN: usize = [$(Indexes::$F),+].len();
75        }
76
77        /// A stream that merges multiple streams into a single stream.
78        ///
79        /// This `struct` is created by the [`merge`] method on the [`Merge`] trait. See its
80        /// documentation for more.
81        ///
82        /// [`merge`]: trait.Merge.html#method.merge
83        /// [`Merge`]: trait.Merge.html
84        #[pin_project::pin_project]
85        pub struct $StructName<T, $($F),*>
86        where $(
87            $F: Stream<Item = T>,
88        )* {
89            #[pin] streams: $mod_name::Streams<$($F,)+>,
90            indexer: utils::Indexer,
91            wakers: WakerArray<{$mod_name::LEN}>,
92            state: PollArray<{$mod_name::LEN}>,
93            completed: u8,
94        }
95
96        impl<T, $($F),*> fmt::Debug for $StructName<T, $($F),*>
97        where
98            $( $F: Stream<Item = T> + fmt::Debug, )*
99        {
100            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
101                f.debug_tuple("Merge")
102                    $( .field(&self.streams.$F) )* // Hides implementation detail of Streams struct
103                    .finish()
104            }
105        }
106
107        impl<T, $($F),*> Stream for $StructName<T, $($F),*>
108        where $(
109            $F: Stream<Item = T>,
110        )* {
111            type Item = T;
112
113            fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
114                let this = self.project();
115
116                let mut readiness = this.wakers.readiness();
117                readiness.set_waker(cx.waker());
118
119                const LEN: u8 = $mod_name::LEN as u8;
120
121                let mut streams = this.streams.project();
122
123                // Iterate over our streams one-by-one. If a stream yields a value,
124                // we exit early. By default we'll return `Poll::Ready(None)`, but
125                // this changes if we encounter a `Poll::Pending`.
126                for index in this.indexer.iter() {
127                    if !readiness.any_ready() {
128                        // Nothing is ready yet
129                        return Poll::Pending;
130                    } else if !readiness.clear_ready(index) || this.state[index].is_none() {
131                        continue;
132                    }
133
134                    // unlock readiness so we don't deadlock when polling
135                    #[allow(clippy::drop_non_drop)]
136                    drop(readiness);
137
138                    // Obtain the intermediate waker.
139                    let mut cx = Context::from_waker(this.wakers.get(index).unwrap());
140
141                    $(
142                        let stream_index = $mod_name::Indexes::$F as usize;
143                        poll_stream!(
144                            stream_index,
145                            index,
146                            this,
147                            streams . $F,
148                            cx,
149                            LEN
150                        );
151                    )+
152
153                    // Lock readiness so we can use it again
154                    readiness = this.wakers.readiness();
155                }
156
157                Poll::Pending
158            }
159        }
160
161        impl<T, $($F),*> MergeTrait for ($($F,)*)
162        where $(
163            $F: IntoStream<Item = T>,
164        )* {
165            type Item = T;
166            type Stream = $StructName<T, $($F::IntoStream),*>;
167
168            fn merge(self) -> Self::Stream {
169                let ($($F,)*): ($($F,)*) = self;
170                $StructName {
171                    streams: $mod_name::Streams { $($F: $F.into_stream()),+ },
172                    indexer: utils::Indexer::new(utils::tuple_len!($($F,)*)),
173                    wakers: WakerArray::new(),
174                    state: PollArray::new_pending(),
175                    completed: 0,
176                }
177            }
178        }
179    };
180}
181
182impl_merge_tuple! { merge0 Merge0  }
183impl_merge_tuple! { merge1 Merge1  A }
184impl_merge_tuple! { merge2 Merge2  A B }
185impl_merge_tuple! { merge3 Merge3  A B C }
186impl_merge_tuple! { merge4 Merge4  A B C D }
187impl_merge_tuple! { merge5 Merge5  A B C D E }
188impl_merge_tuple! { merge6 Merge6  A B C D E F }
189impl_merge_tuple! { merge7 Merge7  A B C D E F G }
190impl_merge_tuple! { merge8 Merge8  A B C D E F G H }
191impl_merge_tuple! { merge9 Merge9  A B C D E F G H I }
192impl_merge_tuple! { merge10 Merge10 A B C D E F G H I J }
193impl_merge_tuple! { merge11 Merge11 A B C D E F G H I J K }
194impl_merge_tuple! { merge12 Merge12 A B C D E F G H I J K L }
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use futures_lite::future::block_on;
200    use futures_lite::prelude::*;
201    use futures_lite::stream;
202
203    #[test]
204    fn merge_tuple_0() {
205        block_on(async {
206            let mut s = ().merge();
207
208            let mut called = false;
209            while s.next().await.is_some() {
210                called = true;
211            }
212            assert!(!called);
213        })
214    }
215
216    #[test]
217    fn merge_tuple_1() {
218        block_on(async {
219            let a = stream::once(1);
220            let mut s = (a,).merge();
221
222            let mut counter = 0;
223            while let Some(n) = s.next().await {
224                counter += n;
225            }
226            assert_eq!(counter, 1);
227        })
228    }
229
230    #[test]
231    fn merge_tuple_2() {
232        block_on(async {
233            let a = stream::once(1);
234            let b = stream::once(2);
235            let mut s = (a, b).merge();
236
237            let mut counter = 0;
238            while let Some(n) = s.next().await {
239                counter += n;
240            }
241            assert_eq!(counter, 3);
242        })
243    }
244
245    #[test]
246    fn merge_tuple_3() {
247        block_on(async {
248            let a = stream::once(1);
249            let b = stream::once(2);
250            let c = stream::once(3);
251            let mut s = (a, b, c).merge();
252
253            let mut counter = 0;
254            while let Some(n) = s.next().await {
255                counter += n;
256            }
257            assert_eq!(counter, 6);
258        })
259    }
260
261    #[test]
262    fn merge_tuple_4() {
263        block_on(async {
264            let a = stream::once(1);
265            let b = stream::once(2);
266            let c = stream::once(3);
267            let d = stream::once(4);
268            let mut s = (a, b, c, d).merge();
269
270            let mut counter = 0;
271            while let Some(n) = s.next().await {
272                counter += n;
273            }
274            assert_eq!(counter, 10);
275        })
276    }
277
278    /// This test case uses channels so we'll have streams that return Pending from time to time.
279    ///
280    /// The purpose of this test is to make sure we have the waking logic working.
281    #[test]
282    #[cfg(feature = "alloc")]
283    fn merge_channels() {
284        use alloc::rc::Rc;
285        use core::cell::RefCell;
286
287        use futures::executor::LocalPool;
288        use futures::task::LocalSpawnExt;
289
290        use crate::future::Join;
291        use crate::utils::channel::local_channel;
292
293        let mut pool = LocalPool::new();
294
295        let done = Rc::new(RefCell::new(false));
296        let done2 = done.clone();
297
298        pool.spawner()
299            .spawn_local(async move {
300                let (send1, receive1) = local_channel();
301                let (send2, receive2) = local_channel();
302                let (send3, receive3) = local_channel();
303
304                let (count, ()) = (
305                    async {
306                        (receive1, receive2, receive3)
307                            .merge()
308                            .fold(0, |a, b| a + b)
309                            .await
310                    },
311                    async {
312                        for i in 1..=4 {
313                            send1.send(i);
314                            send2.send(i);
315                            send3.send(i);
316                        }
317                        drop(send1);
318                        drop(send2);
319                        drop(send3);
320                    },
321                )
322                    .join()
323                    .await;
324
325                assert_eq!(count, 30);
326
327                *done2.borrow_mut() = true;
328            })
329            .unwrap();
330
331        while !*done.borrow() {
332            pool.run_until_stalled()
333        }
334    }
335}