1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use super::Timestamp;
use core::{
    future::Future,
    pin::Pin,
    task::{Context, Poll},
    time::Duration,
};
use futures::ready;
use tokio::time::{sleep, Instant, Sleep};

#[derive(Debug)]
pub struct Timer {
    start: Instant,
    target: Option<Timestamp>,
    sleep: Pin<Box<Sleep>>,
}

impl Timer {
    fn set_target(&mut self, target: Timestamp) {
        self.target = Some(target);
        let duration = unsafe { target.as_duration() };
        self.sleep.as_mut().reset(self.start + duration);
    }
}

impl Default for Timer {
    fn default() -> Self {
        Self {
            start: Instant::now(),
            target: None,
            sleep: Box::pin(sleep(Duration::from_secs(0))),
        }
    }
}

impl super::Timer for Timer {
    fn now(&self) -> Timestamp {
        let duration = self.start.elapsed();
        unsafe { Timestamp::from_duration(duration) }
    }

    fn poll(&mut self, target: Timestamp, cx: &mut Context) -> Poll<()> {
        if let Some(prev_target) = self.target.as_mut() {
            if *prev_target != target {
                self.set_target(target);
            }
        } else {
            self.set_target(target);
        }

        ready!(self.sleep.as_mut().poll(cx));

        self.target = None;

        Poll::Ready(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::timer::Timer as _;
    use futures_test::task::new_count_waker;

    #[tokio::test(start_paused = true)]
    async fn timer_test() {
        let mut timer = Timer::default();
        let (waker, _count) = new_count_waker();
        let mut cx = Context::from_waker(&waker);

        tokio::time::advance(Duration::from_secs(1)).await;

        let mut now = timer.now();

        let mut times = [now; 5];
        for (idx, time) in times.iter_mut().enumerate() {
            *time += Duration::from_secs(idx as _);
        }

        assert!(timer.poll(now, &mut cx).is_ready());

        for _ in 0..times.len() {
            // poll a bunch of different times in loop to make sure all of the branches are covered
            for time in times.iter().chain(times.iter().rev()).copied() {
                assert_eq!(timer.poll(time, &mut cx).is_ready(), time <= now);
                assert_eq!(timer.poll(time, &mut cx).is_ready(), time <= now);
            }

            tokio::time::advance(Duration::from_secs(1)).await;
            now += Duration::from_secs(1);
        }
    }
}