use crate::error::{Error, Recoverability};
use crate::types::Time;
use std::sync::atomic::{AtomicU8, AtomicU32, AtomicU64, Ordering};
use std::time::Duration;
pub trait RecoveryStrategy: Send + Sync {
fn should_retry(&self, error: &Error, attempt: u32) -> bool;
fn backoff_duration(&self, attempt: u32) -> Duration;
fn on_success(&self, attempts: u32);
fn on_give_up(&self, error: &Error, attempts: u32);
}
#[derive(Debug)]
pub struct ExponentialBackoff {
initial: Duration,
max: Duration,
multiplier: f64,
max_attempts: u32,
jitter: f64,
}
impl ExponentialBackoff {
#[inline]
#[must_use]
pub fn new(initial: Duration, max: Duration, multiplier: f64, max_attempts: u32) -> Self {
Self {
initial,
max,
multiplier,
max_attempts,
jitter: 0.1, }
}
#[inline]
#[must_use]
pub fn with_jitter(mut self, jitter: f64) -> Self {
self.jitter = if jitter.is_finite() {
jitter.clamp(0.0, 1.0)
} else {
0.0
};
self
}
}
impl RecoveryStrategy for ExponentialBackoff {
fn should_retry(&self, error: &Error, attempt: u32) -> bool {
if attempt >= self.max_attempts {
return false;
}
error.recoverability() == Recoverability::Transient
}
#[allow(
clippy::cast_possible_wrap,
clippy::cast_precision_loss,
clippy::cast_sign_loss
)]
fn backoff_duration(&self, attempt: u32) -> Duration {
let exponent = i32::try_from(attempt).unwrap_or(i32::MAX);
let factor = self.multiplier.powi(exponent);
let mut base_ms = (self.initial.as_millis() as f64 * factor) as u64;
let max_ms = self.max.as_millis().min(u128::from(u64::MAX)) as u64;
if base_ms > max_ms {
base_ms = max_ms;
}
let jitter_amount = (base_ms as f64 * self.jitter) as u64;
let with_jitter = if jitter_amount == 0 {
base_ms
} else {
let jitter_range = jitter_amount.saturating_mul(2).max(1);
let variation = u64::from(attempt).wrapping_mul(31) % jitter_range;
base_ms
.saturating_sub(jitter_amount)
.saturating_add(variation)
};
Duration::from_millis(with_jitter)
}
fn on_success(&self, _attempts: u32) {}
fn on_give_up(&self, _error: &Error, _attempts: u32) {}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum CircuitState {
Closed = 0, Open = 1, HalfOpen = 2, }
#[derive(Debug)]
pub struct CircuitBreaker {
failure_threshold: u32,
recovery_timeout: Duration,
state: AtomicU8,
failures: AtomicU32,
last_failure_time: AtomicU64, successes_needed: u32,
consecutive_successes: AtomicU32,
}
impl CircuitBreaker {
#[inline]
#[must_use]
pub fn new(failure_threshold: u32, recovery_timeout: Duration) -> Self {
Self {
failure_threshold,
recovery_timeout,
state: AtomicU8::new(CircuitState::Closed as u8),
failures: AtomicU32::new(0),
last_failure_time: AtomicU64::new(0),
successes_needed: 1, consecutive_successes: AtomicU32::new(0),
}
}
pub fn should_try(&self, now: Time) -> bool {
match self.state() {
CircuitState::Closed | CircuitState::HalfOpen => true,
CircuitState::Open => {
let last = Time::from_nanos(self.last_failure_time.load(Ordering::Relaxed));
let timeout_nanos =
self.recovery_timeout.as_nanos().min(u128::from(u64::MAX)) as u64;
if now >= last.saturating_add_nanos(timeout_nanos) {
if self.transition(CircuitState::Open, CircuitState::HalfOpen) {
self.consecutive_successes.store(0, Ordering::Relaxed);
return true;
}
self.state() == CircuitState::HalfOpen
} else {
false
}
}
}
}
pub fn record_success(&self) {
if self.state() == CircuitState::HalfOpen {
let successes = self.consecutive_successes.fetch_add(1, Ordering::Relaxed) + 1;
if successes >= self.successes_needed {
self.reset();
}
} else {
self.failures.store(0, Ordering::Relaxed);
}
}
pub fn record_failure(&self, now: Time) {
self.last_failure_time
.store(now.as_nanos(), Ordering::Relaxed);
match self.state() {
CircuitState::Closed => {
let failures = self.failures.fetch_add(1, Ordering::Relaxed) + 1;
if failures >= self.failure_threshold {
self.transition(CircuitState::Closed, CircuitState::Open);
}
}
CircuitState::HalfOpen => {
self.transition(CircuitState::HalfOpen, CircuitState::Open);
}
CircuitState::Open => {
}
}
}
#[inline]
fn state(&self) -> CircuitState {
match self.state.load(Ordering::Acquire) {
0 => CircuitState::Closed,
2 => CircuitState::HalfOpen,
_ => CircuitState::Open,
}
}
#[inline]
fn transition(&self, from: CircuitState, to: CircuitState) -> bool {
self.state
.compare_exchange(from as u8, to as u8, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
}
fn reset(&self) {
if self.transition(CircuitState::HalfOpen, CircuitState::Closed) {
self.failures.store(0, Ordering::Relaxed);
self.consecutive_successes.store(0, Ordering::Relaxed);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn backoff_increases() {
let backoff =
ExponentialBackoff::new(Duration::from_millis(10), Duration::from_secs(1), 2.0, 5)
.with_jitter(0.0);
assert_eq!(backoff.backoff_duration(0), Duration::from_millis(10));
assert_eq!(backoff.backoff_duration(1), Duration::from_millis(20));
assert_eq!(backoff.backoff_duration(2), Duration::from_millis(40));
}
#[test]
fn backoff_saturates_for_attempts_beyond_i32_range() {
let backoff = ExponentialBackoff::new(
Duration::from_millis(10),
Duration::from_secs(1),
2.0,
u32::MAX,
)
.with_jitter(0.0);
let duration = backoff.backoff_duration((i32::MAX as u32).saturating_add(1));
assert_eq!(
duration,
Duration::from_secs(1),
"large attempts must saturate at max delay instead of wrapping the exponent"
);
}
#[test]
fn jitter_is_clamped_to_documented_range() {
let high =
ExponentialBackoff::new(Duration::from_millis(10), Duration::from_secs(1), 2.0, 5)
.with_jitter(5.0);
assert_eq!(high.jitter.to_bits(), 1.0f64.to_bits());
let low =
ExponentialBackoff::new(Duration::from_millis(10), Duration::from_secs(1), 2.0, 5)
.with_jitter(-1.0);
assert_eq!(low.jitter.to_bits(), 0.0f64.to_bits());
let nan =
ExponentialBackoff::new(Duration::from_millis(10), Duration::from_secs(1), 2.0, 5)
.with_jitter(f64::NAN);
assert_eq!(nan.jitter.to_bits(), 0.0f64.to_bits());
}
#[test]
fn circuit_breaker_trips() {
let cb = CircuitBreaker::new(2, Duration::from_secs(1));
let t0 = Time::from_secs(100);
assert!(cb.should_try(t0));
cb.record_failure(t0); assert!(cb.should_try(t0));
cb.record_failure(t0);
assert!(!cb.should_try(t0));
assert_eq!(cb.state(), CircuitState::Open);
}
#[test]
fn circuit_breaker_recovers() {
let cb = CircuitBreaker::new(1, Duration::from_secs(1));
let t0 = Time::from_secs(100);
let t1 = Time::from_secs(102);
cb.record_failure(t0); assert!(!cb.should_try(t0));
assert!(cb.should_try(t1)); assert_eq!(cb.state(), CircuitState::HalfOpen);
cb.record_success();
assert_eq!(cb.state(), CircuitState::Closed);
}
#[test]
fn reset_does_not_overwrite_concurrent_open() {
let cb = CircuitBreaker::new(1, Duration::from_secs(1));
let t0 = Time::from_secs(100);
let t1 = Time::from_secs(102);
cb.record_failure(t0); assert!(cb.should_try(t1));
cb.record_failure(t1);
assert_eq!(cb.state(), CircuitState::Open);
cb.reset();
assert_eq!(
cb.state(),
CircuitState::Open,
"reset must not overwrite concurrent HalfOpen→Open transition"
);
}
#[test]
fn circuit_state_debug_clone_copy_eq() {
let s = CircuitState::Closed;
let s2 = s; let s3 = s;
assert_eq!(s, s2);
assert_eq!(s, s3);
assert_ne!(s, CircuitState::Open);
assert_ne!(s, CircuitState::HalfOpen);
let dbg = format!("{s:?}");
assert!(dbg.contains("Closed"));
}
}