futures_concurrency/stream/merge/
vec.rs1use super::Merge as MergeTrait;
2use crate::stream::IntoStream;
3use crate::utils::{self, DynIndexer, PollVec, WakerVec};
4
5#[cfg(all(feature = "alloc", not(feature = "std")))]
6use alloc::vec::Vec;
7
8use core::fmt;
9use core::pin::Pin;
10use core::task::{Context, Poll};
11use futures_core::Stream;
12
13#[pin_project::pin_project]
21pub struct Merge<S>
22where
23 S: Stream,
24{
25 #[pin]
26 streams: Vec<S>,
27 indexer: DynIndexer,
28 complete: usize,
29 wakers: WakerVec,
30 state: PollVec,
31 done: bool,
32}
33
34impl<S> Merge<S>
35where
36 S: Stream,
37{
38 pub(crate) fn new(streams: Vec<S>) -> Self {
39 let len = streams.len();
40 Self {
41 wakers: WakerVec::new(len),
42 state: PollVec::new_pending(len),
43 indexer: DynIndexer::new(len),
44 streams,
45 complete: 0,
46 done: false,
47 }
48 }
49}
50
51impl<S> fmt::Debug for Merge<S>
52where
53 S: Stream + fmt::Debug,
54{
55 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56 f.debug_list().entries(self.streams.iter()).finish()
57 }
58}
59
60impl<S> Stream for Merge<S>
61where
62 S: Stream,
63{
64 type Item = S::Item;
65
66 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
67 let mut this = self.project();
68
69 if *this.complete == this.streams.len() {
71 return Poll::Ready(None);
72 }
73
74 let mut readiness = this.wakers.readiness();
75 readiness.set_waker(cx.waker());
76
77 for index in this.indexer.iter() {
81 if !readiness.any_ready() {
82 return Poll::Pending;
84 } else if !readiness.clear_ready(index) || this.state[index].is_none() {
85 continue;
86 }
87
88 #[allow(clippy::drop_non_drop)]
90 drop(readiness);
91
92 let mut cx = Context::from_waker(this.wakers.get(index).unwrap());
94
95 let stream = utils::get_pin_mut_from_vec(this.streams.as_mut(), index).unwrap();
96 match stream.poll_next(&mut cx) {
97 Poll::Ready(Some(item)) => {
98 this.wakers.readiness().set_ready(index);
100 return Poll::Ready(Some(item));
101 }
102 Poll::Ready(None) => {
103 *this.complete += 1;
104 this.state[index].set_none();
105 if *this.complete == this.streams.len() {
106 return Poll::Ready(None);
107 }
108 }
109 Poll::Pending => {}
110 }
111
112 readiness = this.wakers.readiness();
114 }
115
116 Poll::Pending
117 }
118}
119
120impl<S> MergeTrait for Vec<S>
121where
122 S: IntoStream,
123{
124 type Item = <Merge<S::IntoStream> as Stream>::Item;
125 type Stream = Merge<S::IntoStream>;
126
127 fn merge(self) -> Self::Stream {
128 Merge::new(self.into_iter().map(|i| i.into_stream()).collect())
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use alloc::rc::Rc;
135 use alloc::vec;
136 use core::cell::RefCell;
137
138 use super::*;
139 use crate::utils::channel::local_channel;
140 use futures::executor::LocalPool;
141 use futures::task::LocalSpawnExt;
142 use futures_lite::future::block_on;
143 use futures_lite::prelude::*;
144 use futures_lite::stream;
145
146 use crate::future::join::Join;
147
148 #[test]
149 fn empty_vec() {
150 block_on(async {
151 let streams: Vec<stream::Once<i32>> = vec![];
152 let mut s = streams.merge();
153 let result = s.next().await;
154 assert_eq!(result, None);
155 })
156 }
157
158 #[test]
159 fn merge_vec_4() {
160 block_on(async {
161 let a = stream::once(1);
162 let b = stream::once(2);
163 let c = stream::once(3);
164 let d = stream::once(4);
165 let mut s = vec![a, b, c, d].merge();
166
167 let mut counter = 0;
168 while let Some(n) = s.next().await {
169 counter += n;
170 }
171 assert_eq!(counter, 10);
172 })
173 }
174
175 #[test]
176 fn merge_vec_2x2() {
177 block_on(async {
178 let a = stream::repeat(1).take(2);
179 let b = stream::repeat(2).take(2);
180 let mut s = vec![a, b].merge();
181
182 let mut counter = 0;
183 while let Some(n) = s.next().await {
184 counter += n;
185 }
186 assert_eq!(counter, 6);
187 })
188 }
189
190 #[test]
194 fn merge_channels() {
195 let mut pool = LocalPool::new();
196
197 let done = Rc::new(RefCell::new(false));
198 let done2 = done.clone();
199
200 pool.spawner()
201 .spawn_local(async move {
202 let (send1, receive1) = local_channel();
203 let (send2, receive2) = local_channel();
204 let (send3, receive3) = local_channel();
205
206 let (count, ()) = (
207 async {
208 vec![receive1, receive2, receive3]
209 .merge()
210 .fold(0, |a, b| a + b)
211 .await
212 },
213 async {
214 for i in 1..=4 {
215 send1.send(i);
216 send2.send(i);
217 send3.send(i);
218 }
219 drop(send1);
220 drop(send2);
221 drop(send3);
222 },
223 )
224 .join()
225 .await;
226
227 assert_eq!(count, 30);
228
229 *done2.borrow_mut() = true;
230 })
231 .unwrap();
232
233 while !*done.borrow() {
234 pool.run_until_stalled()
235 }
236 }
237}