use std::fmt;
use std::sync::Arc;
use std::sync::atomic::{AtomicU8, AtomicU32, AtomicU64, Ordering};
use std::time::{Duration, SystemTime};
use async_trait::async_trait;
use crate::llm::middleware::LlmMiddleware;
use crate::llm::{CallOptions, LlmError, Message};
const STATE_CLOSED: u8 = 0;
const STATE_OPEN: u8 = 1;
const STATE_HALF_OPEN: u8 = 2;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
#[allow(
clippy::match_same_arms,
reason = "Each enum variant maps to its corresponding state value"
)]
impl CircuitState {
const fn from_u8(value: u8) -> Self {
match value {
STATE_CLOSED => Self::Closed,
STATE_OPEN => Self::Open,
STATE_HALF_OPEN => Self::HalfOpen,
_ => Self::Closed, }
}
#[allow(
dead_code,
reason = "Method provided for API completeness, may be used in future"
)]
const fn as_u8(self) -> u8 {
match self {
Self::Closed => STATE_CLOSED,
Self::Open => STATE_OPEN,
Self::HalfOpen => STATE_HALF_OPEN,
}
}
}
#[derive(Clone, Debug)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub recovery_timeout: Duration,
pub half_open_max_calls: u32,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
recovery_timeout: Duration::from_secs(30),
half_open_max_calls: 1,
}
}
}
#[derive(thiserror::Error)]
#[error("circuit breaker is open")]
pub struct CircuitBreakerOpenError;
#[derive(Clone)]
pub struct CircuitBreaker {
config: CircuitBreakerConfig,
state: Arc<AtomicU8>,
failure_count: Arc<AtomicU32>,
success_count: Arc<AtomicU32>,
last_failure_time: Arc<AtomicU64>,
half_open_calls: Arc<AtomicU32>,
}
#[allow(
clippy::missing_fields_in_debug,
reason = "Atomic fields are internal implementation details"
)]
impl fmt::Debug for CircuitBreaker {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CircuitBreaker")
.field("config", &self.config)
.field("state", &self.state())
.field("failure_count", &self.failure_count())
.finish()
}
}
impl CircuitBreaker {
#[must_use]
pub fn new(config: CircuitBreakerConfig) -> Self {
Self {
config,
state: Arc::new(AtomicU8::new(STATE_CLOSED)),
failure_count: Arc::new(AtomicU32::new(0)),
success_count: Arc::new(AtomicU32::new(0)),
last_failure_time: Arc::new(AtomicU64::new(0)),
half_open_calls: Arc::new(AtomicU32::new(0)),
}
}
#[must_use]
pub fn state(&self) -> CircuitState {
let current_state = self.state.load(Ordering::Acquire);
if current_state == STATE_OPEN {
let last_failure = self.last_failure_time.load(Ordering::Relaxed);
#[allow(
clippy::cast_possible_truncation,
reason = "Recovery timeout fits in u64 for realistic values"
)]
let recovery_ms = self.config.recovery_timeout.as_millis() as u64;
let now = current_time_millis();
if last_failure > 0 && now.saturating_sub(last_failure) >= recovery_ms {
if self
.state
.compare_exchange(
current_state,
STATE_HALF_OPEN,
Ordering::AcqRel,
Ordering::Relaxed,
)
.is_ok()
{
self.half_open_calls.store(0, Ordering::Relaxed);
return CircuitState::HalfOpen;
}
}
}
CircuitState::from_u8(self.state.load(Ordering::Acquire))
}
#[must_use]
pub fn failure_count(&self) -> u32 {
self.failure_count.load(Ordering::Relaxed)
}
pub fn reset(&self) {
self.state.store(STATE_CLOSED, Ordering::Release);
self.failure_count.store(0, Ordering::Relaxed);
self.success_count.store(0, Ordering::Relaxed);
self.last_failure_time.store(0, Ordering::Relaxed);
self.half_open_calls.store(0, Ordering::Relaxed);
}
#[cfg(test)]
fn set_last_failure_time(&self, time: u64) {
self.last_failure_time.store(time, Ordering::Relaxed);
}
fn handle_success(&self) {
self.failure_count.store(0, Ordering::Relaxed);
let current_state = self.state.load(Ordering::Acquire);
if current_state == STATE_HALF_OPEN {
let successes = self.success_count.fetch_add(1, Ordering::Relaxed) + 1;
if successes >= self.config.half_open_max_calls
&& self
.state
.compare_exchange(
current_state,
STATE_CLOSED,
Ordering::AcqRel,
Ordering::Relaxed,
)
.is_ok()
{
self.success_count.store(0, Ordering::Relaxed);
self.half_open_calls.store(0, Ordering::Relaxed);
}
}
}
fn handle_failure(&self) {
self.last_failure_time
.store(current_time_millis(), Ordering::Relaxed);
let failures = self.failure_count.fetch_add(1, Ordering::Relaxed) + 1;
let current_state = self.state.load(Ordering::Acquire);
if failures >= self.config.failure_threshold {
if current_state != STATE_OPEN
&& self
.state
.compare_exchange(
current_state,
STATE_OPEN,
Ordering::AcqRel,
Ordering::Relaxed,
)
.is_ok()
{
self.success_count.store(0, Ordering::Relaxed);
}
} else if current_state == STATE_HALF_OPEN
&& self
.state
.compare_exchange(
current_state,
STATE_OPEN,
Ordering::AcqRel,
Ordering::Relaxed,
)
.is_ok()
{
self.success_count.store(0, Ordering::Relaxed);
self.half_open_calls.store(0, Ordering::Relaxed);
}
}
}
impl Default for CircuitBreaker {
fn default() -> Self {
Self::new(CircuitBreakerConfig::default())
}
}
#[cfg_attr(target_family = "wasm", async_trait(?Send))]
#[cfg_attr(not(target_family = "wasm"), async_trait)]
impl LlmMiddleware for CircuitBreaker {
async fn pre_invoke(
&self,
_messages: &mut Vec<Message>,
_options: &mut CallOptions,
) -> Result<(), LlmError> {
let current_state = self.state();
match current_state {
CircuitState::Closed => Ok(()),
CircuitState::Open => Err(LlmError::Other(Box::new(CircuitBreakerOpenError))),
CircuitState::HalfOpen => {
let calls = self.half_open_calls.fetch_add(1, Ordering::Relaxed);
if calls >= self.config.half_open_max_calls {
self.half_open_calls.fetch_sub(1, Ordering::Relaxed);
Err(LlmError::Other(Box::new(CircuitBreakerOpenError)))
} else {
Ok(())
}
}
}
}
async fn post_invoke(&self, result: &mut Result<Message, LlmError>) -> Result<(), LlmError> {
match result {
Ok(_) => self.handle_success(),
Err(_) => self.handle_failure(),
}
Ok(())
}
}
impl fmt::Debug for CircuitBreakerOpenError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CircuitBreakerOpenError").finish()
}
}
#[allow(
clippy::cast_possible_truncation,
reason = "Milliseconds fit in u64 for realistic time ranges"
)]
fn current_time_millis() -> u64 {
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map_or(0, |d| d.as_millis() as u64)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::llm::ChatModel;
use crate::llm::middleware::MiddlewareModel;
use crate::llm::mock::MockChatModel;
#[test]
fn test_circuit_breaker_new() {
let config = CircuitBreakerConfig::default();
let breaker = CircuitBreaker::new(config);
assert_eq!(breaker.state(), CircuitState::Closed);
assert_eq!(breaker.failure_count(), 0);
}
#[test]
fn test_circuit_breaker_default_config() {
let config = CircuitBreakerConfig::default();
assert_eq!(config.failure_threshold, 5);
assert_eq!(config.recovery_timeout, Duration::from_secs(30));
assert_eq!(config.half_open_max_calls, 1);
}
#[test]
fn test_circuit_breaker_default() {
let breaker = CircuitBreaker::default();
assert_eq!(breaker.state(), CircuitState::Closed);
}
#[tokio::test]
async fn test_circuit_breaker_closed_allows_calls() {
let base_model = MockChatModel::new("gpt-4").with_response("Hello!");
let breaker = CircuitBreaker::default();
let model = MiddlewareModel::new(base_model).with_middleware(breaker);
let messages = vec![Message::human("Hi")];
let result = model.invoke(&messages, None).await;
let _ = result.unwrap();
}
#[tokio::test]
async fn test_circuit_breaker_transitions_to_open() {
let config = CircuitBreakerConfig {
failure_threshold: 3,
recovery_timeout: Duration::from_millis(1),
half_open_max_calls: 1,
};
let breaker = CircuitBreaker::new(config);
for _ in 0..3 {
breaker.handle_failure();
}
assert_eq!(breaker.state(), CircuitState::Open);
assert_eq!(breaker.failure_count(), 3);
}
#[tokio::test]
async fn test_circuit_breaker_open_rejects_calls() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
recovery_timeout: Duration::from_secs(30),
half_open_max_calls: 1,
};
let breaker = CircuitBreaker::new(config);
breaker.handle_failure();
breaker.handle_failure();
assert_eq!(breaker.state(), CircuitState::Open);
let base_model = MockChatModel::new("gpt-4").with_response("Hello!");
let model = MiddlewareModel::new(base_model).with_middleware(breaker);
let messages = vec![Message::human("Hi")];
let result = model.invoke(&messages, None).await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("circuit breaker is open")
);
}
#[tokio::test]
async fn test_circuit_breaker_half_open_allows_limited_calls() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
recovery_timeout: Duration::from_millis(1),
half_open_max_calls: 2,
};
let breaker = CircuitBreaker::new(config);
breaker.handle_failure();
breaker.handle_failure();
assert_eq!(breaker.state(), CircuitState::Open);
breaker.set_last_failure_time(current_time_millis().saturating_sub(100));
assert_eq!(breaker.state(), CircuitState::HalfOpen);
let base_model = MockChatModel::new("gpt-4").with_response("Hello!");
let model = MiddlewareModel::new(base_model.clone()).with_middleware(breaker.clone());
let messages = vec![Message::human("Hi")];
let result1 = model.invoke(&messages, None).await;
let _ = result1.unwrap();
assert_eq!(breaker.state(), CircuitState::HalfOpen);
let model2 = MiddlewareModel::new(base_model.clone()).with_middleware(breaker.clone());
let result2 = model2.invoke(&messages, None).await;
let _ = result2.unwrap();
assert_eq!(breaker.state(), CircuitState::Closed);
let model3 = MiddlewareModel::new(base_model).with_middleware(breaker.clone());
let result3 = model3.invoke(&messages, None).await;
let _ = result3.unwrap();
}
#[tokio::test]
async fn test_circuit_breaker_half_open_success_closes() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
recovery_timeout: Duration::from_millis(1),
half_open_max_calls: 1,
};
let breaker = CircuitBreaker::new(config);
breaker.handle_failure();
breaker.handle_failure();
breaker.set_last_failure_time(current_time_millis().saturating_sub(100));
assert_eq!(breaker.state(), CircuitState::HalfOpen);
let base_model = MockChatModel::new("gpt-4").with_response("Hello!");
let model = MiddlewareModel::new(base_model).with_middleware(breaker.clone());
let messages = vec![Message::human("Hi")];
let result = model.invoke(&messages, None).await;
let _ = result.unwrap();
assert_eq!(breaker.state(), CircuitState::Closed);
assert_eq!(breaker.failure_count(), 0);
}
#[tokio::test]
async fn test_circuit_breaker_half_open_failure_opens() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
recovery_timeout: Duration::from_millis(1),
half_open_max_calls: 1,
};
let breaker = CircuitBreaker::new(config);
breaker.handle_failure();
breaker.handle_failure();
breaker.set_last_failure_time(current_time_millis().saturating_sub(100));
assert_eq!(breaker.state(), CircuitState::HalfOpen);
let base_model = MockChatModel::new("gpt-4").with_error();
let model = MiddlewareModel::new(base_model).with_middleware(breaker.clone());
let messages = vec![Message::human("Hi")];
let result = model.invoke(&messages, None).await;
let _ = result.unwrap_err();
assert_eq!(breaker.state(), CircuitState::Open);
}
#[test]
fn test_circuit_breaker_reset() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
recovery_timeout: Duration::from_secs(30),
half_open_max_calls: 1,
};
let breaker = CircuitBreaker::new(config);
breaker.handle_failure();
breaker.handle_failure();
assert_eq!(breaker.state(), CircuitState::Open);
breaker.reset();
assert_eq!(breaker.state(), CircuitState::Closed);
assert_eq!(breaker.failure_count(), 0);
}
#[tokio::test]
async fn test_circuit_breaker_success_in_closed_resets_failures() {
let config = CircuitBreakerConfig {
failure_threshold: 5,
recovery_timeout: Duration::from_secs(30),
half_open_max_calls: 1,
};
let breaker = CircuitBreaker::new(config);
breaker.handle_failure();
breaker.handle_failure();
assert_eq!(breaker.failure_count(), 2);
let base_model = MockChatModel::new("gpt-4").with_response("Hello!");
let model = MiddlewareModel::new(base_model).with_middleware(breaker.clone());
let messages = vec![Message::human("Hi")];
let result = model.invoke(&messages, None).await;
let _ = result.unwrap();
assert_eq!(breaker.failure_count(), 0);
}
#[test]
fn test_circuit_state_constants() {
assert_eq!(CircuitState::Closed.as_u8(), 0);
assert_eq!(CircuitState::Open.as_u8(), 1);
assert_eq!(CircuitState::HalfOpen.as_u8(), 2);
assert_eq!(CircuitState::from_u8(0), CircuitState::Closed);
assert_eq!(CircuitState::from_u8(1), CircuitState::Open);
assert_eq!(CircuitState::from_u8(2), CircuitState::HalfOpen);
assert_eq!(CircuitState::from_u8(255), CircuitState::Closed);
}
#[test]
fn test_circuit_breaker_open_to_half_open_transition() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
recovery_timeout: Duration::from_millis(1),
half_open_max_calls: 1,
};
let breaker = CircuitBreaker::new(config);
breaker.handle_failure();
breaker.handle_failure();
assert_eq!(breaker.state(), CircuitState::Open);
breaker.set_last_failure_time(current_time_millis().saturating_sub(100));
assert_eq!(breaker.state(), CircuitState::HalfOpen);
}
#[tokio::test]
async fn test_circuit_breaker_multiple_failures_then_success() {
let config = CircuitBreakerConfig {
failure_threshold: 3,
recovery_timeout: Duration::from_secs(30),
half_open_max_calls: 1,
};
let breaker = CircuitBreaker::new(config);
breaker.handle_failure();
breaker.handle_failure();
assert_eq!(breaker.failure_count(), 2);
assert_eq!(breaker.state(), CircuitState::Closed);
let base_model = MockChatModel::new("gpt-4").with_response("Hello!");
let model = MiddlewareModel::new(base_model).with_middleware(breaker.clone());
let messages = vec![Message::human("Hi")];
let result = model.invoke(&messages, None).await;
let _ = result.unwrap();
assert_eq!(breaker.failure_count(), 0);
assert_eq!(breaker.state(), CircuitState::Closed);
}
#[test]
fn test_circuit_breaker_clone() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
recovery_timeout: Duration::from_secs(30),
half_open_max_calls: 1,
};
let breaker = CircuitBreaker::new(config);
breaker.handle_failure();
breaker.handle_failure();
let cloned = breaker.clone();
assert_eq!(breaker.state(), cloned.state());
assert_eq!(breaker.failure_count(), cloned.failure_count());
}
}