use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
pub struct CircuitBreaker {
failure_threshold: u32,
recovery_timeout: Duration,
half_open_max_calls: u32,
failure_count: AtomicU32,
half_open_calls: AtomicU32,
opened_at_nanos: AtomicU64,
state: Mutex<CircuitState>,
epoch: Instant,
}
impl CircuitBreaker {
pub fn new(
failure_threshold: u32,
recovery_timeout: Duration,
half_open_max_calls: u32,
) -> Self {
Self {
failure_threshold,
recovery_timeout,
half_open_max_calls,
failure_count: AtomicU32::new(0),
half_open_calls: AtomicU32::new(0),
opened_at_nanos: AtomicU64::new(0),
state: Mutex::new(CircuitState::Closed),
epoch: Instant::now(),
}
}
pub fn from_config(cfg: &llmtrace_core::CircuitBreakerConfig) -> Self {
Self::new(
cfg.failure_threshold,
Duration::from_millis(cfg.recovery_timeout_ms),
cfg.half_open_max_calls,
)
}
pub async fn allow(&self) -> bool {
let mut state = self.state.lock().await;
match *state {
CircuitState::Closed => true,
CircuitState::Open => {
let opened_at = self.opened_at_nanos.load(Ordering::Acquire);
let elapsed = self.epoch.elapsed().as_nanos() as u64 - opened_at;
if elapsed >= self.recovery_timeout.as_nanos() as u64 {
*state = CircuitState::HalfOpen;
self.half_open_calls.store(1, Ordering::Release);
true
} else {
false
}
}
CircuitState::HalfOpen => {
let calls = self.half_open_calls.fetch_add(1, Ordering::AcqRel);
calls < self.half_open_max_calls
}
}
}
pub async fn record_success(&self) {
let mut state = self.state.lock().await;
match *state {
CircuitState::HalfOpen => {
*state = CircuitState::Closed;
self.failure_count.store(0, Ordering::Release);
}
CircuitState::Closed => {
self.failure_count.store(0, Ordering::Release);
}
CircuitState::Open => {}
}
}
pub async fn record_failure(&self) {
let mut state = self.state.lock().await;
match *state {
CircuitState::Closed => {
let count = self.failure_count.fetch_add(1, Ordering::AcqRel) + 1;
if count >= self.failure_threshold {
*state = CircuitState::Open;
let now_nanos = self.epoch.elapsed().as_nanos() as u64;
self.opened_at_nanos.store(now_nanos, Ordering::Release);
}
}
CircuitState::HalfOpen => {
*state = CircuitState::Open;
let now_nanos = self.epoch.elapsed().as_nanos() as u64;
self.opened_at_nanos.store(now_nanos, Ordering::Release);
}
CircuitState::Open => {}
}
}
pub async fn state(&self) -> CircuitState {
*self.state.lock().await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_closed_allows_calls() {
let cb = CircuitBreaker::new(3, Duration::from_secs(5), 1);
assert!(cb.allow().await);
assert_eq!(cb.state().await, CircuitState::Closed);
}
#[tokio::test]
async fn test_opens_after_threshold() {
let cb = CircuitBreaker::new(3, Duration::from_secs(60), 1);
for _ in 0..3 {
assert!(cb.allow().await);
cb.record_failure().await;
}
assert_eq!(cb.state().await, CircuitState::Open);
assert!(!cb.allow().await);
}
#[tokio::test]
async fn test_success_resets_failure_count() {
let cb = CircuitBreaker::new(3, Duration::from_secs(60), 1);
cb.record_failure().await;
cb.record_failure().await;
cb.record_success().await;
cb.record_failure().await;
assert_eq!(cb.state().await, CircuitState::Closed);
}
#[tokio::test]
async fn test_half_open_after_recovery_timeout() {
let cb = CircuitBreaker::new(1, Duration::from_millis(10), 1);
cb.record_failure().await;
assert_eq!(cb.state().await, CircuitState::Open);
tokio::time::sleep(Duration::from_millis(20)).await;
assert!(cb.allow().await);
assert_eq!(cb.state().await, CircuitState::HalfOpen);
}
#[tokio::test]
async fn test_half_open_success_closes() {
let cb = CircuitBreaker::new(1, Duration::from_millis(10), 1);
cb.record_failure().await;
tokio::time::sleep(Duration::from_millis(20)).await;
assert!(cb.allow().await);
cb.record_success().await;
assert_eq!(cb.state().await, CircuitState::Closed);
}
#[tokio::test]
async fn test_half_open_failure_reopens() {
let cb = CircuitBreaker::new(1, Duration::from_millis(10), 1);
cb.record_failure().await;
tokio::time::sleep(Duration::from_millis(20)).await;
assert!(cb.allow().await);
cb.record_failure().await;
assert_eq!(cb.state().await, CircuitState::Open);
}
#[tokio::test]
async fn test_half_open_limited_calls() {
let cb = CircuitBreaker::new(1, Duration::from_millis(10), 2);
cb.record_failure().await;
tokio::time::sleep(Duration::from_millis(20)).await;
assert!(cb.allow().await);
assert!(cb.allow().await);
assert!(!cb.allow().await);
}
}