debounced/
stream.rs

1use std::pin::Pin;
2use std::task::{Context, Poll};
3use std::time::Duration;
4
5use futures_util::{FutureExt, Stream, StreamExt};
6
7use super::{delayed, Delayed};
8
9/// Stream that delays its items for a given duration and only yields the most
10/// recent item afterwards.
11///
12/// ```rust
13/// # use std::time::{Duration, Instant};
14/// # use futures_util::{SinkExt, StreamExt};
15/// # tokio_test::block_on(async {
16/// use debounced::Debounced;
17///
18/// # let start = Instant::now();
19/// let (mut sender, receiver) = futures_channel::mpsc::channel(1024);
20/// let mut debounced = Debounced::new(receiver, Duration::from_secs(1));
21/// sender.send(21).await;
22/// sender.send(42).await;
23/// assert_eq!(debounced.next().await, Some(42));
24/// assert_eq!(start.elapsed().as_secs(), 1);
25/// std::mem::drop(sender);
26/// assert_eq!(debounced.next().await, None);
27/// # })
28pub struct Debounced<S>
29where
30    S: Stream,
31{
32    stream: S,
33    delay: Duration,
34    pending: Option<Delayed<S::Item>>,
35}
36
37impl<S> Debounced<S>
38where
39    S: Stream + Unpin,
40{
41    /// Returns a new stream that delays its items for a given duration and only
42    /// yields the most recent item afterwards.
43    pub fn new(stream: S, delay: Duration) -> Debounced<S> {
44        Debounced {
45            stream,
46            delay,
47            pending: None,
48        }
49    }
50}
51
52impl<S> Stream for Debounced<S>
53where
54    S: Stream + Unpin,
55{
56    type Item = S::Item;
57
58    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
59        while let Poll::Ready(next) = self.stream.poll_next_unpin(cx) {
60            match next {
61                Some(next) => self.pending = Some(delayed(next, self.delay)),
62                None => {
63                    if self.pending.is_none() {
64                        return Poll::Ready(None);
65                    }
66                    break;
67                }
68            }
69        }
70
71        match self.pending.as_mut() {
72            Some(pending) => match pending.poll_unpin(cx) {
73                Poll::Ready(value) => {
74                    let _ = self.pending.take();
75                    Poll::Ready(Some(value))
76                }
77                Poll::Pending => Poll::Pending,
78            },
79            None => Poll::Pending,
80        }
81    }
82}
83
84/// Returns a new stream that delays its items for a given duration and only
85/// yields the most recent item afterwards.
86///
87/// ```rust
88/// # use std::time::{Duration, Instant};
89/// # use futures_util::{SinkExt, StreamExt};
90/// # tokio_test::block_on(async {
91/// use debounced::debounced;
92///
93/// # let start = Instant::now();
94/// let (mut sender, receiver) = futures_channel::mpsc::channel(1024);
95/// let mut debounced = debounced(receiver, Duration::from_secs(1));
96/// sender.send(21).await;
97/// sender.send(42).await;
98/// assert_eq!(debounced.next().await, Some(42));
99/// assert_eq!(start.elapsed().as_secs(), 1);
100/// std::mem::drop(sender);
101/// assert_eq!(debounced.next().await, None);
102/// # })
103pub fn debounced<S>(stream: S, delay: Duration) -> Debounced<S>
104where
105    S: Stream + Unpin,
106{
107    Debounced::new(stream, delay)
108}
109
110#[cfg(test)]
111mod tests {
112    use std::sync::{Arc, Mutex};
113    use std::time::{Duration, Instant};
114
115    use futures_channel::mpsc::channel;
116    use futures_util::future::join;
117    use futures_util::{SinkExt, StreamExt};
118    use tokio::time::sleep;
119
120    use super::debounced;
121
122    #[tokio::test]
123    async fn test_debounce() {
124        let start = Instant::now();
125        let (mut sender, receiver) = futures_channel::mpsc::channel(1024);
126        let mut debounced = debounced(receiver, Duration::from_secs(1));
127        let _ = sender.send(21).await;
128        let _ = sender.send(42).await;
129        assert_eq!(debounced.next().await, Some(42));
130        assert_eq!(start.elapsed().as_secs(), 1);
131        std::mem::drop(sender);
132        assert_eq!(debounced.next().await, None);
133    }
134
135    #[tokio::test]
136    async fn test_debounce_order() {
137        #[derive(Debug, PartialEq, Eq)]
138        pub enum Message {
139            Value(u64),
140            SenderEnded,
141            ReceiverEnded,
142        }
143
144        let (mut sender, receiver) = channel(1024);
145        let mut receiver = debounced(receiver, Duration::from_millis(100));
146        let messages = Arc::new(Mutex::new(vec![]));
147
148        join(
149            {
150                let messages = messages.clone();
151                async move {
152                    for i in 0..10u64 {
153                        let _ = sleep(Duration::from_millis(23 * i)).await;
154                        let _ = sender.send(i).await;
155                    }
156
157                    messages.lock().unwrap().push(Message::SenderEnded);
158                }
159            },
160            {
161                let messages = messages.clone();
162
163                async move {
164                    while let Some(value) = receiver.next().await {
165                        messages.lock().unwrap().push(Message::Value(value));
166                    }
167
168                    messages.lock().unwrap().push(Message::ReceiverEnded);
169                }
170            },
171        )
172        .await;
173
174        assert_eq!(
175            messages.lock().unwrap().as_slice(),
176            &[
177                Message::Value(4),
178                Message::Value(5),
179                Message::Value(6),
180                Message::Value(7),
181                Message::Value(8),
182                Message::SenderEnded,
183                Message::Value(9),
184                Message::ReceiverEnded
185            ]
186        );
187    }
188}