Skip to main content

slim_session/
timer.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3
4// Standard library imports
5use std::sync::Arc;
6
7// Third-party crates
8use async_trait::async_trait;
9use tokio::time::{self, Duration};
10use tokio_util::sync::CancellationToken;
11use tracing::trace;
12
13#[async_trait]
14pub trait TimerObserver {
15    async fn on_timeout(&self, timer_id: u32, timeouts: u32);
16    async fn on_failure(&self, timer_id: u32, timeouts: u32);
17    async fn on_stop(&self, timer_id: u32);
18}
19
20#[derive(Debug, Clone)]
21pub enum TimerType {
22    Constant = 0,
23    Exponential = 1,
24}
25
26#[derive(Debug)]
27pub struct Timer {
28    /// timer id
29    timer_id: u32,
30
31    /// timer type
32    timer_type: TimerType,
33
34    /// constant timer: timer duration
35    /// exponential timer: min timer duration. at every new timer the duration is computers as last_duration * 2
36    duration: Duration,
37
38    /// constant timer: None
39    /// exponential timer: maximum timer duration. once the duration reaches this time it will not be encreased anymore
40    max_duration: Option<Duration>,
41
42    /// if not None, it indicates the maximum number of retryes before call on_failure
43    /// if set to None the timer will go on forever unless cancelled
44    max_retries: Option<u32>,
45
46    /// token used to cancel the timer
47    cancellation_token: CancellationToken,
48}
49
50impl Timer {
51    pub fn new(
52        timer_id: u32,
53        timer_type: TimerType,
54        duration: Duration,
55        max_duration: Option<Duration>,
56        max_retries: Option<u32>,
57    ) -> Self {
58        Timer {
59            timer_id,
60            timer_type,
61            duration,
62            max_duration,
63            max_retries,
64            cancellation_token: CancellationToken::new(),
65        }
66    }
67
68    pub fn start<T: TimerObserver + Send + Sync + 'static>(&self, observer: Arc<T>) {
69        let timer_id = self.timer_id;
70        let timer_type = self.timer_type.clone();
71        let duration = self.duration;
72        let max_retries = self.max_retries;
73        let max_duration = self.max_duration;
74        let cancellation_token = self.cancellation_token.clone();
75
76        tokio::spawn(async move {
77            let mut retry = 0;
78            let mut timeouts = 0;
79            let mut last_duration = duration;
80
81            trace!(%timer_id, "timer started");
82            loop {
83                let timer_duration = match timer_type {
84                    TimerType::Constant => {
85                        trace!(
86                            %timer_id, next_ms = duration.as_millis(),
87                            "constant timer",
88                        );
89                        duration
90                    }
91                    TimerType::Exponential => {
92                        let mut d = duration;
93                        if timeouts != 0 {
94                            d = last_duration * 2;
95                        }
96                        match max_duration {
97                            None => {
98                                trace!(
99                                    %timer_id, next_ms = d.as_millis(),
100                                    "exponential timer",
101                                );
102                                last_duration = d;
103                                d
104                            }
105                            Some(max_d) => {
106                                if d > max_d {
107                                    trace!(
108                                        %timer_id,
109                                        next_ms = max_d.as_millis(),
110                                        "exponential timer (use max duration)",
111                                    );
112                                    last_duration = max_d;
113                                    max_d
114                                } else {
115                                    trace!(
116                                        %timer_id, next_ms = max_d.as_millis(),
117                                        "exponential timer",
118                                    );
119                                    last_duration = d;
120                                    d
121                                }
122                            }
123                        }
124                    }
125                };
126
127                let timer = time::sleep(timer_duration);
128                tokio::pin!(timer);
129
130                tokio::select! {
131                    _ = timer.as_mut() => {
132                        timeouts += 1;
133                        match max_retries {
134                            Some(max) => {
135                                if retry < max {
136                                    observer.on_timeout(timer_id, timeouts).await
137                                } else {
138                                    observer.on_failure(timer_id, timeouts).await;
139                                    break;
140                                }
141                            }
142                            None => observer.on_timeout(timer_id, timeouts).await
143                        }
144                        retry += 1;
145                    },
146                    _ = cancellation_token.cancelled() => {
147                        observer.on_stop(timer_id).await;
148                        break;
149                    },
150                }
151            }
152        });
153    }
154
155    pub fn stop(&mut self) {
156        self.cancellation_token.cancel();
157        self.cancellation_token = CancellationToken::new();
158    }
159
160    pub fn reset<T: TimerObserver + Send + Sync + 'static>(&mut self, observer: Arc<T>) {
161        self.stop();
162        self.start(observer);
163    }
164
165    pub fn get_id(&self) -> u32 {
166        self.timer_id
167    }
168}
169
170impl Drop for Timer {
171    fn drop(&mut self) {
172        self.cancellation_token.cancel();
173    }
174}
175
176// tests
177#[cfg(test)]
178mod tests {
179    use tracing::debug;
180    use tracing_test::traced_test;
181
182    use super::*;
183
184    struct Observer {
185        id: u32,
186    }
187
188    #[async_trait]
189    impl TimerObserver for Observer {
190        async fn on_timeout(&self, timer_id: u32, timeouts: u32) {
191            debug!(
192                %timeouts, %timer_id,
193                "timeout occurred, retry",
194            );
195        }
196
197        async fn on_failure(&self, timer_id: u32, timeouts: u32) {
198            debug!(
199                %timeouts, %timer_id,
200                "timeout occurred, stop retry",
201            );
202        }
203
204        async fn on_stop(&self, timer_id: u32) {
205            debug!(%timer_id, "timer cancelled");
206        }
207    }
208
209    #[tokio::test]
210    #[traced_test]
211    async fn test_timer() {
212        let o = Arc::new(Observer { id: 10 });
213        let t = Timer::new(
214            o.id,
215            TimerType::Constant,
216            Duration::from_millis(100),
217            None,
218            Some(3),
219        );
220
221        t.start(o);
222
223        time::sleep(Duration::from_millis(500)).await;
224
225        // check logs to validate the test
226        let expected_msg = "timeout occurred, retry timeouts=1 timer_id=10";
227        assert!(logs_contain(expected_msg));
228        let expected_msg = "timeout occurred, retry timeouts=2 timer_id=10";
229        assert!(logs_contain(expected_msg));
230        let expected_msg = "timeout occurred, retry timeouts=3 timer_id=10";
231        assert!(logs_contain(expected_msg));
232        let expected_msg = "timeout occurred, stop retry timeouts=4 timer_id=10";
233        assert!(logs_contain(expected_msg));
234
235        let o = Arc::new(Observer { id: 20 });
236        let t = Timer::new(
237            o.id,
238            TimerType::Exponential,
239            Duration::from_millis(100),
240            Some(Duration::from_millis(400)),
241            Some(3),
242        );
243
244        t.start(o);
245        time::sleep(Duration::from_millis(1200)).await;
246
247        let expected_msg = "exponential timer timer_id=20 next_ms=400";
248        assert!(logs_contain(expected_msg));
249        let expected_msg = "exponential timer timer_id=20 next_ms=400";
250        assert!(logs_contain(expected_msg));
251        let expected_msg = "exponential timer timer_id=20 next_ms=400";
252        assert!(logs_contain(expected_msg));
253        let expected_msg = "exponential timer (use max duration) timer_id=20 next_ms=400";
254        assert!(logs_contain(expected_msg));
255        let expected_msg = "timeout occurred, stop retry timeouts=4 timer_id=20";
256        assert!(logs_contain(expected_msg));
257
258        let o = Arc::new(Observer { id: 30 });
259        let mut t = Timer::new(
260            o.id,
261            TimerType::Exponential,
262            Duration::from_millis(100),
263            None,
264            None,
265        );
266
267        t.start(o);
268
269        time::sleep(Duration::from_millis(2000)).await;
270        t.stop();
271        time::sleep(Duration::from_millis(500)).await;
272        let expected_msg = "exponential timer timer_id=30 next_ms=400";
273        assert!(logs_contain(expected_msg));
274        let expected_msg = "exponential timer timer_id=30 next_ms=800";
275        assert!(logs_contain(expected_msg));
276        let expected_msg = "exponential timer timer_id=30 next_ms=1600";
277        assert!(logs_contain(expected_msg));
278        let expected_msg = "exponential timer timer_id=30 next_ms=800";
279        assert!(logs_contain(expected_msg));
280        let expected_msg = "exponential timer timer_id=30 next_ms=1600";
281        assert!(logs_contain(expected_msg));
282        let expected_msg = "timer cancelled timer_id=30";
283        assert!(logs_contain(expected_msg))
284    }
285
286    #[tokio::test]
287    #[traced_test]
288    async fn test_timer_stop() {
289        let o = Arc::new(Observer { id: 10 });
290
291        let mut t = Timer::new(
292            o.id,
293            TimerType::Constant,
294            Duration::from_millis(100),
295            None,
296            Some(5),
297        );
298
299        t.start(o);
300
301        time::sleep(Duration::from_millis(350)).await;
302
303        t.stop();
304
305        time::sleep(Duration::from_millis(500)).await;
306
307        // check logs to validate the test
308        let expected_msg = "timeout occurred, retry timeouts=1 timer_id=10";
309        assert!(logs_contain(expected_msg));
310        let expected_msg = "timeout occurred, retry timeouts=2 timer_id=10";
311        assert!(logs_contain(expected_msg));
312        let expected_msg = "timeout occurred, retry timeouts=3 timer_id=10";
313        assert!(logs_contain(expected_msg));
314        let expected_msg = "timer cancelled timer_id=10";
315        assert!(logs_contain(expected_msg));
316    }
317
318    #[tokio::test]
319    #[traced_test]
320    async fn test_multiple_timers() {
321        let o1 = Arc::new(Observer { id: 1 });
322        let o2 = Arc::new(Observer { id: 2 });
323        let o3 = Arc::new(Observer { id: 3 });
324
325        let mut t1 = Timer::new(
326            o1.id,
327            TimerType::Constant,
328            Duration::from_millis(100),
329            None,
330            Some(5),
331        );
332        let mut t2 = Timer::new(
333            o2.id,
334            TimerType::Constant,
335            Duration::from_millis(200),
336            None,
337            Some(5),
338        );
339        let mut t3 = Timer::new(
340            o3.id,
341            TimerType::Constant,
342            Duration::from_millis(200),
343            None,
344            Some(5),
345        );
346
347        t1.start(o1);
348        t2.start(o2);
349        t3.start(o3);
350
351        time::sleep(Duration::from_millis(700)).await;
352
353        t1.stop();
354        t2.stop();
355        t3.stop();
356
357        time::sleep(Duration::from_millis(500)).await;
358
359        // timeouts after 100ms
360        let expected_msg = "timeout occurred, retry timeouts=1 timer_id=1";
361        assert!(logs_contain(expected_msg));
362
363        // timeouts after 200ms
364        let expected_msg = "timeout occurred, retry timeouts=1 timer_id=2";
365        assert!(logs_contain(expected_msg));
366        let expected_msg = "timeout occurred, retry timeouts=1 timer_id=3";
367        assert!(logs_contain(expected_msg));
368        let expected_msg = "timeout occurred, retry timeouts=2 timer_id=1";
369        assert!(logs_contain(expected_msg));
370
371        // timeouts after 300ms
372        let expected_msg = "timeout occurred, retry timeouts=3 timer_id=1";
373        assert!(logs_contain(expected_msg));
374
375        // timeouts after 400ms
376        let expected_msg = "timeout occurred, retry timeouts=2 timer_id=2";
377        assert!(logs_contain(expected_msg));
378        let expected_msg = "timeout occurred, retry timeouts=2 timer_id=3";
379        assert!(logs_contain(expected_msg));
380        let expected_msg = "timeout occurred, retry timeouts=4 timer_id=1";
381        assert!(logs_contain(expected_msg));
382
383        // timeouts after 500ms
384        let expected_msg = "timeout occurred, retry timeouts=4 timer_id=1";
385        assert!(logs_contain(expected_msg));
386
387        // timeouts after 600ms
388        let expected_msg = "timeout occurred, retry timeouts=3 timer_id=2";
389        assert!(logs_contain(expected_msg));
390        let expected_msg = "timeout occurred, retry timeouts=3 timer_id=3";
391        assert!(logs_contain(expected_msg));
392        let expected_msg = "timeout occurred, retry timeouts=5 timer_id=1";
393        assert!(logs_contain(expected_msg));
394
395        // timeouts after 700ms
396        let expected_msg = "timeout occurred, stop retry timeouts=6 timer_id=1";
397        assert!(logs_contain(expected_msg));
398
399        // stop timer 2 and 3
400        let expected_msg = "timer cancelled timer_id=2";
401        assert!(logs_contain(expected_msg));
402        let expected_msg = "timer cancelled timer_id=3";
403        assert!(logs_contain(expected_msg));
404    }
405
406    #[tokio::test]
407    #[traced_test]
408    async fn test_timer_reset() {
409        let o = Arc::new(Observer { id: 10 });
410
411        let mut t = Timer::new(
412            o.id,
413            TimerType::Constant,
414            Duration::from_millis(100),
415            None,
416            Some(5),
417        );
418
419        t.start(o.clone());
420
421        time::sleep(Duration::from_millis(350)).await;
422
423        let expected_msg = "timeout occurred, retry timeouts=3 timer_id=10";
424        assert!(logs_contain(expected_msg));
425
426        t.reset(o.clone());
427
428        time::sleep(Duration::from_millis(250)).await;
429
430        let expected_msg = "timeout occurred, retry timeouts=2 timer_id=10";
431        assert!(logs_contain(expected_msg));
432
433        t.reset(o.clone());
434
435        time::sleep(Duration::from_millis(700)).await;
436
437        let expected_msg = "timeout occurred, stop retry timeouts=6 timer_id=10";
438        assert!(logs_contain(expected_msg));
439
440        t.reset(o);
441
442        time::sleep(Duration::from_millis(700)).await;
443
444        let expected_msg = "timeout occurred, stop retry timeouts=6 timer_id=10";
445        assert!(logs_contain(expected_msg));
446    }
447
448    #[tokio::test]
449    #[traced_test]
450    async fn test_timer_reset_without_start() {
451        let o = Arc::new(Observer { id: 10 });
452
453        let mut t = Timer::new(
454            o.id,
455            TimerType::Constant,
456            Duration::from_millis(100),
457            None,
458            Some(5),
459        );
460
461        t.reset(o);
462
463        time::sleep(Duration::from_millis(350)).await;
464
465        let expected_msg = "timeout occurred, retry timeouts=3 timer_id=10";
466        assert!(logs_contain(expected_msg));
467    }
468}