1use std::{sync::Arc, time::Duration};
5
6use async_trait::async_trait;
7use slim_datapath::api::{EncodedName, ProtoSessionMessageType};
8use tokio::sync::mpsc::Sender;
9use tracing::debug;
10
11use crate::{
12 common::SessionMessage,
13 timer::{Timer, TimerObserver, TimerType},
14};
15
16struct ReliableTimerObserver {
17 tx: Sender<SessionMessage>,
18 message_type: ProtoSessionMessageType,
19 name: Option<EncodedName>,
20}
21
22#[async_trait]
23impl TimerObserver for ReliableTimerObserver {
24 async fn on_timeout(&self, message_id: u32, timeouts: u32) {
25 if let Err(e) = self
26 .tx
27 .send(SessionMessage::TimerTimeout {
28 message_id,
29 message_type: self.message_type,
30 name: self.name,
31 timeouts,
32 })
33 .await
34 {
35 debug!(%message_id, error = %e, "timer timeout: session already closed, dropping");
38 }
39 }
40
41 async fn on_failure(&self, message_id: u32, timeouts: u32) {
42 if let Err(e) = self
44 .tx
45 .send(SessionMessage::TimerFailure {
46 message_id,
47 message_type: self.message_type,
48 name: self.name,
49 timeouts,
50 })
51 .await
52 {
53 debug!(%message_id, error = %e, "timer failure: session already closed, dropping");
55 }
56 }
57
58 async fn on_stop(&self, message_id: u32) {
59 debug!(timer_id = %message_id, "timer stopped");
60 }
61}
62
63#[derive(Clone)]
64pub struct TimerSettings {
65 pub duration: Duration,
66 pub max_duration: Option<Duration>,
67 pub max_retries: Option<u32>,
68 pub timer_type: TimerType,
69}
70
71impl TimerSettings {
72 pub fn new(
74 duration: Duration,
75 max_duration: Option<Duration>,
76 max_retries: Option<u32>,
77 timer_type: TimerType,
78 ) -> Self {
79 Self {
80 duration,
81 max_duration,
82 max_retries,
83 timer_type,
84 }
85 }
86
87 pub fn constant(duration: Duration) -> Self {
89 Self {
90 duration,
91 max_duration: None,
92 max_retries: None,
93 timer_type: TimerType::Constant,
94 }
95 }
96
97 pub fn exponential(initial_duration: Duration, max_duration: Option<Duration>) -> Self {
99 Self {
100 duration: initial_duration,
101 max_duration,
102 max_retries: None,
103 timer_type: TimerType::Exponential,
104 }
105 }
106
107 pub fn with_max_retries(mut self, max_retries: u32) -> Self {
109 self.max_retries = Some(max_retries);
110 self
111 }
112}
113
114pub struct TimerFactory {
115 tx: Sender<SessionMessage>,
117 settings: TimerSettings,
118}
119
120impl TimerFactory {
121 pub fn new(settings: TimerSettings, tx: Sender<SessionMessage>) -> Self {
122 Self {
123 tx: tx.clone(),
124 settings,
125 }
126 }
127
128 pub fn create_timer(&self, id: u32) -> Timer {
129 Timer::new(
130 id,
131 self.settings.timer_type.clone(),
132 self.settings.duration,
133 self.settings.max_duration,
134 self.settings.max_retries,
135 )
136 }
137
138 pub fn create_and_start_timer(
139 &self,
140 id: u32,
141 message_type: ProtoSessionMessageType,
142 name: Option<EncodedName>,
143 ) -> Timer {
144 let t = Timer::new(
145 id,
146 self.settings.timer_type.clone(),
147 self.settings.duration,
148 self.settings.max_duration,
149 self.settings.max_retries,
150 );
151 self.start_timer(&t, message_type, name);
152 t
153 }
154
155 pub fn start_timer(
156 &self,
157 timer: &Timer,
158 message_type: ProtoSessionMessageType,
159 name: Option<EncodedName>,
160 ) {
161 let observer = ReliableTimerObserver {
163 tx: self.tx.clone(),
164 message_type,
165 name,
166 };
167 timer.start(Arc::new(observer));
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use slim_datapath::api::ProtoName;
174
175 use super::*;
176 use std::time::Duration;
177 use tokio::sync::mpsc;
178 use tokio::time::timeout;
179
180 fn test_encoded_name() -> EncodedName {
182 ProtoName::from_strings(["test", "org", "app"])
183 .with_id(1)
184 .name
185 .unwrap()
186 }
187
188 #[tokio::test]
189 async fn test_timer_factory_new() {
190 let (tx, _rx) = mpsc::channel(10);
192 let settings =
193 TimerSettings::new(Duration::from_millis(100), None, None, TimerType::Constant);
194
195 let factory = TimerFactory::new(settings, tx);
197
198 assert_eq!(factory.settings.duration, Duration::from_millis(100));
201 assert!(factory.settings.max_duration.is_none());
202 assert!(factory.settings.max_retries.is_none());
203 matches!(factory.settings.timer_type, TimerType::Constant);
204 }
205
206 #[tokio::test]
207 async fn test_create_and_start_timer() {
208 let (tx, mut rx) = mpsc::channel(10);
210 let settings = TimerSettings::new(
211 Duration::from_millis(50),
212 None,
213 Some(1), TimerType::Constant,
215 );
216 let factory = TimerFactory::new(settings, tx);
217 let timer_id = 123;
218 let name = test_encoded_name();
219
220 let _timer = factory.create_and_start_timer(
221 timer_id,
222 ProtoSessionMessageType::DiscoveryRequest,
223 Some(name),
224 );
225
226 let message = timeout(Duration::from_millis(200), rx.recv())
228 .await
229 .expect("Should receive a message within timeout")
230 .expect("Should receive a message");
231
232 match message {
233 SessionMessage::TimerTimeout {
234 message_id,
235 message_type,
236 name: received_name,
237 timeouts,
238 } => {
239 assert_eq!(message_id, timer_id);
240 assert_eq!(message_type, ProtoSessionMessageType::DiscoveryRequest);
241 assert_eq!(received_name, Some(name));
242 assert_eq!(timeouts, 1);
243 }
244 _ => panic!("Expected TimerTimeout message"),
245 }
246 }
247
248 #[tokio::test]
249 async fn test_timer_timeout_with_constant_timer() {
250 let (tx, mut rx) = mpsc::channel(10);
252 let settings = TimerSettings::new(
253 Duration::from_millis(30),
254 None,
255 Some(2), TimerType::Constant,
257 );
258 let factory = TimerFactory::new(settings, tx);
259 let timer_id = 456;
260 let name = test_encoded_name();
261
262 let timer = factory.create_timer(timer_id);
264 factory.start_timer(
265 &timer,
266 ProtoSessionMessageType::DiscoveryRequest,
267 Some(name),
268 );
269
270 let first_timeout = timeout(Duration::from_millis(100), rx.recv())
272 .await
273 .expect("Should receive first timeout")
274 .expect("Should receive a message");
275
276 match first_timeout {
277 SessionMessage::TimerTimeout {
278 message_id,
279 message_type,
280 name: received_name,
281 timeouts,
282 } => {
283 assert_eq!(message_id, timer_id);
284 assert_eq!(message_type, ProtoSessionMessageType::DiscoveryRequest);
285 assert_eq!(timeouts, 1);
286 assert_eq!(received_name, Some(name));
287 }
288 _ => panic!("Expected TimerTimeout message for first timeout"),
289 }
290
291 let second_timeout = timeout(Duration::from_millis(100), rx.recv())
292 .await
293 .expect("Should receive second timeout")
294 .expect("Should receive a message");
295
296 match second_timeout {
297 SessionMessage::TimerTimeout {
298 message_id,
299 message_type,
300 name: received_name,
301 timeouts,
302 } => {
303 assert_eq!(message_id, timer_id);
304 assert_eq!(message_type, ProtoSessionMessageType::DiscoveryRequest);
305 assert_eq!(timeouts, 2);
306 assert_eq!(received_name, Some(name));
307 }
308 _ => panic!("Expected TimerTimeout message for second timeout"),
309 }
310 }
311
312 #[tokio::test]
313 async fn test_timer_failure_after_max_retries() {
314 let (tx, mut rx) = mpsc::channel(10);
316 let settings = TimerSettings::new(
317 Duration::from_millis(30),
318 None,
319 Some(1), TimerType::Constant,
321 );
322 let factory = TimerFactory::new(settings, tx);
323 let timer_id = 789;
324 let name = test_encoded_name();
325
326 let timer = factory.create_timer(timer_id);
328 factory.start_timer(
329 &timer,
330 ProtoSessionMessageType::DiscoveryRequest,
331 Some(name),
332 );
333
334 let timeout_message = timeout(Duration::from_millis(100), rx.recv())
336 .await
337 .expect("Should receive timeout message")
338 .expect("Should receive a message");
339
340 match timeout_message {
341 SessionMessage::TimerTimeout {
342 message_id,
343 message_type,
344 name: received_name,
345 timeouts,
346 } => {
347 assert_eq!(message_id, timer_id);
348 assert_eq!(message_type, ProtoSessionMessageType::DiscoveryRequest);
349 assert_eq!(timeouts, 1);
350 assert_eq!(received_name, Some(name));
351 }
352 _ => panic!("Expected TimerTimeout message in failure test"),
353 }
354
355 let failure_message = timeout(Duration::from_millis(100), rx.recv())
356 .await
357 .expect("Should receive failure message")
358 .expect("Should receive a message");
359
360 match failure_message {
361 SessionMessage::TimerFailure {
362 message_id,
363 message_type,
364 name: received_name,
365 timeouts,
366 } => {
367 assert_eq!(message_id, timer_id);
368 assert_eq!(message_type, ProtoSessionMessageType::DiscoveryRequest);
369 assert_eq!(timeouts, 2);
370 assert_eq!(received_name, Some(name));
371 }
372 _ => panic!("Expected TimerFailure message"),
373 }
374 }
375
376 #[tokio::test]
377 async fn test_exponential_timer() {
378 let (tx, mut rx) = mpsc::channel(10);
380 let settings = TimerSettings::new(
381 Duration::from_millis(20), Some(Duration::from_millis(100)), Some(2), TimerType::Exponential,
385 );
386 let factory = TimerFactory::new(settings, tx);
387 let timer_id = 999;
388 let name = test_encoded_name();
389
390 let timer = factory.create_timer(timer_id);
392 factory.start_timer(
393 &timer,
394 ProtoSessionMessageType::DiscoveryRequest,
395 Some(name),
396 );
397
398 let first_timeout = timeout(Duration::from_millis(150), rx.recv())
400 .await
401 .expect("Should receive first timeout")
402 .expect("Should receive a message");
403
404 match first_timeout {
405 SessionMessage::TimerTimeout {
406 message_id,
407 message_type,
408 name: received_name,
409 timeouts,
410 } => {
411 assert_eq!(message_id, timer_id);
412 assert_eq!(message_type, ProtoSessionMessageType::DiscoveryRequest);
413 assert_eq!(timeouts, 1);
414 assert_eq!(received_name, Some(name));
415 }
416 _ => panic!("Expected TimerTimeout message for exponential timer first timeout"),
417 }
418
419 let second_timeout = timeout(Duration::from_millis(200), rx.recv())
420 .await
421 .expect("Should receive second timeout")
422 .expect("Should receive a message");
423
424 match second_timeout {
425 SessionMessage::TimerTimeout {
426 message_id,
427 message_type,
428 name: received_name,
429 timeouts,
430 } => {
431 assert_eq!(message_id, timer_id);
432 assert_eq!(message_type, ProtoSessionMessageType::DiscoveryRequest);
433 assert_eq!(timeouts, 2);
434 assert_eq!(received_name, Some(name));
435 }
436 _ => panic!("Expected TimerTimeout message for exponential timer second timeout"),
437 }
438 }
439
440 #[tokio::test]
441 async fn test_timer_settings_with_all_options() {
442 let (tx, _rx) = mpsc::channel(10);
444 let duration = Duration::from_millis(500);
445 let max_duration = Some(Duration::from_secs(5));
446 let max_retries = Some(10);
447 let timer_type = TimerType::Exponential;
448
449 let settings = TimerSettings::new(duration, max_duration, max_retries, timer_type);
450
451 let factory = TimerFactory::new(settings, tx);
453
454 assert_eq!(factory.settings.duration, duration);
456 assert_eq!(factory.settings.max_duration, max_duration);
457 assert_eq!(factory.settings.max_retries, max_retries);
458 matches!(factory.settings.timer_type, TimerType::Exponential);
459 }
460
461 #[tokio::test]
462 async fn test_multiple_timers() {
463 let (tx, mut rx) = mpsc::channel(20);
465 let settings = TimerSettings::new(
466 Duration::from_millis(50),
467 None,
468 Some(1),
469 TimerType::Constant,
470 );
471 let factory = TimerFactory::new(settings, tx);
472 let name1 = ProtoName::from_strings(["test", "org", "app1"])
473 .with_id(1)
474 .name
475 .unwrap();
476 let name2 = ProtoName::from_strings(["test", "org", "app2"])
477 .with_id(2)
478 .name
479 .unwrap();
480
481 let timer1 = factory.create_and_start_timer(
483 100,
484 ProtoSessionMessageType::DiscoveryRequest,
485 Some(name1),
486 );
487 let timer2 = factory.create_and_start_timer(
488 200,
489 ProtoSessionMessageType::DiscoveryRequest,
490 Some(name2),
491 );
492
493 let mut received_ids = Vec::new();
495
496 for _ in 0..2 {
497 let message = timeout(Duration::from_millis(200), rx.recv())
498 .await
499 .expect("Should receive a message within timeout")
500 .expect("Should receive a message");
501
502 match message {
503 SessionMessage::TimerTimeout {
504 message_id,
505 message_type: _,
506 name: _,
507 timeouts,
508 } => {
509 received_ids.push(message_id);
510 assert_eq!(timeouts, 1);
511 }
512 _ => panic!("Expected TimerTimeout message in multiple timers test"),
513 }
514 }
515
516 received_ids.sort();
517 assert_eq!(received_ids, vec![100, 200]);
518
519 drop(timer1);
521 drop(timer2);
522 }
523
524 #[test]
525 fn test_timer_settings_creation() {
526 let settings1 =
528 TimerSettings::new(Duration::from_millis(100), None, None, TimerType::Constant);
529
530 assert_eq!(settings1.duration, Duration::from_millis(100));
531 assert!(settings1.max_duration.is_none());
532 assert!(settings1.max_retries.is_none());
533 matches!(settings1.timer_type, TimerType::Constant);
534
535 let settings2 = TimerSettings::new(
536 Duration::from_secs(1),
537 Some(Duration::from_secs(10)),
538 Some(5),
539 TimerType::Exponential,
540 );
541
542 assert_eq!(settings2.duration, Duration::from_secs(1));
543 assert_eq!(settings2.max_duration, Some(Duration::from_secs(10)));
544 assert_eq!(settings2.max_retries, Some(5));
545 matches!(settings2.timer_type, TimerType::Exponential);
546 }
547
548 #[test]
549 fn test_timer_settings_convenience_constructors() {
550 let constant_settings = TimerSettings::constant(Duration::from_millis(500));
552 assert_eq!(constant_settings.duration, Duration::from_millis(500));
553 assert!(constant_settings.max_duration.is_none());
554 assert!(constant_settings.max_retries.is_none());
555 matches!(constant_settings.timer_type, TimerType::Constant);
556
557 let exponential_settings =
559 TimerSettings::exponential(Duration::from_millis(100), Some(Duration::from_secs(5)));
560 assert_eq!(exponential_settings.duration, Duration::from_millis(100));
561 assert_eq!(
562 exponential_settings.max_duration,
563 Some(Duration::from_secs(5))
564 );
565 assert!(exponential_settings.max_retries.is_none());
566 matches!(exponential_settings.timer_type, TimerType::Exponential);
567
568 let settings_with_retries =
570 TimerSettings::constant(Duration::from_millis(250)).with_max_retries(10);
571 assert_eq!(settings_with_retries.duration, Duration::from_millis(250));
572 assert_eq!(settings_with_retries.max_retries, Some(10));
573 matches!(settings_with_retries.timer_type, TimerType::Constant);
574 }
575
576 #[tokio::test]
577 async fn test_timer_factory_with_convenience_constructors() {
578 let (tx, mut rx) = mpsc::channel(10);
580 let settings = TimerSettings::constant(Duration::from_millis(40)).with_max_retries(1);
581 let factory = TimerFactory::new(settings, tx);
582 let timer_id = 888;
583 let name = test_encoded_name();
584
585 let _timer = factory.create_and_start_timer(
587 timer_id,
588 ProtoSessionMessageType::DiscoveryRequest,
589 Some(name),
590 );
591
592 let timeout_message = timeout(Duration::from_millis(100), rx.recv())
594 .await
595 .expect("Should receive timeout message")
596 .expect("Should receive a message");
597
598 match timeout_message {
599 SessionMessage::TimerTimeout {
600 message_id,
601 message_type,
602 name: received_name,
603 timeouts,
604 } => {
605 assert_eq!(message_id, timer_id);
606 assert_eq!(message_type, ProtoSessionMessageType::DiscoveryRequest);
607 assert_eq!(timeouts, 1);
608 assert_eq!(received_name, Some(name));
609 }
610 _ => panic!("Expected TimerTimeout message with convenience constructors"),
611 }
612 }
613}