use parking_lot::Mutex;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub recovery_timeout: Duration,
pub success_threshold: u32,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
recovery_timeout: Duration::from_secs(30),
success_threshold: 2,
}
}
}
#[derive(Debug)]
struct BreakerState {
state: CircuitState,
consecutive_failures: u32,
consecutive_successes: u32,
last_failure_time: Option<Instant>,
}
#[derive(Debug)]
pub struct CircuitBreaker {
provider: String,
config: CircuitBreakerConfig,
inner: Mutex<BreakerState>,
}
impl CircuitBreaker {
pub fn new(provider: impl Into<String>, config: CircuitBreakerConfig) -> Self {
Self {
provider: provider.into(),
config,
inner: Mutex::new(BreakerState {
state: CircuitState::Closed,
consecutive_failures: 0,
consecutive_successes: 0,
last_failure_time: None,
}),
}
}
pub fn provider(&self) -> &str {
&self.provider
}
pub fn state(&self) -> CircuitState {
let mut guard = self.inner.lock();
Self::maybe_transition_to_half_open(&self.config, &mut guard);
guard.state
}
pub fn time_until_probe(&self) -> Option<Duration> {
let guard = self.inner.lock();
match (guard.state, guard.last_failure_time) {
(CircuitState::Open, Some(last)) => {
let elapsed = last.elapsed();
if elapsed < self.config.recovery_timeout {
Some(self.config.recovery_timeout.saturating_sub(elapsed))
} else {
Some(Duration::ZERO)
}
}
_ => None,
}
}
pub fn record_success(&self) {
let mut guard = self.inner.lock();
match guard.state {
CircuitState::Closed => {
guard.consecutive_failures = 0;
}
CircuitState::HalfOpen => {
guard.consecutive_successes += 1;
if guard.consecutive_successes >= self.config.success_threshold {
guard.state = CircuitState::Closed;
guard.consecutive_failures = 0;
guard.consecutive_successes = 0;
}
}
CircuitState::Open => {
}
}
}
pub fn record_failure(&self) {
let mut guard = self.inner.lock();
guard.last_failure_time = Some(Instant::now());
match guard.state {
CircuitState::Closed => {
guard.consecutive_failures += 1;
if guard.consecutive_failures >= self.config.failure_threshold {
guard.state = CircuitState::Open;
}
}
CircuitState::HalfOpen => {
guard.state = CircuitState::Open;
guard.consecutive_successes = 0;
}
CircuitState::Open => {
}
}
}
pub fn allow_call(&self) -> bool {
let mut guard = self.inner.lock();
Self::maybe_transition_to_half_open(&self.config, &mut guard);
match guard.state {
CircuitState::Closed | CircuitState::HalfOpen => true,
CircuitState::Open => false,
}
}
fn maybe_transition_to_half_open(config: &CircuitBreakerConfig, state: &mut BreakerState) {
if state.state == CircuitState::Open
&& let Some(last) = state.last_failure_time
&& last.elapsed() >= config.recovery_timeout
{
state.state = CircuitState::HalfOpen;
state.consecutive_successes = 0;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
fn test_config() -> CircuitBreakerConfig {
CircuitBreakerConfig {
failure_threshold: 5,
recovery_timeout: Duration::from_millis(100),
success_threshold: 2,
}
}
#[test]
fn starts_closed() {
let cb = CircuitBreaker::new("test", test_config());
assert_eq!(cb.state(), CircuitState::Closed);
assert!(cb.allow_call());
}
#[test]
fn five_failures_opens_breaker() {
let cb = CircuitBreaker::new("test", test_config());
for _ in 0..4 {
assert!(cb.allow_call());
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Closed);
}
assert!(cb.allow_call());
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
assert!(!cb.allow_call());
}
#[test]
fn open_rejects_immediately() {
let cb = CircuitBreaker::new("openai", test_config());
for _ in 0..5 {
cb.record_failure();
}
assert!(!cb.allow_call());
assert!(cb.time_until_probe().is_some());
}
#[test]
fn recovery_timeout_transitions_to_half_open() {
let cb = CircuitBreaker::new("test", test_config());
for _ in 0..5 {
cb.record_failure();
}
assert_eq!(cb.state(), CircuitState::Open);
thread::sleep(Duration::from_millis(150));
assert_eq!(cb.state(), CircuitState::HalfOpen);
assert!(cb.allow_call());
}
#[test]
fn half_open_success_closes() {
let cb = CircuitBreaker::new("test", test_config());
for _ in 0..5 {
cb.record_failure();
}
thread::sleep(Duration::from_millis(150));
assert_eq!(cb.state(), CircuitState::HalfOpen);
cb.record_success();
assert_eq!(cb.state(), CircuitState::HalfOpen);
cb.record_success();
assert_eq!(cb.state(), CircuitState::Closed);
}
#[test]
fn half_open_failure_reopens() {
let cb = CircuitBreaker::new("test", test_config());
for _ in 0..5 {
cb.record_failure();
}
thread::sleep(Duration::from_millis(150));
assert_eq!(cb.state(), CircuitState::HalfOpen);
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
assert!(!cb.allow_call());
}
#[test]
fn success_resets_failure_counter() {
let cb = CircuitBreaker::new("test", test_config());
for _ in 0..4 {
cb.record_failure();
}
cb.record_success();
for _ in 0..4 {
cb.record_failure();
}
assert_eq!(cb.state(), CircuitState::Closed);
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
}
#[test]
fn concurrent_access_is_safe() {
use std::sync::Arc;
let cb = Arc::new(CircuitBreaker::new("test", test_config()));
let mut handles = vec![];
for _ in 0..10 {
let cb_clone = Arc::clone(&cb);
handles.push(thread::spawn(move || {
for _ in 0..100 {
let _ = cb_clone.allow_call();
cb_clone.record_failure();
let _ = cb_clone.state();
}
}));
}
for h in handles {
h.join().unwrap();
}
assert_eq!(cb.state(), CircuitState::Open);
}
}