futures_buffered/
merge_bounded.rs1use core::{
2 pin::Pin,
3 task::{Context, Poll},
4};
5
6use futures_core::Stream;
7
8use crate::FuturesUnorderedBounded;
9
10#[deprecated = "use `MergeBounded` instead"]
11pub type Merge<S> = MergeBounded<S>;
12
13pub struct MergeBounded<S> {
38 pub(crate) streams: FuturesUnorderedBounded<S>,
39}
40
41impl<S> MergeBounded<S> {
42 #[track_caller]
52 pub fn push(&mut self, stream: S) {
53 if self.try_push(stream).is_err() {
54 panic!("attempted to push into a full `Merge`");
55 }
56 }
57
58 pub fn try_push(&mut self, stream: S) -> Result<(), S> {
68 self.streams.try_push_with(stream, core::convert::identity)
69 }
70}
71
72impl<S: Stream> Stream for MergeBounded<S> {
73 type Item = S::Item;
74
75 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
76 loop {
77 match self.streams.poll_inner_no_remove(cx, S::poll_next) {
78 Poll::Ready(Some((i, Some(x)))) => {
80 unsafe {
82 self.streams.shared.push(i);
83 }
84 break Poll::Ready(Some(x));
85 }
86 Poll::Ready(Some((i, None))) => {
88 self.streams.tasks.remove(i);
89 }
90 Poll::Pending => break Poll::Pending,
91 Poll::Ready(None) => break Poll::Ready(None),
92 }
93 }
94 }
95}
96
97impl<S: Stream> FromIterator<S> for MergeBounded<S> {
98 fn from_iter<T>(iter: T) -> Self
99 where
100 T: IntoIterator<Item = S>,
101 {
102 Self {
103 streams: iter.into_iter().collect(),
104 }
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use core::cell::RefCell;
111 use core::task::Waker;
112
113 use super::*;
114 use alloc::collections::VecDeque;
115 use alloc::rc::Rc;
116 use futures::executor::block_on;
117 use futures::executor::LocalPool;
118 use futures::prelude::*;
119 use futures::stream;
120 use futures::task::LocalSpawnExt;
121
122 #[test]
123 fn merge_tuple_4() {
124 block_on(async {
125 let a = stream::repeat(2).take(2);
126 let b = stream::repeat(3).take(3);
127 let c = stream::repeat(5).take(5);
128 let d = stream::repeat(7).take(7);
129 let mut s: MergeBounded<_> = [a, b, c, d].into_iter().collect();
130
131 let mut counter = 0;
132 while let Some(n) = s.next().await {
133 counter += n;
134 }
135 assert_eq!(counter, 4 + 9 + 25 + 49);
136 });
137 }
138
139 #[test]
143 fn merge_channels() {
144 struct LocalChannel<T> {
145 queue: VecDeque<T>,
146 waker: Option<Waker>,
147 closed: bool,
148 }
149
150 struct LocalReceiver<T> {
151 channel: Rc<RefCell<LocalChannel<T>>>,
152 }
153
154 impl<T> Stream for LocalReceiver<T> {
155 type Item = T;
156
157 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
158 let mut channel = self.channel.borrow_mut();
159
160 match channel.queue.pop_front() {
161 Some(item) => Poll::Ready(Some(item)),
162 None => {
163 if channel.closed {
164 Poll::Ready(None)
165 } else {
166 channel.waker = Some(cx.waker().clone());
167 Poll::Pending
168 }
169 }
170 }
171 }
172 }
173
174 struct LocalSender<T> {
175 channel: Rc<RefCell<LocalChannel<T>>>,
176 }
177
178 impl<T> LocalSender<T> {
179 fn send(&self, item: T) {
180 let mut channel = self.channel.borrow_mut();
181
182 channel.queue.push_back(item);
183
184 let _ = channel.waker.take().map(Waker::wake);
185 }
186 }
187
188 impl<T> Drop for LocalSender<T> {
189 fn drop(&mut self) {
190 let mut channel = self.channel.borrow_mut();
191 channel.closed = true;
192 let _ = channel.waker.take().map(Waker::wake);
193 }
194 }
195
196 fn local_channel<T>() -> (LocalSender<T>, LocalReceiver<T>) {
197 let channel = Rc::new(RefCell::new(LocalChannel {
198 queue: VecDeque::new(),
199 waker: None,
200 closed: false,
201 }));
202
203 (
204 LocalSender {
205 channel: channel.clone(),
206 },
207 LocalReceiver { channel },
208 )
209 }
210
211 let mut pool = LocalPool::new();
212
213 let done = Rc::new(RefCell::new(false));
214 let done2 = done.clone();
215
216 pool.spawner()
217 .spawn_local(async move {
218 let (send1, receive1) = local_channel();
219 let (send2, receive2) = local_channel();
220 let (send3, receive3) = local_channel();
221
222 let (count, ()) = futures::future::join(
223 async {
224 let s: MergeBounded<_> =
225 [receive1, receive2, receive3].into_iter().collect();
226 s.fold(0, |a, b| async move { a + b }).await
227 },
228 async {
229 for i in 1..=4 {
230 send1.send(i);
231 send2.send(i);
232 send3.send(i);
233 }
234 drop(send1);
235 drop(send2);
236 drop(send3);
237 },
238 )
239 .await;
240
241 assert_eq!(count, 30);
242
243 *done2.borrow_mut() = true;
244 })
245 .unwrap();
246
247 while !*done.borrow() {
248 pool.run_until_stalled();
249 }
250 }
251}