use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionState {
Connected,
Connecting {
attempts: u32,
},
Disconnected {
since: Instant,
failures: u32,
},
Banned {
until: Instant,
failures: u32,
},
}
impl ConnectionState {
pub const fn is_available(&self) -> bool {
matches!(self, Self::Connected)
}
pub fn should_reconnect(&self) -> bool {
match self {
Self::Disconnected { .. } => true,
Self::Banned { until, .. } => Instant::now() >= *until,
_ => false,
}
}
}
#[derive(Debug, Clone)]
pub struct ConnectionConfig {
pub base_delay: Duration,
pub max_delay: Duration,
pub max_retries: u32,
pub ban_duration: Duration,
}
impl Default for ConnectionConfig {
fn default() -> Self {
Self {
base_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(30),
max_retries: 5,
ban_duration: Duration::from_secs(60),
}
}
}
#[derive(Debug)]
pub struct ConnectionStateMachine {
state: ConnectionState,
config: ConnectionConfig,
}
impl ConnectionStateMachine {
pub const fn new(config: ConnectionConfig) -> Self {
Self { state: ConnectionState::Connecting { attempts: 0 }, config }
}
pub const fn state(&self) -> ConnectionState {
self.state
}
pub const fn is_available(&self) -> bool {
self.state.is_available()
}
pub const fn on_success(&mut self) {
self.state = ConnectionState::Connected;
}
pub fn on_failure(&mut self) {
self.state = match self.state {
ConnectionState::Connected => {
ConnectionState::Disconnected { since: Instant::now(), failures: 1 }
}
ConnectionState::Connecting { attempts } => {
let new_attempts = attempts + 1;
if new_attempts >= self.config.max_retries {
ConnectionState::Banned {
until: Instant::now() + self.config.ban_duration,
failures: new_attempts,
}
} else {
ConnectionState::Connecting { attempts: new_attempts }
}
}
ConnectionState::Disconnected { failures, .. } => {
let new_failures = failures + 1;
if new_failures >= self.config.max_retries {
ConnectionState::Banned {
until: Instant::now() + self.config.ban_duration,
failures: new_failures,
}
} else {
ConnectionState::Disconnected { since: Instant::now(), failures: new_failures }
}
}
ConnectionState::Banned { failures, .. } => {
ConnectionState::Banned {
until: Instant::now() + self.config.ban_duration,
failures: failures + 1,
}
}
};
}
pub fn backoff_duration(&self) -> Duration {
match self.state {
ConnectionState::Connecting { attempts }
| ConnectionState::Disconnected { failures: attempts, .. } => {
let delay = self.config.base_delay.saturating_mul(2u32.saturating_pow(attempts));
delay.min(self.config.max_delay)
}
ConnectionState::Banned { until, .. } => {
until.saturating_duration_since(Instant::now())
}
ConnectionState::Connected => Duration::ZERO,
}
}
pub fn try_reconnect(&mut self) -> bool {
if self.state.should_reconnect() {
self.state = ConnectionState::Connecting { attempts: 0 };
true
} else {
false
}
}
pub const fn reset(&mut self) {
self.state = ConnectionState::Connecting { attempts: 0 };
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_initial_state() {
let sm = ConnectionStateMachine::new(ConnectionConfig::default());
assert!(matches!(sm.state(), ConnectionState::Connecting { attempts: 0 }));
assert!(!sm.is_available());
}
#[test]
fn test_success_transitions_to_connected() {
let mut sm = ConnectionStateMachine::new(ConnectionConfig::default());
sm.on_success();
assert!(matches!(sm.state(), ConnectionState::Connected));
assert!(sm.is_available());
}
#[test]
fn test_failure_from_connected() {
let mut sm = ConnectionStateMachine::new(ConnectionConfig::default());
sm.on_success();
sm.on_failure();
assert!(matches!(sm.state(), ConnectionState::Disconnected { failures: 1, .. }));
assert!(!sm.is_available());
}
#[test]
fn test_repeated_failures_lead_to_ban() {
let config = ConnectionConfig { max_retries: 3, ..Default::default() };
let mut sm = ConnectionStateMachine::new(config);
sm.on_failure(); sm.on_failure(); sm.on_failure();
assert!(matches!(sm.state(), ConnectionState::Banned { .. }));
}
#[test]
fn test_backoff_exponential() {
let config = ConnectionConfig {
base_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(10),
max_retries: 10,
..Default::default()
};
let mut sm = ConnectionStateMachine::new(config);
assert_eq!(sm.backoff_duration(), Duration::from_millis(100));
sm.on_failure();
assert_eq!(sm.backoff_duration(), Duration::from_millis(200));
sm.on_failure();
assert_eq!(sm.backoff_duration(), Duration::from_millis(400));
}
#[test]
fn test_backoff_capped_at_max() {
let config = ConnectionConfig {
base_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(5),
max_retries: 10,
..Default::default()
};
let mut sm = ConnectionStateMachine::new(config);
for _ in 0..5 {
sm.on_failure();
}
assert!(sm.backoff_duration() <= Duration::from_secs(5));
}
#[test]
fn test_reset() {
let mut sm = ConnectionStateMachine::new(ConnectionConfig::default());
sm.on_success();
sm.on_failure();
sm.on_failure();
sm.reset();
assert!(matches!(sm.state(), ConnectionState::Connecting { attempts: 0 }));
}
}