proto_tower_util/
timout_counter.rs1use std::future::Future;
2use std::pin::Pin;
3use std::sync::atomic::AtomicU64;
4use std::sync::RwLock as StdRwLock;
5use std::task::{Context, Poll};
6use std::time::{Duration, Instant};
7
8pub struct TimeoutCounter {
9 pub tick_rate: Duration,
10 pub total_count: u64,
11 pub current_count: AtomicU64,
12 pub start_time: StdRwLock<Instant>,
13}
14
15pub enum CountOrDuration {
16 Count(u64),
17 Duration(Duration),
18}
19
20impl TimeoutCounter {
21 pub fn new(interval: CountOrDuration, timeout: CountOrDuration) -> Self {
24 match (interval, timeout) {
25 (CountOrDuration::Count(count), CountOrDuration::Duration(timeout)) => {
26 let tick_rate = Duration::from_millis((timeout.as_millis() / count as u128) as u64);
27 TimeoutCounter {
28 tick_rate,
29 total_count: count,
30 current_count: AtomicU64::new(0),
31 start_time: StdRwLock::new(Instant::now()),
32 }
33 }
34 (CountOrDuration::Duration(tick_rate), CountOrDuration::Count(count)) => TimeoutCounter {
35 tick_rate,
36 total_count: count,
37 current_count: AtomicU64::new(0),
38 start_time: StdRwLock::new(Instant::now()),
39 },
40 (_, _) => panic!("Invalid combination of interval and timeout"),
41 }
42 }
43 pub fn next_timeout(&self) -> NextTimeout {
44 let start_time = self.start_time.read().unwrap();
46 let now = Instant::now();
47 let elapsed_time = now.checked_duration_since(*start_time).expect("Time went backwards");
48 let next_count = (elapsed_time.as_millis() / self.tick_rate.as_millis() + 1) as u64;
50 self.current_count.store(next_count, std::sync::atomic::Ordering::Relaxed);
51 NextTimeout::new(now + self.tick_rate, next_count > self.total_count)
52 }
53
54 pub fn reset(&self) {
55 {
56 let mut lock = self.start_time.write().unwrap();
57 *lock = Instant::now();
58 }
59 self.current_count.store(0, std::sync::atomic::Ordering::Relaxed);
60 }
61}
62
63pub struct NextTimeout {
64 when: Instant,
65 fail: bool,
66}
67
68impl NextTimeout {
69 pub fn new(when: Instant, fail: bool) -> Self {
70 NextTimeout { when, fail }
71 }
72}
73
74impl Future for NextTimeout {
75 type Output = Result<(), ()>;
76
77 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
78 if self.fail {
79 return Poll::Ready(Err(()));
80 }
81 if Instant::now() >= self.when {
82 Poll::Ready(Ok(()))
83 } else {
84 cx.waker().wake_by_ref();
85 Poll::Pending
86 }
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93
94 #[tokio::test]
95 async fn test_timeout_counter() {
96 let start = Instant::now();
97 let timeout_counter = TimeoutCounter::new(CountOrDuration::Duration(Duration::from_millis(100)), CountOrDuration::Count(3));
98 timeout_counter.next_timeout().await.unwrap();
99 timeout_counter.next_timeout().await.unwrap();
100 timeout_counter.reset();
101 timeout_counter.next_timeout().await.unwrap();
102 timeout_counter.next_timeout().await.unwrap();
103 timeout_counter.next_timeout().await.unwrap();
104 assert!(timeout_counter.next_timeout().await.is_err());
105 let duration = start.elapsed();
106 assert!(duration.as_millis() >= 500, "{:?}", duration);
107 }
108}