futures_rx/stream_ext/
throttle.rs

1use std::{
2    future::Future,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use futures::{
8    stream::{Fuse, FusedStream},
9    Stream, StreamExt,
10};
11use pin_project_lite::pin_project;
12
13pub enum ThrottleConfig {
14    Leading,
15    Trailing,
16    All,
17}
18
19pin_project! {
20    /// Stream for the [`throttle`](RxStreamExt::throttle) method.
21    #[must_use = "streams do nothing unless polled"]
22    pub struct Throttle<S: Stream, Fut, F> {
23        config: ThrottleConfig,
24        #[pin]
25        stream: Fuse<S>,
26        f: F,
27        #[pin]
28        current_interval: Option<Fut>,
29        trailing: Option<S::Item>,
30    }
31}
32
33impl<S: Stream, Fut, F> Throttle<S, Fut, F> {
34    pub(crate) fn new(stream: S, f: F, config: ThrottleConfig) -> Self {
35        Self {
36            config,
37            stream: stream.fuse(),
38            f,
39            current_interval: None,
40            trailing: None,
41        }
42    }
43}
44
45impl<S: Stream, Fut, F> FusedStream for Throttle<S, Fut, F>
46where
47    F: for<'a> FnMut(&'a S::Item) -> Fut,
48    Fut: Future,
49{
50    fn is_terminated(&self) -> bool {
51        self.stream.is_terminated()
52    }
53}
54
55impl<S: Stream, Fut, F> Stream for Throttle<S, Fut, F>
56where
57    F: for<'a> FnMut(&'a S::Item) -> Fut,
58    Fut: Future,
59{
60    type Item = S::Item;
61
62    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
63        let mut this = self.project();
64        let is_in_interval = this
65            .current_interval
66            .as_mut()
67            .as_pin_mut()
68            .map(|it| it.poll(cx).is_pending())
69            .unwrap_or(false);
70
71        if !is_in_interval && this.current_interval.is_some() {
72            this.current_interval.set(None);
73
74            if matches!(this.config, ThrottleConfig::All | ThrottleConfig::Trailing) {
75                if let Some(trailing) = this.trailing.take() {
76                    return Poll::Ready(Some(trailing));
77                }
78            }
79        }
80
81        match this.stream.poll_next(cx) {
82            Poll::Ready(Some(item)) => {
83                if is_in_interval {
84                    this.trailing.replace(item);
85                } else {
86                    this.current_interval.set(Some((this.f)(&item)));
87
88                    if matches!(this.config, ThrottleConfig::All | ThrottleConfig::Leading) {
89                        return Poll::Ready(Some(item));
90                    }
91                }
92
93                cx.waker().wake_by_ref();
94
95                Poll::Pending
96            }
97            Poll::Ready(None) => Poll::Ready(None),
98            Poll::Pending => Poll::Pending,
99        }
100    }
101
102    fn size_hint(&self) -> (usize, Option<usize>) {
103        let (lower, upper) = self.stream.size_hint();
104        // we know for sure that the final event (if any) will always emit,
105        // any other events depend on a time interval and must be discarded.
106        let lower = if lower > 0 { 1 } else { 0 };
107
108        (lower, upper)
109    }
110}
111
112#[cfg(test)]
113mod test {
114    use futures::{executor::block_on, stream, Stream, StreamExt};
115    use futures_time::{future::IntoFuture, time::Duration};
116
117    use crate::RxExt;
118
119    #[test]
120    fn smoke() {
121        block_on(async {
122            let stream = create_stream();
123            let all_events = stream
124                .throttle(|_| Duration::from_millis(175).into_future())
125                .collect::<Vec<_>>()
126                .await;
127
128            assert_eq!(all_events, [0, 4, 8]);
129        });
130
131        block_on(async {
132            let stream = create_stream();
133            let all_events = stream
134                .throttle_trailing(|_| Duration::from_millis(175).into_future())
135                .collect::<Vec<_>>()
136                .await;
137
138            assert_eq!(all_events, [3, 7]);
139        });
140
141        block_on(async {
142            let stream = create_stream();
143            let all_events = stream
144                .throttle_all(|_| Duration::from_millis(175).into_future())
145                .collect::<Vec<_>>()
146                .await;
147
148            assert_eq!(all_events, [0, 3, 4, 7, 8]);
149        });
150    }
151
152    fn create_stream() -> impl Stream<Item = usize> {
153        stream::unfold(0, move |count| async move {
154            if count < 10 {
155                Duration::from_millis(50).into_future().await;
156
157                Some((count, count + 1))
158            } else {
159                None
160            }
161        })
162    }
163}