futures_concurrency/stream/merge/
vec.rs

1use super::Merge as MergeTrait;
2use crate::stream::IntoStream;
3use crate::utils::{self, Indexer, PollVec, WakerVec};
4
5#[cfg(all(feature = "alloc", not(feature = "std")))]
6use alloc::vec::Vec;
7
8use core::fmt;
9use core::pin::Pin;
10use core::task::{Context, Poll};
11use futures_core::Stream;
12
13/// A stream that merges multiple streams into a single stream.
14///
15/// This `struct` is created by the [`merge`] method on the [`Merge`] trait. See its
16/// documentation for more.
17///
18/// [`merge`]: trait.Merge.html#method.merge
19/// [`Merge`]: trait.Merge.html
20#[pin_project::pin_project]
21pub struct Merge<S>
22where
23    S: Stream,
24{
25    #[pin]
26    streams: Vec<S>,
27    indexer: Indexer,
28    complete: usize,
29    wakers: WakerVec,
30    state: PollVec,
31    done: bool,
32}
33
34impl<S> Merge<S>
35where
36    S: Stream,
37{
38    pub(crate) fn new(streams: Vec<S>) -> Self {
39        let len = streams.len();
40        Self {
41            wakers: WakerVec::new(len),
42            state: PollVec::new_pending(len),
43            indexer: Indexer::new(len),
44            streams,
45            complete: 0,
46            done: false,
47        }
48    }
49}
50
51impl<S> fmt::Debug for Merge<S>
52where
53    S: Stream + fmt::Debug,
54{
55    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56        f.debug_list().entries(self.streams.iter()).finish()
57    }
58}
59
60impl<S> Stream for Merge<S>
61where
62    S: Stream,
63{
64    type Item = S::Item;
65
66    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
67        let mut this = self.project();
68
69        let mut readiness = this.wakers.readiness();
70        readiness.set_waker(cx.waker());
71
72        // Iterate over our streams one-by-one. If a stream yields a value,
73        // we exit early. By default we'll return `Poll::Ready(None)`, but
74        // this changes if we encounter a `Poll::Pending`.
75        for index in this.indexer.iter() {
76            if !readiness.any_ready() {
77                // Nothing is ready yet
78                return Poll::Pending;
79            } else if !readiness.clear_ready(index) || this.state[index].is_none() {
80                continue;
81            }
82
83            // unlock readiness so we don't deadlock when polling
84            #[allow(clippy::drop_non_drop)]
85            drop(readiness);
86
87            // Obtain the intermediate waker.
88            let mut cx = Context::from_waker(this.wakers.get(index).unwrap());
89
90            let stream = utils::get_pin_mut_from_vec(this.streams.as_mut(), index).unwrap();
91            match stream.poll_next(&mut cx) {
92                Poll::Ready(Some(item)) => {
93                    // Mark ourselves as ready again because we need to poll for the next item.
94                    this.wakers.readiness().set_ready(index);
95                    return Poll::Ready(Some(item));
96                }
97                Poll::Ready(None) => {
98                    *this.complete += 1;
99                    this.state[index].set_none();
100                    if *this.complete == this.streams.len() {
101                        return Poll::Ready(None);
102                    }
103                }
104                Poll::Pending => {}
105            }
106
107            // Lock readiness so we can use it again
108            readiness = this.wakers.readiness();
109        }
110
111        Poll::Pending
112    }
113}
114
115impl<S> MergeTrait for Vec<S>
116where
117    S: IntoStream,
118{
119    type Item = <Merge<S::IntoStream> as Stream>::Item;
120    type Stream = Merge<S::IntoStream>;
121
122    fn merge(self) -> Self::Stream {
123        Merge::new(self.into_iter().map(|i| i.into_stream()).collect())
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use alloc::rc::Rc;
130    use alloc::vec;
131    use core::cell::RefCell;
132
133    use super::*;
134    use crate::utils::channel::local_channel;
135    use futures::executor::LocalPool;
136    use futures::task::LocalSpawnExt;
137    use futures_lite::future::block_on;
138    use futures_lite::prelude::*;
139    use futures_lite::stream;
140
141    use crate::future::join::Join;
142
143    #[test]
144    fn merge_vec_4() {
145        block_on(async {
146            let a = stream::once(1);
147            let b = stream::once(2);
148            let c = stream::once(3);
149            let d = stream::once(4);
150            let mut s = vec![a, b, c, d].merge();
151
152            let mut counter = 0;
153            while let Some(n) = s.next().await {
154                counter += n;
155            }
156            assert_eq!(counter, 10);
157        })
158    }
159
160    #[test]
161    fn merge_vec_2x2() {
162        block_on(async {
163            let a = stream::repeat(1).take(2);
164            let b = stream::repeat(2).take(2);
165            let mut s = vec![a, b].merge();
166
167            let mut counter = 0;
168            while let Some(n) = s.next().await {
169                counter += n;
170            }
171            assert_eq!(counter, 6);
172        })
173    }
174
175    /// This test case uses channels so we'll have streams that return Pending from time to time.
176    ///
177    /// The purpose of this test is to make sure we have the waking logic working.
178    #[test]
179    fn merge_channels() {
180        let mut pool = LocalPool::new();
181
182        let done = Rc::new(RefCell::new(false));
183        let done2 = done.clone();
184
185        pool.spawner()
186            .spawn_local(async move {
187                let (send1, receive1) = local_channel();
188                let (send2, receive2) = local_channel();
189                let (send3, receive3) = local_channel();
190
191                let (count, ()) = (
192                    async {
193                        vec![receive1, receive2, receive3]
194                            .merge()
195                            .fold(0, |a, b| a + b)
196                            .await
197                    },
198                    async {
199                        for i in 1..=4 {
200                            send1.send(i);
201                            send2.send(i);
202                            send3.send(i);
203                        }
204                        drop(send1);
205                        drop(send2);
206                        drop(send3);
207                    },
208                )
209                    .join()
210                    .await;
211
212                assert_eq!(count, 30);
213
214                *done2.borrow_mut() = true;
215            })
216            .unwrap();
217
218        while !*done.borrow() {
219            pool.run_until_stalled()
220        }
221    }
222}