use crate::recovery::RecoveryResult;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
#[derive(Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: usize,
pub failure_window_ms: u64,
pub reset_timeout_ms: u64,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
failure_window_ms: 60000, reset_timeout_ms: 30000, }
}
}
struct CircuitBreakerInner {
config: CircuitBreakerConfig,
state: CircuitState,
failures: Vec<Instant>,
last_state_change: Instant,
}
pub struct CircuitBreaker {
name: String,
inner: Arc<Mutex<CircuitBreakerInner>>,
}
impl CircuitBreaker {
pub fn new(name: impl Into<String>) -> Self {
Self::with_config(name, CircuitBreakerConfig::default())
}
pub fn with_config(name: impl Into<String>, config: CircuitBreakerConfig) -> Self {
Self {
name: name.into(),
inner: Arc::new(Mutex::new(CircuitBreakerInner {
config,
state: CircuitState::Closed,
failures: Vec::new(),
last_state_change: Instant::now(),
})),
}
}
pub fn state(&self) -> CircuitState {
let inner = self.inner.lock().unwrap();
inner.state
}
pub fn name(&self) -> &str {
&self.name
}
pub fn execute<F, T, E>(&self, f: F) -> RecoveryResult<T>
where
F: FnOnce() -> Result<T, E>,
E: std::error::Error + Send + Sync + 'static,
{
let can_proceed = {
let mut inner = self.inner.lock().unwrap();
self.update_state(&mut inner);
inner.state != CircuitState::Open
};
if !can_proceed {
return Err(Box::new(CircuitOpenError::new(&self.name)));
}
match f() {
Ok(value) => {
self.on_success();
Ok(value)
}
Err(err) => {
self.on_failure();
Err(Box::new(err))
}
}
}
pub fn reset(&self) {
let mut inner = self.inner.lock().unwrap();
inner.state = CircuitState::Closed;
inner.failures.clear();
inner.last_state_change = Instant::now();
}
fn on_success(&self) {
let mut inner = self.inner.lock().unwrap();
if inner.state == CircuitState::HalfOpen {
inner.state = CircuitState::Closed;
inner.failures.clear();
inner.last_state_change = Instant::now();
}
}
fn on_failure(&self) {
let mut inner = self.inner.lock().unwrap();
if inner.state == CircuitState::HalfOpen {
inner.state = CircuitState::Open;
inner.last_state_change = Instant::now();
return;
}
let now = Instant::now();
inner.failures.push(now);
let window_start = now - Duration::from_millis(inner.config.failure_window_ms);
inner.failures.retain(|&time| time >= window_start);
if inner.state == CircuitState::Closed
&& inner.failures.len() >= inner.config.failure_threshold
{
inner.state = CircuitState::Open;
inner.last_state_change = now;
}
}
fn update_state(&self, inner: &mut CircuitBreakerInner) {
if inner.state == CircuitState::Open {
let now = Instant::now();
let elapsed = now.duration_since(inner.last_state_change);
if elapsed >= Duration::from_millis(inner.config.reset_timeout_ms) {
inner.state = CircuitState::HalfOpen;
inner.last_state_change = now;
}
}
}
}
#[derive(Debug)]
pub struct CircuitOpenError {
circuit_name: String,
}
impl CircuitOpenError {
fn new(circuit_name: &str) -> Self {
Self {
circuit_name: circuit_name.to_string(),
}
}
}
impl std::fmt::Display for CircuitOpenError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Circuit '{}' is open, failing fast", self.circuit_name)
}
}
impl std::error::Error for CircuitOpenError {}