proto_tower_util/
timout_counter.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::atomic::AtomicU64;
4use std::sync::RwLock as StdRwLock;
5use std::task::{Context, Poll};
6use std::time::{Duration, Instant};
7
8pub struct TimeoutCounter {
9    pub tick_rate: Duration,
10    pub total_count: u64,
11    pub current_count: AtomicU64,
12    pub start_time: StdRwLock<Instant>,
13}
14
15pub enum CountOrDuration {
16    Count(u64),
17    Duration(Duration),
18}
19
20impl TimeoutCounter {
21    /// TODO this can actually be a fixed timeout with default of 1 increments (assuming the future gets cancelled)
22    /// Then the second part would always be duration and the tick increment would be num or duration
23    pub fn new(interval: CountOrDuration, timeout: CountOrDuration) -> Self {
24        match (interval, timeout) {
25            (CountOrDuration::Count(count), CountOrDuration::Duration(timeout)) => {
26                let tick_rate = Duration::from_millis((timeout.as_millis() / count as u128) as u64);
27                TimeoutCounter {
28                    tick_rate,
29                    total_count: count,
30                    current_count: AtomicU64::new(0),
31                    start_time: StdRwLock::new(Instant::now()),
32                }
33            }
34            (CountOrDuration::Duration(tick_rate), CountOrDuration::Count(count)) => TimeoutCounter {
35                tick_rate,
36                total_count: count,
37                current_count: AtomicU64::new(0),
38                start_time: StdRwLock::new(Instant::now()),
39            },
40            (_, _) => panic!("Invalid combination of interval and timeout"),
41        }
42    }
43    pub fn next_timeout(&self) -> NextTimeout {
44        // adjust count and timeout for current time
45        let start_time = self.start_time.read().unwrap();
46        let now = Instant::now();
47        let elapsed_time = now.checked_duration_since(*start_time).expect("Time went backwards");
48        // find tick
49        let next_count = (elapsed_time.as_millis() / self.tick_rate.as_millis() + 1) as u64;
50        self.current_count.store(next_count, std::sync::atomic::Ordering::Relaxed);
51        NextTimeout::new(now + self.tick_rate, next_count > self.total_count)
52    }
53
54    pub fn reset(&self) {
55        {
56            let mut lock = self.start_time.write().unwrap();
57            *lock = Instant::now();
58        }
59        self.current_count.store(0, std::sync::atomic::Ordering::Relaxed);
60    }
61}
62
63pub struct NextTimeout {
64    when: Instant,
65    fail: bool,
66}
67
68impl NextTimeout {
69    pub fn new(when: Instant, fail: bool) -> Self {
70        NextTimeout { when, fail }
71    }
72}
73
74impl Future for NextTimeout {
75    type Output = Result<(), ()>;
76
77    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
78        if self.fail {
79            return Poll::Ready(Err(()));
80        }
81        if Instant::now() >= self.when {
82            Poll::Ready(Ok(()))
83        } else {
84            cx.waker().wake_by_ref();
85            Poll::Pending
86        }
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93
94    #[tokio::test]
95    async fn test_timeout_counter() {
96        let start = Instant::now();
97        let timeout_counter = TimeoutCounter::new(CountOrDuration::Duration(Duration::from_millis(100)), CountOrDuration::Count(3));
98        timeout_counter.next_timeout().await.unwrap();
99        timeout_counter.next_timeout().await.unwrap();
100        timeout_counter.reset();
101        timeout_counter.next_timeout().await.unwrap();
102        timeout_counter.next_timeout().await.unwrap();
103        timeout_counter.next_timeout().await.unwrap();
104        assert!(timeout_counter.next_timeout().await.is_err());
105        let duration = start.elapsed();
106        assert!(duration.as_millis() >= 500, "{:?}", duration);
107    }
108}