little_stomper/asynchronous/
delayable_stream.rs

1use std::{pin::Pin, task::Poll, time::Duration};
2
3use futures::{future::pending, Future, FutureExt, Stream};
4use tokio::{
5    select,
6    sync::mpsc::{self, UnboundedReceiver, UnboundedSender},
7    task::JoinHandle,
8    time::{sleep, Instant},
9};
10
11use crate::error::StomperError;
12
13enum ResettableTimerCommand {
14    Reset,
15    ChangePeriod(Duration),
16}
17
18/// Provides the API for resetting or changing the period of it's associated ResettableTimer
19pub struct ResettableTimerResetter {
20    sender: UnboundedSender<ResettableTimerCommand>,
21}
22
23impl ResettableTimerResetter {
24    /// Resets the timer, so that the next time the timer emits will be `now + period`
25    pub fn reset(&self) -> Result<(), StomperError> {
26        self.sender
27            .send(ResettableTimerCommand::Reset)
28            .map(|_| ())
29            .map_err(|_| StomperError::new("Error resetting stream"))
30    }
31
32    /// Changes the period of the timer.
33    ///
34    /// This will also reset the timer, so that the next emission will be at `now + new_period`
35    pub fn change_period(&self, new_period: Duration) -> Result<(), StomperError> {
36        self.sender
37            .send(ResettableTimerCommand::ChangePeriod(new_period))
38            .map(|_| ())
39            .map_err(|_| StomperError::new("Error updating stream period"))
40    }
41}
42
43/// A timer, implemented as as a stream of `unit` emissions, emitting each time `period` has elapsed.
44///
45/// The timer can be reset, so that the period will begin anew from `now`, and the period can be changed, which also causes a reset.
46pub struct ResettableTimer {
47    period: Duration,
48    receiver: Option<UnboundedReceiver<ResettableTimerCommand>>,
49    task: Option<JoinHandle<StreamState>>,
50}
51
52#[derive(Debug)]
53enum StreamState {
54    /// Indicates whether the stream has fired, and what the new period is
55    Fired(JoinHandle<StreamState>),
56}
57
58impl ResettableTimer {
59    /// Creates a new timer with the given `period`. Note that the timer will not start until it is first polled.
60    ///
61    /// Both the timer and the resetter allowing the timer to be reset or modified are returned
62    pub fn create(period: Duration) -> (Self, ResettableTimerResetter) {
63        let (sender, receiver) = mpsc::unbounded_channel();
64        (
65            ResettableTimer {
66                period,
67                receiver: Some(receiver),
68                task: None,
69            },
70            ResettableTimerResetter { sender },
71        )
72    }
73
74    /// Creates a new timer with an infinite `period`. This is useful when a timer is required prior to knowledge about
75    /// what its period should be.
76    pub fn default() -> (Self, ResettableTimerResetter) {
77        Self::create(Duration::from_millis(0))
78    }
79
80    fn create_task_no_receiver(
81        period: Duration,
82    ) -> Pin<Box<dyn Future<Output = StreamState> + Send>> {
83        // The period can never be changed, the sleep never reset; so just sleep
84        sleep(period)
85            .map(move |_| {
86                StreamState::Fired(tokio::task::spawn(
87                    ResettableTimer::create_task_no_receiver(period).boxed(),
88                ))
89            })
90            .boxed()
91    }
92
93    fn create_task_with_receiver(
94        period: Duration,
95        receiver: UnboundedReceiver<ResettableTimerCommand>,
96    ) -> Pin<Box<dyn Future<Output = StreamState> + Send>> {
97        async move {
98            let period = period;
99            let mut receiver = receiver;
100            let mut sleep = Box::pin(sleep(period));
101
102            let receive = receiver.recv();
103
104            // reset the timer; not strictly necessary first iteration, but not harm
105            sleep.as_mut().reset(Instant::now() + period);
106
107            let command_to_new_period = |period, command| {
108                if let ResettableTimerCommand::ChangePeriod(new_period) = command {
109                    new_period
110                } else {
111                    period
112                }
113            };
114
115            // If the period is not set to a positive number, wait only for events that change it
116            if period.as_millis() == 0 {
117                match receive.await {
118                    None => pending::<StreamState>().await, // sender was dropped, so no new event will be received or emitted; wait indefinitely
119                    Some(command) => {
120                        ResettableTimer::create_task_with_receiver(
121                            command_to_new_period(period, command),
122                            receiver,
123                        )
124                        .await
125                    }
126                }
127            } else {
128                select! {
129                    _ = &mut sleep => {
130                        StreamState::Fired(tokio::task::spawn(ResettableTimer::create_task_with_receiver(period, receiver).boxed()))
131                    }
132
133                    received = receive => match received {
134                        None => ResettableTimer::create_task_no_receiver(period).await, // wait again, but this time without being able to reset
135                        Some(command) => ResettableTimer::create_task_with_receiver(command_to_new_period(period, command), receiver).await,
136                    }
137                }
138            }
139        }.boxed()
140    }
141
142    fn poll_existing_task(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Option<()>> {
143        self.task
144            .as_mut()
145            .unwrap()
146            .poll_unpin(cx) // Poll the task
147            .map(|state| {
148                // Transform the result appropriately...
149                match state {
150                    Ok(StreamState::Fired(new_task)) => {
151                        // ... updating our own state along the way
152                        self.task.replace(new_task);
153                        Some(())
154                    }
155                    _ => None,
156                }
157            })
158    }
159
160    fn initialise(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Option<()>> {
161        let receiver = self
162            .receiver
163            .take()
164            .expect("Bad state: neither remote not receiver present");
165
166        let task = ResettableTimer::create_task_with_receiver(self.period, receiver).boxed();
167
168        self.task.replace(tokio::task::spawn(task));
169        // poll the task so that it will trigger the waker if it gets done
170        self.task.as_mut().unwrap().poll_unpin(cx).map(|_| None)
171    }
172}
173
174impl Drop for ResettableTimer {
175    fn drop(&mut self) {
176        if let Some(task) = self.task.take() {
177            task.abort()
178        }
179    }
180}
181
182impl Stream for ResettableTimer {
183    type Item = ();
184
185    fn poll_next(
186        self: Pin<&mut Self>,
187        cx: &mut std::task::Context<'_>,
188    ) -> Poll<Option<Self::Item>> {
189        let stream = self.get_mut();
190
191        if stream.task.is_some() {
192            stream.poll_existing_task(cx)
193        } else {
194            stream.initialise(cx)
195        }
196    }
197}
198
199#[cfg(test)]
200mod test {
201
202    use futures::{FutureExt, StreamExt};
203    use tokio::{
204        task::yield_now,
205        time::{pause, resume, sleep},
206    };
207
208    use super::*;
209
210    #[derive(Debug, PartialEq, Eq)]
211    enum State {
212        Fired,
213        NotFired,
214    }
215
216    async fn sleep_and_wait(stream: &mut ResettableTimer, millis: u64) -> State {
217        pause();
218        sleep(Duration::from_millis(millis)).await;
219        resume();
220
221        match stream.next().now_or_never() {
222            None => State::NotFired,
223            _ => State::Fired,
224        }
225    }
226
227    #[tokio::test]
228    async fn it_does_not_fire_if_not_elapsed() {
229        let (mut stream, _) = ResettableTimer::create(Duration::from_millis(5000));
230
231        delay_expecting_not_fired(&mut stream, 4800).await;
232    }
233
234    #[tokio::test]
235    async fn it_fires_if_elapsed() {
236        let (mut stream, _) = ResettableTimer::create(Duration::from_millis(5000));
237
238        stream.next().now_or_never(); // Kick off the timer.
239
240        delay_expecting_fired(&mut stream, 6000).await;
241    }
242
243    #[tokio::test]
244    async fn it_keeps_firing() {
245        let (mut stream, _) = ResettableTimer::create(Duration::from_millis(500));
246        stream.next().now_or_never(); // Kick off the timer.
247
248        delay_expecting_fired(&mut stream, 600).await;
249        delay_expecting_fired(&mut stream, 500).await;
250        delay_expecting_fired(&mut stream, 500).await;
251        delay_expecting_fired(&mut stream, 500).await;
252        delay_expecting_fired(&mut stream, 500).await;
253        delay_expecting_not_fired(&mut stream, 300).await;
254        delay_expecting_fired(&mut stream, 300).await;
255    }
256
257    async fn delay_expecting_fired(stream: &mut ResettableTimer, millis: u64) {
258        assert_eq!(State::Fired, sleep_and_wait(stream, millis).await);
259    }
260    async fn delay_expecting_not_fired(stream: &mut ResettableTimer, millis: u64) {
261        assert_eq!(State::NotFired, sleep_and_wait(stream, millis).await);
262    }
263
264    #[tokio::test]
265    async fn it_fires_later_if_reset() {
266        let (mut stream, resetter) = ResettableTimer::create(Duration::from_millis(5000));
267        stream.next().now_or_never(); // Kick off the timer.
268
269        delay_expecting_not_fired(&mut stream, 4000).await;
270        resetter.reset().expect("Unexpected error");
271        yield_now().await;
272        delay_expecting_not_fired(&mut stream, 2000).await;
273        delay_expecting_fired(&mut stream, 3050).await;
274    }
275
276    #[tokio::test]
277    async fn it_stays_on_new_schedule_after_reset() {
278        let (mut stream, resetter) = ResettableTimer::create(Duration::from_millis(5000));
279
280        delay_expecting_not_fired(&mut stream, 4000).await;
281        resetter.reset().expect("Unexpected error");
282        yield_now().await;
283        delay_expecting_not_fired(&mut stream, 2000).await;
284        delay_expecting_fired(&mut stream, 3050).await;
285        delay_expecting_not_fired(&mut stream, 4000).await;
286        delay_expecting_fired(&mut stream, 2000).await;
287        delay_expecting_fired(&mut stream, 5000).await;
288        delay_expecting_fired(&mut stream, 5000).await;
289    }
290
291    #[tokio::test]
292    async fn it_changes_period_and_resets() {
293        let (mut stream, resetter) = ResettableTimer::create(Duration::from_millis(5000));
294        stream.next().now_or_never(); // Kick off the timer.
295
296        delay_expecting_fired(&mut stream, 6000).await;
297
298        resetter
299            .change_period(Duration::from_millis(7000))
300            .expect("Unexpected Error");
301
302        // Nothing after original period would have passed
303        delay_expecting_not_fired(&mut stream, 4100).await;
304
305        // Nothing after new period would have passed from first fired
306        delay_expecting_not_fired(&mut stream, 2000).await;
307
308        // Fired after new period duration from reset
309        delay_expecting_fired(&mut stream, 1000).await;
310
311        // New period is maintained
312        delay_expecting_not_fired(&mut stream, 5000).await;
313        delay_expecting_fired(&mut stream, 2000).await;
314    }
315
316    #[tokio::test]
317    async fn it_ends_task_when_dropped() {
318        let (mut stream, resetter) = ResettableTimer::create(Duration::from_millis(5000));
319        stream.next().now_or_never(); // Kick off the timer
320
321        drop(stream);
322
323        yield_now().await;
324
325        resetter
326            .reset()
327            .expect_err("Should be an error because the other end is no longer listening");
328    }
329}