futures_concurrency/stream/merge/
array.rs

1use super::Merge as MergeTrait;
2use crate::stream::IntoStream;
3use crate::utils::{self, Indexer, PollArray, WakerArray};
4
5use core::fmt;
6use core::pin::Pin;
7use core::task::{Context, Poll};
8use futures_core::Stream;
9
10/// A stream that merges multiple streams into a single stream.
11///
12/// This `struct` is created by the [`merge`] method on the [`Merge`] trait. See its
13/// documentation for more.
14///
15/// [`merge`]: trait.Merge.html#method.merge
16/// [`Merge`]: trait.Merge.html
17#[pin_project::pin_project]
18pub struct Merge<S, const N: usize>
19where
20    S: Stream,
21{
22    #[pin]
23    streams: [S; N],
24    indexer: Indexer<N>,
25    wakers: WakerArray<N>,
26    state: PollArray<N>,
27    complete: usize,
28    done: bool,
29}
30
31impl<S, const N: usize> Merge<S, N>
32where
33    S: Stream,
34{
35    pub(crate) fn new(streams: [S; N]) -> Self {
36        Self {
37            streams,
38            indexer: Indexer::new(),
39            wakers: WakerArray::new(),
40            state: PollArray::new_pending(),
41            complete: 0,
42            done: false,
43        }
44    }
45}
46
47impl<S, const N: usize> fmt::Debug for Merge<S, N>
48where
49    S: Stream + fmt::Debug,
50{
51    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52        f.debug_list().entries(self.streams.iter()).finish()
53    }
54}
55
56impl<S, const N: usize> Stream for Merge<S, N>
57where
58    S: Stream,
59{
60    type Item = S::Item;
61
62    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
63        let mut this = self.project();
64
65        // return early if all streams are complete (handles empty array)
66        if *this.complete == N {
67            return Poll::Ready(None);
68        }
69
70        let mut readiness = this.wakers.readiness();
71        readiness.set_waker(cx.waker());
72
73        // Iterate over our streams one-by-one. If a stream yields a value,
74        // we exit early. By default we'll return `Poll::Ready(None)`, but
75        // this changes if we encounter a `Poll::Pending`.
76        for index in this.indexer.iter() {
77            if !readiness.any_ready() {
78                // Nothing is ready yet
79                return Poll::Pending;
80            } else if !readiness.clear_ready(index) || this.state[index].is_none() {
81                continue;
82            }
83
84            // unlock readiness so we don't deadlock when polling
85            #[allow(clippy::drop_non_drop)]
86            drop(readiness);
87
88            // Obtain the intermediate waker.
89            let mut cx = Context::from_waker(this.wakers.get(index).unwrap());
90
91            let stream = utils::get_pin_mut(this.streams.as_mut(), index).unwrap();
92            match stream.poll_next(&mut cx) {
93                Poll::Ready(Some(item)) => {
94                    // Mark ourselves as ready again because we need to poll for the next item.
95                    this.wakers.readiness().set_ready(index);
96                    return Poll::Ready(Some(item));
97                }
98                Poll::Ready(None) => {
99                    *this.complete += 1;
100                    this.state[index].set_none();
101                    if *this.complete == this.streams.len() {
102                        return Poll::Ready(None);
103                    }
104                }
105                Poll::Pending => {}
106            }
107
108            // Lock readiness so we can use it again
109            readiness = this.wakers.readiness();
110        }
111
112        Poll::Pending
113    }
114}
115
116impl<S, const N: usize> MergeTrait for [S; N]
117where
118    S: IntoStream,
119{
120    type Item = <Merge<S::IntoStream, N> as Stream>::Item;
121    type Stream = Merge<S::IntoStream, N>;
122
123    fn merge(self) -> Self::Stream {
124        Merge::new(self.map(|i| i.into_stream()))
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131    use futures_lite::future::block_on;
132    use futures_lite::prelude::*;
133    use futures_lite::stream;
134
135    #[test]
136    fn empty_array() {
137        block_on(async {
138            let streams: [stream::Once<i32>; 0] = [];
139            let mut s = streams.merge();
140            let result = s.next().await;
141            assert_eq!(result, None);
142        })
143    }
144
145    #[test]
146    fn merge_array_4() {
147        block_on(async {
148            let a = stream::once(1);
149            let b = stream::once(2);
150            let c = stream::once(3);
151            let d = stream::once(4);
152            let mut s = [a, b, c, d].merge();
153
154            let mut counter = 0;
155            while let Some(n) = s.next().await {
156                counter += n;
157            }
158            assert_eq!(counter, 10);
159        })
160    }
161
162    #[test]
163    fn merge_array_2x2() {
164        block_on(async {
165            let a = stream::repeat(1).take(2);
166            let b = stream::repeat(2).take(2);
167            let mut s = [a, b].merge();
168
169            let mut counter = 0;
170            while let Some(n) = s.next().await {
171                counter += n;
172            }
173            assert_eq!(counter, 6);
174        })
175    }
176
177    /// This test case uses channels so we'll have streams that return Pending from time to time.
178    ///
179    /// The purpose of this test is to make sure we have the waking logic working.
180    #[test]
181    #[cfg(feature = "alloc")]
182    fn merge_channels() {
183        use alloc::rc::Rc;
184        use core::cell::RefCell;
185        use futures::executor::LocalPool;
186        use futures::task::LocalSpawnExt;
187
188        use crate::future::join::Join;
189        use crate::utils::channel::local_channel;
190
191        let mut pool = LocalPool::new();
192
193        let done = Rc::new(RefCell::new(false));
194        let done2 = done.clone();
195
196        pool.spawner()
197            .spawn_local(async move {
198                let (send1, receive1) = local_channel();
199                let (send2, receive2) = local_channel();
200                let (send3, receive3) = local_channel();
201
202                let (count, ()) = (
203                    async {
204                        [receive1, receive2, receive3]
205                            .merge()
206                            .fold(0, |a, b| a + b)
207                            .await
208                    },
209                    async {
210                        for i in 1..=4 {
211                            send1.send(i);
212                            send2.send(i);
213                            send3.send(i);
214                        }
215                        drop(send1);
216                        drop(send2);
217                        drop(send3);
218                    },
219                )
220                    .join()
221                    .await;
222
223                assert_eq!(count, 30);
224
225                *done2.borrow_mut() = true;
226            })
227            .unwrap();
228
229        while !*done.borrow() {
230            pool.run_until_stalled()
231        }
232    }
233}