use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use rust_decimal::Decimal;
use tokio::sync::Mutex;
use crate::llm::error::LlmError;
use crate::llm::provider::{
CompletionRequest, CompletionResponse, LlmProvider, ModelMetadata, ToolCompletionRequest,
ToolCompletionResponse,
};
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub recovery_timeout: Duration,
pub half_open_successes_needed: u32,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
recovery_timeout: Duration::from_secs(30),
half_open_successes_needed: 2,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
struct BreakerState {
state: CircuitState,
consecutive_failures: u32,
opened_at: Option<Instant>,
half_open_successes: u32,
}
impl BreakerState {
fn new() -> Self {
Self {
state: CircuitState::Closed,
consecutive_failures: 0,
opened_at: None,
half_open_successes: 0,
}
}
}
pub struct CircuitBreakerProvider {
inner: Arc<dyn LlmProvider>,
state: Mutex<BreakerState>,
config: CircuitBreakerConfig,
}
impl CircuitBreakerProvider {
pub fn new(inner: Arc<dyn LlmProvider>, config: CircuitBreakerConfig) -> Self {
Self {
inner,
state: Mutex::new(BreakerState::new()),
config,
}
}
pub async fn circuit_state(&self) -> CircuitState {
self.state.lock().await.state
}
pub async fn consecutive_failures(&self) -> u32 {
self.state.lock().await.consecutive_failures
}
async fn check_allowed(&self) -> Result<(), LlmError> {
let mut state = self.state.lock().await;
match state.state {
CircuitState::Closed | CircuitState::HalfOpen => Ok(()),
CircuitState::Open => {
if let Some(opened_at) = state.opened_at {
if opened_at.elapsed() >= self.config.recovery_timeout {
state.state = CircuitState::HalfOpen;
state.half_open_successes = 0;
tracing::info!(
provider = self.inner.model_name(),
"Circuit breaker: Open -> HalfOpen, allowing probe"
);
Ok(())
} else {
let remaining = self
.config
.recovery_timeout
.checked_sub(opened_at.elapsed())
.unwrap_or(Duration::ZERO);
Err(LlmError::RequestFailed {
provider: self.inner.model_name().to_string(),
reason: format!(
"Circuit breaker open ({} consecutive failures, \
recovery in {:.0}s)",
state.consecutive_failures,
remaining.as_secs_f64()
),
})
}
} else {
state.state = CircuitState::Closed;
Ok(())
}
}
}
}
async fn record_success(&self) {
let mut state = self.state.lock().await;
match state.state {
CircuitState::Closed => {
state.consecutive_failures = 0;
}
CircuitState::HalfOpen => {
state.half_open_successes += 1;
if state.half_open_successes >= self.config.half_open_successes_needed {
state.state = CircuitState::Closed;
state.consecutive_failures = 0;
state.opened_at = None;
tracing::info!(
provider = self.inner.model_name(),
"Circuit breaker: HalfOpen -> Closed (recovered)"
);
}
}
CircuitState::Open => {
debug_assert!(
false,
"BUG: record_success() called while circuit breaker is Open — \
check_allowed() was bypassed for provider {}",
self.inner.model_name()
);
state.state = CircuitState::Closed;
state.consecutive_failures = 0;
state.opened_at = None;
}
}
}
async fn record_failure(&self, err: &LlmError) {
if !is_transient(err) {
return;
}
let mut state = self.state.lock().await;
match state.state {
CircuitState::Closed => {
state.consecutive_failures += 1;
if state.consecutive_failures >= self.config.failure_threshold {
state.state = CircuitState::Open;
state.opened_at = Some(Instant::now());
tracing::warn!(
provider = self.inner.model_name(),
failures = state.consecutive_failures,
"Circuit breaker: Closed -> Open"
);
}
}
CircuitState::HalfOpen => {
state.state = CircuitState::Open;
state.opened_at = Some(Instant::now());
state.half_open_successes = 0;
tracing::warn!(
provider = self.inner.model_name(),
"Circuit breaker: HalfOpen -> Open (probe failed)"
);
}
CircuitState::Open => {}
}
}
}
fn is_transient(err: &LlmError) -> bool {
matches!(
err,
LlmError::RequestFailed { .. }
| LlmError::RateLimited { .. }
| LlmError::InvalidResponse { .. }
| LlmError::SessionExpired { .. }
| LlmError::SessionRenewalFailed { .. }
| LlmError::Http(_)
| LlmError::Io(_)
)
}
#[async_trait]
impl LlmProvider for CircuitBreakerProvider {
fn model_name(&self) -> &str {
self.inner.model_name()
}
fn cost_per_token(&self) -> (Decimal, Decimal) {
self.inner.cost_per_token()
}
fn cache_write_multiplier(&self) -> Decimal {
self.inner.cache_write_multiplier()
}
fn cache_read_discount(&self) -> Decimal {
self.inner.cache_read_discount()
}
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, LlmError> {
self.check_allowed().await?;
match self.inner.complete(request).await {
Ok(resp) => {
self.record_success().await;
Ok(resp)
}
Err(err) => {
self.record_failure(&err).await;
Err(err)
}
}
}
async fn complete_with_tools(
&self,
request: ToolCompletionRequest,
) -> Result<ToolCompletionResponse, LlmError> {
self.check_allowed().await?;
match self.inner.complete_with_tools(request).await {
Ok(resp) => {
self.record_success().await;
Ok(resp)
}
Err(err) => {
self.record_failure(&err).await;
Err(err)
}
}
}
async fn list_models(&self) -> Result<Vec<String>, LlmError> {
self.inner.list_models().await
}
async fn model_metadata(&self) -> Result<ModelMetadata, LlmError> {
self.inner.model_metadata().await
}
fn effective_model_name(&self, requested_model: Option<&str>) -> String {
self.inner.effective_model_name(requested_model)
}
fn active_model_name(&self) -> String {
self.inner.active_model_name()
}
fn set_model(&self, model: &str) -> Result<(), LlmError> {
self.inner.set_model(model)
}
fn calculate_cost(&self, input_tokens: u32, output_tokens: u32) -> Decimal {
self.inner.calculate_cost(input_tokens, output_tokens)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::testing::StubLlm;
fn make_request() -> CompletionRequest {
CompletionRequest::new(vec![crate::llm::ChatMessage::user("hello")])
}
fn make_tool_request() -> ToolCompletionRequest {
ToolCompletionRequest::new(vec![crate::llm::ChatMessage::user("hello")], vec![])
}
fn fast_config(threshold: u32) -> CircuitBreakerConfig {
CircuitBreakerConfig {
failure_threshold: threshold,
recovery_timeout: Duration::from_millis(50),
half_open_successes_needed: 1,
}
}
#[tokio::test]
async fn closed_allows_calls_and_resets_on_success() {
let stub = Arc::new(StubLlm::new("ok").with_model_name("test"));
let cb = CircuitBreakerProvider::new(stub, fast_config(3));
let resp = cb.complete(make_request()).await;
assert!(resp.is_ok());
assert_eq!(cb.circuit_state().await, CircuitState::Closed);
assert_eq!(cb.consecutive_failures().await, 0);
}
#[tokio::test]
async fn failures_accumulate_then_trip_to_open() {
let stub = Arc::new(StubLlm::failing("test"));
let cb = CircuitBreakerProvider::new(stub, fast_config(3));
for i in 0..2 {
let _ = cb.complete(make_request()).await;
assert_eq!(cb.circuit_state().await, CircuitState::Closed);
assert_eq!(cb.consecutive_failures().await, i + 1);
}
let _ = cb.complete(make_request()).await;
assert_eq!(cb.circuit_state().await, CircuitState::Open);
}
#[tokio::test]
async fn open_rejects_immediately() {
let stub = Arc::new(StubLlm::failing("test"));
let cb = CircuitBreakerProvider::new(
stub,
CircuitBreakerConfig {
failure_threshold: 1,
recovery_timeout: Duration::from_secs(60),
half_open_successes_needed: 1,
},
);
let _ = cb.complete(make_request()).await;
assert_eq!(cb.circuit_state().await, CircuitState::Open);
let err = cb.complete(make_request()).await.unwrap_err();
match err {
LlmError::RequestFailed { reason, .. } => {
assert!(
reason.contains("Circuit breaker open"),
"Expected circuit breaker message, got: {}",
reason
);
}
other => panic!("Expected RequestFailed, got: {:?}", other),
}
}
#[tokio::test]
async fn recovery_timeout_transitions_to_half_open() {
let stub = Arc::new(StubLlm::failing("test"));
let cb = CircuitBreakerProvider::new(stub, fast_config(1));
let _ = cb.complete(make_request()).await;
assert_eq!(cb.circuit_state().await, CircuitState::Open);
tokio::time::sleep(Duration::from_millis(60)).await;
let _ = cb.complete(make_request()).await;
assert_eq!(cb.circuit_state().await, CircuitState::Open);
}
#[tokio::test]
async fn half_open_success_closes_circuit() {
let stub = Arc::new(StubLlm::failing("test"));
let cb = CircuitBreakerProvider::new(stub.clone(), fast_config(1));
let _ = cb.complete(make_request()).await;
assert_eq!(cb.circuit_state().await, CircuitState::Open);
tokio::time::sleep(Duration::from_millis(60)).await;
stub.set_failing(false);
let resp = cb.complete(make_request()).await;
assert!(resp.is_ok());
assert_eq!(cb.circuit_state().await, CircuitState::Closed);
assert_eq!(cb.consecutive_failures().await, 0);
}
#[tokio::test]
async fn half_open_failure_reopens_circuit() {
let stub = Arc::new(StubLlm::failing("test"));
let cb = CircuitBreakerProvider::new(stub, fast_config(1));
let _ = cb.complete(make_request()).await;
tokio::time::sleep(Duration::from_millis(60)).await;
let _ = cb.complete(make_request()).await;
assert_eq!(cb.circuit_state().await, CircuitState::Open);
}
#[tokio::test]
async fn non_transient_errors_do_not_trip_breaker() {
let stub = Arc::new(StubLlm::failing_non_transient("test"));
let cb = CircuitBreakerProvider::new(stub, fast_config(1));
for _ in 0..5 {
let _ = cb.complete(make_request()).await;
}
assert_eq!(cb.circuit_state().await, CircuitState::Closed);
assert_eq!(cb.consecutive_failures().await, 0);
}
#[tokio::test]
async fn success_resets_failure_count() {
let stub = Arc::new(StubLlm::failing("test"));
let cb = CircuitBreakerProvider::new(stub.clone(), fast_config(3));
let _ = cb.complete(make_request()).await;
let _ = cb.complete(make_request()).await;
assert_eq!(cb.consecutive_failures().await, 2);
stub.set_failing(false);
let resp = cb.complete(make_request()).await;
assert!(resp.is_ok());
assert_eq!(cb.consecutive_failures().await, 0);
}
#[tokio::test]
async fn complete_with_tools_uses_same_breaker_logic() {
let stub = Arc::new(StubLlm::failing("test"));
let cb = CircuitBreakerProvider::new(stub, fast_config(2));
let _ = cb.complete_with_tools(make_tool_request()).await;
let _ = cb.complete_with_tools(make_tool_request()).await;
assert_eq!(cb.circuit_state().await, CircuitState::Open);
}
#[tokio::test]
async fn multiple_half_open_successes_needed() {
let stub = Arc::new(StubLlm::failing("test"));
let cb = CircuitBreakerProvider::new(
stub.clone(),
CircuitBreakerConfig {
failure_threshold: 1,
recovery_timeout: Duration::from_millis(50),
half_open_successes_needed: 3,
},
);
let _ = cb.complete(make_request()).await;
tokio::time::sleep(Duration::from_millis(60)).await;
stub.set_failing(false);
let _ = cb.complete(make_request()).await;
assert_eq!(cb.circuit_state().await, CircuitState::HalfOpen);
let _ = cb.complete(make_request()).await;
assert_eq!(cb.circuit_state().await, CircuitState::HalfOpen);
let _ = cb.complete(make_request()).await;
assert_eq!(cb.circuit_state().await, CircuitState::Closed);
}
#[test]
fn transient_classification() {
assert!(is_transient(&LlmError::RequestFailed {
provider: "p".into(),
reason: "err".into(),
}));
assert!(is_transient(&LlmError::RateLimited {
provider: "p".into(),
retry_after: None,
}));
assert!(is_transient(&LlmError::InvalidResponse {
provider: "p".into(),
reason: "bad".into(),
}));
assert!(is_transient(&LlmError::SessionExpired {
provider: "p".into(),
}));
assert!(is_transient(&LlmError::SessionRenewalFailed {
provider: "p".into(),
reason: "timeout".into(),
}));
assert!(is_transient(&LlmError::Io(std::io::Error::new(
std::io::ErrorKind::ConnectionReset,
"reset"
))));
assert!(!is_transient(&LlmError::AuthFailed {
provider: "p".into(),
}));
assert!(!is_transient(&LlmError::ContextLengthExceeded {
used: 100_000,
limit: 50_000,
}));
assert!(!is_transient(&LlmError::ModelNotAvailable {
provider: "p".into(),
model: "m".into(),
}));
assert!(!is_transient(&LlmError::Json(
serde_json::from_str::<String>("bad").unwrap_err()
)));
}
#[tokio::test]
async fn passthrough_methods_delegate_to_inner() {
let stub = Arc::new(StubLlm::new("ok").with_model_name("my-model"));
let cb = CircuitBreakerProvider::new(stub, fast_config(3));
assert_eq!(cb.model_name(), "my-model");
assert_eq!(cb.active_model_name(), "my-model");
assert_eq!(cb.cost_per_token(), (Decimal::ZERO, Decimal::ZERO));
assert_eq!(cb.calculate_cost(100, 50), Decimal::ZERO);
}
struct HangingProvider;
#[async_trait]
impl LlmProvider for HangingProvider {
fn model_name(&self) -> &str {
"hanging"
}
fn cost_per_token(&self) -> (Decimal, Decimal) {
(Decimal::ZERO, Decimal::ZERO)
}
async fn complete(
&self,
_request: CompletionRequest,
) -> Result<CompletionResponse, LlmError> {
std::future::pending().await
}
async fn complete_with_tools(
&self,
_request: ToolCompletionRequest,
) -> Result<ToolCompletionResponse, LlmError> {
std::future::pending().await
}
}
#[tokio::test]
async fn hanging_provider_behind_breaker_can_be_timed_out() {
let hanging: Arc<dyn LlmProvider> = Arc::new(HangingProvider);
let cb = CircuitBreakerProvider::new(hanging, fast_config(1));
let result =
tokio::time::timeout(Duration::from_millis(100), cb.complete(make_request())).await;
assert!(result.is_err(), "should timeout, not hang");
}
#[tokio::test]
async fn rapid_open_close_cycles_do_not_corrupt_state() {
let stub = Arc::new(StubLlm::failing("test"));
let cb = CircuitBreakerProvider::new(
stub.clone(),
CircuitBreakerConfig {
failure_threshold: 1,
recovery_timeout: Duration::from_millis(10),
half_open_successes_needed: 1,
},
);
for _ in 0..5 {
let _ = cb.complete(make_request()).await;
assert_eq!(cb.circuit_state().await, CircuitState::Open);
tokio::time::sleep(Duration::from_millis(15)).await;
let _ = cb.complete(make_request()).await;
assert_eq!(cb.circuit_state().await, CircuitState::Open);
}
tokio::time::sleep(Duration::from_millis(15)).await;
stub.set_failing(false);
let result = cb.complete(make_request()).await;
assert!(result.is_ok());
assert_eq!(cb.circuit_state().await, CircuitState::Closed);
}
#[tokio::test]
async fn mixed_error_types_only_transient_counts() {
let non_transient = Arc::new(StubLlm::failing_non_transient("test"));
let cb_nt = CircuitBreakerProvider::new(non_transient, fast_config(3));
for _ in 0..100 {
let _ = cb_nt.complete(make_request()).await;
}
assert_eq!(cb_nt.circuit_state().await, CircuitState::Closed);
assert_eq!(cb_nt.consecutive_failures().await, 0);
}
#[tokio::test]
async fn test_cooldown_at_zero_nanos() {
let stub = Arc::new(StubLlm::failing("test"));
let cb = CircuitBreakerProvider::new(
stub.clone(),
CircuitBreakerConfig {
failure_threshold: 1,
recovery_timeout: Duration::ZERO,
half_open_successes_needed: 1,
},
);
let _ = cb.complete(make_request()).await;
assert_eq!(cb.circuit_state().await, CircuitState::Open);
stub.set_failing(false);
let result = cb.complete(make_request()).await;
assert!(
result.is_ok(),
"zero recovery_timeout should allow immediate probe"
);
assert_eq!(
cb.circuit_state().await,
CircuitState::Closed,
"successful probe after zero-timeout should close the circuit"
);
stub.set_failing(true);
let _ = cb.complete(make_request()).await;
assert_eq!(cb.circuit_state().await, CircuitState::Open);
let _ = cb.complete(make_request()).await;
assert_eq!(
cb.circuit_state().await,
CircuitState::Open,
"failed probe should re-open circuit even with zero timeout"
);
}
#[tokio::test]
async fn test_circuit_breaker_half_open_failure_reopens() {
let stub = Arc::new(StubLlm::failing("test"));
let cb = CircuitBreakerProvider::new(
stub.clone(),
CircuitBreakerConfig {
failure_threshold: 1,
recovery_timeout: Duration::from_millis(20),
half_open_successes_needed: 3, },
);
let _ = cb.complete(make_request()).await;
assert_eq!(cb.circuit_state().await, CircuitState::Open);
tokio::time::sleep(Duration::from_millis(30)).await;
stub.set_failing(false);
let _ = cb.complete(make_request()).await;
assert_eq!(cb.circuit_state().await, CircuitState::HalfOpen);
stub.set_failing(true);
let _ = cb.complete(make_request()).await;
assert_eq!(
cb.circuit_state().await,
CircuitState::Open,
"failure in half-open should immediately re-open the circuit"
);
tokio::time::sleep(Duration::from_millis(30)).await;
stub.set_failing(false);
let _ = cb.complete(make_request()).await;
assert_eq!(cb.circuit_state().await, CircuitState::HalfOpen);
let _ = cb.complete(make_request()).await;
assert_eq!(cb.circuit_state().await, CircuitState::HalfOpen);
let _ = cb.complete(make_request()).await;
assert_eq!(
cb.circuit_state().await,
CircuitState::Closed,
"3 fresh successes needed after re-open, not 2"
);
assert_eq!(cb.consecutive_failures().await, 0);
}
}