use std::time::Duration;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum CircuitState {
Closed = 0,
Open = 1,
HalfOpen = 2,
}
impl From<u8> for CircuitState {
fn from(value: u8) -> Self {
match value {
0 => Self::Closed,
1 => Self::Open,
2 => Self::HalfOpen,
_ => Self::Closed,
}
}
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub recovery_timeout: Duration,
pub success_threshold: u32,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
recovery_timeout: Duration::from_secs(30),
success_threshold: 1,
}
}
}
impl CircuitBreakerConfig {
pub fn new(failure_threshold: u32, recovery_timeout: Duration) -> Self {
Self {
failure_threshold,
recovery_timeout,
success_threshold: 1,
}
}
pub fn with_success_threshold(mut self, threshold: u32) -> Self {
self.success_threshold = threshold;
self
}
}
pub trait CircuitBreaker: Send + Sync {
fn state(&self) -> CircuitState;
fn is_open(&self) -> bool {
self.state() == CircuitState::Open
}
fn allows_request(&self) -> bool {
self.state() != CircuitState::Open
}
fn record_success(&self);
fn record_failure(&self);
fn trip(&self);
fn reset(&self);
fn time_until_half_open(&self) -> Option<Duration>;
fn failure_count(&self) -> u32;
}
pub struct AtomicCircuitBreaker {
pub state_val: std::sync::atomic::AtomicU8,
failure_count_val: std::sync::atomic::AtomicU32,
success_count_val: std::sync::atomic::AtomicU32,
last_failure: parking_lot::Mutex<Option<std::time::Instant>>,
config: CircuitBreakerConfig,
}
impl AtomicCircuitBreaker {
pub fn new(config: CircuitBreakerConfig) -> Self {
Self {
state_val: std::sync::atomic::AtomicU8::new(CircuitState::Closed as u8),
failure_count_val: std::sync::atomic::AtomicU32::new(0),
success_count_val: std::sync::atomic::AtomicU32::new(0),
last_failure: parking_lot::Mutex::new(None),
config,
}
}
pub fn with_defaults() -> Self {
Self::new(CircuitBreakerConfig::default())
}
}
impl CircuitBreaker for AtomicCircuitBreaker {
fn state(&self) -> CircuitState {
use std::sync::atomic::Ordering;
let state = CircuitState::from(self.state_val.load(Ordering::Acquire));
if state == CircuitState::Open
&& let Some(last) = *self.last_failure.lock()
&& last.elapsed() >= self.config.recovery_timeout
&& self
.state_val
.compare_exchange(
CircuitState::Open as u8,
CircuitState::HalfOpen as u8,
Ordering::AcqRel,
Ordering::Acquire,
)
.is_ok()
{
self.success_count_val.store(0, Ordering::Release);
return CircuitState::HalfOpen;
}
state
}
fn record_success(&self) {
use std::sync::atomic::Ordering;
let state = CircuitState::from(self.state_val.load(Ordering::Acquire));
match state {
CircuitState::Closed => {
self.failure_count_val.store(0, Ordering::Release);
}
CircuitState::HalfOpen => {
let successes = self.success_count_val.fetch_add(1, Ordering::AcqRel) + 1;
if successes >= self.config.success_threshold {
self.state_val
.store(CircuitState::Closed as u8, Ordering::Release);
self.failure_count_val.store(0, Ordering::Release);
self.success_count_val.store(0, Ordering::Release);
*self.last_failure.lock() = None;
}
}
CircuitState::Open => {}
}
}
fn record_failure(&self) {
use std::sync::atomic::Ordering;
*self.last_failure.lock() = Some(std::time::Instant::now());
let state = CircuitState::from(self.state_val.load(Ordering::Acquire));
match state {
CircuitState::Closed => {
let count = self.failure_count_val.fetch_add(1, Ordering::AcqRel) + 1;
if count >= self.config.failure_threshold {
self.state_val
.store(CircuitState::Open as u8, Ordering::Release);
}
}
CircuitState::HalfOpen => {
self.state_val
.store(CircuitState::Open as u8, Ordering::Release);
self.success_count_val.store(0, Ordering::Release);
}
CircuitState::Open => {}
}
}
fn trip(&self) {
use std::sync::atomic::Ordering;
self.state_val
.store(CircuitState::Open as u8, Ordering::Release);
*self.last_failure.lock() = Some(std::time::Instant::now());
}
fn reset(&self) {
use std::sync::atomic::Ordering;
self.state_val
.store(CircuitState::Closed as u8, Ordering::Release);
self.failure_count_val.store(0, Ordering::Release);
self.success_count_val.store(0, Ordering::Release);
*self.last_failure.lock() = None;
}
fn time_until_half_open(&self) -> Option<Duration> {
use std::sync::atomic::Ordering;
let state = CircuitState::from(self.state_val.load(Ordering::Acquire));
if state != CircuitState::Open {
return None;
}
let last = (*self.last_failure.lock())?;
let elapsed = last.elapsed();
if elapsed >= self.config.recovery_timeout {
Some(Duration::ZERO)
} else {
Some(self.config.recovery_timeout - elapsed)
}
}
fn failure_count(&self) -> u32 {
self.failure_count_val
.load(std::sync::atomic::Ordering::Acquire)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_starts_closed() {
let cb = AtomicCircuitBreaker::with_defaults();
assert_eq!(cb.state(), CircuitState::Closed);
assert!(cb.allows_request());
}
#[test]
fn test_opens_after_threshold() {
let config = CircuitBreakerConfig::new(3, Duration::from_secs(30));
let cb = AtomicCircuitBreaker::new(config);
cb.record_failure();
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Closed);
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
assert!(!cb.allows_request());
}
#[test]
fn test_transitions_to_half_open() {
let config = CircuitBreakerConfig::new(1, Duration::from_millis(10));
let cb = AtomicCircuitBreaker::new(config);
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
thread::sleep(Duration::from_millis(15));
assert_eq!(cb.state(), CircuitState::HalfOpen);
}
#[test]
fn test_closes_on_success_in_half_open() {
let config = CircuitBreakerConfig::new(1, Duration::from_millis(10));
let cb = AtomicCircuitBreaker::new(config);
cb.record_failure();
thread::sleep(Duration::from_millis(15));
assert_eq!(cb.state(), CircuitState::HalfOpen);
cb.record_success();
assert_eq!(cb.state(), CircuitState::Closed);
}
#[test]
fn test_success_threshold() {
let config =
CircuitBreakerConfig::new(1, Duration::from_millis(10)).with_success_threshold(3);
let cb = AtomicCircuitBreaker::new(config);
cb.record_failure();
thread::sleep(Duration::from_millis(15));
assert_eq!(cb.state(), CircuitState::HalfOpen);
cb.record_success();
assert_eq!(cb.state(), CircuitState::HalfOpen);
cb.record_success();
assert_eq!(cb.state(), CircuitState::HalfOpen);
cb.record_success();
assert_eq!(cb.state(), CircuitState::Closed); }
#[test]
fn test_trip_and_reset() {
let cb = AtomicCircuitBreaker::with_defaults();
cb.trip();
assert_eq!(cb.state(), CircuitState::Open);
cb.reset();
assert_eq!(cb.state(), CircuitState::Closed);
}
#[test]
fn test_thread_safety() {
use std::sync::Arc;
let config = CircuitBreakerConfig::new(100, Duration::from_secs(30));
let cb = Arc::new(AtomicCircuitBreaker::new(config));
let handles: Vec<_> = (0..10)
.map(|_| {
let cb = Arc::clone(&cb);
thread::spawn(move || {
for _ in 0..10 {
cb.record_failure();
}
})
})
.collect();
for h in handles {
h.join().unwrap();
}
assert_eq!(cb.state(), CircuitState::Open);
}
}