futures_concurrency/stream/merge/
array.rs1use super::Merge as MergeTrait;
2use crate::stream::IntoStream;
3use crate::utils::{self, Indexer, PollArray, WakerArray};
4
5use core::fmt;
6use core::pin::Pin;
7use core::task::{Context, Poll};
8use futures_core::Stream;
9
10#[pin_project::pin_project]
18pub struct Merge<S, const N: usize>
19where
20 S: Stream,
21{
22 #[pin]
23 streams: [S; N],
24 indexer: Indexer<N>,
25 wakers: WakerArray<N>,
26 state: PollArray<N>,
27 complete: usize,
28 done: bool,
29}
30
31impl<S, const N: usize> Merge<S, N>
32where
33 S: Stream,
34{
35 pub(crate) fn new(streams: [S; N]) -> Self {
36 Self {
37 streams,
38 indexer: Indexer::new(),
39 wakers: WakerArray::new(),
40 state: PollArray::new_pending(),
41 complete: 0,
42 done: false,
43 }
44 }
45}
46
47impl<S, const N: usize> fmt::Debug for Merge<S, N>
48where
49 S: Stream + fmt::Debug,
50{
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 f.debug_list().entries(self.streams.iter()).finish()
53 }
54}
55
56impl<S, const N: usize> Stream for Merge<S, N>
57where
58 S: Stream,
59{
60 type Item = S::Item;
61
62 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
63 let mut this = self.project();
64
65 if *this.complete == N {
67 return Poll::Ready(None);
68 }
69
70 let mut readiness = this.wakers.readiness();
71 readiness.set_waker(cx.waker());
72
73 for index in this.indexer.iter() {
77 if !readiness.any_ready() {
78 return Poll::Pending;
80 } else if !readiness.clear_ready(index) || this.state[index].is_none() {
81 continue;
82 }
83
84 #[allow(clippy::drop_non_drop)]
86 drop(readiness);
87
88 let mut cx = Context::from_waker(this.wakers.get(index).unwrap());
90
91 let stream = utils::get_pin_mut(this.streams.as_mut(), index).unwrap();
92 match stream.poll_next(&mut cx) {
93 Poll::Ready(Some(item)) => {
94 this.wakers.readiness().set_ready(index);
96 return Poll::Ready(Some(item));
97 }
98 Poll::Ready(None) => {
99 *this.complete += 1;
100 this.state[index].set_none();
101 if *this.complete == this.streams.len() {
102 return Poll::Ready(None);
103 }
104 }
105 Poll::Pending => {}
106 }
107
108 readiness = this.wakers.readiness();
110 }
111
112 Poll::Pending
113 }
114}
115
116impl<S, const N: usize> MergeTrait for [S; N]
117where
118 S: IntoStream,
119{
120 type Item = <Merge<S::IntoStream, N> as Stream>::Item;
121 type Stream = Merge<S::IntoStream, N>;
122
123 fn merge(self) -> Self::Stream {
124 Merge::new(self.map(|i| i.into_stream()))
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131 use futures_lite::future::block_on;
132 use futures_lite::prelude::*;
133 use futures_lite::stream;
134
135 #[test]
136 fn empty_array() {
137 block_on(async {
138 let streams: [stream::Once<i32>; 0] = [];
139 let mut s = streams.merge();
140 let result = s.next().await;
141 assert_eq!(result, None);
142 })
143 }
144
145 #[test]
146 fn merge_array_4() {
147 block_on(async {
148 let a = stream::once(1);
149 let b = stream::once(2);
150 let c = stream::once(3);
151 let d = stream::once(4);
152 let mut s = [a, b, c, d].merge();
153
154 let mut counter = 0;
155 while let Some(n) = s.next().await {
156 counter += n;
157 }
158 assert_eq!(counter, 10);
159 })
160 }
161
162 #[test]
163 fn merge_array_2x2() {
164 block_on(async {
165 let a = stream::repeat(1).take(2);
166 let b = stream::repeat(2).take(2);
167 let mut s = [a, b].merge();
168
169 let mut counter = 0;
170 while let Some(n) = s.next().await {
171 counter += n;
172 }
173 assert_eq!(counter, 6);
174 })
175 }
176
177 #[test]
181 #[cfg(feature = "alloc")]
182 fn merge_channels() {
183 use alloc::rc::Rc;
184 use core::cell::RefCell;
185 use futures::executor::LocalPool;
186 use futures::task::LocalSpawnExt;
187
188 use crate::future::join::Join;
189 use crate::utils::channel::local_channel;
190
191 let mut pool = LocalPool::new();
192
193 let done = Rc::new(RefCell::new(false));
194 let done2 = done.clone();
195
196 pool.spawner()
197 .spawn_local(async move {
198 let (send1, receive1) = local_channel();
199 let (send2, receive2) = local_channel();
200 let (send3, receive3) = local_channel();
201
202 let (count, ()) = (
203 async {
204 [receive1, receive2, receive3]
205 .merge()
206 .fold(0, |a, b| a + b)
207 .await
208 },
209 async {
210 for i in 1..=4 {
211 send1.send(i);
212 send2.send(i);
213 send3.send(i);
214 }
215 drop(send1);
216 drop(send2);
217 drop(send3);
218 },
219 )
220 .join()
221 .await;
222
223 assert_eq!(count, 30);
224
225 *done2.borrow_mut() = true;
226 })
227 .unwrap();
228
229 while !*done.borrow() {
230 pool.run_until_stalled()
231 }
232 }
233}