1use std::sync::Arc;
6
7use async_trait::async_trait;
9use tokio::time::{self, Duration};
10use tokio_util::sync::CancellationToken;
11use tracing::trace;
12
13#[async_trait]
14pub trait TimerObserver {
15 async fn on_timeout(&self, timer_id: u32, timeouts: u32);
16 async fn on_failure(&self, timer_id: u32, timeouts: u32);
17 async fn on_stop(&self, timer_id: u32);
18}
19
20#[derive(Debug, Clone)]
21pub enum TimerType {
22 Constant = 0,
23 Exponential = 1,
24}
25
26#[derive(Debug)]
27pub struct Timer {
28 timer_id: u32,
30
31 timer_type: TimerType,
33
34 duration: Duration,
37
38 max_duration: Option<Duration>,
41
42 max_retries: Option<u32>,
45
46 cancellation_token: CancellationToken,
48}
49
50impl Timer {
51 pub fn new(
52 timer_id: u32,
53 timer_type: TimerType,
54 duration: Duration,
55 max_duration: Option<Duration>,
56 max_retries: Option<u32>,
57 ) -> Self {
58 Timer {
59 timer_id,
60 timer_type,
61 duration,
62 max_duration,
63 max_retries,
64 cancellation_token: CancellationToken::new(),
65 }
66 }
67
68 pub fn start<T: TimerObserver + Send + Sync + 'static>(&self, observer: Arc<T>) {
69 let timer_id = self.timer_id;
70 let timer_type = self.timer_type.clone();
71 let duration = self.duration;
72 let max_retries = self.max_retries;
73 let max_duration = self.max_duration;
74 let cancellation_token = self.cancellation_token.clone();
75
76 tokio::spawn(async move {
77 let mut retry = 0;
78 let mut timeouts = 0;
79 let mut last_duration = duration;
80
81 trace!(%timer_id, "timer started");
82 loop {
83 let timer_duration = match timer_type {
84 TimerType::Constant => {
85 trace!(
86 %timer_id, next_ms = duration.as_millis(),
87 "constant timer",
88 );
89 duration
90 }
91 TimerType::Exponential => {
92 let mut d = duration;
93 if timeouts != 0 {
94 d = last_duration * 2;
95 }
96 match max_duration {
97 None => {
98 trace!(
99 %timer_id, next_ms = d.as_millis(),
100 "exponential timer",
101 );
102 last_duration = d;
103 d
104 }
105 Some(max_d) => {
106 if d > max_d {
107 trace!(
108 %timer_id,
109 next_ms = max_d.as_millis(),
110 "exponential timer (use max duration)",
111 );
112 last_duration = max_d;
113 max_d
114 } else {
115 trace!(
116 %timer_id, next_ms = max_d.as_millis(),
117 "exponential timer",
118 );
119 last_duration = d;
120 d
121 }
122 }
123 }
124 }
125 };
126
127 let timer = time::sleep(timer_duration);
128 tokio::pin!(timer);
129
130 tokio::select! {
131 _ = timer.as_mut() => {
132 timeouts += 1;
133 match max_retries {
134 Some(max) => {
135 if retry < max {
136 observer.on_timeout(timer_id, timeouts).await
137 } else {
138 observer.on_failure(timer_id, timeouts).await;
139 break;
140 }
141 }
142 None => observer.on_timeout(timer_id, timeouts).await
143 }
144 retry += 1;
145 },
146 _ = cancellation_token.cancelled() => {
147 observer.on_stop(timer_id).await;
148 break;
149 },
150 }
151 }
152 });
153 }
154
155 pub fn stop(&mut self) {
156 self.cancellation_token.cancel();
157 self.cancellation_token = CancellationToken::new();
158 }
159
160 pub fn reset<T: TimerObserver + Send + Sync + 'static>(&mut self, observer: Arc<T>) {
161 self.stop();
162 self.start(observer);
163 }
164
165 pub fn get_id(&self) -> u32 {
166 self.timer_id
167 }
168}
169
170impl Drop for Timer {
171 fn drop(&mut self) {
172 self.cancellation_token.cancel();
173 }
174}
175
176#[cfg(test)]
178mod tests {
179 use tracing::debug;
180 use tracing_test::traced_test;
181
182 use super::*;
183
184 struct Observer {
185 id: u32,
186 }
187
188 #[async_trait]
189 impl TimerObserver for Observer {
190 async fn on_timeout(&self, timer_id: u32, timeouts: u32) {
191 debug!(
192 %timeouts, %timer_id,
193 "timeout occurred, retry",
194 );
195 }
196
197 async fn on_failure(&self, timer_id: u32, timeouts: u32) {
198 debug!(
199 %timeouts, %timer_id,
200 "timeout occurred, stop retry",
201 );
202 }
203
204 async fn on_stop(&self, timer_id: u32) {
205 debug!(%timer_id, "timer cancelled");
206 }
207 }
208
209 #[tokio::test]
210 #[traced_test]
211 async fn test_timer() {
212 let o = Arc::new(Observer { id: 10 });
213 let t = Timer::new(
214 o.id,
215 TimerType::Constant,
216 Duration::from_millis(100),
217 None,
218 Some(3),
219 );
220
221 t.start(o);
222
223 time::sleep(Duration::from_millis(500)).await;
224
225 let expected_msg = "timeout occurred, retry timeouts=1 timer_id=10";
227 assert!(logs_contain(expected_msg));
228 let expected_msg = "timeout occurred, retry timeouts=2 timer_id=10";
229 assert!(logs_contain(expected_msg));
230 let expected_msg = "timeout occurred, retry timeouts=3 timer_id=10";
231 assert!(logs_contain(expected_msg));
232 let expected_msg = "timeout occurred, stop retry timeouts=4 timer_id=10";
233 assert!(logs_contain(expected_msg));
234
235 let o = Arc::new(Observer { id: 20 });
236 let t = Timer::new(
237 o.id,
238 TimerType::Exponential,
239 Duration::from_millis(100),
240 Some(Duration::from_millis(400)),
241 Some(3),
242 );
243
244 t.start(o);
245 time::sleep(Duration::from_millis(1200)).await;
246
247 let expected_msg = "exponential timer timer_id=20 next_ms=400";
248 assert!(logs_contain(expected_msg));
249 let expected_msg = "exponential timer timer_id=20 next_ms=400";
250 assert!(logs_contain(expected_msg));
251 let expected_msg = "exponential timer timer_id=20 next_ms=400";
252 assert!(logs_contain(expected_msg));
253 let expected_msg = "exponential timer (use max duration) timer_id=20 next_ms=400";
254 assert!(logs_contain(expected_msg));
255 let expected_msg = "timeout occurred, stop retry timeouts=4 timer_id=20";
256 assert!(logs_contain(expected_msg));
257
258 let o = Arc::new(Observer { id: 30 });
259 let mut t = Timer::new(
260 o.id,
261 TimerType::Exponential,
262 Duration::from_millis(100),
263 None,
264 None,
265 );
266
267 t.start(o);
268
269 time::sleep(Duration::from_millis(2000)).await;
270 t.stop();
271 time::sleep(Duration::from_millis(500)).await;
272 let expected_msg = "exponential timer timer_id=30 next_ms=400";
273 assert!(logs_contain(expected_msg));
274 let expected_msg = "exponential timer timer_id=30 next_ms=800";
275 assert!(logs_contain(expected_msg));
276 let expected_msg = "exponential timer timer_id=30 next_ms=1600";
277 assert!(logs_contain(expected_msg));
278 let expected_msg = "exponential timer timer_id=30 next_ms=800";
279 assert!(logs_contain(expected_msg));
280 let expected_msg = "exponential timer timer_id=30 next_ms=1600";
281 assert!(logs_contain(expected_msg));
282 let expected_msg = "timer cancelled timer_id=30";
283 assert!(logs_contain(expected_msg))
284 }
285
286 #[tokio::test]
287 #[traced_test]
288 async fn test_timer_stop() {
289 let o = Arc::new(Observer { id: 10 });
290
291 let mut t = Timer::new(
292 o.id,
293 TimerType::Constant,
294 Duration::from_millis(100),
295 None,
296 Some(5),
297 );
298
299 t.start(o);
300
301 time::sleep(Duration::from_millis(350)).await;
302
303 t.stop();
304
305 time::sleep(Duration::from_millis(500)).await;
306
307 let expected_msg = "timeout occurred, retry timeouts=1 timer_id=10";
309 assert!(logs_contain(expected_msg));
310 let expected_msg = "timeout occurred, retry timeouts=2 timer_id=10";
311 assert!(logs_contain(expected_msg));
312 let expected_msg = "timeout occurred, retry timeouts=3 timer_id=10";
313 assert!(logs_contain(expected_msg));
314 let expected_msg = "timer cancelled timer_id=10";
315 assert!(logs_contain(expected_msg));
316 }
317
318 #[tokio::test]
319 #[traced_test]
320 async fn test_multiple_timers() {
321 let o1 = Arc::new(Observer { id: 1 });
322 let o2 = Arc::new(Observer { id: 2 });
323 let o3 = Arc::new(Observer { id: 3 });
324
325 let mut t1 = Timer::new(
326 o1.id,
327 TimerType::Constant,
328 Duration::from_millis(100),
329 None,
330 Some(5),
331 );
332 let mut t2 = Timer::new(
333 o2.id,
334 TimerType::Constant,
335 Duration::from_millis(200),
336 None,
337 Some(5),
338 );
339 let mut t3 = Timer::new(
340 o3.id,
341 TimerType::Constant,
342 Duration::from_millis(200),
343 None,
344 Some(5),
345 );
346
347 t1.start(o1);
348 t2.start(o2);
349 t3.start(o3);
350
351 time::sleep(Duration::from_millis(700)).await;
352
353 t1.stop();
354 t2.stop();
355 t3.stop();
356
357 time::sleep(Duration::from_millis(500)).await;
358
359 let expected_msg = "timeout occurred, retry timeouts=1 timer_id=1";
361 assert!(logs_contain(expected_msg));
362
363 let expected_msg = "timeout occurred, retry timeouts=1 timer_id=2";
365 assert!(logs_contain(expected_msg));
366 let expected_msg = "timeout occurred, retry timeouts=1 timer_id=3";
367 assert!(logs_contain(expected_msg));
368 let expected_msg = "timeout occurred, retry timeouts=2 timer_id=1";
369 assert!(logs_contain(expected_msg));
370
371 let expected_msg = "timeout occurred, retry timeouts=3 timer_id=1";
373 assert!(logs_contain(expected_msg));
374
375 let expected_msg = "timeout occurred, retry timeouts=2 timer_id=2";
377 assert!(logs_contain(expected_msg));
378 let expected_msg = "timeout occurred, retry timeouts=2 timer_id=3";
379 assert!(logs_contain(expected_msg));
380 let expected_msg = "timeout occurred, retry timeouts=4 timer_id=1";
381 assert!(logs_contain(expected_msg));
382
383 let expected_msg = "timeout occurred, retry timeouts=4 timer_id=1";
385 assert!(logs_contain(expected_msg));
386
387 let expected_msg = "timeout occurred, retry timeouts=3 timer_id=2";
389 assert!(logs_contain(expected_msg));
390 let expected_msg = "timeout occurred, retry timeouts=3 timer_id=3";
391 assert!(logs_contain(expected_msg));
392 let expected_msg = "timeout occurred, retry timeouts=5 timer_id=1";
393 assert!(logs_contain(expected_msg));
394
395 let expected_msg = "timeout occurred, stop retry timeouts=6 timer_id=1";
397 assert!(logs_contain(expected_msg));
398
399 let expected_msg = "timer cancelled timer_id=2";
401 assert!(logs_contain(expected_msg));
402 let expected_msg = "timer cancelled timer_id=3";
403 assert!(logs_contain(expected_msg));
404 }
405
406 #[tokio::test]
407 #[traced_test]
408 async fn test_timer_reset() {
409 let o = Arc::new(Observer { id: 10 });
410
411 let mut t = Timer::new(
412 o.id,
413 TimerType::Constant,
414 Duration::from_millis(100),
415 None,
416 Some(5),
417 );
418
419 t.start(o.clone());
420
421 time::sleep(Duration::from_millis(350)).await;
422
423 let expected_msg = "timeout occurred, retry timeouts=3 timer_id=10";
424 assert!(logs_contain(expected_msg));
425
426 t.reset(o.clone());
427
428 time::sleep(Duration::from_millis(250)).await;
429
430 let expected_msg = "timeout occurred, retry timeouts=2 timer_id=10";
431 assert!(logs_contain(expected_msg));
432
433 t.reset(o.clone());
434
435 time::sleep(Duration::from_millis(700)).await;
436
437 let expected_msg = "timeout occurred, stop retry timeouts=6 timer_id=10";
438 assert!(logs_contain(expected_msg));
439
440 t.reset(o);
441
442 time::sleep(Duration::from_millis(700)).await;
443
444 let expected_msg = "timeout occurred, stop retry timeouts=6 timer_id=10";
445 assert!(logs_contain(expected_msg));
446 }
447
448 #[tokio::test]
449 #[traced_test]
450 async fn test_timer_reset_without_start() {
451 let o = Arc::new(Observer { id: 10 });
452
453 let mut t = Timer::new(
454 o.id,
455 TimerType::Constant,
456 Duration::from_millis(100),
457 None,
458 Some(5),
459 );
460
461 t.reset(o);
462
463 time::sleep(Duration::from_millis(350)).await;
464
465 let expected_msg = "timeout occurred, retry timeouts=3 timer_id=10";
466 assert!(logs_contain(expected_msg));
467 }
468}