use std::collections::HashMap;
use std::time::{Duration, Instant};
use awaken_contract::contract::executor::InferenceExecutionError;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum CircuitStatus {
Closed,
Open,
HalfOpen,
}
#[derive(Debug)]
struct CircuitState {
consecutive_failures: u32,
last_failure: Instant,
status: CircuitStatus,
half_open_attempts: u32,
}
impl CircuitState {
fn new() -> Self {
Self {
consecutive_failures: 0,
last_failure: Instant::now(),
status: CircuitStatus::Closed,
half_open_attempts: 0,
}
}
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub cooldown: Duration,
pub half_open_max: u32,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
cooldown: Duration::from_secs(30),
half_open_max: 1,
}
}
}
pub struct CircuitBreaker {
states: parking_lot::RwLock<HashMap<String, CircuitState>>,
config: CircuitBreakerConfig,
}
impl CircuitBreaker {
pub fn new(config: CircuitBreakerConfig) -> Self {
Self {
states: parking_lot::RwLock::new(HashMap::new()),
config,
}
}
pub fn check(&self, model: &str) -> Result<(), InferenceExecutionError> {
let mut states = self.states.write();
let state = states
.entry(model.to_string())
.or_insert_with(CircuitState::new);
match state.status {
CircuitStatus::Closed => Ok(()),
CircuitStatus::Open => {
if state.last_failure.elapsed() >= self.config.cooldown {
state.status = CircuitStatus::HalfOpen;
state.half_open_attempts = 1;
Ok(())
} else {
Err(InferenceExecutionError::Provider(format!(
"circuit breaker open for model {model}"
)))
}
}
CircuitStatus::HalfOpen => {
if state.half_open_attempts < self.config.half_open_max {
state.half_open_attempts += 1;
Ok(())
} else {
Err(InferenceExecutionError::Provider(format!(
"circuit breaker open for model {model}"
)))
}
}
}
}
pub fn record_success(&self, model: &str) {
let mut states = self.states.write();
if let Some(state) = states.get_mut(model) {
state.consecutive_failures = 0;
state.half_open_attempts = 0;
state.status = CircuitStatus::Closed;
}
}
pub fn record_failure(&self, model: &str) {
let mut states = self.states.write();
let state = states
.entry(model.to_string())
.or_insert_with(CircuitState::new);
state.consecutive_failures += 1;
state.last_failure = Instant::now();
if state.status == CircuitStatus::HalfOpen {
state.status = CircuitStatus::Open;
} else if state.consecutive_failures >= self.config.failure_threshold {
state.status = CircuitStatus::Open;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn fast_config() -> CircuitBreakerConfig {
CircuitBreakerConfig {
failure_threshold: 3,
cooldown: Duration::from_millis(50),
half_open_max: 1,
}
}
#[test]
fn closed_allows_requests() {
let cb = CircuitBreaker::new(fast_config());
assert!(cb.check("model-a").is_ok());
}
#[test]
fn opens_after_threshold_failures() {
let cb = CircuitBreaker::new(fast_config());
for _ in 0..3 {
cb.record_failure("model-a");
}
assert!(cb.check("model-a").is_err());
}
#[test]
fn below_threshold_stays_closed() {
let cb = CircuitBreaker::new(fast_config());
cb.record_failure("model-a");
cb.record_failure("model-a");
assert!(cb.check("model-a").is_ok());
}
#[test]
fn success_resets_failure_count() {
let cb = CircuitBreaker::new(fast_config());
cb.record_failure("model-a");
cb.record_failure("model-a");
cb.record_success("model-a");
cb.record_failure("model-a");
cb.record_failure("model-a");
assert!(cb.check("model-a").is_ok());
}
#[test]
fn transitions_to_half_open_after_cooldown() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
cooldown: Duration::from_millis(10),
half_open_max: 1,
};
let cb = CircuitBreaker::new(config);
cb.record_failure("model-a");
cb.record_failure("model-a");
assert!(cb.check("model-a").is_err());
std::thread::sleep(Duration::from_millis(15));
assert!(cb.check("model-a").is_ok());
assert!(cb.check("model-a").is_err());
}
#[test]
fn half_open_success_closes_circuit() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
cooldown: Duration::from_millis(10),
half_open_max: 1,
};
let cb = CircuitBreaker::new(config);
cb.record_failure("model-a");
cb.record_failure("model-a");
std::thread::sleep(Duration::from_millis(15));
assert!(cb.check("model-a").is_ok());
cb.record_success("model-a");
assert!(cb.check("model-a").is_ok());
assert!(cb.check("model-a").is_ok());
}
#[test]
fn half_open_failure_reopens_circuit() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
cooldown: Duration::from_millis(10),
half_open_max: 1,
};
let cb = CircuitBreaker::new(config);
cb.record_failure("model-a");
cb.record_failure("model-a");
std::thread::sleep(Duration::from_millis(15));
assert!(cb.check("model-a").is_ok());
cb.record_failure("model-a");
assert!(cb.check("model-a").is_err());
}
#[test]
fn independent_models() {
let cb = CircuitBreaker::new(fast_config());
for _ in 0..3 {
cb.record_failure("model-a");
}
assert!(cb.check("model-a").is_err());
assert!(cb.check("model-b").is_ok());
}
#[test]
fn default_config_values() {
let config = CircuitBreakerConfig::default();
assert_eq!(config.failure_threshold, 5);
assert_eq!(config.cooldown, Duration::from_secs(30));
assert_eq!(config.half_open_max, 1);
}
#[test]
fn error_message_contains_model_name() {
let cb = CircuitBreaker::new(fast_config());
for _ in 0..3 {
cb.record_failure("gpt-4o");
}
let err = cb.check("gpt-4o").unwrap_err();
let msg = err.to_string();
assert!(msg.contains("gpt-4o"), "error should mention model: {msg}");
assert!(
msg.contains("circuit breaker"),
"error should mention circuit breaker: {msg}"
);
}
#[test]
fn half_open_allows_multiple_probes() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
cooldown: Duration::from_millis(10),
half_open_max: 3,
};
let cb = CircuitBreaker::new(config);
cb.record_failure("m");
cb.record_failure("m");
assert!(cb.check("m").is_err());
std::thread::sleep(Duration::from_millis(15));
assert!(cb.check("m").is_ok()); assert!(cb.check("m").is_ok()); assert!(cb.check("m").is_ok()); assert!(cb.check("m").is_err()); }
mod proptest_circuit_breaker {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn opens_exactly_at_threshold(
threshold in 1u32..20,
failures in 0u32..30,
) {
let cb = CircuitBreaker::new(CircuitBreakerConfig {
failure_threshold: threshold,
cooldown: Duration::from_secs(60),
half_open_max: 1,
});
for _ in 0..failures {
cb.record_failure("model");
}
if failures >= threshold {
prop_assert!(
cb.check("model").is_err(),
"should be open after {failures} >= {threshold} failures"
);
} else {
prop_assert!(
cb.check("model").is_ok(),
"should be closed after {failures} < {threshold} failures"
);
}
}
#[test]
fn success_always_resets_to_closed(
pre_failures in 0u32..10,
) {
let cb = CircuitBreaker::new(CircuitBreakerConfig {
failure_threshold: 20, cooldown: Duration::from_secs(60),
half_open_max: 1,
});
for _ in 0..pre_failures {
cb.record_failure("model");
}
cb.record_success("model");
cb.record_failure("model");
prop_assert!(
cb.check("model").is_ok(),
"circuit should be closed after success reset"
);
}
#[test]
fn independent_models_do_not_interfere(
failures_a in 0u32..10,
failures_b in 0u32..10,
) {
let threshold = 5;
let cb = CircuitBreaker::new(CircuitBreakerConfig {
failure_threshold: threshold,
cooldown: Duration::from_secs(60),
half_open_max: 1,
});
for _ in 0..failures_a {
cb.record_failure("model-a");
}
for _ in 0..failures_b {
cb.record_failure("model-b");
}
let a_expected_open = failures_a >= threshold;
let b_expected_open = failures_b >= threshold;
prop_assert_eq!(cb.check("model-a").is_err(), a_expected_open);
prop_assert_eq!(cb.check("model-b").is_err(), b_expected_open);
}
}
}
}