futures_buffered/
merge_unbounded.rs

1use alloc::vec::Vec;
2use core::{
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use futures_core::Stream;
8
9use crate::{futures_unordered::MIN_CAPACITY, FuturesUnorderedBounded, MergeBounded};
10
11/// A combined stream that releases values in any order that they come.
12///
13/// This differs from [`crate::Merge`] in that [`MergeUnbounded`] does not have a fixed capacity
14/// but instead grows on demand. It uses [`crate::FuturesUnordered`] under the hood.
15///
16/// # Example
17///
18/// ```
19/// use std::future::ready;
20/// use futures::stream::{self, StreamExt};
21/// use futures::executor::block_on;
22/// use futures_buffered::MergeUnbounded;
23///
24/// block_on(async {
25///     let a = stream::once(ready(2));
26///     let b = stream::once(ready(3));
27///     let mut s = MergeUnbounded::from_iter([a, b]);
28///
29///     let mut counter = 0;
30///     while let Some(n) = s.next().await {
31///         if n == 3 {
32///             s.push(stream::once(ready(4)));
33///         }
34///         counter += n;
35///     }
36///     assert_eq!(counter, 2+3+4);
37/// })
38/// ```
39pub struct MergeUnbounded<S> {
40    pub(crate) groups: Vec<MergeBounded<S>>,
41    poll_next: usize,
42}
43
44impl<S> Default for MergeUnbounded<S> {
45    fn default() -> Self {
46        Self::new()
47    }
48}
49
50impl<S> MergeUnbounded<S> {
51    /// Create a new, empty [`MergeUnbounded`].
52    ///
53    /// Calling [`poll_next`](futures_core::Stream::poll_next) will return `Poll::Ready(None)`
54    /// until a stream is added with [`Self::push`].
55    pub const fn new() -> Self {
56        Self {
57            groups: Vec::new(),
58            poll_next: 0,
59        }
60    }
61
62    /// Push a stream into the set.
63    ///
64    /// This method adds the given stream to the set. This method will not
65    /// call [`poll_next`](futures_core::Stream::poll_next) on the submitted stream. The caller must
66    /// ensure that [`MergeUnbounded::poll_next`](Stream::poll_next) is called
67    /// in order to receive wake-up notifications for the given stream.
68    #[track_caller]
69    pub fn push(&mut self, stream: S) {
70        let last = match self.groups.last_mut() {
71            Some(last) => last,
72            None => {
73                self.groups.push(MergeBounded {
74                    streams: FuturesUnorderedBounded::new(MIN_CAPACITY),
75                });
76                self.groups.last_mut().unwrap()
77            }
78        };
79        match last.try_push(stream) {
80            Ok(()) => {}
81            Err(stream) => {
82                let mut next = MergeBounded {
83                    streams: FuturesUnorderedBounded::new(last.streams.capacity() * 2),
84                };
85                next.push(stream);
86                self.groups.push(next);
87            }
88        }
89    }
90
91    /// Returns `true` if there are no streams in the set.
92    pub fn is_empty(&self) -> bool {
93        self.groups.iter().all(|g| g.streams.is_empty())
94    }
95
96    /// Returns the number of streams currently in the set.
97    pub fn len(&self) -> usize {
98        self.groups.iter().map(|g| g.streams.len()).sum()
99    }
100}
101
102impl<S: Stream + Unpin> Stream for MergeUnbounded<S> {
103    type Item = S::Item;
104
105    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
106        let Self { groups, poll_next } = &mut *self;
107        if groups.is_empty() {
108            return Poll::Ready(None);
109        }
110
111        for _ in 0..groups.len() {
112            if *poll_next >= groups.len() {
113                *poll_next = 0;
114            }
115
116            let poll = Pin::new(&mut groups[*poll_next]).poll_next(cx);
117            match poll {
118                Poll::Ready(Some(x)) => {
119                    return Poll::Ready(Some(x));
120                }
121                Poll::Ready(None) => {
122                    let group = groups.remove(*poll_next);
123                    debug_assert!(group.streams.is_empty());
124
125                    if groups.is_empty() {
126                        // group should contain at least 1 set
127                        groups.push(group);
128                        return Poll::Ready(None);
129                    }
130
131                    // we do not want to drop the last set as it contains
132                    // the largest allocation that we want to keep a hold of
133                    if *poll_next == groups.len() {
134                        groups.push(group);
135                        *poll_next = 0;
136                    }
137                }
138                Poll::Pending => {
139                    *poll_next += 1;
140                }
141            }
142        }
143        Poll::Pending
144    }
145}
146
147impl<S: Stream + Unpin> FromIterator<S> for MergeUnbounded<S> {
148    fn from_iter<T>(iter: T) -> Self
149    where
150        T: IntoIterator<Item = S>,
151    {
152        let iter = iter.into_iter();
153        // let mut this =
154        //     Self::with_capacity(usize::max(iter.size_hint().0, MIN_CAPACITY));
155        let mut this = Self::new();
156        for stream in iter {
157            this.push(stream);
158        }
159        this
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use core::cell::RefCell;
166    use core::task::Waker;
167
168    use super::*;
169    use alloc::collections::VecDeque;
170    use alloc::rc::Rc;
171    use futures::executor::block_on;
172    use futures::executor::LocalPool;
173    use futures::stream;
174    use futures::task::LocalSpawnExt;
175    use futures::StreamExt;
176
177    #[test]
178    fn merge_tuple_4() {
179        block_on(async {
180            let a = stream::repeat(2).take(2);
181            let b = stream::repeat(3).take(3);
182            let c = stream::repeat(5).take(5);
183            let d = stream::repeat(7).take(7);
184            let mut s: MergeUnbounded<_> = [a, b, c, d].into_iter().collect();
185
186            let mut counter = 0;
187            while let Some(n) = s.next().await {
188                counter += n;
189            }
190            assert_eq!(counter, 4 + 9 + 25 + 49);
191        });
192    }
193
194    #[test]
195    fn add_streams() {
196        block_on(async {
197            let a = stream::repeat(2).take(2);
198            let b = stream::repeat(3).take(3);
199            let mut s = MergeUnbounded::default();
200            assert_eq!(s.next().await, None);
201            assert!(s.is_empty());
202            assert_eq!(s.len(), 0);
203
204            s.push(a);
205            s.push(b);
206
207            assert!(!s.is_empty());
208            assert_eq!(s.len(), 2);
209
210            let mut counter = 0;
211            while let Some(n) = s.next().await {
212                counter += n;
213                assert!(!s.is_empty());
214            }
215
216            assert!(s.is_empty());
217            assert_eq!(s.len(), 0);
218
219            let b = stream::repeat(4).take(4);
220            s.push(b);
221
222            assert!(!s.is_empty());
223            assert_eq!(s.len(), 1);
224
225            while let Some(n) = s.next().await {
226                counter += n;
227            }
228
229            assert_eq!(counter, 4 + 9 + 16);
230
231            assert!(s.is_empty());
232            assert_eq!(s.len(), 0);
233        });
234    }
235
236    /// This test case uses channels so we'll have streams that return Pending from time to time.
237    ///
238    /// The purpose of this test is to make sure we have the waking logic working.
239    #[test]
240    fn merge_channels() {
241        struct LocalChannel<T> {
242            queue: VecDeque<T>,
243            waker: Option<Waker>,
244            closed: bool,
245        }
246
247        struct LocalReceiver<T> {
248            channel: Rc<RefCell<LocalChannel<T>>>,
249        }
250
251        impl<T> Stream for LocalReceiver<T> {
252            type Item = T;
253
254            fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
255                let mut channel = self.channel.borrow_mut();
256
257                match channel.queue.pop_front() {
258                    Some(item) => Poll::Ready(Some(item)),
259                    None => {
260                        if channel.closed {
261                            Poll::Ready(None)
262                        } else {
263                            channel.waker = Some(cx.waker().clone());
264                            Poll::Pending
265                        }
266                    }
267                }
268            }
269        }
270
271        struct LocalSender<T> {
272            channel: Rc<RefCell<LocalChannel<T>>>,
273        }
274
275        impl<T> LocalSender<T> {
276            fn send(&self, item: T) {
277                let mut channel = self.channel.borrow_mut();
278
279                channel.queue.push_back(item);
280
281                let _ = channel.waker.take().map(Waker::wake);
282            }
283        }
284
285        impl<T> Drop for LocalSender<T> {
286            fn drop(&mut self) {
287                let mut channel = self.channel.borrow_mut();
288                channel.closed = true;
289                let _ = channel.waker.take().map(Waker::wake);
290            }
291        }
292
293        fn local_channel<T>() -> (LocalSender<T>, LocalReceiver<T>) {
294            let channel = Rc::new(RefCell::new(LocalChannel {
295                queue: VecDeque::new(),
296                waker: None,
297                closed: false,
298            }));
299
300            (
301                LocalSender {
302                    channel: channel.clone(),
303                },
304                LocalReceiver { channel },
305            )
306        }
307
308        let mut pool = LocalPool::new();
309
310        let done = Rc::new(RefCell::new(false));
311        let done2 = done.clone();
312
313        pool.spawner()
314            .spawn_local(async move {
315                let (send1, receive1) = local_channel();
316                let (send2, receive2) = local_channel();
317                let (send3, receive3) = local_channel();
318
319                let (count, ()) = futures::future::join(
320                    async {
321                        let s: MergeUnbounded<_> =
322                            [receive1, receive2, receive3].into_iter().collect();
323                        s.fold(0, |a, b| async move { a + b }).await
324                    },
325                    async {
326                        for i in 1..=4 {
327                            send1.send(i);
328                            send2.send(i);
329                            send3.send(i);
330                        }
331                        drop(send1);
332                        drop(send2);
333                        drop(send3);
334                    },
335                )
336                .await;
337
338                assert_eq!(count, 30);
339
340                *done2.borrow_mut() = true;
341            })
342            .unwrap();
343
344        while !*done.borrow() {
345            pool.run_until_stalled();
346        }
347    }
348}