use std::sync::atomic::{AtomicU64, AtomicU8, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed = 0,
Open = 1,
HalfOpen = 2,
}
impl From<u8> for CircuitState {
fn from(value: u8) -> Self {
match value {
0 => CircuitState::Closed,
1 => CircuitState::Open,
2 => CircuitState::HalfOpen,
_ => CircuitState::Closed, }
}
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub recovery_timeout: Duration,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
recovery_timeout: Duration::from_secs(30),
}
}
}
#[derive(Debug)]
pub struct CircuitBreaker {
state: AtomicU8,
consecutive_failures: AtomicU64,
last_failure_time_nanos: AtomicU64,
config: CircuitBreakerConfig,
epoch: Instant,
}
impl CircuitBreaker {
pub fn new() -> Self {
Self::with_config(CircuitBreakerConfig::default())
}
pub fn with_config(config: CircuitBreakerConfig) -> Self {
Self {
state: AtomicU8::new(CircuitState::Closed as u8),
consecutive_failures: AtomicU64::new(0),
last_failure_time_nanos: AtomicU64::new(0),
config,
epoch: Instant::now(),
}
}
pub fn state(&self) -> CircuitState {
CircuitState::from(self.state.load(Ordering::Acquire))
}
pub fn allow_request(&self) -> bool {
match self.state() {
CircuitState::Closed => true,
CircuitState::Open => {
let now = Instant::now();
let last_failure = self.last_failure_time();
if now.duration_since(last_failure) >= self.config.recovery_timeout {
let result = self.state.compare_exchange(
CircuitState::Open as u8,
CircuitState::HalfOpen as u8,
Ordering::AcqRel,
Ordering::Acquire,
);
result.is_ok() || self.state() == CircuitState::HalfOpen
} else {
false
}
}
CircuitState::HalfOpen => {
true
}
}
}
pub fn record_success(&self) {
let current_state = self.state();
match current_state {
CircuitState::HalfOpen => {
self.consecutive_failures.store(0, Ordering::Release);
self.state
.store(CircuitState::Closed as u8, Ordering::Release);
}
CircuitState::Closed => {
self.consecutive_failures.store(0, Ordering::Release);
}
CircuitState::Open => {
}
}
}
pub fn record_failure(&self) {
let failures = self.consecutive_failures.fetch_add(1, Ordering::Relaxed) + 1;
let now = Instant::now();
let nanos = now
.duration_since(self.epoch)
.as_nanos()
.try_into()
.unwrap_or(u64::MAX);
self.last_failure_time_nanos.store(nanos, Ordering::Release);
let current_state = self.state();
match current_state {
CircuitState::HalfOpen => {
self.state
.store(CircuitState::Open as u8, Ordering::Release);
}
CircuitState::Closed => {
if failures >= self.config.failure_threshold as u64 {
self.state
.store(CircuitState::Open as u8, Ordering::Release);
}
}
CircuitState::Open => {
}
}
}
fn last_failure_time(&self) -> Instant {
let nanos = self.last_failure_time_nanos.load(Ordering::Acquire);
self.epoch + Duration::from_nanos(nanos)
}
pub fn consecutive_failures(&self) -> u64 {
self.consecutive_failures.load(Ordering::Relaxed)
}
pub fn reset(&self) {
self.state
.store(CircuitState::Closed as u8, Ordering::Release);
self.consecutive_failures.store(0, Ordering::Release);
}
}
impl Default for CircuitBreaker {
fn default() -> Self {
Self::new()
}
}
impl Clone for CircuitBreaker {
fn clone(&self) -> Self {
Self {
state: AtomicU8::new(self.state.load(Ordering::Relaxed)),
consecutive_failures: AtomicU64::new(self.consecutive_failures.load(Ordering::Relaxed)),
last_failure_time_nanos: AtomicU64::new(
self.last_failure_time_nanos.load(Ordering::Relaxed),
),
config: self.config.clone(),
epoch: self.epoch,
}
}
}
pub type SharedCircuitBreaker = Arc<CircuitBreaker>;
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_initial_state() {
let cb = CircuitBreaker::new();
assert_eq!(cb.state(), CircuitState::Closed);
assert_eq!(cb.consecutive_failures(), 0);
assert!(cb.allow_request());
}
#[test]
fn test_failure_threshold() {
let config = CircuitBreakerConfig {
failure_threshold: 3,
recovery_timeout: Duration::from_secs(1),
};
let cb = CircuitBreaker::with_config(config);
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Closed);
assert_eq!(cb.consecutive_failures(), 1);
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Closed);
assert_eq!(cb.consecutive_failures(), 2);
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
assert_eq!(cb.consecutive_failures(), 3);
}
#[test]
fn test_fail_open_when_circuit_open() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
recovery_timeout: Duration::from_secs(10),
};
let cb = CircuitBreaker::with_config(config);
cb.record_failure();
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
assert!(!cb.allow_request());
}
#[test]
fn test_recovery_after_timeout() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
recovery_timeout: Duration::from_millis(100),
};
let cb = CircuitBreaker::with_config(config);
cb.record_failure();
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
thread::sleep(Duration::from_millis(150));
assert!(cb.allow_request());
assert_eq!(cb.state(), CircuitState::HalfOpen);
}
#[test]
fn test_half_open_success_closes_circuit() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
recovery_timeout: Duration::from_millis(100),
};
let cb = CircuitBreaker::with_config(config);
cb.record_failure();
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
thread::sleep(Duration::from_millis(150));
cb.allow_request();
cb.record_success();
assert_eq!(cb.state(), CircuitState::Closed);
assert_eq!(cb.consecutive_failures(), 0);
}
#[test]
fn test_half_open_failure_reopens_circuit() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
recovery_timeout: Duration::from_millis(100),
};
let cb = CircuitBreaker::with_config(config);
cb.record_failure();
cb.record_failure();
thread::sleep(Duration::from_millis(150));
cb.allow_request();
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
}
#[test]
fn test_success_resets_failure_count() {
let cb = CircuitBreaker::new();
cb.record_failure();
cb.record_failure();
assert_eq!(cb.consecutive_failures(), 2);
cb.record_success();
assert_eq!(cb.consecutive_failures(), 0);
assert_eq!(cb.state(), CircuitState::Closed);
}
#[test]
fn test_reset() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
recovery_timeout: Duration::from_secs(10),
};
let cb = CircuitBreaker::with_config(config);
cb.record_failure();
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
cb.reset();
assert_eq!(cb.state(), CircuitState::Closed);
assert_eq!(cb.consecutive_failures(), 0);
assert!(cb.allow_request());
}
#[test]
fn test_concurrent_failures() {
let cb = Arc::new(CircuitBreaker::new());
let mut handles = vec![];
for _ in 0..10 {
let cb_clone = Arc::clone(&cb);
handles.push(thread::spawn(move || {
cb_clone.record_failure();
}));
}
for handle in handles {
handle.join().unwrap();
}
assert_eq!(cb.consecutive_failures(), 10);
assert_eq!(cb.state(), CircuitState::Open);
}
#[test]
fn test_clone() {
let cb1 = CircuitBreaker::new();
cb1.record_failure();
let cb2 = cb1.clone();
assert_eq!(cb2.consecutive_failures(), 1);
cb2.record_failure();
assert_eq!(cb1.consecutive_failures(), 1);
assert_eq!(cb2.consecutive_failures(), 2);
}
#[test]
fn test_concurrent_half_open_to_closed_race() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
recovery_timeout: Duration::from_millis(100),
};
let cb = Arc::new(CircuitBreaker::with_config(config));
cb.record_failure();
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
thread::sleep(Duration::from_millis(150));
cb.allow_request();
assert_eq!(cb.state(), CircuitState::HalfOpen);
let mut handles = vec![];
for _ in 0..10 {
let cb_clone = Arc::clone(&cb);
handles.push(thread::spawn(move || {
cb_clone.record_success();
}));
}
for handle in handles {
handle.join().unwrap();
}
assert_eq!(cb.state(), CircuitState::Closed);
assert_eq!(cb.consecutive_failures(), 0);
}
#[test]
fn test_concurrent_half_open_to_open_race() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
recovery_timeout: Duration::from_millis(100),
};
let cb = Arc::new(CircuitBreaker::with_config(config));
cb.record_failure();
cb.record_failure();
thread::sleep(Duration::from_millis(150));
cb.allow_request();
assert_eq!(cb.state(), CircuitState::HalfOpen);
let mut handles = vec![];
for _ in 0..10 {
let cb_clone = Arc::clone(&cb);
handles.push(thread::spawn(move || {
cb_clone.record_failure();
}));
}
for handle in handles {
handle.join().unwrap();
}
assert_eq!(cb.state(), CircuitState::Open);
assert!(cb.consecutive_failures() >= 2);
}
#[test]
fn test_time_going_backwards() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
recovery_timeout: Duration::from_millis(100),
};
let cb = CircuitBreaker::with_config(config);
cb.record_failure();
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
thread::sleep(Duration::from_millis(150));
cb.allow_request();
assert_eq!(cb.state(), CircuitState::HalfOpen);
}
#[test]
fn test_rapid_state_transitions() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
recovery_timeout: Duration::from_millis(100),
};
let cb = Arc::new(CircuitBreaker::with_config(config));
for _ in 0..100 {
cb.record_failure();
cb.record_failure(); thread::sleep(Duration::from_millis(150));
cb.allow_request(); cb.record_success(); }
assert_eq!(cb.state(), CircuitState::Closed);
assert_eq!(cb.consecutive_failures(), 0);
}
#[test]
fn test_concurrent_mixed_operations() {
let cb = Arc::new(CircuitBreaker::new());
let mut handles = vec![];
let cb1 = Arc::clone(&cb);
handles.push(thread::spawn(move || {
for _ in 0..20 {
cb1.record_failure();
thread::sleep(Duration::from_micros(100));
}
}));
let cb2 = Arc::clone(&cb);
handles.push(thread::spawn(move || {
for _ in 0..20 {
cb2.record_success();
thread::sleep(Duration::from_micros(100));
}
}));
let cb3 = Arc::clone(&cb);
handles.push(thread::spawn(move || {
for _ in 0..20 {
let _allowed = cb3.allow_request();
thread::sleep(Duration::from_micros(100));
}
}));
let cb4 = Arc::clone(&cb);
handles.push(thread::spawn(move || {
for _ in 0..20 {
let _state = cb4.state();
let _failures = cb4.consecutive_failures();
thread::sleep(Duration::from_micros(100));
}
}));
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn test_state_consistency_after_clone() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
recovery_timeout: Duration::from_millis(100),
};
let cb1 = CircuitBreaker::with_config(config);
cb1.record_failure();
let cb2 = cb1.clone();
assert_eq!(cb1.state(), CircuitState::Closed);
assert_eq!(cb2.state(), CircuitState::Closed);
cb2.record_failure();
assert_eq!(cb2.consecutive_failures(), 2);
assert_eq!(cb2.state(), CircuitState::Open);
assert_eq!(cb1.state(), CircuitState::Closed);
}
#[test]
fn test_consecutive_failures_overflow_resistance() {
let cb = CircuitBreaker::new();
for _ in 0..10_000 {
cb.record_failure();
}
assert!(cb.consecutive_failures() >= 10_000);
assert_eq!(cb.state(), CircuitState::Open);
}
#[test]
fn test_recovery_timeout_boundary() {
let cb = CircuitBreaker::with_config(CircuitBreakerConfig {
failure_threshold: 1,
recovery_timeout: Duration::from_millis(50),
});
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
thread::sleep(Duration::from_millis(30));
assert_eq!(cb.state(), CircuitState::Open);
thread::sleep(Duration::from_millis(30)); cb.allow_request(); assert_eq!(cb.state(), CircuitState::HalfOpen);
}
#[test]
fn test_allow_request_consistency() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
recovery_timeout: Duration::from_millis(100),
};
let cb = CircuitBreaker::with_config(config);
assert!(cb.allow_request());
cb.record_failure();
cb.record_failure();
assert!(!cb.allow_request());
assert!(!cb.allow_request());
thread::sleep(Duration::from_millis(150));
assert!(cb.allow_request());
}
}