futures_buffered/
merge_unbounded.rs1use alloc::vec::Vec;
2use core::{
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use futures_core::Stream;
8
9use crate::{futures_unordered::MIN_CAPACITY, FuturesUnorderedBounded, MergeBounded};
10
11pub struct MergeUnbounded<S> {
40 pub(crate) groups: Vec<MergeBounded<S>>,
41 poll_next: usize,
42}
43
44impl<S> Default for MergeUnbounded<S> {
45 fn default() -> Self {
46 Self::new()
47 }
48}
49
50impl<S> MergeUnbounded<S> {
51 pub const fn new() -> Self {
56 Self {
57 groups: Vec::new(),
58 poll_next: 0,
59 }
60 }
61
62 #[track_caller]
69 pub fn push(&mut self, stream: S) {
70 let last = match self.groups.last_mut() {
71 Some(last) => last,
72 None => {
73 self.groups.push(MergeBounded {
74 streams: FuturesUnorderedBounded::new(MIN_CAPACITY),
75 });
76 self.groups.last_mut().unwrap()
77 }
78 };
79 match last.try_push(stream) {
80 Ok(()) => {}
81 Err(stream) => {
82 let mut next = MergeBounded {
83 streams: FuturesUnorderedBounded::new(last.streams.capacity() * 2),
84 };
85 next.push(stream);
86 self.groups.push(next);
87 }
88 }
89 }
90
91 pub fn is_empty(&self) -> bool {
93 self.groups.iter().all(|g| g.streams.is_empty())
94 }
95
96 pub fn len(&self) -> usize {
98 self.groups.iter().map(|g| g.streams.len()).sum()
99 }
100}
101
102impl<S: Stream + Unpin> Stream for MergeUnbounded<S> {
103 type Item = S::Item;
104
105 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
106 let Self { groups, poll_next } = &mut *self;
107 if groups.is_empty() {
108 return Poll::Ready(None);
109 }
110
111 for _ in 0..groups.len() {
112 if *poll_next >= groups.len() {
113 *poll_next = 0;
114 }
115
116 let poll = Pin::new(&mut groups[*poll_next]).poll_next(cx);
117 match poll {
118 Poll::Ready(Some(x)) => {
119 return Poll::Ready(Some(x));
120 }
121 Poll::Ready(None) => {
122 let group = groups.remove(*poll_next);
123 debug_assert!(group.streams.is_empty());
124
125 if groups.is_empty() {
126 groups.push(group);
128 return Poll::Ready(None);
129 }
130
131 if *poll_next == groups.len() {
134 groups.push(group);
135 *poll_next = 0;
136 }
137 }
138 Poll::Pending => {
139 *poll_next += 1;
140 }
141 }
142 }
143 Poll::Pending
144 }
145}
146
147impl<S: Stream + Unpin> FromIterator<S> for MergeUnbounded<S> {
148 fn from_iter<T>(iter: T) -> Self
149 where
150 T: IntoIterator<Item = S>,
151 {
152 let iter = iter.into_iter();
153 let mut this = Self::new();
156 for stream in iter {
157 this.push(stream);
158 }
159 this
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use core::cell::RefCell;
166 use core::task::Waker;
167
168 use super::*;
169 use alloc::collections::VecDeque;
170 use alloc::rc::Rc;
171 use futures::executor::block_on;
172 use futures::executor::LocalPool;
173 use futures::stream;
174 use futures::task::LocalSpawnExt;
175 use futures::StreamExt;
176
177 #[test]
178 fn merge_tuple_4() {
179 block_on(async {
180 let a = stream::repeat(2).take(2);
181 let b = stream::repeat(3).take(3);
182 let c = stream::repeat(5).take(5);
183 let d = stream::repeat(7).take(7);
184 let mut s: MergeUnbounded<_> = [a, b, c, d].into_iter().collect();
185
186 let mut counter = 0;
187 while let Some(n) = s.next().await {
188 counter += n;
189 }
190 assert_eq!(counter, 4 + 9 + 25 + 49);
191 });
192 }
193
194 #[test]
195 fn add_streams() {
196 block_on(async {
197 let a = stream::repeat(2).take(2);
198 let b = stream::repeat(3).take(3);
199 let mut s = MergeUnbounded::default();
200 assert_eq!(s.next().await, None);
201 assert!(s.is_empty());
202 assert_eq!(s.len(), 0);
203
204 s.push(a);
205 s.push(b);
206
207 assert!(!s.is_empty());
208 assert_eq!(s.len(), 2);
209
210 let mut counter = 0;
211 while let Some(n) = s.next().await {
212 counter += n;
213 assert!(!s.is_empty());
214 }
215
216 assert!(s.is_empty());
217 assert_eq!(s.len(), 0);
218
219 let b = stream::repeat(4).take(4);
220 s.push(b);
221
222 assert!(!s.is_empty());
223 assert_eq!(s.len(), 1);
224
225 while let Some(n) = s.next().await {
226 counter += n;
227 }
228
229 assert_eq!(counter, 4 + 9 + 16);
230
231 assert!(s.is_empty());
232 assert_eq!(s.len(), 0);
233 });
234 }
235
236 #[test]
240 fn merge_channels() {
241 struct LocalChannel<T> {
242 queue: VecDeque<T>,
243 waker: Option<Waker>,
244 closed: bool,
245 }
246
247 struct LocalReceiver<T> {
248 channel: Rc<RefCell<LocalChannel<T>>>,
249 }
250
251 impl<T> Stream for LocalReceiver<T> {
252 type Item = T;
253
254 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
255 let mut channel = self.channel.borrow_mut();
256
257 match channel.queue.pop_front() {
258 Some(item) => Poll::Ready(Some(item)),
259 None => {
260 if channel.closed {
261 Poll::Ready(None)
262 } else {
263 channel.waker = Some(cx.waker().clone());
264 Poll::Pending
265 }
266 }
267 }
268 }
269 }
270
271 struct LocalSender<T> {
272 channel: Rc<RefCell<LocalChannel<T>>>,
273 }
274
275 impl<T> LocalSender<T> {
276 fn send(&self, item: T) {
277 let mut channel = self.channel.borrow_mut();
278
279 channel.queue.push_back(item);
280
281 let _ = channel.waker.take().map(Waker::wake);
282 }
283 }
284
285 impl<T> Drop for LocalSender<T> {
286 fn drop(&mut self) {
287 let mut channel = self.channel.borrow_mut();
288 channel.closed = true;
289 let _ = channel.waker.take().map(Waker::wake);
290 }
291 }
292
293 fn local_channel<T>() -> (LocalSender<T>, LocalReceiver<T>) {
294 let channel = Rc::new(RefCell::new(LocalChannel {
295 queue: VecDeque::new(),
296 waker: None,
297 closed: false,
298 }));
299
300 (
301 LocalSender {
302 channel: channel.clone(),
303 },
304 LocalReceiver { channel },
305 )
306 }
307
308 let mut pool = LocalPool::new();
309
310 let done = Rc::new(RefCell::new(false));
311 let done2 = done.clone();
312
313 pool.spawner()
314 .spawn_local(async move {
315 let (send1, receive1) = local_channel();
316 let (send2, receive2) = local_channel();
317 let (send3, receive3) = local_channel();
318
319 let (count, ()) = futures::future::join(
320 async {
321 let s: MergeUnbounded<_> =
322 [receive1, receive2, receive3].into_iter().collect();
323 s.fold(0, |a, b| async move { a + b }).await
324 },
325 async {
326 for i in 1..=4 {
327 send1.send(i);
328 send2.send(i);
329 send3.send(i);
330 }
331 drop(send1);
332 drop(send2);
333 drop(send3);
334 },
335 )
336 .await;
337
338 assert_eq!(count, 30);
339
340 *done2.borrow_mut() = true;
341 })
342 .unwrap();
343
344 while !*done.borrow() {
345 pool.run_until_stalled();
346 }
347 }
348}