use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitBreakerState {
Closed,
Open,
HalfOpen,
}
pub struct CircuitBreaker {
state: Arc<Mutex<CircuitBreakerState>>,
failure_count: Arc<AtomicUsize>,
success_count: Arc<AtomicUsize>,
failure_threshold: usize,
timeout: Duration,
last_failure_time: Arc<Mutex<Option<Instant>>>,
}
impl CircuitBreaker {
pub fn new(failure_threshold: usize, timeout: Duration) -> Self {
Self {
state: Arc::new(Mutex::new(CircuitBreakerState::Closed)),
failure_count: Arc::new(AtomicUsize::new(0)),
success_count: Arc::new(AtomicUsize::new(0)),
failure_threshold,
timeout,
last_failure_time: Arc::new(Mutex::new(None)),
}
}
pub fn state(&self) -> CircuitBreakerState {
*self.state.lock().unwrap()
}
pub fn allow_request(&self) -> bool {
let current_state = self.state();
match current_state {
CircuitBreakerState::Closed => true,
CircuitBreakerState::Open => {
let last_failure = self.last_failure_time.lock().unwrap();
if let Some(time) = *last_failure
&& time.elapsed() > self.timeout
{
drop(last_failure);
self.transition_to_half_open();
return true;
}
false
}
CircuitBreakerState::HalfOpen => true,
}
}
pub fn record_success(&self) {
let current_state = self.state();
match current_state {
CircuitBreakerState::Closed => {
self.failure_count.store(0, Ordering::Relaxed);
}
CircuitBreakerState::HalfOpen => {
self.success_count.fetch_add(1, Ordering::Relaxed);
if self.success_count.load(Ordering::Relaxed) >= 3 {
self.transition_to_closed();
}
}
CircuitBreakerState::Open => {}
}
}
pub fn record_failure(&self) {
let count = self.failure_count.fetch_add(1, Ordering::Relaxed) + 1;
*self.last_failure_time.lock().unwrap() = Some(Instant::now());
let current_state = self.state();
match current_state {
CircuitBreakerState::Closed => {
if count >= self.failure_threshold {
self.transition_to_open();
}
}
CircuitBreakerState::HalfOpen => {
self.transition_to_open();
}
CircuitBreakerState::Open => {}
}
}
fn transition_to_open(&self) {
*self.state.lock().unwrap() = CircuitBreakerState::Open;
}
fn transition_to_half_open(&self) {
*self.state.lock().unwrap() = CircuitBreakerState::HalfOpen;
self.success_count.store(0, Ordering::Relaxed);
}
fn transition_to_closed(&self) {
*self.state.lock().unwrap() = CircuitBreakerState::Closed;
self.failure_count.store(0, Ordering::Relaxed);
self.success_count.store(0, Ordering::Relaxed);
}
pub fn reset(&self) {
self.transition_to_closed();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn does_not_open_on_non_consecutive_failures() {
let breaker = CircuitBreaker::new(3, Duration::from_secs(60));
breaker.record_failure();
breaker.record_success();
breaker.record_failure();
breaker.record_success();
breaker.record_failure();
assert_eq!(breaker.state(), CircuitBreakerState::Closed);
assert!(breaker.allow_request());
}
#[test]
fn opens_after_consecutive_failures() {
let breaker = CircuitBreaker::new(3, Duration::from_secs(60));
breaker.record_failure();
breaker.record_failure();
breaker.record_failure();
assert_eq!(breaker.state(), CircuitBreakerState::Open);
assert!(!breaker.allow_request());
}
#[test]
fn half_open_failure_reopens_immediately() {
let breaker = CircuitBreaker::new(1, Duration::from_millis(5));
breaker.record_failure();
assert_eq!(breaker.state(), CircuitBreakerState::Open);
std::thread::sleep(Duration::from_millis(10));
assert!(breaker.allow_request());
assert_eq!(breaker.state(), CircuitBreakerState::HalfOpen);
breaker.record_failure();
assert_eq!(breaker.state(), CircuitBreakerState::Open);
}
#[test]
fn reset_closes_open_circuit() {
let breaker = CircuitBreaker::new(1, Duration::from_secs(60));
breaker.record_failure();
assert_eq!(breaker.state(), CircuitBreakerState::Open);
breaker.reset();
assert_eq!(breaker.state(), CircuitBreakerState::Closed);
assert!(breaker.allow_request());
}
#[test]
fn half_open_transitions_to_closed_after_three_successes() {
let breaker = CircuitBreaker::new(1, Duration::from_millis(5));
breaker.record_failure();
std::thread::sleep(Duration::from_millis(10));
assert!(breaker.allow_request());
assert_eq!(breaker.state(), CircuitBreakerState::HalfOpen);
breaker.record_success();
assert_eq!(breaker.state(), CircuitBreakerState::HalfOpen);
breaker.record_success();
assert_eq!(breaker.state(), CircuitBreakerState::HalfOpen);
breaker.record_success(); assert_eq!(breaker.state(), CircuitBreakerState::Closed);
assert!(breaker.allow_request());
}
#[test]
fn partial_successes_in_half_open_do_not_close() {
let breaker = CircuitBreaker::new(1, Duration::from_millis(5));
breaker.record_failure();
std::thread::sleep(Duration::from_millis(10));
breaker.allow_request();
breaker.record_success();
assert_eq!(breaker.state(), CircuitBreakerState::HalfOpen);
breaker.record_success();
assert_eq!(breaker.state(), CircuitBreakerState::HalfOpen);
}
#[test]
fn failure_in_open_state_is_noop() {
let breaker = CircuitBreaker::new(1, Duration::from_secs(60));
breaker.record_failure(); assert_eq!(breaker.state(), CircuitBreakerState::Open);
breaker.record_failure();
assert_eq!(breaker.state(), CircuitBreakerState::Open);
}
#[test]
fn success_in_open_state_is_noop() {
let breaker = CircuitBreaker::new(1, Duration::from_secs(60));
breaker.record_failure(); assert_eq!(breaker.state(), CircuitBreakerState::Open);
breaker.record_success(); assert_eq!(breaker.state(), CircuitBreakerState::Open);
assert!(!breaker.allow_request());
}
#[test]
fn open_with_unexpired_timeout_denies_request() {
let breaker = CircuitBreaker::new(1, Duration::from_secs(60));
breaker.record_failure();
assert_eq!(breaker.state(), CircuitBreakerState::Open);
assert!(!breaker.allow_request());
assert_eq!(breaker.state(), CircuitBreakerState::Open);
}
#[test]
fn closed_after_reset_accepts_new_failures() {
let breaker = CircuitBreaker::new(2, Duration::from_secs(60));
breaker.record_failure();
breaker.record_failure(); breaker.reset();
breaker.record_failure();
assert_eq!(breaker.state(), CircuitBreakerState::Closed);
breaker.record_failure(); assert_eq!(breaker.state(), CircuitBreakerState::Open);
}
}
impl Default for CircuitBreaker {
fn default() -> Self {
Self::new(5, Duration::from_secs(60))
}
}