futures_concurrency/stream/merge/
vec.rs

1use super::Merge as MergeTrait;
2use crate::stream::IntoStream;
3use crate::utils::{self, DynIndexer, PollVec, WakerVec};
4
5#[cfg(all(feature = "alloc", not(feature = "std")))]
6use alloc::vec::Vec;
7
8use core::fmt;
9use core::pin::Pin;
10use core::task::{Context, Poll};
11use futures_core::Stream;
12
13/// A stream that merges multiple streams into a single stream.
14///
15/// This `struct` is created by the [`merge`] method on the [`Merge`] trait. See its
16/// documentation for more.
17///
18/// [`merge`]: trait.Merge.html#method.merge
19/// [`Merge`]: trait.Merge.html
20#[pin_project::pin_project]
21pub struct Merge<S>
22where
23    S: Stream,
24{
25    #[pin]
26    streams: Vec<S>,
27    indexer: DynIndexer,
28    complete: usize,
29    wakers: WakerVec,
30    state: PollVec,
31    done: bool,
32}
33
34impl<S> Merge<S>
35where
36    S: Stream,
37{
38    pub(crate) fn new(streams: Vec<S>) -> Self {
39        let len = streams.len();
40        Self {
41            wakers: WakerVec::new(len),
42            state: PollVec::new_pending(len),
43            indexer: DynIndexer::new(len),
44            streams,
45            complete: 0,
46            done: false,
47        }
48    }
49}
50
51impl<S> fmt::Debug for Merge<S>
52where
53    S: Stream + fmt::Debug,
54{
55    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56        f.debug_list().entries(self.streams.iter()).finish()
57    }
58}
59
60impl<S> Stream for Merge<S>
61where
62    S: Stream,
63{
64    type Item = S::Item;
65
66    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
67        let mut this = self.project();
68
69        // return early if all streams are complete (handles empty vec)
70        if *this.complete == this.streams.len() {
71            return Poll::Ready(None);
72        }
73
74        let mut readiness = this.wakers.readiness();
75        readiness.set_waker(cx.waker());
76
77        // Iterate over our streams one-by-one. If a stream yields a value,
78        // we exit early. By default we'll return `Poll::Ready(None)`, but
79        // this changes if we encounter a `Poll::Pending`.
80        for index in this.indexer.iter() {
81            if !readiness.any_ready() {
82                // Nothing is ready yet
83                return Poll::Pending;
84            } else if !readiness.clear_ready(index) || this.state[index].is_none() {
85                continue;
86            }
87
88            // unlock readiness so we don't deadlock when polling
89            #[allow(clippy::drop_non_drop)]
90            drop(readiness);
91
92            // Obtain the intermediate waker.
93            let mut cx = Context::from_waker(this.wakers.get(index).unwrap());
94
95            let stream = utils::get_pin_mut_from_vec(this.streams.as_mut(), index).unwrap();
96            match stream.poll_next(&mut cx) {
97                Poll::Ready(Some(item)) => {
98                    // Mark ourselves as ready again because we need to poll for the next item.
99                    this.wakers.readiness().set_ready(index);
100                    return Poll::Ready(Some(item));
101                }
102                Poll::Ready(None) => {
103                    *this.complete += 1;
104                    this.state[index].set_none();
105                    if *this.complete == this.streams.len() {
106                        return Poll::Ready(None);
107                    }
108                }
109                Poll::Pending => {}
110            }
111
112            // Lock readiness so we can use it again
113            readiness = this.wakers.readiness();
114        }
115
116        Poll::Pending
117    }
118}
119
120impl<S> MergeTrait for Vec<S>
121where
122    S: IntoStream,
123{
124    type Item = <Merge<S::IntoStream> as Stream>::Item;
125    type Stream = Merge<S::IntoStream>;
126
127    fn merge(self) -> Self::Stream {
128        Merge::new(self.into_iter().map(|i| i.into_stream()).collect())
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use alloc::rc::Rc;
135    use alloc::vec;
136    use core::cell::RefCell;
137
138    use super::*;
139    use crate::utils::channel::local_channel;
140    use futures::executor::LocalPool;
141    use futures::task::LocalSpawnExt;
142    use futures_lite::future::block_on;
143    use futures_lite::prelude::*;
144    use futures_lite::stream;
145
146    use crate::future::join::Join;
147
148    #[test]
149    fn empty_vec() {
150        block_on(async {
151            let streams: Vec<stream::Once<i32>> = vec![];
152            let mut s = streams.merge();
153            let result = s.next().await;
154            assert_eq!(result, None);
155        })
156    }
157
158    #[test]
159    fn merge_vec_4() {
160        block_on(async {
161            let a = stream::once(1);
162            let b = stream::once(2);
163            let c = stream::once(3);
164            let d = stream::once(4);
165            let mut s = vec![a, b, c, d].merge();
166
167            let mut counter = 0;
168            while let Some(n) = s.next().await {
169                counter += n;
170            }
171            assert_eq!(counter, 10);
172        })
173    }
174
175    #[test]
176    fn merge_vec_2x2() {
177        block_on(async {
178            let a = stream::repeat(1).take(2);
179            let b = stream::repeat(2).take(2);
180            let mut s = vec![a, b].merge();
181
182            let mut counter = 0;
183            while let Some(n) = s.next().await {
184                counter += n;
185            }
186            assert_eq!(counter, 6);
187        })
188    }
189
190    /// This test case uses channels so we'll have streams that return Pending from time to time.
191    ///
192    /// The purpose of this test is to make sure we have the waking logic working.
193    #[test]
194    fn merge_channels() {
195        let mut pool = LocalPool::new();
196
197        let done = Rc::new(RefCell::new(false));
198        let done2 = done.clone();
199
200        pool.spawner()
201            .spawn_local(async move {
202                let (send1, receive1) = local_channel();
203                let (send2, receive2) = local_channel();
204                let (send3, receive3) = local_channel();
205
206                let (count, ()) = (
207                    async {
208                        vec![receive1, receive2, receive3]
209                            .merge()
210                            .fold(0, |a, b| a + b)
211                            .await
212                    },
213                    async {
214                        for i in 1..=4 {
215                            send1.send(i);
216                            send2.send(i);
217                            send3.send(i);
218                        }
219                        drop(send1);
220                        drop(send2);
221                        drop(send3);
222                    },
223                )
224                    .join()
225                    .await;
226
227                assert_eq!(count, 30);
228
229                *done2.borrow_mut() = true;
230            })
231            .unwrap();
232
233        while !*done.borrow() {
234            pool.run_until_stalled()
235        }
236    }
237}