use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::time::Duration;
use async_trait::async_trait;
use crate::error::PeError;
use crate::llm::{LlmProvider, LlmResponse, ToolSchema};
use crate::message::Message;
use crate::provider_middleware::ProviderMiddleware;
const STATE_CLOSED: u32 = 0;
const STATE_OPEN: u32 = 1;
const STATE_HALF_OPEN: u32 = 2;
pub struct CircuitBreaker {
failure_threshold: u32,
recovery_timeout: Duration,
state: AtomicU32,
failure_count: AtomicU32,
opened_at: AtomicU64,
}
impl CircuitBreaker {
pub fn new(failure_threshold: u32, recovery_timeout: Duration) -> Self {
Self {
failure_threshold,
recovery_timeout,
state: AtomicU32::new(STATE_CLOSED),
failure_count: AtomicU32::new(0),
opened_at: AtomicU64::new(0),
}
}
pub fn state_name(&self) -> &'static str {
match self.state.load(Ordering::SeqCst) {
STATE_CLOSED => "closed",
STATE_OPEN => "open",
STATE_HALF_OPEN => "half-open",
_ => "unknown",
}
}
pub fn failure_count(&self) -> u32 {
self.failure_count.load(Ordering::SeqCst)
}
fn now_millis() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
fn record_success(&self) {
self.failure_count.store(0, Ordering::SeqCst);
self.state.store(STATE_CLOSED, Ordering::SeqCst);
}
fn record_failure(&self) {
let count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
if count >= self.failure_threshold {
self.state.store(STATE_OPEN, Ordering::SeqCst);
self.opened_at.store(Self::now_millis(), Ordering::SeqCst);
}
}
fn should_allow(&self) -> bool {
match self.state.load(Ordering::SeqCst) {
STATE_CLOSED => true,
STATE_HALF_OPEN => {
self.state
.compare_exchange(
STATE_HALF_OPEN,
STATE_CLOSED,
Ordering::SeqCst,
Ordering::SeqCst,
)
.is_ok()
}
STATE_OPEN => {
let opened = self.opened_at.load(Ordering::SeqCst);
let elapsed = Self::now_millis().saturating_sub(opened);
if elapsed >= self.recovery_timeout.as_millis() as u64 {
self.state
.compare_exchange(
STATE_OPEN,
STATE_HALF_OPEN,
Ordering::SeqCst,
Ordering::SeqCst,
)
.is_ok()
} else {
false
}
}
_ => false,
}
}
}
#[async_trait]
impl ProviderMiddleware for CircuitBreaker {
async fn wrap_complete(
&self,
messages: &[Message],
tools: &[ToolSchema],
next: &dyn LlmProvider,
) -> Result<LlmResponse, PeError> {
if !self.should_allow() {
return Err(PeError::LlmProvider {
details: "circuit breaker open — provider is unavailable".into(),
});
}
match next.complete(messages, tools).await {
Ok(resp) => {
self.record_success();
Ok(resp)
}
Err(e) if e.is_transient() => {
self.record_failure();
Err(e)
}
Err(e) => Err(e), }
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mock_provider::MockProvider;
fn llm_err() -> PeError {
PeError::LlmProvider {
details: "err".into(),
}
}
fn fail_provider(n: usize) -> MockProvider {
let mut p = MockProvider::new();
for _ in 0..n {
p = p.respond_with_error(llm_err());
}
p
}
#[tokio::test]
async fn test_closed_allows_calls() {
let cb = CircuitBreaker::new(3, Duration::from_secs(60));
let resp = cb
.wrap_complete(&[], &[], &MockProvider::new().respond_with("ok"))
.await
.unwrap();
assert_eq!(resp.message.content.as_text(), Some("ok"));
assert_eq!(cb.state_name(), "closed");
}
#[tokio::test]
async fn test_opens_after_threshold_failures() {
let cb = CircuitBreaker::new(2, Duration::from_secs(60));
let provider = fail_provider(2);
let _ = cb.wrap_complete(&[], &[], &provider).await;
assert_eq!(cb.state_name(), "closed");
let _ = cb.wrap_complete(&[], &[], &provider).await;
assert_eq!(cb.state_name(), "open");
}
#[tokio::test]
async fn test_open_rejects_immediately() {
let cb = CircuitBreaker::new(1, Duration::from_secs(60));
let _ = cb.wrap_complete(&[], &[], &fail_provider(1)).await;
assert_eq!(cb.state_name(), "open");
let ok = MockProvider::new().respond_with("should not reach");
let err = cb.wrap_complete(&[], &[], &ok).await.unwrap_err();
assert!(matches!(err, PeError::LlmProvider { .. }));
assert_eq!(ok.remaining(), 1);
}
#[tokio::test]
async fn test_half_open_recovery_and_reopen() {
let cb = CircuitBreaker::new(1, Duration::from_millis(10));
let _ = cb.wrap_complete(&[], &[], &fail_provider(1)).await;
assert_eq!(cb.state_name(), "open");
tokio::time::sleep(Duration::from_millis(20)).await;
let probe = MockProvider::new().respond_with("recovered");
let resp = cb.wrap_complete(&[], &[], &probe).await.unwrap();
assert_eq!(resp.message.content.as_text(), Some("recovered"));
assert_eq!(cb.state_name(), "closed");
let _ = cb.wrap_complete(&[], &[], &fail_provider(1)).await;
tokio::time::sleep(Duration::from_millis(20)).await;
let _ = cb.wrap_complete(&[], &[], &fail_provider(1)).await;
assert_eq!(cb.state_name(), "open");
}
#[tokio::test]
async fn test_permanent_errors_dont_trip_breaker() {
let cb = CircuitBreaker::new(1, Duration::from_secs(60));
let p = MockProvider::new().respond_with_error(PeError::PermissionDenied {
action: "write".into(),
});
let _ = cb.wrap_complete(&[], &[], &p).await;
assert_eq!(cb.state_name(), "closed");
assert_eq!(cb.failure_count(), 0);
}
#[tokio::test]
async fn test_success_resets_failure_count() {
let cb = CircuitBreaker::new(3, Duration::from_secs(60));
let provider = fail_provider(2).respond_with("ok");
let _ = cb.wrap_complete(&[], &[], &provider).await;
let _ = cb.wrap_complete(&[], &[], &provider).await;
assert_eq!(cb.failure_count(), 2);
let _ = cb.wrap_complete(&[], &[], &provider).await;
assert_eq!(cb.failure_count(), 0);
}
}