little_stomper/asynchronous/
delayable_stream.rs1use std::{pin::Pin, task::Poll, time::Duration};
2
3use futures::{future::pending, Future, FutureExt, Stream};
4use tokio::{
5 select,
6 sync::mpsc::{self, UnboundedReceiver, UnboundedSender},
7 task::JoinHandle,
8 time::{sleep, Instant},
9};
10
11use crate::error::StomperError;
12
13enum ResettableTimerCommand {
14 Reset,
15 ChangePeriod(Duration),
16}
17
18pub struct ResettableTimerResetter {
20 sender: UnboundedSender<ResettableTimerCommand>,
21}
22
23impl ResettableTimerResetter {
24 pub fn reset(&self) -> Result<(), StomperError> {
26 self.sender
27 .send(ResettableTimerCommand::Reset)
28 .map(|_| ())
29 .map_err(|_| StomperError::new("Error resetting stream"))
30 }
31
32 pub fn change_period(&self, new_period: Duration) -> Result<(), StomperError> {
36 self.sender
37 .send(ResettableTimerCommand::ChangePeriod(new_period))
38 .map(|_| ())
39 .map_err(|_| StomperError::new("Error updating stream period"))
40 }
41}
42
43pub struct ResettableTimer {
47 period: Duration,
48 receiver: Option<UnboundedReceiver<ResettableTimerCommand>>,
49 task: Option<JoinHandle<StreamState>>,
50}
51
52#[derive(Debug)]
53enum StreamState {
54 Fired(JoinHandle<StreamState>),
56}
57
58impl ResettableTimer {
59 pub fn create(period: Duration) -> (Self, ResettableTimerResetter) {
63 let (sender, receiver) = mpsc::unbounded_channel();
64 (
65 ResettableTimer {
66 period,
67 receiver: Some(receiver),
68 task: None,
69 },
70 ResettableTimerResetter { sender },
71 )
72 }
73
74 pub fn default() -> (Self, ResettableTimerResetter) {
77 Self::create(Duration::from_millis(0))
78 }
79
80 fn create_task_no_receiver(
81 period: Duration,
82 ) -> Pin<Box<dyn Future<Output = StreamState> + Send>> {
83 sleep(period)
85 .map(move |_| {
86 StreamState::Fired(tokio::task::spawn(
87 ResettableTimer::create_task_no_receiver(period).boxed(),
88 ))
89 })
90 .boxed()
91 }
92
93 fn create_task_with_receiver(
94 period: Duration,
95 receiver: UnboundedReceiver<ResettableTimerCommand>,
96 ) -> Pin<Box<dyn Future<Output = StreamState> + Send>> {
97 async move {
98 let period = period;
99 let mut receiver = receiver;
100 let mut sleep = Box::pin(sleep(period));
101
102 let receive = receiver.recv();
103
104 sleep.as_mut().reset(Instant::now() + period);
106
107 let command_to_new_period = |period, command| {
108 if let ResettableTimerCommand::ChangePeriod(new_period) = command {
109 new_period
110 } else {
111 period
112 }
113 };
114
115 if period.as_millis() == 0 {
117 match receive.await {
118 None => pending::<StreamState>().await, Some(command) => {
120 ResettableTimer::create_task_with_receiver(
121 command_to_new_period(period, command),
122 receiver,
123 )
124 .await
125 }
126 }
127 } else {
128 select! {
129 _ = &mut sleep => {
130 StreamState::Fired(tokio::task::spawn(ResettableTimer::create_task_with_receiver(period, receiver).boxed()))
131 }
132
133 received = receive => match received {
134 None => ResettableTimer::create_task_no_receiver(period).await, Some(command) => ResettableTimer::create_task_with_receiver(command_to_new_period(period, command), receiver).await,
136 }
137 }
138 }
139 }.boxed()
140 }
141
142 fn poll_existing_task(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Option<()>> {
143 self.task
144 .as_mut()
145 .unwrap()
146 .poll_unpin(cx) .map(|state| {
148 match state {
150 Ok(StreamState::Fired(new_task)) => {
151 self.task.replace(new_task);
153 Some(())
154 }
155 _ => None,
156 }
157 })
158 }
159
160 fn initialise(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Option<()>> {
161 let receiver = self
162 .receiver
163 .take()
164 .expect("Bad state: neither remote not receiver present");
165
166 let task = ResettableTimer::create_task_with_receiver(self.period, receiver).boxed();
167
168 self.task.replace(tokio::task::spawn(task));
169 self.task.as_mut().unwrap().poll_unpin(cx).map(|_| None)
171 }
172}
173
174impl Drop for ResettableTimer {
175 fn drop(&mut self) {
176 if let Some(task) = self.task.take() {
177 task.abort()
178 }
179 }
180}
181
182impl Stream for ResettableTimer {
183 type Item = ();
184
185 fn poll_next(
186 self: Pin<&mut Self>,
187 cx: &mut std::task::Context<'_>,
188 ) -> Poll<Option<Self::Item>> {
189 let stream = self.get_mut();
190
191 if stream.task.is_some() {
192 stream.poll_existing_task(cx)
193 } else {
194 stream.initialise(cx)
195 }
196 }
197}
198
199#[cfg(test)]
200mod test {
201
202 use futures::{FutureExt, StreamExt};
203 use tokio::{
204 task::yield_now,
205 time::{pause, resume, sleep},
206 };
207
208 use super::*;
209
210 #[derive(Debug, PartialEq, Eq)]
211 enum State {
212 Fired,
213 NotFired,
214 }
215
216 async fn sleep_and_wait(stream: &mut ResettableTimer, millis: u64) -> State {
217 pause();
218 sleep(Duration::from_millis(millis)).await;
219 resume();
220
221 match stream.next().now_or_never() {
222 None => State::NotFired,
223 _ => State::Fired,
224 }
225 }
226
227 #[tokio::test]
228 async fn it_does_not_fire_if_not_elapsed() {
229 let (mut stream, _) = ResettableTimer::create(Duration::from_millis(5000));
230
231 delay_expecting_not_fired(&mut stream, 4800).await;
232 }
233
234 #[tokio::test]
235 async fn it_fires_if_elapsed() {
236 let (mut stream, _) = ResettableTimer::create(Duration::from_millis(5000));
237
238 stream.next().now_or_never(); delay_expecting_fired(&mut stream, 6000).await;
241 }
242
243 #[tokio::test]
244 async fn it_keeps_firing() {
245 let (mut stream, _) = ResettableTimer::create(Duration::from_millis(500));
246 stream.next().now_or_never(); delay_expecting_fired(&mut stream, 600).await;
249 delay_expecting_fired(&mut stream, 500).await;
250 delay_expecting_fired(&mut stream, 500).await;
251 delay_expecting_fired(&mut stream, 500).await;
252 delay_expecting_fired(&mut stream, 500).await;
253 delay_expecting_not_fired(&mut stream, 300).await;
254 delay_expecting_fired(&mut stream, 300).await;
255 }
256
257 async fn delay_expecting_fired(stream: &mut ResettableTimer, millis: u64) {
258 assert_eq!(State::Fired, sleep_and_wait(stream, millis).await);
259 }
260 async fn delay_expecting_not_fired(stream: &mut ResettableTimer, millis: u64) {
261 assert_eq!(State::NotFired, sleep_and_wait(stream, millis).await);
262 }
263
264 #[tokio::test]
265 async fn it_fires_later_if_reset() {
266 let (mut stream, resetter) = ResettableTimer::create(Duration::from_millis(5000));
267 stream.next().now_or_never(); delay_expecting_not_fired(&mut stream, 4000).await;
270 resetter.reset().expect("Unexpected error");
271 yield_now().await;
272 delay_expecting_not_fired(&mut stream, 2000).await;
273 delay_expecting_fired(&mut stream, 3050).await;
274 }
275
276 #[tokio::test]
277 async fn it_stays_on_new_schedule_after_reset() {
278 let (mut stream, resetter) = ResettableTimer::create(Duration::from_millis(5000));
279
280 delay_expecting_not_fired(&mut stream, 4000).await;
281 resetter.reset().expect("Unexpected error");
282 yield_now().await;
283 delay_expecting_not_fired(&mut stream, 2000).await;
284 delay_expecting_fired(&mut stream, 3050).await;
285 delay_expecting_not_fired(&mut stream, 4000).await;
286 delay_expecting_fired(&mut stream, 2000).await;
287 delay_expecting_fired(&mut stream, 5000).await;
288 delay_expecting_fired(&mut stream, 5000).await;
289 }
290
291 #[tokio::test]
292 async fn it_changes_period_and_resets() {
293 let (mut stream, resetter) = ResettableTimer::create(Duration::from_millis(5000));
294 stream.next().now_or_never(); delay_expecting_fired(&mut stream, 6000).await;
297
298 resetter
299 .change_period(Duration::from_millis(7000))
300 .expect("Unexpected Error");
301
302 delay_expecting_not_fired(&mut stream, 4100).await;
304
305 delay_expecting_not_fired(&mut stream, 2000).await;
307
308 delay_expecting_fired(&mut stream, 1000).await;
310
311 delay_expecting_not_fired(&mut stream, 5000).await;
313 delay_expecting_fired(&mut stream, 2000).await;
314 }
315
316 #[tokio::test]
317 async fn it_ends_task_when_dropped() {
318 let (mut stream, resetter) = ResettableTimer::create(Duration::from_millis(5000));
319 stream.next().now_or_never(); drop(stream);
322
323 yield_now().await;
324
325 resetter
326 .reset()
327 .expect_err("Should be an error because the other end is no longer listening");
328 }
329}