futures_rx/stream_ext/
buffer.rs

1use std::{
2    collections::VecDeque,
3    future::Future,
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use futures::{
9    future::{select, Either},
10    stream::{Fuse, FusedStream},
11    FutureExt, Stream, StreamExt,
12};
13use pin_project_lite::pin_project;
14
15pin_project! {
16    /// Stream for the [`buffer`](RxStreamExt::buffer) method.
17    #[must_use = "streams do nothing unless polled"]
18    pub struct Buffer<S: Stream, Fut, F> {
19        #[pin]
20        stream: Fuse<S>,
21        f: F,
22        #[pin]
23        current_interval: Option<Fut>,
24        buffer: Option<VecDeque<S::Item>>,
25    }
26}
27
28impl<S: Stream, Fut, F> Buffer<S, Fut, F> {
29    pub(crate) fn new(stream: S, f: F) -> Self {
30        Self {
31            stream: stream.fuse(),
32            f,
33            current_interval: None,
34            buffer: None,
35        }
36    }
37}
38
39impl<S: Stream, Fut, F> FusedStream for Buffer<S, Fut, F>
40where
41    F: for<'a> FnMut(&'a S::Item, usize) -> Fut,
42    Fut: Future<Output = bool>,
43{
44    fn is_terminated(&self) -> bool {
45        self.stream.is_terminated()
46    }
47}
48
49impl<S: Stream, Fut, F> Stream for Buffer<S, Fut, F>
50where
51    F: for<'a> FnMut(&'a S::Item, usize) -> Fut,
52    Fut: Future<Output = bool>,
53{
54    type Item = VecDeque<S::Item>;
55
56    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
57        let mut this = self.project();
58
59        if let Some(interval) = this.current_interval.as_mut().as_pin_mut() {
60            match select(interval, this.stream.next()).poll_unpin(cx) {
61                Poll::Ready(it) => match it {
62                    Either::Left((it, _)) => {
63                        this.current_interval.set(None);
64
65                        if it {
66                            Poll::Ready(this.buffer.take())
67                        } else {
68                            cx.waker().wake_by_ref();
69
70                            Poll::Pending
71                        }
72                    }
73                    Either::Right((it, mut interval)) => match it {
74                        Some(item) => {
75                            interval.set((this.f)(
76                                &item,
77                                this.buffer.as_ref().map(|it| it.len()).unwrap_or_default() + 1,
78                            ));
79
80                            if let Some(it) = this.buffer.as_mut() {
81                                it.push_back(item);
82                            } else {
83                                this.buffer.replace(VecDeque::from_iter([item]));
84                            }
85
86                            cx.waker().wake_by_ref();
87
88                            Poll::Pending
89                        }
90                        None => Poll::Ready(this.buffer.take()),
91                    },
92                },
93                Poll::Pending => Poll::Pending,
94            }
95        } else {
96            match this.stream.poll_next(cx) {
97                Poll::Ready(Some(item)) => {
98                    this.current_interval.set(Some((this.f)(
99                        &item,
100                        this.buffer.as_ref().map(|it| it.len()).unwrap_or_default() + 1,
101                    )));
102
103                    if let Some(it) = this.buffer.as_mut() {
104                        it.push_back(item);
105                    } else {
106                        this.buffer.replace(VecDeque::from_iter([item]));
107                    }
108
109                    cx.waker().wake_by_ref();
110
111                    Poll::Pending
112                }
113                Poll::Ready(None) => Poll::Ready(this.buffer.take()),
114                Poll::Pending => Poll::Pending,
115            }
116        }
117    }
118
119    fn size_hint(&self) -> (usize, Option<usize>) {
120        let (lower, upper) = self.stream.size_hint();
121        // we know for sure that the final event (if any) will always emit,
122        // any other events depend on a time interval and must be discarded.
123        let lower = if lower > 0 { 1 } else { 0 };
124
125        (lower, upper)
126    }
127}
128
129#[cfg(test)]
130mod test {
131    use futures::{executor::block_on, stream, StreamExt};
132
133    use crate::RxExt;
134
135    #[test]
136    fn smoke() {
137        block_on(async {
138            let all_events = stream::iter(0..=8)
139                .buffer(|_, count| async move { count == 3 })
140                .collect::<Vec<_>>()
141                .await;
142
143            assert_eq!(
144                all_events,
145                vec![vec![0, 1, 2], vec![3, 4, 5], vec![6, 7, 8]]
146            );
147        });
148    }
149}