use std::collections::VecDeque;
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU8, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionState {
Disconnected = 0,
Connecting = 1,
Connected = 2,
Reconnecting = 3,
}
impl ConnectionState {
fn from_u8(v: u8) -> Self {
match v {
1 => Self::Connecting,
2 => Self::Connected,
3 => Self::Reconnecting,
_ => Self::Disconnected,
}
}
}
#[derive(Debug, Clone)]
pub struct ResiliencyConfig {
pub heartbeat_interval: Duration,
pub heartbeat_timeout: Duration,
pub base_backoff: Duration,
pub max_backoff: Duration,
pub max_reconnect_attempts: Option<u32>,
pub jitter_percent: u8,
pub message_buffer_size: usize,
}
impl Default for ResiliencyConfig {
fn default() -> Self {
Self {
heartbeat_interval: Duration::from_secs(15),
heartbeat_timeout: Duration::from_secs(5),
base_backoff: Duration::from_secs(1),
max_backoff: Duration::from_secs(30),
max_reconnect_attempts: None,
jitter_percent: 20,
message_buffer_size: 256,
}
}
}
pub struct ConnectionManager {
state: Arc<AtomicU8>,
connected: Arc<AtomicBool>,
reconnect_attempts: Arc<AtomicU32>,
config: ResiliencyConfig,
}
impl ConnectionManager {
pub fn new(config: ResiliencyConfig) -> Self {
Self {
state: Arc::new(AtomicU8::new(ConnectionState::Disconnected as u8)),
connected: Arc::new(AtomicBool::new(false)),
reconnect_attempts: Arc::new(AtomicU32::new(0)),
config,
}
}
pub fn backoff_duration(&self, attempt: u32) -> Duration {
let base_ms = u64::try_from(self.config.base_backoff.as_millis()).unwrap_or(u64::MAX);
let exp_ms = base_ms.saturating_mul(1u64 << attempt.min(31));
let max_ms = u64::try_from(self.config.max_backoff.as_millis()).unwrap_or(u64::MAX);
let capped_ms = exp_ms.min(max_ms);
let jitter_range = capped_ms * u64::from(self.config.jitter_percent) / 100;
let jitter = if jitter_range > 0 {
let hash = u64::from(attempt).wrapping_mul(2_654_435_761);
hash % (jitter_range + 1)
} else {
0
};
Duration::from_millis(capped_ms.saturating_sub(jitter))
}
pub fn connected(&self) {
self.state.store(ConnectionState::Connected as u8, Ordering::SeqCst);
self.connected.store(true, Ordering::SeqCst);
self.reconnect_attempts.store(0, Ordering::SeqCst);
}
pub fn disconnected(&self) {
self.state.store(ConnectionState::Disconnected as u8, Ordering::SeqCst);
self.connected.store(false, Ordering::SeqCst);
}
pub fn set_connecting(&self) {
self.state.store(ConnectionState::Connecting as u8, Ordering::SeqCst);
}
pub fn set_reconnecting(&self) {
self.state.store(ConnectionState::Reconnecting as u8, Ordering::SeqCst);
self.reconnect_attempts.fetch_add(1, Ordering::SeqCst);
}
pub fn should_reconnect(&self) -> bool {
self.config
.max_reconnect_attempts
.is_none_or(|max| self.reconnect_attempts.load(Ordering::SeqCst) < max)
}
pub fn state(&self) -> ConnectionState {
ConnectionState::from_u8(self.state.load(Ordering::SeqCst))
}
pub fn is_connected(&self) -> bool {
self.connected.load(Ordering::SeqCst)
}
pub fn reset(&self) {
self.reconnect_attempts.store(0, Ordering::SeqCst);
}
pub fn reconnect_attempts(&self) -> u32 {
self.reconnect_attempts.load(Ordering::SeqCst)
}
pub fn config(&self) -> &ResiliencyConfig {
&self.config
}
}
impl std::fmt::Debug for ConnectionManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConnectionManager")
.field("state", &self.state())
.field("reconnect_attempts", &self.reconnect_attempts())
.finish()
}
}
pub struct MessageBuffer {
queue: Mutex<VecDeque<String>>,
max_size: usize,
}
impl MessageBuffer {
pub fn new(max_size: usize) -> Self {
Self {
queue: Mutex::new(VecDeque::with_capacity(max_size.min(1024))),
max_size,
}
}
pub fn enqueue(&self, message: String) -> bool {
let mut q = self.queue.lock().expect("MessageBuffer lock poisoned");
if q.len() >= self.max_size {
return false;
}
q.push_back(message);
true
}
pub fn drain(&self) -> Vec<String> {
let mut q = self.queue.lock().expect("MessageBuffer lock poisoned");
q.drain(..).collect()
}
pub fn len(&self) -> usize {
self.queue.lock().expect("MessageBuffer lock poisoned").len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl std::fmt::Debug for MessageBuffer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MessageBuffer")
.field("len", &self.len())
.field("max_size", &self.max_size)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn resiliency_config_defaults() {
let cfg = ResiliencyConfig::default();
assert_eq!(cfg.heartbeat_interval, Duration::from_secs(15));
assert_eq!(cfg.heartbeat_timeout, Duration::from_secs(5));
assert_eq!(cfg.base_backoff, Duration::from_secs(1));
assert_eq!(cfg.max_backoff, Duration::from_secs(30));
assert!(cfg.max_reconnect_attempts.is_none());
assert_eq!(cfg.jitter_percent, 20);
assert_eq!(cfg.message_buffer_size, 256);
}
#[test]
fn backoff_exponential_growth() {
let cfg = ResiliencyConfig {
jitter_percent: 0, ..Default::default()
};
let mgr = ConnectionManager::new(cfg);
assert_eq!(mgr.backoff_duration(0), Duration::from_secs(1));
assert_eq!(mgr.backoff_duration(1), Duration::from_secs(2));
assert_eq!(mgr.backoff_duration(2), Duration::from_secs(4));
assert_eq!(mgr.backoff_duration(3), Duration::from_secs(8));
}
#[test]
fn backoff_capped_at_max() {
let cfg = ResiliencyConfig {
jitter_percent: 0,
..Default::default()
};
let mgr = ConnectionManager::new(cfg);
assert_eq!(mgr.backoff_duration(10), Duration::from_secs(30));
assert_eq!(mgr.backoff_duration(20), Duration::from_secs(30));
}
#[test]
fn backoff_has_jitter() {
let cfg = ResiliencyConfig {
jitter_percent: 20,
..Default::default()
};
let mgr = ConnectionManager::new(cfg);
let d = mgr.backoff_duration(3);
assert!(d >= Duration::from_millis(6400), "duration {d:?} below expected floor");
assert!(d <= Duration::from_millis(8000), "duration {d:?} above expected ceiling");
let exact_values: Vec<bool> = (0..5).map(|a| mgr.backoff_duration(a).as_millis() == 1000 * (1u128 << a)).collect();
let all_exact = exact_values.iter().all(|&v| v);
assert!(!all_exact, "jitter should cause at least one attempt to differ from exact 2^n");
}
#[test]
fn connection_state_transitions() {
let mgr = ConnectionManager::new(ResiliencyConfig::default());
assert_eq!(mgr.state(), ConnectionState::Disconnected);
assert!(!mgr.is_connected());
mgr.set_connecting();
assert_eq!(mgr.state(), ConnectionState::Connecting);
assert!(!mgr.is_connected());
mgr.connected();
assert_eq!(mgr.state(), ConnectionState::Connected);
assert!(mgr.is_connected());
mgr.disconnected();
assert_eq!(mgr.state(), ConnectionState::Disconnected);
assert!(!mgr.is_connected());
mgr.set_reconnecting();
assert_eq!(mgr.state(), ConnectionState::Reconnecting);
assert_eq!(mgr.reconnect_attempts(), 1);
}
#[test]
fn should_reconnect_respects_max_attempts() {
let cfg = ResiliencyConfig {
max_reconnect_attempts: Some(3),
..Default::default()
};
let mgr = ConnectionManager::new(cfg);
assert!(mgr.should_reconnect()); mgr.set_reconnecting(); assert!(mgr.should_reconnect()); mgr.set_reconnecting(); assert!(mgr.should_reconnect()); mgr.set_reconnecting(); assert!(!mgr.should_reconnect());
let unlimited_mgr = ConnectionManager::new(ResiliencyConfig::default());
for _ in 0..100 {
unlimited_mgr.set_reconnecting();
}
assert!(unlimited_mgr.should_reconnect());
}
#[test]
fn message_buffer_enqueue_and_drain() {
let buf = MessageBuffer::new(10);
assert!(buf.is_empty());
assert_eq!(buf.len(), 0);
assert!(buf.enqueue("msg1".into()));
assert!(buf.enqueue("msg2".into()));
assert!(buf.enqueue("msg3".into()));
assert_eq!(buf.len(), 3);
assert!(!buf.is_empty());
let drained = buf.drain();
assert_eq!(drained, vec!["msg1", "msg2", "msg3"]);
assert!(buf.is_empty());
let empty = buf.drain();
assert!(empty.is_empty());
}
#[test]
fn message_buffer_rejects_when_full() {
let buf = MessageBuffer::new(3);
assert!(buf.enqueue("a".into()));
assert!(buf.enqueue("b".into()));
assert!(buf.enqueue("c".into()));
assert_eq!(buf.len(), 3);
assert!(!buf.enqueue("d".into()));
assert_eq!(buf.len(), 3);
let drained = buf.drain();
assert_eq!(drained.len(), 3);
assert!(buf.enqueue("d".into()));
assert_eq!(buf.len(), 1);
}
}