use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::time::{Duration, Instant};
use parking_lot::Mutex;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub failure_window: Duration,
pub reset_timeout: Duration,
pub success_threshold: u32,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
failure_window: Duration::from_secs(60),
reset_timeout: Duration::from_secs(30),
success_threshold: 1,
}
}
}
pub struct CircuitBreaker {
config: CircuitBreakerConfig,
failure_count: AtomicU32,
window_start: AtomicU64,
opened_at: AtomicU64,
half_open_successes: AtomicU32,
state_lock: Mutex<CircuitState>,
epoch: Instant,
}
impl CircuitBreaker {
pub fn new(config: CircuitBreakerConfig) -> Self {
Self {
config,
failure_count: AtomicU32::new(0),
window_start: AtomicU64::new(0),
opened_at: AtomicU64::new(0),
half_open_successes: AtomicU32::new(0),
state_lock: Mutex::new(CircuitState::Closed),
epoch: Instant::now(),
}
}
pub fn state(&self) -> CircuitState {
let mut state = self.state_lock.lock();
self.evaluate_state(&mut state);
*state
}
pub fn record_success(&self) {
let mut state = self.state_lock.lock();
self.evaluate_state(&mut state);
match *state {
CircuitState::Closed => {
self.failure_count.store(0, Ordering::SeqCst);
self.window_start.store(0, Ordering::SeqCst);
}
CircuitState::HalfOpen => {
let successes = self.half_open_successes.fetch_add(1, Ordering::SeqCst) + 1;
if successes >= self.config.success_threshold {
*state = CircuitState::Closed;
self.failure_count.store(0, Ordering::SeqCst);
self.window_start.store(0, Ordering::SeqCst);
self.half_open_successes.store(0, Ordering::SeqCst);
tracing::info!("circuit breaker closed after successful recovery");
}
}
CircuitState::Open => {
}
}
}
pub fn record_failure(&self) {
let mut state = self.state_lock.lock();
self.evaluate_state(&mut state);
let now = self.now_millis();
match *state {
CircuitState::Closed => {
let window_start = self.window_start.load(Ordering::SeqCst);
let window_ms = self.config.failure_window.as_millis() as u64;
if window_start == 0 || now > window_start + window_ms {
self.window_start.store(now, Ordering::SeqCst);
self.failure_count.store(1, Ordering::SeqCst);
if self.config.failure_threshold == 1 {
*state = CircuitState::Open;
self.opened_at.store(now, Ordering::SeqCst);
tracing::warn!(
failures = 1,
threshold = self.config.failure_threshold,
"circuit breaker opened"
);
}
} else {
let count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
if count >= self.config.failure_threshold {
*state = CircuitState::Open;
self.opened_at.store(now, Ordering::SeqCst);
tracing::warn!(
failures = count,
threshold = self.config.failure_threshold,
"circuit breaker opened"
);
}
}
}
CircuitState::HalfOpen => {
*state = CircuitState::Open;
self.opened_at.store(now, Ordering::SeqCst);
self.half_open_successes.store(0, Ordering::SeqCst);
tracing::warn!("circuit breaker reopened after half-open failure");
}
CircuitState::Open => {
}
}
}
fn evaluate_state(&self, state: &mut CircuitState) {
if *state == CircuitState::Open {
let now = self.now_millis();
let opened_at = self.opened_at.load(Ordering::SeqCst);
let reset_ms = self.config.reset_timeout.as_millis() as u64;
if now > opened_at + reset_ms {
*state = CircuitState::HalfOpen;
self.half_open_successes.store(0, Ordering::SeqCst);
tracing::info!("circuit breaker transitioned to half-open");
}
}
}
fn now_millis(&self) -> u64 {
let elapsed = self.epoch.elapsed().as_millis() as u64;
elapsed.max(1)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread::sleep;
#[test]
fn test_initial_state_closed() {
let cb = CircuitBreaker::new(CircuitBreakerConfig::default());
assert_eq!(cb.state(), CircuitState::Closed);
}
#[test]
fn test_opens_after_threshold() {
let config = CircuitBreakerConfig {
failure_threshold: 3,
failure_window: Duration::from_secs(60),
reset_timeout: Duration::from_secs(30),
success_threshold: 1,
};
let cb = CircuitBreaker::new(config);
cb.record_failure();
cb.record_failure();
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
}
#[test]
fn test_success_resets_count() {
let config = CircuitBreakerConfig {
failure_threshold: 3,
failure_window: Duration::from_secs(60),
reset_timeout: Duration::from_secs(30),
success_threshold: 1,
};
let cb = CircuitBreaker::new(config);
cb.record_failure();
cb.record_failure();
cb.record_success();
cb.record_failure();
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Closed);
}
#[test]
fn test_half_open_after_timeout() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
failure_window: Duration::from_secs(60),
reset_timeout: Duration::from_millis(50),
success_threshold: 1,
};
let cb = CircuitBreaker::new(config);
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
sleep(Duration::from_millis(100));
assert_eq!(cb.state(), CircuitState::HalfOpen);
}
#[test]
fn test_closes_after_half_open_success() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
failure_window: Duration::from_secs(60),
reset_timeout: Duration::from_millis(20),
success_threshold: 1,
};
let cb = CircuitBreaker::new(config);
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
sleep(Duration::from_millis(50));
assert_eq!(cb.state(), CircuitState::HalfOpen);
cb.record_success();
assert_eq!(cb.state(), CircuitState::Closed);
}
#[test]
fn test_reopens_on_half_open_failure() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
failure_window: Duration::from_secs(60),
reset_timeout: Duration::from_millis(20),
success_threshold: 1,
};
let cb = CircuitBreaker::new(config);
cb.record_failure();
sleep(Duration::from_millis(50));
assert_eq!(cb.state(), CircuitState::HalfOpen);
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
}
}