use std::{sync::Arc, time::Duration};
use async_trait::async_trait;
use slim_datapath::api::{EncodedName, ProtoSessionMessageType};
use tokio::sync::mpsc::Sender;
use tracing::debug;
use crate::{
common::SessionMessage,
timer::{Timer, TimerObserver, TimerType},
};
struct ReliableTimerObserver {
tx: Sender<SessionMessage>,
message_type: ProtoSessionMessageType,
name: Option<EncodedName>,
}
#[async_trait]
impl TimerObserver for ReliableTimerObserver {
async fn on_timeout(&self, message_id: u32, timeouts: u32) {
if let Err(e) = self
.tx
.send(SessionMessage::TimerTimeout {
message_id,
message_type: self.message_type,
name: self.name,
timeouts,
})
.await
{
debug!(%message_id, error = %e, "timer timeout: session already closed, dropping");
}
}
async fn on_failure(&self, message_id: u32, timeouts: u32) {
if let Err(e) = self
.tx
.send(SessionMessage::TimerFailure {
message_id,
message_type: self.message_type,
name: self.name,
timeouts,
})
.await
{
debug!(%message_id, error = %e, "timer failure: session already closed, dropping");
}
}
async fn on_stop(&self, message_id: u32) {
debug!(timer_id = %message_id, "timer stopped");
}
}
#[derive(Clone)]
pub struct TimerSettings {
pub duration: Duration,
pub max_duration: Option<Duration>,
pub max_retries: Option<u32>,
pub timer_type: TimerType,
}
impl TimerSettings {
pub fn new(
duration: Duration,
max_duration: Option<Duration>,
max_retries: Option<u32>,
timer_type: TimerType,
) -> Self {
Self {
duration,
max_duration,
max_retries,
timer_type,
}
}
pub fn constant(duration: Duration) -> Self {
Self {
duration,
max_duration: None,
max_retries: None,
timer_type: TimerType::Constant,
}
}
pub fn exponential(initial_duration: Duration, max_duration: Option<Duration>) -> Self {
Self {
duration: initial_duration,
max_duration,
max_retries: None,
timer_type: TimerType::Exponential,
}
}
pub fn with_max_retries(mut self, max_retries: u32) -> Self {
self.max_retries = Some(max_retries);
self
}
}
pub struct TimerFactory {
tx: Sender<SessionMessage>,
settings: TimerSettings,
}
impl TimerFactory {
pub fn new(settings: TimerSettings, tx: Sender<SessionMessage>) -> Self {
Self {
tx: tx.clone(),
settings,
}
}
pub fn create_timer(&self, id: u32) -> Timer {
Timer::new(
id,
self.settings.timer_type.clone(),
self.settings.duration,
self.settings.max_duration,
self.settings.max_retries,
)
}
pub fn create_and_start_timer(
&self,
id: u32,
message_type: ProtoSessionMessageType,
name: Option<EncodedName>,
) -> Timer {
let t = Timer::new(
id,
self.settings.timer_type.clone(),
self.settings.duration,
self.settings.max_duration,
self.settings.max_retries,
);
self.start_timer(&t, message_type, name);
t
}
pub fn start_timer(
&self,
timer: &Timer,
message_type: ProtoSessionMessageType,
name: Option<EncodedName>,
) {
let observer = ReliableTimerObserver {
tx: self.tx.clone(),
message_type,
name,
};
timer.start(Arc::new(observer));
}
}
#[cfg(test)]
mod tests {
use slim_datapath::api::ProtoName;
use super::*;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::time::timeout;
fn test_encoded_name() -> EncodedName {
ProtoName::from_strings(["test", "org", "app"])
.with_id(1)
.name
.unwrap()
}
#[tokio::test]
async fn test_timer_factory_new() {
let (tx, _rx) = mpsc::channel(10);
let settings =
TimerSettings::new(Duration::from_millis(100), None, None, TimerType::Constant);
let factory = TimerFactory::new(settings, tx);
assert_eq!(factory.settings.duration, Duration::from_millis(100));
assert!(factory.settings.max_duration.is_none());
assert!(factory.settings.max_retries.is_none());
matches!(factory.settings.timer_type, TimerType::Constant);
}
#[tokio::test]
async fn test_create_and_start_timer() {
let (tx, mut rx) = mpsc::channel(10);
let settings = TimerSettings::new(
Duration::from_millis(50),
None,
Some(1), TimerType::Constant,
);
let factory = TimerFactory::new(settings, tx);
let timer_id = 123;
let name = test_encoded_name();
let _timer = factory.create_and_start_timer(
timer_id,
ProtoSessionMessageType::DiscoveryRequest,
Some(name),
);
let message = timeout(Duration::from_millis(200), rx.recv())
.await
.expect("Should receive a message within timeout")
.expect("Should receive a message");
match message {
SessionMessage::TimerTimeout {
message_id,
message_type,
name: received_name,
timeouts,
} => {
assert_eq!(message_id, timer_id);
assert_eq!(message_type, ProtoSessionMessageType::DiscoveryRequest);
assert_eq!(received_name, Some(name));
assert_eq!(timeouts, 1);
}
_ => panic!("Expected TimerTimeout message"),
}
}
#[tokio::test]
async fn test_timer_timeout_with_constant_timer() {
let (tx, mut rx) = mpsc::channel(10);
let settings = TimerSettings::new(
Duration::from_millis(30),
None,
Some(2), TimerType::Constant,
);
let factory = TimerFactory::new(settings, tx);
let timer_id = 456;
let name = test_encoded_name();
let timer = factory.create_timer(timer_id);
factory.start_timer(
&timer,
ProtoSessionMessageType::DiscoveryRequest,
Some(name),
);
let first_timeout = timeout(Duration::from_millis(100), rx.recv())
.await
.expect("Should receive first timeout")
.expect("Should receive a message");
match first_timeout {
SessionMessage::TimerTimeout {
message_id,
message_type,
name: received_name,
timeouts,
} => {
assert_eq!(message_id, timer_id);
assert_eq!(message_type, ProtoSessionMessageType::DiscoveryRequest);
assert_eq!(timeouts, 1);
assert_eq!(received_name, Some(name));
}
_ => panic!("Expected TimerTimeout message for first timeout"),
}
let second_timeout = timeout(Duration::from_millis(100), rx.recv())
.await
.expect("Should receive second timeout")
.expect("Should receive a message");
match second_timeout {
SessionMessage::TimerTimeout {
message_id,
message_type,
name: received_name,
timeouts,
} => {
assert_eq!(message_id, timer_id);
assert_eq!(message_type, ProtoSessionMessageType::DiscoveryRequest);
assert_eq!(timeouts, 2);
assert_eq!(received_name, Some(name));
}
_ => panic!("Expected TimerTimeout message for second timeout"),
}
}
#[tokio::test]
async fn test_timer_failure_after_max_retries() {
let (tx, mut rx) = mpsc::channel(10);
let settings = TimerSettings::new(
Duration::from_millis(30),
None,
Some(1), TimerType::Constant,
);
let factory = TimerFactory::new(settings, tx);
let timer_id = 789;
let name = test_encoded_name();
let timer = factory.create_timer(timer_id);
factory.start_timer(
&timer,
ProtoSessionMessageType::DiscoveryRequest,
Some(name),
);
let timeout_message = timeout(Duration::from_millis(100), rx.recv())
.await
.expect("Should receive timeout message")
.expect("Should receive a message");
match timeout_message {
SessionMessage::TimerTimeout {
message_id,
message_type,
name: received_name,
timeouts,
} => {
assert_eq!(message_id, timer_id);
assert_eq!(message_type, ProtoSessionMessageType::DiscoveryRequest);
assert_eq!(timeouts, 1);
assert_eq!(received_name, Some(name));
}
_ => panic!("Expected TimerTimeout message in failure test"),
}
let failure_message = timeout(Duration::from_millis(100), rx.recv())
.await
.expect("Should receive failure message")
.expect("Should receive a message");
match failure_message {
SessionMessage::TimerFailure {
message_id,
message_type,
name: received_name,
timeouts,
} => {
assert_eq!(message_id, timer_id);
assert_eq!(message_type, ProtoSessionMessageType::DiscoveryRequest);
assert_eq!(timeouts, 2);
assert_eq!(received_name, Some(name));
}
_ => panic!("Expected TimerFailure message"),
}
}
#[tokio::test]
async fn test_exponential_timer() {
let (tx, mut rx) = mpsc::channel(10);
let settings = TimerSettings::new(
Duration::from_millis(20), Some(Duration::from_millis(100)), Some(2), TimerType::Exponential,
);
let factory = TimerFactory::new(settings, tx);
let timer_id = 999;
let name = test_encoded_name();
let timer = factory.create_timer(timer_id);
factory.start_timer(
&timer,
ProtoSessionMessageType::DiscoveryRequest,
Some(name),
);
let first_timeout = timeout(Duration::from_millis(150), rx.recv())
.await
.expect("Should receive first timeout")
.expect("Should receive a message");
match first_timeout {
SessionMessage::TimerTimeout {
message_id,
message_type,
name: received_name,
timeouts,
} => {
assert_eq!(message_id, timer_id);
assert_eq!(message_type, ProtoSessionMessageType::DiscoveryRequest);
assert_eq!(timeouts, 1);
assert_eq!(received_name, Some(name));
}
_ => panic!("Expected TimerTimeout message for exponential timer first timeout"),
}
let second_timeout = timeout(Duration::from_millis(200), rx.recv())
.await
.expect("Should receive second timeout")
.expect("Should receive a message");
match second_timeout {
SessionMessage::TimerTimeout {
message_id,
message_type,
name: received_name,
timeouts,
} => {
assert_eq!(message_id, timer_id);
assert_eq!(message_type, ProtoSessionMessageType::DiscoveryRequest);
assert_eq!(timeouts, 2);
assert_eq!(received_name, Some(name));
}
_ => panic!("Expected TimerTimeout message for exponential timer second timeout"),
}
}
#[tokio::test]
async fn test_timer_settings_with_all_options() {
let (tx, _rx) = mpsc::channel(10);
let duration = Duration::from_millis(500);
let max_duration = Some(Duration::from_secs(5));
let max_retries = Some(10);
let timer_type = TimerType::Exponential;
let settings = TimerSettings::new(duration, max_duration, max_retries, timer_type);
let factory = TimerFactory::new(settings, tx);
assert_eq!(factory.settings.duration, duration);
assert_eq!(factory.settings.max_duration, max_duration);
assert_eq!(factory.settings.max_retries, max_retries);
matches!(factory.settings.timer_type, TimerType::Exponential);
}
#[tokio::test]
async fn test_multiple_timers() {
let (tx, mut rx) = mpsc::channel(20);
let settings = TimerSettings::new(
Duration::from_millis(50),
None,
Some(1),
TimerType::Constant,
);
let factory = TimerFactory::new(settings, tx);
let name1 = ProtoName::from_strings(["test", "org", "app1"])
.with_id(1)
.name
.unwrap();
let name2 = ProtoName::from_strings(["test", "org", "app2"])
.with_id(2)
.name
.unwrap();
let timer1 = factory.create_and_start_timer(
100,
ProtoSessionMessageType::DiscoveryRequest,
Some(name1),
);
let timer2 = factory.create_and_start_timer(
200,
ProtoSessionMessageType::DiscoveryRequest,
Some(name2),
);
let mut received_ids = Vec::new();
for _ in 0..2 {
let message = timeout(Duration::from_millis(200), rx.recv())
.await
.expect("Should receive a message within timeout")
.expect("Should receive a message");
match message {
SessionMessage::TimerTimeout {
message_id,
message_type: _,
name: _,
timeouts,
} => {
received_ids.push(message_id);
assert_eq!(timeouts, 1);
}
_ => panic!("Expected TimerTimeout message in multiple timers test"),
}
}
received_ids.sort();
assert_eq!(received_ids, vec![100, 200]);
drop(timer1);
drop(timer2);
}
#[test]
fn test_timer_settings_creation() {
let settings1 =
TimerSettings::new(Duration::from_millis(100), None, None, TimerType::Constant);
assert_eq!(settings1.duration, Duration::from_millis(100));
assert!(settings1.max_duration.is_none());
assert!(settings1.max_retries.is_none());
matches!(settings1.timer_type, TimerType::Constant);
let settings2 = TimerSettings::new(
Duration::from_secs(1),
Some(Duration::from_secs(10)),
Some(5),
TimerType::Exponential,
);
assert_eq!(settings2.duration, Duration::from_secs(1));
assert_eq!(settings2.max_duration, Some(Duration::from_secs(10)));
assert_eq!(settings2.max_retries, Some(5));
matches!(settings2.timer_type, TimerType::Exponential);
}
#[test]
fn test_timer_settings_convenience_constructors() {
let constant_settings = TimerSettings::constant(Duration::from_millis(500));
assert_eq!(constant_settings.duration, Duration::from_millis(500));
assert!(constant_settings.max_duration.is_none());
assert!(constant_settings.max_retries.is_none());
matches!(constant_settings.timer_type, TimerType::Constant);
let exponential_settings =
TimerSettings::exponential(Duration::from_millis(100), Some(Duration::from_secs(5)));
assert_eq!(exponential_settings.duration, Duration::from_millis(100));
assert_eq!(
exponential_settings.max_duration,
Some(Duration::from_secs(5))
);
assert!(exponential_settings.max_retries.is_none());
matches!(exponential_settings.timer_type, TimerType::Exponential);
let settings_with_retries =
TimerSettings::constant(Duration::from_millis(250)).with_max_retries(10);
assert_eq!(settings_with_retries.duration, Duration::from_millis(250));
assert_eq!(settings_with_retries.max_retries, Some(10));
matches!(settings_with_retries.timer_type, TimerType::Constant);
}
#[tokio::test]
async fn test_timer_factory_with_convenience_constructors() {
let (tx, mut rx) = mpsc::channel(10);
let settings = TimerSettings::constant(Duration::from_millis(40)).with_max_retries(1);
let factory = TimerFactory::new(settings, tx);
let timer_id = 888;
let name = test_encoded_name();
let _timer = factory.create_and_start_timer(
timer_id,
ProtoSessionMessageType::DiscoveryRequest,
Some(name),
);
let timeout_message = timeout(Duration::from_millis(100), rx.recv())
.await
.expect("Should receive timeout message")
.expect("Should receive a message");
match timeout_message {
SessionMessage::TimerTimeout {
message_id,
message_type,
name: received_name,
timeouts,
} => {
assert_eq!(message_id, timer_id);
assert_eq!(message_type, ProtoSessionMessageType::DiscoveryRequest);
assert_eq!(timeouts, 1);
assert_eq!(received_name, Some(name));
}
_ => panic!("Expected TimerTimeout message with convenience constructors"),
}
}
}