use serde::{Deserialize, Serialize};
use std::sync::{Arc, Mutex, MutexGuard};
use std::time::{Duration, Instant};
enum CircuitState {
Closed,
Open { opened_at: Instant },
HalfOpen,
}
impl CircuitState {
fn snapshot(&self) -> CircuitStateSnapshot {
match self {
Self::Closed => CircuitStateSnapshot::Closed,
Self::Open { .. } => CircuitStateSnapshot::Open,
Self::HalfOpen => CircuitStateSnapshot::HalfOpen,
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CircuitStateSnapshot {
Closed,
Open,
HalfOpen,
}
impl std::fmt::Display for CircuitStateSnapshot {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Closed => write!(f, "closed"),
Self::Open => write!(f, "open"),
Self::HalfOpen => write!(f, "half_open"),
}
}
}
pub struct CircuitBreaker {
name: String,
state: Arc<Mutex<CircuitState>>,
failure_threshold: u32,
reset_duration: Duration,
consecutive_failures: Arc<Mutex<u32>>,
}
impl CircuitBreaker {
fn lock_state(&self) -> MutexGuard<'_, CircuitState> {
self.state.lock().unwrap_or_else(|poisoned| {
tracing::warn!(
breaker = %self.name,
"circuit breaker state lock poisoned; continuing with recovered state"
);
poisoned.into_inner()
})
}
fn lock_failures(&self) -> MutexGuard<'_, u32> {
self.consecutive_failures.lock().unwrap_or_else(|poisoned| {
tracing::warn!(
breaker = %self.name,
"circuit breaker failure-count lock poisoned; continuing with recovered state"
);
poisoned.into_inner()
})
}
pub fn new(name: impl Into<String>, failure_threshold: u32, reset_secs: u64) -> Self {
Self {
name: name.into(),
state: Arc::new(Mutex::new(CircuitState::Closed)),
failure_threshold,
reset_duration: Duration::from_secs(reset_secs),
consecutive_failures: Arc::new(Mutex::new(0)),
}
}
pub fn name(&self) -> &str {
&self.name
}
pub fn is_open(&self) -> bool {
let mut state = self.lock_state();
if let CircuitState::Open { opened_at } = &*state {
if opened_at.elapsed() >= self.reset_duration {
*state = CircuitState::HalfOpen;
tracing::info!(breaker = %self.name, "circuit breaker → half_open");
return false;
}
return true;
}
false
}
pub fn state_snapshot(&self) -> CircuitStateSnapshot {
self.lock_state().snapshot()
}
pub fn record_success(&self) {
let mut state = self.lock_state();
let mut failures = self.lock_failures();
*failures = 0;
if matches!(*state, CircuitState::HalfOpen) {
*state = CircuitState::Closed;
tracing::info!(breaker = %self.name, "circuit breaker → closed (recovered)");
}
}
pub fn record_failure(&self) -> CircuitStateSnapshot {
let mut state = self.lock_state();
let mut failures = self.lock_failures();
match &*state {
CircuitState::HalfOpen => {
*state = CircuitState::Open {
opened_at: Instant::now(),
};
tracing::warn!(breaker = %self.name, "circuit breaker → open (probe failed)");
}
CircuitState::Closed => {
*failures += 1;
if *failures >= self.failure_threshold {
*state = CircuitState::Open {
opened_at: Instant::now(),
};
*failures = 0;
tracing::warn!(
breaker = %self.name,
threshold = self.failure_threshold,
"circuit breaker → open (threshold reached)"
);
}
}
CircuitState::Open { .. } => {} }
state.snapshot()
}
}
impl Clone for CircuitBreaker {
fn clone(&self) -> Self {
Self {
name: self.name.clone(),
state: Arc::clone(&self.state),
failure_threshold: self.failure_threshold,
reset_duration: self.reset_duration,
consecutive_failures: Arc::clone(&self.consecutive_failures),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_starts_closed() {
let cb = CircuitBreaker::new("test", 3, 60);
assert!(!cb.is_open());
assert_eq!(cb.state_snapshot(), CircuitStateSnapshot::Closed);
}
#[test]
fn test_opens_at_threshold() {
let cb = CircuitBreaker::new("test", 2, 60);
cb.record_failure();
assert_eq!(cb.state_snapshot(), CircuitStateSnapshot::Closed);
cb.record_failure();
assert_eq!(cb.state_snapshot(), CircuitStateSnapshot::Open);
assert!(cb.is_open());
}
#[test]
fn test_success_resets_failures() {
let cb = CircuitBreaker::new("test", 3, 60);
cb.record_failure();
cb.record_failure();
cb.record_success();
cb.record_failure();
assert_eq!(cb.state_snapshot(), CircuitStateSnapshot::Closed);
}
#[test]
fn test_half_open_to_closed_on_success() {
let cb = CircuitBreaker::new("test", 1, 0); cb.record_failure(); assert!(!cb.is_open()); assert_eq!(cb.state_snapshot(), CircuitStateSnapshot::HalfOpen);
cb.record_success(); assert_eq!(cb.state_snapshot(), CircuitStateSnapshot::Closed);
}
#[test]
fn test_half_open_to_open_on_failure() {
let cb = CircuitBreaker::new("test", 1, 0);
cb.record_failure(); cb.is_open(); cb.record_failure(); assert_eq!(cb.state_snapshot(), CircuitStateSnapshot::Open);
}
}