futures_concurrency/stream/merge/
vec.rs1use super::Merge as MergeTrait;
2use crate::stream::IntoStream;
3use crate::utils::{self, Indexer, PollVec, WakerVec};
4
5#[cfg(all(feature = "alloc", not(feature = "std")))]
6use alloc::vec::Vec;
7
8use core::fmt;
9use core::pin::Pin;
10use core::task::{Context, Poll};
11use futures_core::Stream;
12
13#[pin_project::pin_project]
21pub struct Merge<S>
22where
23 S: Stream,
24{
25 #[pin]
26 streams: Vec<S>,
27 indexer: Indexer,
28 complete: usize,
29 wakers: WakerVec,
30 state: PollVec,
31 done: bool,
32}
33
34impl<S> Merge<S>
35where
36 S: Stream,
37{
38 pub(crate) fn new(streams: Vec<S>) -> Self {
39 let len = streams.len();
40 Self {
41 wakers: WakerVec::new(len),
42 state: PollVec::new_pending(len),
43 indexer: Indexer::new(len),
44 streams,
45 complete: 0,
46 done: false,
47 }
48 }
49}
50
51impl<S> fmt::Debug for Merge<S>
52where
53 S: Stream + fmt::Debug,
54{
55 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56 f.debug_list().entries(self.streams.iter()).finish()
57 }
58}
59
60impl<S> Stream for Merge<S>
61where
62 S: Stream,
63{
64 type Item = S::Item;
65
66 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
67 let mut this = self.project();
68
69 let mut readiness = this.wakers.readiness();
70 readiness.set_waker(cx.waker());
71
72 for index in this.indexer.iter() {
76 if !readiness.any_ready() {
77 return Poll::Pending;
79 } else if !readiness.clear_ready(index) || this.state[index].is_none() {
80 continue;
81 }
82
83 #[allow(clippy::drop_non_drop)]
85 drop(readiness);
86
87 let mut cx = Context::from_waker(this.wakers.get(index).unwrap());
89
90 let stream = utils::get_pin_mut_from_vec(this.streams.as_mut(), index).unwrap();
91 match stream.poll_next(&mut cx) {
92 Poll::Ready(Some(item)) => {
93 this.wakers.readiness().set_ready(index);
95 return Poll::Ready(Some(item));
96 }
97 Poll::Ready(None) => {
98 *this.complete += 1;
99 this.state[index].set_none();
100 if *this.complete == this.streams.len() {
101 return Poll::Ready(None);
102 }
103 }
104 Poll::Pending => {}
105 }
106
107 readiness = this.wakers.readiness();
109 }
110
111 Poll::Pending
112 }
113}
114
115impl<S> MergeTrait for Vec<S>
116where
117 S: IntoStream,
118{
119 type Item = <Merge<S::IntoStream> as Stream>::Item;
120 type Stream = Merge<S::IntoStream>;
121
122 fn merge(self) -> Self::Stream {
123 Merge::new(self.into_iter().map(|i| i.into_stream()).collect())
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use alloc::rc::Rc;
130 use alloc::vec;
131 use core::cell::RefCell;
132
133 use super::*;
134 use crate::utils::channel::local_channel;
135 use futures::executor::LocalPool;
136 use futures::task::LocalSpawnExt;
137 use futures_lite::future::block_on;
138 use futures_lite::prelude::*;
139 use futures_lite::stream;
140
141 use crate::future::join::Join;
142
143 #[test]
144 fn merge_vec_4() {
145 block_on(async {
146 let a = stream::once(1);
147 let b = stream::once(2);
148 let c = stream::once(3);
149 let d = stream::once(4);
150 let mut s = vec![a, b, c, d].merge();
151
152 let mut counter = 0;
153 while let Some(n) = s.next().await {
154 counter += n;
155 }
156 assert_eq!(counter, 10);
157 })
158 }
159
160 #[test]
161 fn merge_vec_2x2() {
162 block_on(async {
163 let a = stream::repeat(1).take(2);
164 let b = stream::repeat(2).take(2);
165 let mut s = vec![a, b].merge();
166
167 let mut counter = 0;
168 while let Some(n) = s.next().await {
169 counter += n;
170 }
171 assert_eq!(counter, 6);
172 })
173 }
174
175 #[test]
179 fn merge_channels() {
180 let mut pool = LocalPool::new();
181
182 let done = Rc::new(RefCell::new(false));
183 let done2 = done.clone();
184
185 pool.spawner()
186 .spawn_local(async move {
187 let (send1, receive1) = local_channel();
188 let (send2, receive2) = local_channel();
189 let (send3, receive3) = local_channel();
190
191 let (count, ()) = (
192 async {
193 vec![receive1, receive2, receive3]
194 .merge()
195 .fold(0, |a, b| a + b)
196 .await
197 },
198 async {
199 for i in 1..=4 {
200 send1.send(i);
201 send2.send(i);
202 send3.send(i);
203 }
204 drop(send1);
205 drop(send2);
206 drop(send3);
207 },
208 )
209 .join()
210 .await;
211
212 assert_eq!(count, 30);
213
214 *done2.borrow_mut() = true;
215 })
216 .unwrap();
217
218 while !*done.borrow() {
219 pool.run_until_stalled()
220 }
221 }
222}