use std::sync::Arc;
use std::sync::atomic::{AtomicU8, AtomicU32, AtomicU64, Ordering};
use std::time::Duration;
use tokio::sync::RwLock;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
pub struct CircuitBreaker {
state: RwLock<CircuitState>,
consecutive_failures: AtomicU32,
failure_threshold: u32,
reset_timeout: Duration,
last_failure_time: AtomicU64, health_state: Arc<AtomicU8>,
}
impl CircuitBreaker {
#[must_use]
pub fn new(failure_threshold: u32, reset_timeout: Duration) -> Self {
let health_state = Arc::new(AtomicU8::new(0));
#[cfg(feature = "health")]
{
let hs = Arc::clone(&health_state);
crate::health::HealthRegistry::register("circuit_breaker", move || {
match hs.load(Ordering::Relaxed) {
0 => crate::health::HealthStatus::Healthy, 2 => crate::health::HealthStatus::Degraded, _ => crate::health::HealthStatus::Unhealthy, }
});
}
Self {
state: RwLock::new(CircuitState::Closed),
consecutive_failures: AtomicU32::new(0),
failure_threshold,
reset_timeout,
last_failure_time: AtomicU64::new(0),
health_state,
}
}
fn sync_health_state(&self, state: CircuitState) {
let val = match state {
CircuitState::Closed => 0,
CircuitState::Open => 1,
CircuitState::HalfOpen => 2,
};
self.health_state.store(val, Ordering::Relaxed);
}
pub async fn state(&self) -> CircuitState {
let mut state = self.state.write().await;
if *state == CircuitState::Open {
let last_failure = self.last_failure_time.load(Ordering::SeqCst);
let now = current_epoch_millis();
let elapsed = Duration::from_millis(now.saturating_sub(last_failure));
if elapsed >= self.reset_timeout {
*state = CircuitState::HalfOpen;
self.sync_health_state(*state);
}
}
*state
}
pub async fn is_closed(&self) -> bool {
self.state().await == CircuitState::Closed
}
pub async fn is_open(&self) -> bool {
self.state().await == CircuitState::Open
}
pub async fn record_success(&self) {
let mut state = self.state.write().await;
self.consecutive_failures.store(0, Ordering::SeqCst);
*state = CircuitState::Closed;
self.sync_health_state(*state);
}
pub async fn record_failure(&self) {
let failures = self.consecutive_failures.fetch_add(1, Ordering::SeqCst) + 1;
self.last_failure_time
.store(current_epoch_millis(), Ordering::SeqCst);
if failures >= self.failure_threshold {
let mut state = self.state.write().await;
*state = CircuitState::Open;
self.sync_health_state(*state);
}
}
#[must_use]
pub fn consecutive_failures(&self) -> u32 {
self.consecutive_failures.load(Ordering::SeqCst)
}
pub async fn reset(&self) {
self.consecutive_failures.store(0, Ordering::SeqCst);
let mut state = self.state.write().await;
*state = CircuitState::Closed;
self.sync_health_state(*state);
}
}
impl std::fmt::Debug for CircuitBreaker {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CircuitBreaker")
.field("failure_threshold", &self.failure_threshold)
.field("reset_timeout", &self.reset_timeout)
.field("consecutive_failures", &self.consecutive_failures())
.finish_non_exhaustive()
}
}
fn current_epoch_millis() -> u64 {
use std::time::SystemTime;
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map_or(0, |d| u64::try_from(d.as_millis()).unwrap_or(u64::MAX))
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_initial_state_is_closed() {
let cb = CircuitBreaker::new(3, Duration::from_secs(30));
assert_eq!(cb.state().await, CircuitState::Closed);
assert!(cb.is_closed().await);
}
#[tokio::test]
async fn test_opens_after_threshold() {
let cb = CircuitBreaker::new(3, Duration::from_secs(30));
cb.record_failure().await;
assert!(cb.is_closed().await);
cb.record_failure().await;
assert!(cb.is_closed().await);
cb.record_failure().await;
assert!(cb.is_open().await);
assert_eq!(cb.consecutive_failures(), 3);
}
#[tokio::test]
async fn test_success_resets_failures() {
let cb = CircuitBreaker::new(3, Duration::from_secs(30));
cb.record_failure().await;
cb.record_failure().await;
assert_eq!(cb.consecutive_failures(), 2);
cb.record_success().await;
assert_eq!(cb.consecutive_failures(), 0);
assert!(cb.is_closed().await);
}
#[tokio::test]
async fn test_half_open_after_timeout() {
let cb = CircuitBreaker::new(1, Duration::from_millis(50));
cb.record_failure().await;
assert!(cb.is_open().await);
tokio::time::sleep(Duration::from_millis(100)).await;
assert_eq!(cb.state().await, CircuitState::HalfOpen);
}
#[tokio::test]
async fn test_half_open_success_closes() {
let cb = CircuitBreaker::new(1, Duration::from_millis(10));
cb.record_failure().await;
tokio::time::sleep(Duration::from_millis(20)).await;
assert_eq!(cb.state().await, CircuitState::HalfOpen);
cb.record_success().await;
assert!(cb.is_closed().await);
}
#[tokio::test]
async fn test_half_open_failure_reopens() {
let cb = CircuitBreaker::new(1, Duration::from_millis(10));
cb.record_failure().await;
tokio::time::sleep(Duration::from_millis(20)).await;
assert_eq!(cb.state().await, CircuitState::HalfOpen);
cb.record_failure().await;
assert!(cb.is_open().await);
}
#[tokio::test]
async fn test_reset() {
let cb = CircuitBreaker::new(1, Duration::from_secs(30));
cb.record_failure().await;
assert!(cb.is_open().await);
cb.reset().await;
assert!(cb.is_closed().await);
assert_eq!(cb.consecutive_failures(), 0);
}
}