use crate::channel::ChannelStatus;
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct ReconnectPolicy {
pub initial_delay: Duration,
pub max_delay: Duration,
pub max_attempts: u32,
}
impl Default for ReconnectPolicy {
fn default() -> Self {
Self {
initial_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(300), max_attempts: 10,
}
}
}
#[derive(Debug, Clone)]
pub struct ReconnectState {
pub attempts: u32,
pub last_attempt: Option<std::time::Instant>,
pub current_delay: Duration,
policy: ReconnectPolicy,
}
impl ReconnectState {
pub fn new(policy: ReconnectPolicy) -> Self {
let initial = policy.initial_delay;
Self {
attempts: 0,
last_attempt: None,
current_delay: initial,
policy,
}
}
pub fn next_delay(&mut self) -> Duration {
self.attempts += 1;
self.last_attempt = Some(std::time::Instant::now());
let delay = self.current_delay;
self.current_delay =
std::cmp::min(self.current_delay.saturating_mul(2), self.policy.max_delay);
delay
}
pub fn should_mark_failed(&self) -> bool {
self.attempts >= self.policy.max_attempts
}
pub fn reset(&mut self) {
self.attempts = 0;
self.last_attempt = None;
self.current_delay = self.policy.initial_delay;
}
pub fn channel_status(&self) -> ChannelStatus {
if self.should_mark_failed() {
ChannelStatus::Failed
} else if self.attempts > 0 {
ChannelStatus::Reconnecting
} else {
ChannelStatus::Connected
}
}
pub fn policy(&self) -> &ReconnectPolicy {
&self.policy
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_policy_values() {
let p = ReconnectPolicy::default();
assert_eq!(p.initial_delay, Duration::from_secs(1));
assert_eq!(p.max_delay, Duration::from_secs(300));
assert_eq!(p.max_attempts, 10);
}
#[test]
fn exponential_backoff_doubles() {
let policy = ReconnectPolicy {
initial_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(300),
max_attempts: 10,
};
let mut state = ReconnectState::new(policy);
assert_eq!(state.next_delay(), Duration::from_secs(1));
assert_eq!(state.next_delay(), Duration::from_secs(2));
assert_eq!(state.next_delay(), Duration::from_secs(4));
assert_eq!(state.next_delay(), Duration::from_secs(8));
assert_eq!(state.next_delay(), Duration::from_secs(16));
}
#[test]
fn backoff_caps_at_max_delay() {
let policy = ReconnectPolicy {
initial_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(10),
max_attempts: 20,
};
let mut state = ReconnectState::new(policy);
assert_eq!(state.next_delay(), Duration::from_secs(1));
assert_eq!(state.next_delay(), Duration::from_secs(2));
assert_eq!(state.next_delay(), Duration::from_secs(4));
assert_eq!(state.next_delay(), Duration::from_secs(8));
assert_eq!(state.next_delay(), Duration::from_secs(10));
assert_eq!(state.next_delay(), Duration::from_secs(10));
}
#[test]
fn should_mark_failed_after_max_attempts() {
let policy = ReconnectPolicy {
initial_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(300),
max_attempts: 3,
};
let mut state = ReconnectState::new(policy);
assert!(!state.should_mark_failed());
state.next_delay();
assert!(!state.should_mark_failed());
state.next_delay();
assert!(!state.should_mark_failed());
state.next_delay(); assert!(state.should_mark_failed());
}
#[test]
fn reset_clears_state() {
let policy = ReconnectPolicy {
initial_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(300),
max_attempts: 10,
};
let mut state = ReconnectState::new(policy);
state.next_delay();
state.next_delay();
assert_eq!(state.attempts, 2);
assert!(state.last_attempt.is_some());
state.reset();
assert_eq!(state.attempts, 0);
assert!(state.last_attempt.is_none());
assert_eq!(state.current_delay, Duration::from_secs(1));
assert_eq!(state.next_delay(), Duration::from_secs(1));
}
#[test]
fn channel_status_transitions() {
let policy = ReconnectPolicy {
initial_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(300),
max_attempts: 3,
};
let mut state = ReconnectState::new(policy);
assert_eq!(state.channel_status(), ChannelStatus::Connected);
state.next_delay();
assert_eq!(state.channel_status(), ChannelStatus::Reconnecting);
state.next_delay();
assert_eq!(state.channel_status(), ChannelStatus::Reconnecting);
state.next_delay(); assert_eq!(state.channel_status(), ChannelStatus::Failed);
state.reset();
assert_eq!(state.channel_status(), ChannelStatus::Connected);
}
#[test]
fn zero_attempts_is_not_failed() {
let policy = ReconnectPolicy {
initial_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(300),
max_attempts: 0,
};
let state = ReconnectState::new(policy);
assert!(state.should_mark_failed());
}
}