Skip to main content

futures_buffered/
merge_bounded.rs

1use core::{
2    pin::Pin,
3    task::{Context, Poll},
4};
5
6use futures_core::Stream;
7
8use crate::FuturesUnorderedBounded;
9
10#[deprecated = "use `MergeBounded` instead"]
11pub type Merge<S> = MergeBounded<S>;
12
13/// A combined stream that releases values in any order that they come
14///
15/// # Example
16///
17/// ```
18/// use std::future::ready;
19/// use futures::stream::{self, StreamExt};
20/// use futures::executor::block_on;
21/// use futures_buffered::Merge;
22///
23/// block_on(async {
24///     let a = stream::once(ready(2));
25///     let b = stream::once(ready(3));
26///     let c = stream::once(ready(5));
27///     let d = stream::once(ready(7));
28///     let mut s = Merge::from_iter([a, b, c, d]);
29///
30///     let mut counter = 0;
31///     while let Some(n) = s.next().await {
32///         counter += n;
33///     }
34///     assert_eq!(counter, 2+3+5+7);
35/// })
36/// ```
37pub struct MergeBounded<S> {
38    pub(crate) streams: FuturesUnorderedBounded<S>,
39}
40
41impl<S> MergeBounded<S> {
42    /// Push a stream into the set.
43    ///
44    /// This method adds the given stream to the set. This method will not
45    /// call [`poll_next`](futures_core::Stream::poll_next) on the submitted stream. The caller must
46    /// ensure that [`Merge::poll_next`](Stream::poll_next) is called
47    /// in order to receive wake-up notifications for the given stream.
48    ///
49    /// # Panics
50    /// This method will panic if the buffer is currently full. See [`Merge::try_push`] to get a result instead
51    #[track_caller]
52    pub fn push(&mut self, stream: S) {
53        if self.try_push(stream).is_err() {
54            panic!("attempted to push into a full `Merge`");
55        }
56    }
57
58    /// Push a future into the set.
59    ///
60    /// This method adds the given future to the set. This method will not
61    /// call [`poll`](core::future::Future::poll) on the submitted future. The caller must
62    /// ensure that [`FuturesUnorderedBounded::poll_next`](Stream::poll_next) is called
63    /// in order to receive wake-up notifications for the given future.
64    ///
65    /// # Errors
66    /// This method will error if the buffer is currently full, returning the future back
67    pub fn try_push(&mut self, stream: S) -> Result<(), S> {
68        self.streams.try_push_with(stream, core::convert::identity)
69    }
70}
71
72impl<S: Stream> Stream for MergeBounded<S> {
73    type Item = S::Item;
74
75    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
76        loop {
77            match self.streams.poll_inner_no_remove(cx, S::poll_next) {
78                // if we have a value from the stream, wake up that slot again
79                Poll::Ready(Some((i, Some(x)))) => {
80                    // safety: i is always within capacity
81                    unsafe {
82                        self.streams.shared.push(i);
83                    }
84                    break Poll::Ready(Some(x));
85                }
86                // if a stream completed, remove it from the queue
87                Poll::Ready(Some((i, None))) => {
88                    self.streams.tasks.remove(i);
89                }
90                Poll::Pending => break Poll::Pending,
91                Poll::Ready(None) => break Poll::Ready(None),
92            }
93        }
94    }
95}
96
97impl<S: Stream> FromIterator<S> for MergeBounded<S> {
98    fn from_iter<T>(iter: T) -> Self
99    where
100        T: IntoIterator<Item = S>,
101    {
102        Self {
103            streams: iter.into_iter().collect(),
104        }
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use core::cell::RefCell;
111    use core::task::Waker;
112
113    use super::*;
114    use alloc::collections::VecDeque;
115    use alloc::rc::Rc;
116    use futures::executor::block_on;
117    use futures::executor::LocalPool;
118    use futures::prelude::*;
119    use futures::stream;
120    use futures::task::LocalSpawnExt;
121
122    #[test]
123    fn merge_tuple_4() {
124        block_on(async {
125            let a = stream::repeat(2).take(2);
126            let b = stream::repeat(3).take(3);
127            let c = stream::repeat(5).take(5);
128            let d = stream::repeat(7).take(7);
129            let mut s: MergeBounded<_> = [a, b, c, d].into_iter().collect();
130
131            let mut counter = 0;
132            while let Some(n) = s.next().await {
133                counter += n;
134            }
135            assert_eq!(counter, 4 + 9 + 25 + 49);
136        });
137    }
138
139    /// This test case uses channels so we'll have streams that return Pending from time to time.
140    ///
141    /// The purpose of this test is to make sure we have the waking logic working.
142    #[test]
143    fn merge_channels() {
144        struct LocalChannel<T> {
145            queue: VecDeque<T>,
146            waker: Option<Waker>,
147            closed: bool,
148        }
149
150        struct LocalReceiver<T> {
151            channel: Rc<RefCell<LocalChannel<T>>>,
152        }
153
154        impl<T> Stream for LocalReceiver<T> {
155            type Item = T;
156
157            fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
158                let mut channel = self.channel.borrow_mut();
159
160                match channel.queue.pop_front() {
161                    Some(item) => Poll::Ready(Some(item)),
162                    None => {
163                        if channel.closed {
164                            Poll::Ready(None)
165                        } else {
166                            channel.waker = Some(cx.waker().clone());
167                            Poll::Pending
168                        }
169                    }
170                }
171            }
172        }
173
174        struct LocalSender<T> {
175            channel: Rc<RefCell<LocalChannel<T>>>,
176        }
177
178        impl<T> LocalSender<T> {
179            fn send(&self, item: T) {
180                let mut channel = self.channel.borrow_mut();
181
182                channel.queue.push_back(item);
183
184                let _ = channel.waker.take().map(Waker::wake);
185            }
186        }
187
188        impl<T> Drop for LocalSender<T> {
189            fn drop(&mut self) {
190                let mut channel = self.channel.borrow_mut();
191                channel.closed = true;
192                let _ = channel.waker.take().map(Waker::wake);
193            }
194        }
195
196        fn local_channel<T>() -> (LocalSender<T>, LocalReceiver<T>) {
197            let channel = Rc::new(RefCell::new(LocalChannel {
198                queue: VecDeque::new(),
199                waker: None,
200                closed: false,
201            }));
202
203            (
204                LocalSender {
205                    channel: channel.clone(),
206                },
207                LocalReceiver { channel },
208            )
209        }
210
211        let mut pool = LocalPool::new();
212
213        let done = Rc::new(RefCell::new(false));
214        let done2 = done.clone();
215
216        pool.spawner()
217            .spawn_local(async move {
218                let (send1, receive1) = local_channel();
219                let (send2, receive2) = local_channel();
220                let (send3, receive3) = local_channel();
221
222                let (count, ()) = futures::future::join(
223                    async {
224                        let s: MergeBounded<_> =
225                            [receive1, receive2, receive3].into_iter().collect();
226                        s.fold(0, |a, b| async move { a + b }).await
227                    },
228                    async {
229                        for i in 1..=4 {
230                            send1.send(i);
231                            send2.send(i);
232                            send3.send(i);
233                        }
234                        drop(send1);
235                        drop(send2);
236                        drop(send3);
237                    },
238                )
239                .await;
240
241                assert_eq!(count, 30);
242
243                *done2.borrow_mut() = true;
244            })
245            .unwrap();
246
247        while !*done.borrow() {
248            pool.run_until_stalled();
249        }
250    }
251}