use std::time::Instant;
use serde::{Deserialize, Serialize};
use tracing::info;
use tracing::warn;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
pub struct SafetyCircuitBreaker {
pub state: CircuitState,
pub threshold: usize,
pub cooldown_secs: u64,
last_failure: Option<Instant>,
window_secs: u64,
failure_timestamps: Vec<Instant>,
}
impl SafetyCircuitBreaker {
#[must_use]
pub fn new(threshold: usize, window_secs: u64, cooldown_secs: u64) -> Self {
Self {
state: CircuitState::Closed,
threshold,
cooldown_secs,
last_failure: None,
window_secs,
failure_timestamps: Vec::new(),
}
}
pub fn record_violation(&mut self) {
let now = Instant::now();
self.failure_timestamps.push(now);
self.last_failure = Some(now);
let cutoff = now - std::time::Duration::from_secs(self.window_secs);
self.failure_timestamps.retain(|t| *t >= cutoff);
if self.failure_timestamps.len() >= self.threshold {
if self.state != CircuitState::Open {
warn!(
failure_count = self.failure_timestamps.len(),
threshold = self.threshold,
"Circuit breaker tripped to Open"
);
}
self.state = CircuitState::Open;
}
}
pub fn check_allowed(&mut self) -> bool {
if self.state == CircuitState::Open
&& let Some(last) = self.last_failure
&& last.elapsed() >= std::time::Duration::from_secs(self.cooldown_secs)
{
info!("Circuit breaker transitioning to HalfOpen");
self.state = CircuitState::HalfOpen;
}
match self.state {
CircuitState::Closed => true,
CircuitState::Open => false,
CircuitState::HalfOpen => {
info!("Circuit breaker test action allowed, transitioning to Closed");
self.state = CircuitState::Closed;
self.failure_timestamps.clear();
true
}
}
}
pub fn reset(&mut self) {
info!("Circuit breaker force-reset to Closed");
self.state = CircuitState::Closed;
self.failure_timestamps.clear();
self.last_failure = None;
}
}