use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
pub struct CircuitBreaker {
failures: AtomicU64,
successes: AtomicU64,
threshold: u64,
timeout: Duration,
last_failure: AtomicU64, is_open: AtomicBool,
}
impl CircuitBreaker {
pub fn new(threshold: u64, timeout: Duration) -> Self {
Self {
failures: AtomicU64::new(0),
successes: AtomicU64::new(0),
threshold,
timeout,
last_failure: AtomicU64::new(0),
is_open: AtomicBool::new(false),
}
}
pub fn is_open(&self) -> bool {
let current_state = self.is_open.load(Ordering::Relaxed);
if !current_state {
return false;
}
let last_failure = self.last_failure.load(Ordering::Relaxed);
let current_time = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let elapsed = current_time.saturating_sub(last_failure);
if elapsed >= self.timeout.as_secs() {
self.try_reset();
false } else {
true }
}
pub fn record_success(&self) {
self.successes.fetch_add(1, Ordering::Relaxed);
self.failures.store(0, Ordering::Relaxed);
self.is_open.store(false, Ordering::Relaxed);
}
pub fn record_failure(&self) {
let failures = self.failures.fetch_add(1, Ordering::Relaxed) + 1;
let current_time = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
self.last_failure.store(current_time, Ordering::Relaxed);
if failures >= self.threshold {
self.is_open.store(true, Ordering::Relaxed);
}
}
fn try_reset(&self) {
self.is_open.store(false, Ordering::Relaxed);
self.failures.store(0, Ordering::Relaxed);
}
pub fn reset(&self) {
self.failures.store(0, Ordering::Relaxed);
self.successes.store(0, Ordering::Relaxed);
self.is_open.store(false, Ordering::Relaxed);
self.last_failure.store(0, Ordering::Relaxed);
}
pub fn get_state(&self) -> CircuitState {
if self.is_open() {
CircuitState::Open
} else if self.failures.load(Ordering::Relaxed) > 0 {
CircuitState::HalfOpen
} else {
CircuitState::Closed
}
}
pub fn get_stats(&self) -> CircuitStats {
CircuitStats {
failures: self.failures.load(Ordering::Relaxed),
successes: self.successes.load(Ordering::Relaxed),
is_open: self.is_open(),
state: self.get_state(),
}
}
}
#[derive(PartialEq, Debug)]
pub enum CircuitState {
Closed,
HalfOpen,
Open, }
pub struct CircuitStats {
pub failures: u64,
pub successes: u64,
pub is_open: bool,
pub state: CircuitState,
}
pub struct CircuitBreakerAIProvider {
inner: Box<dyn crate::common::traits::AIProvider>,
circuit_breaker: std::sync::Arc<CircuitBreaker>,
}
impl CircuitBreakerAIProvider {
pub fn new(
inner: Box<dyn crate::common::traits::AIProvider>,
threshold: u64,
timeout: Duration,
) -> Self {
let circuit_breaker = std::sync::Arc::new(CircuitBreaker::new(threshold, timeout));
Self {
inner,
circuit_breaker,
}
}
pub async fn complete_with_circuit_breaker(
&self,
prompt: &str,
) -> Result<String, crate::errors::LocalModelError> {
if self.circuit_breaker.is_open() {
return Err(crate::errors::LocalModelError::ConfigurationError {
message: "Circuit breaker is open - too many failures".to_string(),
});
}
match self.inner.complete(prompt).await {
Ok(response) => {
self.circuit_breaker.record_success();
Ok(response)
}
Err(e) => {
self.circuit_breaker.record_failure();
Err(e)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio;
#[tokio::test]
async fn test_circuit_breaker() {
let circuit_breaker = CircuitBreaker::new(3, Duration::from_millis(100));
assert!(!circuit_breaker.is_open());
assert_eq!(circuit_breaker.get_state(), CircuitState::Closed);
for _ in 0..3 {
circuit_breaker.record_failure();
}
assert!(circuit_breaker.is_open());
assert_eq!(circuit_breaker.get_state(), CircuitState::Open);
tokio::time::sleep(Duration::from_millis(101)).await;
assert!(!circuit_breaker.is_open());
circuit_breaker.record_failure();
circuit_breaker.record_failure();
assert_eq!(circuit_breaker.get_state(), CircuitState::HalfOpen);
circuit_breaker.record_success();
assert_eq!(circuit_breaker.get_state(), CircuitState::Closed);
}
}