futures_rx/stream_ext/
throttle.rs1use std::{
2 future::Future,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use futures::{
8 stream::{Fuse, FusedStream},
9 Stream, StreamExt,
10};
11use pin_project_lite::pin_project;
12
13pub enum ThrottleConfig {
14 Leading,
15 Trailing,
16 All,
17}
18
19pin_project! {
20 #[must_use = "streams do nothing unless polled"]
22 pub struct Throttle<S: Stream, Fut, F> {
23 config: ThrottleConfig,
24 #[pin]
25 stream: Fuse<S>,
26 f: F,
27 #[pin]
28 current_interval: Option<Fut>,
29 trailing: Option<S::Item>,
30 }
31}
32
33impl<S: Stream, Fut, F> Throttle<S, Fut, F> {
34 pub(crate) fn new(stream: S, f: F, config: ThrottleConfig) -> Self {
35 Self {
36 config,
37 stream: stream.fuse(),
38 f,
39 current_interval: None,
40 trailing: None,
41 }
42 }
43}
44
45impl<S: Stream, Fut, F> FusedStream for Throttle<S, Fut, F>
46where
47 F: for<'a> FnMut(&'a S::Item) -> Fut,
48 Fut: Future,
49{
50 fn is_terminated(&self) -> bool {
51 self.stream.is_terminated()
52 }
53}
54
55impl<S: Stream, Fut, F> Stream for Throttle<S, Fut, F>
56where
57 F: for<'a> FnMut(&'a S::Item) -> Fut,
58 Fut: Future,
59{
60 type Item = S::Item;
61
62 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
63 let mut this = self.project();
64 let is_in_interval = this
65 .current_interval
66 .as_mut()
67 .as_pin_mut()
68 .map(|it| it.poll(cx).is_pending())
69 .unwrap_or(false);
70
71 if !is_in_interval && this.current_interval.is_some() {
72 this.current_interval.set(None);
73
74 if matches!(this.config, ThrottleConfig::All | ThrottleConfig::Trailing) {
75 if let Some(trailing) = this.trailing.take() {
76 return Poll::Ready(Some(trailing));
77 }
78 }
79 }
80
81 match this.stream.poll_next(cx) {
82 Poll::Ready(Some(item)) => {
83 if is_in_interval {
84 this.trailing.replace(item);
85 } else {
86 this.current_interval.set(Some((this.f)(&item)));
87
88 if matches!(this.config, ThrottleConfig::All | ThrottleConfig::Leading) {
89 return Poll::Ready(Some(item));
90 }
91 }
92
93 cx.waker().wake_by_ref();
94
95 Poll::Pending
96 }
97 Poll::Ready(None) => Poll::Ready(None),
98 Poll::Pending => Poll::Pending,
99 }
100 }
101
102 fn size_hint(&self) -> (usize, Option<usize>) {
103 let (lower, upper) = self.stream.size_hint();
104 let lower = if lower > 0 { 1 } else { 0 };
107
108 (lower, upper)
109 }
110}
111
112#[cfg(test)]
113mod test {
114 use futures::{executor::block_on, stream, Stream, StreamExt};
115 use futures_time::{future::IntoFuture, time::Duration};
116
117 use crate::RxExt;
118
119 #[test]
120 fn smoke() {
121 block_on(async {
122 let stream = create_stream();
123 let all_events = stream
124 .throttle(|_| Duration::from_millis(175).into_future())
125 .collect::<Vec<_>>()
126 .await;
127
128 assert_eq!(all_events, [0, 4, 8]);
129 });
130
131 block_on(async {
132 let stream = create_stream();
133 let all_events = stream
134 .throttle_trailing(|_| Duration::from_millis(175).into_future())
135 .collect::<Vec<_>>()
136 .await;
137
138 assert_eq!(all_events, [3, 7]);
139 });
140
141 block_on(async {
142 let stream = create_stream();
143 let all_events = stream
144 .throttle_all(|_| Duration::from_millis(175).into_future())
145 .collect::<Vec<_>>()
146 .await;
147
148 assert_eq!(all_events, [0, 3, 4, 7, 8]);
149 });
150 }
151
152 fn create_stream() -> impl Stream<Item = usize> {
153 stream::unfold(0, move |count| async move {
154 if count < 10 {
155 Duration::from_millis(50).into_future().await;
156
157 Some((count, count + 1))
158 } else {
159 None
160 }
161 })
162 }
163}