use std::sync::atomic::{AtomicU64, AtomicU8, Ordering};
use std::time::{Duration, Instant};
#[repr(u8)]
enum CircuitState {
Closed = 0,
Open = 1,
HalfOpen = 2,
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub open_duration: Duration,
pub half_open_successes: u32,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
open_duration: Duration::from_secs(30),
half_open_successes: 1,
}
}
}
pub struct CircuitBreaker {
config: CircuitBreakerConfig,
state: AtomicU8,
consecutive_failures: AtomicU64,
consecutive_successes: AtomicU64,
opened_at: parking_lot::Mutex<Option<Instant>>,
}
impl CircuitBreaker {
pub fn new(config: CircuitBreakerConfig) -> Self {
Self {
config,
state: AtomicU8::new(CircuitState::Closed as u8),
consecutive_failures: AtomicU64::new(0),
consecutive_successes: AtomicU64::new(0),
opened_at: parking_lot::Mutex::new(None),
}
}
pub fn allow_request(&self) -> Result<(), CircuitOpenError> {
let state = self.load_state();
match state {
CircuitState::Closed => Ok(()),
CircuitState::Open => {
let opened_at = self.opened_at.lock();
if let Some(t) = *opened_at {
if t.elapsed() >= self.config.open_duration {
drop(opened_at);
self.state.store(CircuitState::HalfOpen as u8, Ordering::SeqCst);
self.consecutive_successes.store(0, Ordering::SeqCst);
return Ok(());
}
}
Err(CircuitOpenError {
remaining: self.config.open_duration.saturating_sub(
opened_at.map(|t| t.elapsed()).unwrap_or_default(),
),
})
}
CircuitState::HalfOpen => Ok(()),
}
}
pub fn record_success(&self) {
match self.load_state() {
CircuitState::Closed => {
self.consecutive_failures.store(0, Ordering::SeqCst);
}
CircuitState::HalfOpen => {
let prev = self.consecutive_successes.fetch_add(1, Ordering::SeqCst);
if prev + 1 >= self.config.half_open_successes as u64 {
self.state.store(CircuitState::Closed as u8, Ordering::SeqCst);
self.consecutive_failures.store(0, Ordering::SeqCst);
}
}
CircuitState::Open => {}
}
}
pub fn record_failure(&self) {
match self.load_state() {
CircuitState::Closed => {
let prev = self.consecutive_failures.fetch_add(1, Ordering::SeqCst);
if prev + 1 >= self.config.failure_threshold as u64 {
self.state.store(CircuitState::Open as u8, Ordering::SeqCst);
*self.opened_at.lock() = Some(Instant::now());
}
}
CircuitState::HalfOpen => {
self.state.store(CircuitState::Open as u8, Ordering::SeqCst);
*self.opened_at.lock() = Some(Instant::now());
}
CircuitState::Open => {}
}
}
pub fn reset(&self) {
self.state.store(CircuitState::Closed as u8, Ordering::SeqCst);
self.consecutive_failures.store(0, Ordering::SeqCst);
self.consecutive_successes.store(0, Ordering::SeqCst);
*self.opened_at.lock() = None;
}
fn load_state(&self) -> CircuitState {
match self.state.load(Ordering::SeqCst) {
0 => CircuitState::Closed,
1 => CircuitState::Open,
_ => CircuitState::HalfOpen,
}
}
}
#[derive(Debug, thiserror::Error)]
#[error("Circuit is open — retry after {remaining:?}")]
pub struct CircuitOpenError {
pub remaining: Duration,
}
#[derive(Debug, Default)]
pub struct PartialResponse {
text: String,
thinking: String,
has_thinking: bool,
}
impl PartialResponse {
pub fn new() -> Self { Self::default() }
pub fn push_text(&mut self, delta: &str) { self.text.push_str(delta); }
pub fn push_thinking(&mut self, delta: &str) { self.has_thinking = true; self.thinking.push_str(delta); }
pub fn take_text(&mut self) -> String { std::mem::take(&mut self.text) }
pub fn text(&self) -> &str { &self.text }
pub fn thinking(&self) -> &str { &self.thinking }
pub fn has_thinking(&self) -> bool { self.has_thinking }
pub fn is_empty(&self) -> bool { self.text.is_empty() && self.thinking.is_empty() }
pub fn clear(&mut self) { self.text.clear(); self.thinking.clear(); self.has_thinking = false; }
}
#[derive(Debug, Clone)]
pub struct FallbackChain {
pub models: Vec<String>,
}
impl Default for FallbackChain {
fn default() -> Self { Self { models: vec!["openai/gpt-4o-mini".to_string()] } }
}
impl FallbackChain {
pub fn new(models: Vec<String>) -> Self { Self { models } }
pub fn get(&self, index: usize) -> Option<&str> { self.models.get(index).map(|s| s.as_str()) }
pub fn is_empty(&self) -> bool { self.models.is_empty() }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn circuit_breaker_allows_when_closed() {
let cb = CircuitBreaker::new(CircuitBreakerConfig::default());
assert!(cb.allow_request().is_ok());
}
#[test]
fn circuit_breaker_opens_after_threshold() {
let config = CircuitBreakerConfig { failure_threshold: 3, ..Default::default() };
let cb = CircuitBreaker::new(config);
cb.record_failure();
cb.record_failure();
assert!(cb.allow_request().is_ok());
cb.record_failure();
assert!(cb.allow_request().is_err());
}
#[test]
fn circuit_breaker_resets() {
let config = CircuitBreakerConfig { failure_threshold: 1, ..Default::default() };
let cb = CircuitBreaker::new(config);
cb.record_failure();
assert!(cb.allow_request().is_err());
cb.reset();
assert!(cb.allow_request().is_ok());
}
#[test]
fn partial_response() {
let mut pr = PartialResponse::new();
pr.push_text("Hello ");
pr.push_text("world");
assert_eq!(pr.text(), "Hello world");
assert!(!pr.take_text().is_empty());
assert!(pr.text().is_empty());
}
}