Skip to main content

slim_session/
timer_factory.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3
4use std::{sync::Arc, time::Duration};
5
6use async_trait::async_trait;
7use slim_datapath::api::{EncodedName, ProtoSessionMessageType};
8use tokio::sync::mpsc::Sender;
9use tracing::debug;
10
11use crate::{
12    common::SessionMessage,
13    timer::{Timer, TimerObserver, TimerType},
14};
15
16struct ReliableTimerObserver {
17    tx: Sender<SessionMessage>,
18    message_type: ProtoSessionMessageType,
19    name: Option<EncodedName>,
20}
21
22#[async_trait]
23impl TimerObserver for ReliableTimerObserver {
24    async fn on_timeout(&self, message_id: u32, timeouts: u32) {
25        if let Err(e) = self
26            .tx
27            .send(SessionMessage::TimerTimeout {
28                message_id,
29                message_type: self.message_type,
30                name: self.name,
31                timeouts,
32            })
33            .await
34        {
35            // The session processing loop has already exited (session closed).
36            // The timer fired after the receiver was dropped; nothing to do.
37            debug!(%message_id, error = %e, "timer timeout: session already closed, dropping");
38        }
39    }
40
41    async fn on_failure(&self, message_id: u32, timeouts: u32) {
42        // remove the state for the lost message
43        if let Err(e) = self
44            .tx
45            .send(SessionMessage::TimerFailure {
46                message_id,
47                message_type: self.message_type,
48                name: self.name,
49                timeouts,
50            })
51            .await
52        {
53            // Same race: session closed before the failure notification arrived.
54            debug!(%message_id, error = %e, "timer failure: session already closed, dropping");
55        }
56    }
57
58    async fn on_stop(&self, message_id: u32) {
59        debug!(timer_id = %message_id, "timer stopped");
60    }
61}
62
63#[derive(Clone)]
64pub struct TimerSettings {
65    pub duration: Duration,
66    pub max_duration: Option<Duration>,
67    pub max_retries: Option<u32>,
68    pub timer_type: TimerType,
69}
70
71impl TimerSettings {
72    /// Create a new TimerSettings with the specified parameters
73    pub fn new(
74        duration: Duration,
75        max_duration: Option<Duration>,
76        max_retries: Option<u32>,
77        timer_type: TimerType,
78    ) -> Self {
79        Self {
80            duration,
81            max_duration,
82            max_retries,
83            timer_type,
84        }
85    }
86
87    /// Create a constant timer settings with the specified duration
88    pub fn constant(duration: Duration) -> Self {
89        Self {
90            duration,
91            max_duration: None,
92            max_retries: None,
93            timer_type: TimerType::Constant,
94        }
95    }
96
97    /// Create an exponential timer settings with the specified duration and max duration
98    pub fn exponential(initial_duration: Duration, max_duration: Option<Duration>) -> Self {
99        Self {
100            duration: initial_duration,
101            max_duration,
102            max_retries: None,
103            timer_type: TimerType::Exponential,
104        }
105    }
106
107    /// Set the maximum number of retries before failure
108    pub fn with_max_retries(mut self, max_retries: u32) -> Self {
109        self.max_retries = Some(max_retries);
110        self
111    }
112}
113
114pub struct TimerFactory {
115    //observer: Arc<ReliableTimerObserver>,
116    tx: Sender<SessionMessage>,
117    settings: TimerSettings,
118}
119
120impl TimerFactory {
121    pub fn new(settings: TimerSettings, tx: Sender<SessionMessage>) -> Self {
122        Self {
123            tx: tx.clone(),
124            settings,
125        }
126    }
127
128    pub fn create_timer(&self, id: u32) -> Timer {
129        Timer::new(
130            id,
131            self.settings.timer_type.clone(),
132            self.settings.duration,
133            self.settings.max_duration,
134            self.settings.max_retries,
135        )
136    }
137
138    pub fn create_and_start_timer(
139        &self,
140        id: u32,
141        message_type: ProtoSessionMessageType,
142        name: Option<EncodedName>,
143    ) -> Timer {
144        let t = Timer::new(
145            id,
146            self.settings.timer_type.clone(),
147            self.settings.duration,
148            self.settings.max_duration,
149            self.settings.max_retries,
150        );
151        self.start_timer(&t, message_type, name);
152        t
153    }
154
155    pub fn start_timer(
156        &self,
157        timer: &Timer,
158        message_type: ProtoSessionMessageType,
159        name: Option<EncodedName>,
160    ) {
161        // start timer
162        let observer = ReliableTimerObserver {
163            tx: self.tx.clone(),
164            message_type,
165            name,
166        };
167        timer.start(Arc::new(observer));
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use slim_datapath::api::ProtoName;
174
175    use super::*;
176    use std::time::Duration;
177    use tokio::sync::mpsc;
178    use tokio::time::timeout;
179
180    // Helper function to create a test EncodedName
181    fn test_encoded_name() -> EncodedName {
182        ProtoName::from_strings(["test", "org", "app"])
183            .with_id(1)
184            .name
185            .unwrap()
186    }
187
188    #[tokio::test]
189    async fn test_timer_factory_new() {
190        // Arrange
191        let (tx, _rx) = mpsc::channel(10);
192        let settings =
193            TimerSettings::new(Duration::from_millis(100), None, None, TimerType::Constant);
194
195        // Act
196        let factory = TimerFactory::new(settings, tx);
197
198        // Assert
199        // Just check that the factory was created successfully
200        assert_eq!(factory.settings.duration, Duration::from_millis(100));
201        assert!(factory.settings.max_duration.is_none());
202        assert!(factory.settings.max_retries.is_none());
203        matches!(factory.settings.timer_type, TimerType::Constant);
204    }
205
206    #[tokio::test]
207    async fn test_create_and_start_timer() {
208        // Arrange
209        let (tx, mut rx) = mpsc::channel(10);
210        let settings = TimerSettings::new(
211            Duration::from_millis(50),
212            None,
213            Some(1), // Only 1 retry to make test faster
214            TimerType::Constant,
215        );
216        let factory = TimerFactory::new(settings, tx);
217        let timer_id = 123;
218        let name = test_encoded_name();
219
220        let _timer = factory.create_and_start_timer(
221            timer_id,
222            ProtoSessionMessageType::DiscoveryRequest,
223            Some(name),
224        );
225
226        // Assert - we should receive a timeout message
227        let message = timeout(Duration::from_millis(200), rx.recv())
228            .await
229            .expect("Should receive a message within timeout")
230            .expect("Should receive a message");
231
232        match message {
233            SessionMessage::TimerTimeout {
234                message_id,
235                message_type,
236                name: received_name,
237                timeouts,
238            } => {
239                assert_eq!(message_id, timer_id);
240                assert_eq!(message_type, ProtoSessionMessageType::DiscoveryRequest);
241                assert_eq!(received_name, Some(name));
242                assert_eq!(timeouts, 1);
243            }
244            _ => panic!("Expected TimerTimeout message"),
245        }
246    }
247
248    #[tokio::test]
249    async fn test_timer_timeout_with_constant_timer() {
250        // Arrange
251        let (tx, mut rx) = mpsc::channel(10);
252        let settings = TimerSettings::new(
253            Duration::from_millis(30),
254            None,
255            Some(2), // Allow 2 retries
256            TimerType::Constant,
257        );
258        let factory = TimerFactory::new(settings, tx);
259        let timer_id = 456;
260        let name = test_encoded_name();
261
262        // Act
263        let timer = factory.create_timer(timer_id);
264        factory.start_timer(
265            &timer,
266            ProtoSessionMessageType::DiscoveryRequest,
267            Some(name),
268        );
269
270        // Assert - we should receive multiple timeout messages
271        let first_timeout = timeout(Duration::from_millis(100), rx.recv())
272            .await
273            .expect("Should receive first timeout")
274            .expect("Should receive a message");
275
276        match first_timeout {
277            SessionMessage::TimerTimeout {
278                message_id,
279                message_type,
280                name: received_name,
281                timeouts,
282            } => {
283                assert_eq!(message_id, timer_id);
284                assert_eq!(message_type, ProtoSessionMessageType::DiscoveryRequest);
285                assert_eq!(timeouts, 1);
286                assert_eq!(received_name, Some(name));
287            }
288            _ => panic!("Expected TimerTimeout message for first timeout"),
289        }
290
291        let second_timeout = timeout(Duration::from_millis(100), rx.recv())
292            .await
293            .expect("Should receive second timeout")
294            .expect("Should receive a message");
295
296        match second_timeout {
297            SessionMessage::TimerTimeout {
298                message_id,
299                message_type,
300                name: received_name,
301                timeouts,
302            } => {
303                assert_eq!(message_id, timer_id);
304                assert_eq!(message_type, ProtoSessionMessageType::DiscoveryRequest);
305                assert_eq!(timeouts, 2);
306                assert_eq!(received_name, Some(name));
307            }
308            _ => panic!("Expected TimerTimeout message for second timeout"),
309        }
310    }
311
312    #[tokio::test]
313    async fn test_timer_failure_after_max_retries() {
314        // Arrange
315        let (tx, mut rx) = mpsc::channel(10);
316        let settings = TimerSettings::new(
317            Duration::from_millis(30),
318            None,
319            Some(1), // Only 1 retry, then failure
320            TimerType::Constant,
321        );
322        let factory = TimerFactory::new(settings, tx);
323        let timer_id = 789;
324        let name = test_encoded_name();
325
326        // Act
327        let timer = factory.create_timer(timer_id);
328        factory.start_timer(
329            &timer,
330            ProtoSessionMessageType::DiscoveryRequest,
331            Some(name),
332        );
333
334        // Assert - we should receive a timeout followed by a failure
335        let timeout_message = timeout(Duration::from_millis(100), rx.recv())
336            .await
337            .expect("Should receive timeout message")
338            .expect("Should receive a message");
339
340        match timeout_message {
341            SessionMessage::TimerTimeout {
342                message_id,
343                message_type,
344                name: received_name,
345                timeouts,
346            } => {
347                assert_eq!(message_id, timer_id);
348                assert_eq!(message_type, ProtoSessionMessageType::DiscoveryRequest);
349                assert_eq!(timeouts, 1);
350                assert_eq!(received_name, Some(name));
351            }
352            _ => panic!("Expected TimerTimeout message in failure test"),
353        }
354
355        let failure_message = timeout(Duration::from_millis(100), rx.recv())
356            .await
357            .expect("Should receive failure message")
358            .expect("Should receive a message");
359
360        match failure_message {
361            SessionMessage::TimerFailure {
362                message_id,
363                message_type,
364                name: received_name,
365                timeouts,
366            } => {
367                assert_eq!(message_id, timer_id);
368                assert_eq!(message_type, ProtoSessionMessageType::DiscoveryRequest);
369                assert_eq!(timeouts, 2);
370                assert_eq!(received_name, Some(name));
371            }
372            _ => panic!("Expected TimerFailure message"),
373        }
374    }
375
376    #[tokio::test]
377    async fn test_exponential_timer() {
378        // Arrange
379        let (tx, mut rx) = mpsc::channel(10);
380        let settings = TimerSettings::new(
381            Duration::from_millis(20),        // Start with 20ms
382            Some(Duration::from_millis(100)), // Max 100ms
383            Some(2),                          // Allow 2 retries before failure
384            TimerType::Exponential,
385        );
386        let factory = TimerFactory::new(settings, tx);
387        let timer_id = 999;
388        let name = test_encoded_name();
389
390        // Act
391        let timer = factory.create_timer(timer_id);
392        factory.start_timer(
393            &timer,
394            ProtoSessionMessageType::DiscoveryRequest,
395            Some(name),
396        );
397
398        // Assert - we should receive timeouts with exponentially increasing intervals
399        let first_timeout = timeout(Duration::from_millis(150), rx.recv())
400            .await
401            .expect("Should receive first timeout")
402            .expect("Should receive a message");
403
404        match first_timeout {
405            SessionMessage::TimerTimeout {
406                message_id,
407                message_type,
408                name: received_name,
409                timeouts,
410            } => {
411                assert_eq!(message_id, timer_id);
412                assert_eq!(message_type, ProtoSessionMessageType::DiscoveryRequest);
413                assert_eq!(timeouts, 1);
414                assert_eq!(received_name, Some(name));
415            }
416            _ => panic!("Expected TimerTimeout message for exponential timer first timeout"),
417        }
418
419        let second_timeout = timeout(Duration::from_millis(200), rx.recv())
420            .await
421            .expect("Should receive second timeout")
422            .expect("Should receive a message");
423
424        match second_timeout {
425            SessionMessage::TimerTimeout {
426                message_id,
427                message_type,
428                name: received_name,
429                timeouts,
430            } => {
431                assert_eq!(message_id, timer_id);
432                assert_eq!(message_type, ProtoSessionMessageType::DiscoveryRequest);
433                assert_eq!(timeouts, 2);
434                assert_eq!(received_name, Some(name));
435            }
436            _ => panic!("Expected TimerTimeout message for exponential timer second timeout"),
437        }
438    }
439
440    #[tokio::test]
441    async fn test_timer_settings_with_all_options() {
442        // Arrange
443        let (tx, _rx) = mpsc::channel(10);
444        let duration = Duration::from_millis(500);
445        let max_duration = Some(Duration::from_secs(5));
446        let max_retries = Some(10);
447        let timer_type = TimerType::Exponential;
448
449        let settings = TimerSettings::new(duration, max_duration, max_retries, timer_type);
450
451        // Act
452        let factory = TimerFactory::new(settings, tx);
453
454        // Assert
455        assert_eq!(factory.settings.duration, duration);
456        assert_eq!(factory.settings.max_duration, max_duration);
457        assert_eq!(factory.settings.max_retries, max_retries);
458        matches!(factory.settings.timer_type, TimerType::Exponential);
459    }
460
461    #[tokio::test]
462    async fn test_multiple_timers() {
463        // Arrange
464        let (tx, mut rx) = mpsc::channel(20);
465        let settings = TimerSettings::new(
466            Duration::from_millis(50),
467            None,
468            Some(1),
469            TimerType::Constant,
470        );
471        let factory = TimerFactory::new(settings, tx);
472        let name1 = ProtoName::from_strings(["test", "org", "app1"])
473            .with_id(1)
474            .name
475            .unwrap();
476        let name2 = ProtoName::from_strings(["test", "org", "app2"])
477            .with_id(2)
478            .name
479            .unwrap();
480
481        // Act - create and start multiple timers
482        let timer1 = factory.create_and_start_timer(
483            100,
484            ProtoSessionMessageType::DiscoveryRequest,
485            Some(name1),
486        );
487        let timer2 = factory.create_and_start_timer(
488            200,
489            ProtoSessionMessageType::DiscoveryRequest,
490            Some(name2),
491        );
492
493        // Assert - we should receive messages from both timers
494        let mut received_ids = Vec::new();
495
496        for _ in 0..2 {
497            let message = timeout(Duration::from_millis(200), rx.recv())
498                .await
499                .expect("Should receive a message within timeout")
500                .expect("Should receive a message");
501
502            match message {
503                SessionMessage::TimerTimeout {
504                    message_id,
505                    message_type: _,
506                    name: _,
507                    timeouts,
508                } => {
509                    received_ids.push(message_id);
510                    assert_eq!(timeouts, 1);
511                }
512                _ => panic!("Expected TimerTimeout message in multiple timers test"),
513            }
514        }
515
516        received_ids.sort();
517        assert_eq!(received_ids, vec![100, 200]);
518
519        // Clean up timers to avoid them running indefinitely
520        drop(timer1);
521        drop(timer2);
522    }
523
524    #[test]
525    fn test_timer_settings_creation() {
526        // Test creating TimerSettings with different configurations
527        let settings1 =
528            TimerSettings::new(Duration::from_millis(100), None, None, TimerType::Constant);
529
530        assert_eq!(settings1.duration, Duration::from_millis(100));
531        assert!(settings1.max_duration.is_none());
532        assert!(settings1.max_retries.is_none());
533        matches!(settings1.timer_type, TimerType::Constant);
534
535        let settings2 = TimerSettings::new(
536            Duration::from_secs(1),
537            Some(Duration::from_secs(10)),
538            Some(5),
539            TimerType::Exponential,
540        );
541
542        assert_eq!(settings2.duration, Duration::from_secs(1));
543        assert_eq!(settings2.max_duration, Some(Duration::from_secs(10)));
544        assert_eq!(settings2.max_retries, Some(5));
545        matches!(settings2.timer_type, TimerType::Exponential);
546    }
547
548    #[test]
549    fn test_timer_settings_convenience_constructors() {
550        // Test constant timer constructor
551        let constant_settings = TimerSettings::constant(Duration::from_millis(500));
552        assert_eq!(constant_settings.duration, Duration::from_millis(500));
553        assert!(constant_settings.max_duration.is_none());
554        assert!(constant_settings.max_retries.is_none());
555        matches!(constant_settings.timer_type, TimerType::Constant);
556
557        // Test exponential timer constructor
558        let exponential_settings =
559            TimerSettings::exponential(Duration::from_millis(100), Some(Duration::from_secs(5)));
560        assert_eq!(exponential_settings.duration, Duration::from_millis(100));
561        assert_eq!(
562            exponential_settings.max_duration,
563            Some(Duration::from_secs(5))
564        );
565        assert!(exponential_settings.max_retries.is_none());
566        matches!(exponential_settings.timer_type, TimerType::Exponential);
567
568        // Test fluent builder pattern
569        let settings_with_retries =
570            TimerSettings::constant(Duration::from_millis(250)).with_max_retries(10);
571        assert_eq!(settings_with_retries.duration, Duration::from_millis(250));
572        assert_eq!(settings_with_retries.max_retries, Some(10));
573        matches!(settings_with_retries.timer_type, TimerType::Constant);
574    }
575
576    #[tokio::test]
577    async fn test_timer_factory_with_convenience_constructors() {
578        // Arrange
579        let (tx, mut rx) = mpsc::channel(10);
580        let settings = TimerSettings::constant(Duration::from_millis(40)).with_max_retries(1);
581        let factory = TimerFactory::new(settings, tx);
582        let timer_id = 888;
583        let name = test_encoded_name();
584
585        // Act
586        let _timer = factory.create_and_start_timer(
587            timer_id,
588            ProtoSessionMessageType::DiscoveryRequest,
589            Some(name),
590        );
591
592        // Assert
593        let timeout_message = timeout(Duration::from_millis(100), rx.recv())
594            .await
595            .expect("Should receive timeout message")
596            .expect("Should receive a message");
597
598        match timeout_message {
599            SessionMessage::TimerTimeout {
600                message_id,
601                message_type,
602                name: received_name,
603                timeouts,
604            } => {
605                assert_eq!(message_id, timer_id);
606                assert_eq!(message_type, ProtoSessionMessageType::DiscoveryRequest);
607                assert_eq!(timeouts, 1);
608                assert_eq!(received_name, Some(name));
609            }
610            _ => panic!("Expected TimerTimeout message with convenience constructors"),
611        }
612    }
613}