1use std::pin::Pin;
2use std::task::{Context, Poll};
3use std::time::Duration;
4
5use futures_util::{FutureExt, Stream, StreamExt};
6
7use super::{delayed, Delayed};
8
9pub struct Debounced<S>
29where
30 S: Stream,
31{
32 stream: S,
33 delay: Duration,
34 pending: Option<Delayed<S::Item>>,
35}
36
37impl<S> Debounced<S>
38where
39 S: Stream + Unpin,
40{
41 pub fn new(stream: S, delay: Duration) -> Debounced<S> {
44 Debounced {
45 stream,
46 delay,
47 pending: None,
48 }
49 }
50}
51
52impl<S> Stream for Debounced<S>
53where
54 S: Stream + Unpin,
55{
56 type Item = S::Item;
57
58 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
59 while let Poll::Ready(next) = self.stream.poll_next_unpin(cx) {
60 match next {
61 Some(next) => self.pending = Some(delayed(next, self.delay)),
62 None => {
63 if self.pending.is_none() {
64 return Poll::Ready(None);
65 }
66 break;
67 }
68 }
69 }
70
71 match self.pending.as_mut() {
72 Some(pending) => match pending.poll_unpin(cx) {
73 Poll::Ready(value) => {
74 let _ = self.pending.take();
75 Poll::Ready(Some(value))
76 }
77 Poll::Pending => Poll::Pending,
78 },
79 None => Poll::Pending,
80 }
81 }
82}
83
84pub fn debounced<S>(stream: S, delay: Duration) -> Debounced<S>
104where
105 S: Stream + Unpin,
106{
107 Debounced::new(stream, delay)
108}
109
110#[cfg(test)]
111mod tests {
112 use std::sync::{Arc, Mutex};
113 use std::time::{Duration, Instant};
114
115 use futures_channel::mpsc::channel;
116 use futures_util::future::join;
117 use futures_util::{SinkExt, StreamExt};
118 use tokio::time::sleep;
119
120 use super::debounced;
121
122 #[tokio::test]
123 async fn test_debounce() {
124 let start = Instant::now();
125 let (mut sender, receiver) = futures_channel::mpsc::channel(1024);
126 let mut debounced = debounced(receiver, Duration::from_secs(1));
127 let _ = sender.send(21).await;
128 let _ = sender.send(42).await;
129 assert_eq!(debounced.next().await, Some(42));
130 assert_eq!(start.elapsed().as_secs(), 1);
131 std::mem::drop(sender);
132 assert_eq!(debounced.next().await, None);
133 }
134
135 #[tokio::test]
136 async fn test_debounce_order() {
137 #[derive(Debug, PartialEq, Eq)]
138 pub enum Message {
139 Value(u64),
140 SenderEnded,
141 ReceiverEnded,
142 }
143
144 let (mut sender, receiver) = channel(1024);
145 let mut receiver = debounced(receiver, Duration::from_millis(100));
146 let messages = Arc::new(Mutex::new(vec![]));
147
148 join(
149 {
150 let messages = messages.clone();
151 async move {
152 for i in 0..10u64 {
153 let _ = sleep(Duration::from_millis(23 * i)).await;
154 let _ = sender.send(i).await;
155 }
156
157 messages.lock().unwrap().push(Message::SenderEnded);
158 }
159 },
160 {
161 let messages = messages.clone();
162
163 async move {
164 while let Some(value) = receiver.next().await {
165 messages.lock().unwrap().push(Message::Value(value));
166 }
167
168 messages.lock().unwrap().push(Message::ReceiverEnded);
169 }
170 },
171 )
172 .await;
173
174 assert_eq!(
175 messages.lock().unwrap().as_slice(),
176 &[
177 Message::Value(4),
178 Message::Value(5),
179 Message::Value(6),
180 Message::Value(7),
181 Message::Value(8),
182 Message::SenderEnded,
183 Message::Value(9),
184 Message::ReceiverEnded
185 ]
186 );
187 }
188}