use std::sync::Mutex;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::{Duration, Instant};
use serde::{Deserialize, Serialize};
use tracing::{debug, info, warn};
use crate::models::HttpError;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub reset_timeout_secs: u64,
pub probe_interval_secs: u64,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
reset_timeout_secs: 30,
probe_interval_secs: 30,
}
}
}
pub struct CircuitBreaker {
failure_count: AtomicU32,
threshold: u32,
last_failure: Mutex<Option<Instant>>,
cooldown: Duration,
provider: String,
}
impl CircuitBreaker {
pub fn new(provider: impl Into<String>, threshold: u32, cooldown: Duration) -> Self {
Self {
failure_count: AtomicU32::new(0),
threshold,
last_failure: Mutex::new(None),
cooldown,
provider: provider.into(),
}
}
pub fn with_defaults(provider: impl Into<String>) -> Self {
Self::new(provider, 5, Duration::from_secs(30))
}
pub fn from_config(provider: impl Into<String>, config: &CircuitBreakerConfig) -> Self {
Self::new(
provider,
config.failure_threshold,
Duration::from_secs(config.reset_timeout_secs),
)
}
pub fn state(&self) -> CircuitState {
let failures = self.failure_count.load(Ordering::Relaxed);
if failures < self.threshold {
return CircuitState::Closed;
}
let lock = self.last_failure.lock().unwrap_or_else(|e| e.into_inner());
match *lock {
Some(last) if last.elapsed() >= self.cooldown => CircuitState::HalfOpen,
_ => CircuitState::Open,
}
}
pub fn check(&self) -> Result<(), HttpError> {
match self.state() {
CircuitState::Closed => Ok(()),
CircuitState::HalfOpen => {
debug!(
provider = %self.provider,
"Circuit half-open, allowing probe request"
);
Ok(())
}
CircuitState::Open => {
let remaining = {
let lock = self.last_failure.lock().unwrap_or_else(|e| e.into_inner());
lock.map(|last| self.cooldown.saturating_sub(last.elapsed()))
.unwrap_or(self.cooldown)
};
warn!(
provider = %self.provider,
remaining_secs = remaining.as_secs(),
"Circuit open, rejecting request"
);
Err(HttpError::Other(format!(
"Circuit breaker open for provider '{}'. \
Too many consecutive failures ({}). \
Will retry in {}s.",
self.provider,
self.failure_count.load(Ordering::Relaxed),
remaining.as_secs(),
)))
}
}
}
pub fn record_success(&self) {
let prev = self.failure_count.swap(0, Ordering::Relaxed);
if prev >= self.threshold {
info!(
provider = %self.provider,
"Circuit breaker closed after successful probe"
);
}
}
pub fn record_failure(&self) {
let new_count = self.failure_count.fetch_add(1, Ordering::Relaxed) + 1;
{
let mut lock = self.last_failure.lock().unwrap_or_else(|e| e.into_inner());
*lock = Some(Instant::now());
}
if new_count == self.threshold {
warn!(
provider = %self.provider,
threshold = self.threshold,
cooldown_secs = self.cooldown.as_secs(),
"Circuit breaker opened after {} consecutive failures",
self.threshold
);
}
}
pub fn failure_count(&self) -> u32 {
self.failure_count.load(Ordering::Relaxed)
}
pub fn reset(&self) {
self.failure_count.store(0, Ordering::Relaxed);
let mut lock = self.last_failure.lock().unwrap_or_else(|e| e.into_inner());
*lock = None;
}
}
impl std::fmt::Debug for CircuitBreaker {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CircuitBreaker")
.field("provider", &self.provider)
.field("state", &self.state())
.field("failure_count", &self.failure_count.load(Ordering::Relaxed))
.field("threshold", &self.threshold)
.field("cooldown", &self.cooldown)
.finish()
}
}
#[cfg(test)]
#[path = "circuit_breaker_tests.rs"]
mod tests;