use parking_lot::Mutex;
use std::sync::atomic::{AtomicU64, AtomicU8, Ordering};
use std::time::{Duration, Instant};
use thiserror::Error;
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed = 0,
Open = 1,
HalfOpen = 2,
}
impl CircuitState {
#[inline]
fn from_u8(value: u8) -> Self {
match value {
0 => Self::Closed,
1 => Self::Open,
_ => Self::HalfOpen,
}
}
#[inline]
fn as_u8(&self) -> u8 {
*self as u8
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub open_duration: Duration,
pub half_open_successes: u32,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
open_duration: Duration::from_secs(30),
half_open_successes: 1,
}
}
}
impl CircuitBreakerConfig {
#[inline]
pub fn new(failure_threshold: u32, open_duration: Duration, half_open_successes: u32) -> Self {
if failure_threshold == 0 {
panic!("failure_threshold cannot be zero");
}
if half_open_successes == 0 {
panic!("half_open_successes cannot be zero");
}
Self {
failure_threshold,
open_duration,
half_open_successes,
}
}
#[inline]
#[must_use]
pub fn with_failure_threshold(mut self, threshold: u32) -> Self {
self.failure_threshold = threshold;
self
}
#[inline]
#[must_use]
pub fn with_open_duration(mut self, duration: Duration) -> Self {
self.open_duration = duration;
self
}
#[inline]
#[must_use]
pub fn with_half_open_successes(mut self, successes: u32) -> Self {
self.half_open_successes = successes;
self
}
}
#[derive(Debug, Error, Clone, PartialEq, Eq)]
#[error("Circuit breaker open for provider '{provider}': retry after {remaining:?}")]
pub struct CircuitOpenError {
pub provider: String,
pub remaining: Duration,
}
impl CircuitOpenError {
#[inline]
pub fn new(provider: impl Into<String>, remaining: Duration) -> Self {
Self {
provider: provider.into(),
remaining,
}
}
}
#[derive(Debug)]
pub struct ProviderCircuitBreaker {
provider_name: String,
config: CircuitBreakerConfig,
state: AtomicU8,
consecutive_failures: AtomicU64,
consecutive_successes: AtomicU64,
opened_at: Mutex<Option<Instant>>,
}
impl ProviderCircuitBreaker {
#[inline]
pub fn new(provider_name: String, config: CircuitBreakerConfig) -> Self {
Self {
provider_name,
config,
state: AtomicU8::new(CircuitState::Closed.as_u8()),
consecutive_failures: AtomicU64::new(0),
consecutive_successes: AtomicU64::new(0),
opened_at: Mutex::new(None),
}
}
#[inline]
pub fn with_defaults(provider_name: String) -> Self {
Self::new(provider_name, CircuitBreakerConfig::default())
}
pub fn allow_request(&self) -> Result<(), CircuitOpenError> {
let state = self.load_state();
match state {
CircuitState::Closed => {
Ok(())
}
CircuitState::Open => {
let opened_at = self.opened_at.lock();
if let Some(timestamp) = *opened_at {
let elapsed = timestamp.elapsed();
if elapsed >= self.config.open_duration {
drop(opened_at);
self.state
.store(CircuitState::HalfOpen.as_u8(), Ordering::SeqCst);
self.consecutive_successes.store(0, Ordering::SeqCst);
return Ok(());
}
let remaining = self.config.open_duration.saturating_sub(elapsed);
return Err(CircuitOpenError::new(&self.provider_name, remaining));
}
drop(opened_at);
self.state
.store(CircuitState::HalfOpen.as_u8(), Ordering::SeqCst);
Ok(())
}
CircuitState::HalfOpen => {
Ok(())
}
}
}
pub fn record_success(&self) {
let state = self.load_state();
match state {
CircuitState::Closed => {
self.consecutive_failures.store(0, Ordering::SeqCst);
}
CircuitState::HalfOpen => {
let prev = self.consecutive_successes.fetch_add(1, Ordering::SeqCst);
let new_count = prev + 1;
if new_count >= self.config.half_open_successes as u64 {
self.state
.store(CircuitState::Closed.as_u8(), Ordering::SeqCst);
self.consecutive_failures.store(0, Ordering::SeqCst);
self.consecutive_successes.store(0, Ordering::SeqCst);
*self.opened_at.lock() = None;
}
}
CircuitState::Open => {
}
}
}
pub fn record_failure(&self) {
let state = self.load_state();
match state {
CircuitState::Closed => {
let prev = self.consecutive_failures.fetch_add(1, Ordering::SeqCst);
let new_count = prev + 1;
if new_count >= self.config.failure_threshold as u64 {
self.state
.store(CircuitState::Open.as_u8(), Ordering::SeqCst);
*self.opened_at.lock() = Some(Instant::now());
}
}
CircuitState::HalfOpen => {
self.state
.store(CircuitState::Open.as_u8(), Ordering::SeqCst);
*self.opened_at.lock() = Some(Instant::now());
}
CircuitState::Open => {
}
}
}
pub fn reset(&self) {
self.state
.store(CircuitState::Closed.as_u8(), Ordering::SeqCst);
self.consecutive_failures.store(0, Ordering::SeqCst);
self.consecutive_successes.store(0, Ordering::SeqCst);
*self.opened_at.lock() = None;
}
#[inline]
pub fn state(&self) -> CircuitState {
self.load_state()
}
#[inline]
pub fn provider_name(&self) -> &str {
&self.provider_name
}
#[inline]
pub fn config(&self) -> &CircuitBreakerConfig {
&self.config
}
#[inline]
pub fn consecutive_failures(&self) -> u64 {
self.consecutive_failures.load(Ordering::SeqCst)
}
#[inline]
pub fn consecutive_successes(&self) -> u64 {
self.consecutive_successes.load(Ordering::SeqCst)
}
#[inline]
pub fn remaining_open_time(&self) -> Option<Duration> {
if self.load_state() == CircuitState::Open {
let opened_at = self.opened_at.lock();
opened_at.map(|t| {
let elapsed = t.elapsed();
self.config.open_duration.saturating_sub(elapsed)
})
} else {
None
}
}
#[inline]
fn load_state(&self) -> CircuitState {
CircuitState::from_u8(self.state.load(Ordering::SeqCst))
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CircuitBreakerDiagnostics {
pub provider: String,
pub state: CircuitState,
pub consecutive_failures: u64,
pub consecutive_successes: u64,
pub is_open: bool,
pub remaining_open_time: Option<Duration>,
}
impl ProviderCircuitBreaker {
pub fn diagnostics(&self) -> CircuitBreakerDiagnostics {
let state = self.load_state();
CircuitBreakerDiagnostics {
provider: self.provider_name.clone(),
state,
consecutive_failures: self.consecutive_failures(),
consecutive_successes: self.consecutive_successes(),
is_open: state == CircuitState::Open,
remaining_open_time: self.remaining_open_time(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn circuit_state_from_u8() {
assert_eq!(CircuitState::from_u8(0), CircuitState::Closed);
assert_eq!(CircuitState::from_u8(1), CircuitState::Open);
assert_eq!(CircuitState::from_u8(2), CircuitState::HalfOpen);
assert_eq!(CircuitState::from_u8(255), CircuitState::HalfOpen); }
#[test]
fn circuit_state_as_u8() {
assert_eq!(CircuitState::Closed.as_u8(), 0);
assert_eq!(CircuitState::Open.as_u8(), 1);
assert_eq!(CircuitState::HalfOpen.as_u8(), 2);
}
#[test]
fn config_default() {
let config = CircuitBreakerConfig::default();
assert_eq!(config.failure_threshold, 5);
assert_eq!(config.open_duration, Duration::from_secs(30));
assert_eq!(config.half_open_successes, 1);
}
#[test]
fn config_new_valid() {
let config = CircuitBreakerConfig::new(3, Duration::from_secs(10), 2);
assert_eq!(config.failure_threshold, 3);
assert_eq!(config.open_duration, Duration::from_secs(10));
assert_eq!(config.half_open_successes, 2);
}
#[test]
#[should_panic(expected = "failure_threshold cannot be zero")]
fn config_new_zero_failure_threshold() {
CircuitBreakerConfig::new(0, Duration::from_secs(10), 1);
}
#[test]
#[should_panic(expected = "half_open_successes cannot be zero")]
fn config_new_zero_half_open_successes() {
CircuitBreakerConfig::new(3, Duration::from_secs(10), 0);
}
#[test]
fn config_builder_methods() {
let config = CircuitBreakerConfig::default()
.with_failure_threshold(10)
.with_open_duration(Duration::from_secs(60))
.with_half_open_successes(2);
assert_eq!(config.failure_threshold, 10);
assert_eq!(config.open_duration, Duration::from_secs(60));
assert_eq!(config.half_open_successes, 2);
}
#[test]
fn breaker_allows_when_closed() {
let breaker = ProviderCircuitBreaker::with_defaults("test".to_string());
assert!(breaker.allow_request().is_ok());
assert_eq!(breaker.state(), CircuitState::Closed);
}
#[test]
fn breaker_success_in_closed_state() {
let breaker = ProviderCircuitBreaker::with_defaults("test".to_string());
breaker.record_success();
assert_eq!(breaker.consecutive_failures(), 0);
}
#[test]
fn breaker_opens_after_threshold() {
let config = CircuitBreakerConfig::new(3, Duration::from_secs(30), 1);
let breaker = ProviderCircuitBreaker::new("test".to_string(), config);
breaker.record_failure();
assert_eq!(breaker.state(), CircuitState::Closed);
breaker.record_failure();
assert_eq!(breaker.state(), CircuitState::Closed);
breaker.record_failure();
assert_eq!(breaker.state(), CircuitState::Open);
assert!(breaker.allow_request().is_err());
}
#[test]
fn breaker_success_resets_failure_count() {
let config = CircuitBreakerConfig::new(3, Duration::from_secs(30), 1);
let breaker = ProviderCircuitBreaker::new("test".to_string(), config);
breaker.record_failure();
breaker.record_failure();
assert_eq!(breaker.consecutive_failures(), 2);
breaker.record_success();
assert_eq!(breaker.consecutive_failures(), 0);
}
#[test]
fn breaker_reset() {
let config = CircuitBreakerConfig::new(1, Duration::from_secs(30), 1);
let breaker = ProviderCircuitBreaker::new("test".to_string(), config);
breaker.record_failure();
assert_eq!(breaker.state(), CircuitState::Open);
breaker.reset();
assert_eq!(breaker.state(), CircuitState::Closed);
assert!(breaker.allow_request().is_ok());
}
#[test]
fn breaker_half_open_on_duration_elapsed() {
let config = CircuitBreakerConfig::new(1, Duration::from_millis(50), 1);
let breaker = ProviderCircuitBreaker::new("test".to_string(), config);
breaker.record_failure();
assert_eq!(breaker.state(), CircuitState::Open);
std::thread::sleep(Duration::from_millis(60));
assert!(breaker.allow_request().is_ok());
assert_eq!(breaker.state(), CircuitState::HalfOpen);
}
#[test]
fn breaker_half_open_success_closes_circuit() {
let config = CircuitBreakerConfig::new(1, Duration::from_secs(30), 1);
let breaker = ProviderCircuitBreaker::new("test".to_string(), config);
breaker.reset();
breaker
.state
.store(CircuitState::HalfOpen.as_u8(), Ordering::SeqCst);
breaker.record_success();
assert_eq!(breaker.state(), CircuitState::Closed);
}
#[test]
fn breaker_half_open_failure_reopens() {
let config = CircuitBreakerConfig::new(1, Duration::from_secs(30), 1);
let breaker = ProviderCircuitBreaker::new("test".to_string(), config);
breaker.reset();
breaker
.state
.store(CircuitState::HalfOpen.as_u8(), Ordering::SeqCst);
breaker.record_failure();
assert_eq!(breaker.state(), CircuitState::Open);
}
#[test]
fn breaker_multiple_half_open_successes() {
let config = CircuitBreakerConfig::new(1, Duration::from_secs(30), 3);
let breaker = ProviderCircuitBreaker::new("test".to_string(), config);
breaker.reset();
breaker
.state
.store(CircuitState::HalfOpen.as_u8(), Ordering::SeqCst);
breaker.record_success();
assert_eq!(breaker.state(), CircuitState::HalfOpen);
breaker.record_success();
assert_eq!(breaker.state(), CircuitState::HalfOpen);
breaker.record_success();
assert_eq!(breaker.state(), CircuitState::Closed);
}
#[test]
fn breaker_diagnostics() {
let config = CircuitBreakerConfig::new(2, Duration::from_secs(30), 1);
let breaker = ProviderCircuitBreaker::new("openai".to_string(), config);
breaker.record_failure();
let diag = breaker.diagnostics();
assert_eq!(diag.provider, "openai");
assert_eq!(diag.state, CircuitState::Closed);
assert_eq!(diag.consecutive_failures, 1);
assert!(!diag.is_open);
}
#[test]
fn breaker_diagnostics_when_open() {
let config = CircuitBreakerConfig::new(1, Duration::from_secs(30), 1);
let breaker = ProviderCircuitBreaker::new("anthropic".to_string(), config);
breaker.record_failure();
let diag = breaker.diagnostics();
assert!(diag.is_open);
assert!(diag.remaining_open_time.is_some());
}
#[test]
fn circuit_open_error_display() {
let error = CircuitOpenError::new("openai", Duration::from_secs(10));
let msg = error.to_string();
assert!(msg.contains("openai"));
assert!(msg.contains("10"));
}
#[test]
fn circuit_open_error_clone() {
let error = CircuitOpenError::new("test", Duration::from_secs(5));
let cloned = error.clone();
assert_eq!(error, cloned);
}
}