futures_concurrency/stream/merge/
tuple.rs1use super::Merge as MergeTrait;
2use crate::stream::IntoStream;
3use crate::utils::{self, PollArray, WakerArray};
4
5use core::fmt;
6use core::pin::Pin;
7use core::task::{Context, Poll};
8use futures_core::Stream;
9
10macro_rules! poll_stream {
11 ($stream_idx:tt, $iteration:ident, $this:ident, $streams:ident . $stream_member:ident, $cx:ident, $len_streams:ident) => {
12 if $stream_idx == $iteration {
13 match unsafe { Pin::new_unchecked(&mut $streams.$stream_member) }.poll_next(&mut $cx) {
14 Poll::Ready(Some(item)) => {
15 $this.wakers.readiness().set_ready($stream_idx);
17 return Poll::Ready(Some(item));
18 }
19 Poll::Ready(None) => {
20 *$this.completed += 1;
21 $this.state[$stream_idx].set_none();
22 if *$this.completed == $len_streams {
23 return Poll::Ready(None);
24 }
25 }
26 Poll::Pending => {}
27 }
28 }
29 };
30}
31
32macro_rules! impl_merge_tuple {
33 ($ignore:ident $StructName:ident) => {
34 pub struct $StructName {}
42
43 impl fmt::Debug for $StructName {
44 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45 f.debug_tuple("Merge").finish()
46 }
47 }
48
49 impl Stream for $StructName {
50 type Item = core::convert::Infallible; fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
53 Poll::Ready(None)
54 }
55 }
56
57 impl MergeTrait for () {
58 type Item = core::convert::Infallible; type Stream = $StructName;
60
61 fn merge(self) -> Self::Stream {
62 $StructName { }
63 }
64 }
65 };
66 ($mod_name:ident $StructName:ident $($F:ident)+) => {
67 mod $mod_name {
68 #[pin_project::pin_project]
69 pub(super) struct Streams<$($F,)+> { $(#[pin] pub(super) $F: $F),+ }
70
71 #[repr(usize)]
72 pub(super) enum Indexes { $($F),+ }
73
74 pub(super) const LEN: usize = [$(Indexes::$F),+].len();
75 }
76
77 #[pin_project::pin_project]
85 pub struct $StructName<T, $($F),*>
86 where $(
87 $F: Stream<Item = T>,
88 )* {
89 #[pin] streams: $mod_name::Streams<$($F,)+>,
90 indexer: utils::Indexer,
91 wakers: WakerArray<{$mod_name::LEN}>,
92 state: PollArray<{$mod_name::LEN}>,
93 completed: u8,
94 }
95
96 impl<T, $($F),*> fmt::Debug for $StructName<T, $($F),*>
97 where
98 $( $F: Stream<Item = T> + fmt::Debug, )*
99 {
100 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
101 f.debug_tuple("Merge")
102 $( .field(&self.streams.$F) )* .finish()
104 }
105 }
106
107 impl<T, $($F),*> Stream for $StructName<T, $($F),*>
108 where $(
109 $F: Stream<Item = T>,
110 )* {
111 type Item = T;
112
113 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
114 let this = self.project();
115
116 let mut readiness = this.wakers.readiness();
117 readiness.set_waker(cx.waker());
118
119 const LEN: u8 = $mod_name::LEN as u8;
120
121 let mut streams = this.streams.project();
122
123 for index in this.indexer.iter() {
127 if !readiness.any_ready() {
128 return Poll::Pending;
130 } else if !readiness.clear_ready(index) || this.state[index].is_none() {
131 continue;
132 }
133
134 #[allow(clippy::drop_non_drop)]
136 drop(readiness);
137
138 let mut cx = Context::from_waker(this.wakers.get(index).unwrap());
140
141 $(
142 let stream_index = $mod_name::Indexes::$F as usize;
143 poll_stream!(
144 stream_index,
145 index,
146 this,
147 streams . $F,
148 cx,
149 LEN
150 );
151 )+
152
153 readiness = this.wakers.readiness();
155 }
156
157 Poll::Pending
158 }
159 }
160
161 impl<T, $($F),*> MergeTrait for ($($F,)*)
162 where $(
163 $F: IntoStream<Item = T>,
164 )* {
165 type Item = T;
166 type Stream = $StructName<T, $($F::IntoStream),*>;
167
168 fn merge(self) -> Self::Stream {
169 let ($($F,)*): ($($F,)*) = self;
170 $StructName {
171 streams: $mod_name::Streams { $($F: $F.into_stream()),+ },
172 indexer: utils::Indexer::new(utils::tuple_len!($($F,)*)),
173 wakers: WakerArray::new(),
174 state: PollArray::new_pending(),
175 completed: 0,
176 }
177 }
178 }
179 };
180}
181
182impl_merge_tuple! { merge0 Merge0 }
183impl_merge_tuple! { merge1 Merge1 A }
184impl_merge_tuple! { merge2 Merge2 A B }
185impl_merge_tuple! { merge3 Merge3 A B C }
186impl_merge_tuple! { merge4 Merge4 A B C D }
187impl_merge_tuple! { merge5 Merge5 A B C D E }
188impl_merge_tuple! { merge6 Merge6 A B C D E F }
189impl_merge_tuple! { merge7 Merge7 A B C D E F G }
190impl_merge_tuple! { merge8 Merge8 A B C D E F G H }
191impl_merge_tuple! { merge9 Merge9 A B C D E F G H I }
192impl_merge_tuple! { merge10 Merge10 A B C D E F G H I J }
193impl_merge_tuple! { merge11 Merge11 A B C D E F G H I J K }
194impl_merge_tuple! { merge12 Merge12 A B C D E F G H I J K L }
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use futures_lite::future::block_on;
200 use futures_lite::prelude::*;
201 use futures_lite::stream;
202
203 #[test]
204 fn merge_tuple_0() {
205 block_on(async {
206 let mut s = ().merge();
207
208 let mut called = false;
209 while s.next().await.is_some() {
210 called = true;
211 }
212 assert!(!called);
213 })
214 }
215
216 #[test]
217 fn merge_tuple_1() {
218 block_on(async {
219 let a = stream::once(1);
220 let mut s = (a,).merge();
221
222 let mut counter = 0;
223 while let Some(n) = s.next().await {
224 counter += n;
225 }
226 assert_eq!(counter, 1);
227 })
228 }
229
230 #[test]
231 fn merge_tuple_2() {
232 block_on(async {
233 let a = stream::once(1);
234 let b = stream::once(2);
235 let mut s = (a, b).merge();
236
237 let mut counter = 0;
238 while let Some(n) = s.next().await {
239 counter += n;
240 }
241 assert_eq!(counter, 3);
242 })
243 }
244
245 #[test]
246 fn merge_tuple_3() {
247 block_on(async {
248 let a = stream::once(1);
249 let b = stream::once(2);
250 let c = stream::once(3);
251 let mut s = (a, b, c).merge();
252
253 let mut counter = 0;
254 while let Some(n) = s.next().await {
255 counter += n;
256 }
257 assert_eq!(counter, 6);
258 })
259 }
260
261 #[test]
262 fn merge_tuple_4() {
263 block_on(async {
264 let a = stream::once(1);
265 let b = stream::once(2);
266 let c = stream::once(3);
267 let d = stream::once(4);
268 let mut s = (a, b, c, d).merge();
269
270 let mut counter = 0;
271 while let Some(n) = s.next().await {
272 counter += n;
273 }
274 assert_eq!(counter, 10);
275 })
276 }
277
278 #[test]
282 #[cfg(feature = "alloc")]
283 fn merge_channels() {
284 use alloc::rc::Rc;
285 use core::cell::RefCell;
286
287 use futures::executor::LocalPool;
288 use futures::task::LocalSpawnExt;
289
290 use crate::future::Join;
291 use crate::utils::channel::local_channel;
292
293 let mut pool = LocalPool::new();
294
295 let done = Rc::new(RefCell::new(false));
296 let done2 = done.clone();
297
298 pool.spawner()
299 .spawn_local(async move {
300 let (send1, receive1) = local_channel();
301 let (send2, receive2) = local_channel();
302 let (send3, receive3) = local_channel();
303
304 let (count, ()) = (
305 async {
306 (receive1, receive2, receive3)
307 .merge()
308 .fold(0, |a, b| a + b)
309 .await
310 },
311 async {
312 for i in 1..=4 {
313 send1.send(i);
314 send2.send(i);
315 send3.send(i);
316 }
317 drop(send1);
318 drop(send2);
319 drop(send3);
320 },
321 )
322 .join()
323 .await;
324
325 assert_eq!(count, 30);
326
327 *done2.borrow_mut() = true;
328 })
329 .unwrap();
330
331 while !*done.borrow() {
332 pool.run_until_stalled()
333 }
334 }
335}