use simple_agent_type::prelude::{ProviderError, SimpleAgentsError};
use std::sync::Mutex;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub open_duration: Duration,
pub success_threshold: u32,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 3,
open_duration: Duration::from_secs(10),
success_threshold: 1,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum InternalState {
Closed,
Open { opened_at: Instant },
HalfOpen,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitBreakerState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug)]
struct CircuitBreakerInner {
state: InternalState,
consecutive_failures: u32,
consecutive_successes: u32,
}
impl CircuitBreakerInner {
fn new() -> Self {
Self {
state: InternalState::Closed,
consecutive_failures: 0,
consecutive_successes: 0,
}
}
}
#[derive(Debug)]
pub struct CircuitBreaker {
config: CircuitBreakerConfig,
inner: Mutex<CircuitBreakerInner>,
}
impl CircuitBreaker {
pub fn new(config: CircuitBreakerConfig) -> Self {
Self {
config,
inner: Mutex::new(CircuitBreakerInner::new()),
}
}
pub fn allow_request(&self) -> bool {
let mut inner = self
.inner
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
match inner.state {
InternalState::Closed => true,
InternalState::Open { opened_at } => {
if opened_at.elapsed() >= self.config.open_duration {
inner.state = InternalState::HalfOpen;
inner.consecutive_successes = 0;
inner.consecutive_failures = 0;
true
} else {
false
}
}
InternalState::HalfOpen => true,
}
}
pub fn record_success(&self) {
let mut inner = self
.inner
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
match inner.state {
InternalState::Closed => {
inner.consecutive_failures = 0;
}
InternalState::HalfOpen => {
inner.consecutive_successes = inner.consecutive_successes.saturating_add(1);
if inner.consecutive_successes >= self.config.success_threshold {
inner.state = InternalState::Closed;
inner.consecutive_failures = 0;
inner.consecutive_successes = 0;
}
}
InternalState::Open { .. } => {}
}
}
pub fn record_failure(&self) {
let mut inner = self
.inner
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
match inner.state {
InternalState::Closed => {
inner.consecutive_failures = inner.consecutive_failures.saturating_add(1);
if inner.consecutive_failures >= self.config.failure_threshold {
inner.state = InternalState::Open {
opened_at: Instant::now(),
};
}
}
InternalState::HalfOpen => {
inner.state = InternalState::Open {
opened_at: Instant::now(),
};
inner.consecutive_failures = 1;
inner.consecutive_successes = 0;
}
InternalState::Open { .. } => {}
}
}
pub fn record_result(&self, result: &std::result::Result<(), SimpleAgentsError>) {
match result {
Ok(_) => self.record_success(),
Err(error) => {
if matches!(
error,
SimpleAgentsError::Provider(
ProviderError::RateLimit { .. }
| ProviderError::Timeout(_)
| ProviderError::ServerError(_)
) | SimpleAgentsError::Network(_)
) {
self.record_failure();
}
}
}
}
pub fn state(&self) -> CircuitBreakerState {
let inner = self
.inner
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
match inner.state {
InternalState::Closed => CircuitBreakerState::Closed,
InternalState::Open { .. } => CircuitBreakerState::Open,
InternalState::HalfOpen => CircuitBreakerState::HalfOpen,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn opens_after_failures() {
let breaker = CircuitBreaker::new(CircuitBreakerConfig {
failure_threshold: 2,
open_duration: Duration::from_secs(10),
success_threshold: 1,
});
assert!(breaker.allow_request());
breaker.record_failure();
assert!(breaker.allow_request());
breaker.record_failure();
assert_eq!(breaker.state(), CircuitBreakerState::Open);
assert!(!breaker.allow_request());
}
#[test]
fn closes_after_success_in_half_open() {
let breaker = CircuitBreaker::new(CircuitBreakerConfig {
failure_threshold: 1,
open_duration: Duration::from_millis(0),
success_threshold: 1,
});
breaker.record_failure();
assert_eq!(breaker.state(), CircuitBreakerState::Open);
assert!(breaker.allow_request());
assert_eq!(breaker.state(), CircuitBreakerState::HalfOpen);
breaker.record_success();
assert_eq!(breaker.state(), CircuitBreakerState::Closed);
}
#[test]
fn reopens_on_failure_in_half_open() {
let breaker = CircuitBreaker::new(CircuitBreakerConfig {
failure_threshold: 1,
open_duration: Duration::from_millis(0),
success_threshold: 2,
});
breaker.record_failure();
assert_eq!(breaker.state(), CircuitBreakerState::Open);
assert!(breaker.allow_request());
assert_eq!(breaker.state(), CircuitBreakerState::HalfOpen);
breaker.record_failure();
assert_eq!(breaker.state(), CircuitBreakerState::Open);
}
}