use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BreakerState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug)]
pub struct CircuitBreaker {
failure_threshold: u32,
reset_timeout: Duration,
state: BreakerState,
failures: u32,
opened_at: Option<Instant>,
}
impl CircuitBreaker {
pub fn new(failure_threshold: u32, reset_timeout: Duration) -> Self {
Self {
failure_threshold: failure_threshold.max(1),
reset_timeout,
state: BreakerState::Closed,
failures: 0,
opened_at: None,
}
}
pub fn state(&mut self) -> BreakerState {
if self.state == BreakerState::Open
&& self
.opened_at
.is_some_and(|opened_at| opened_at.elapsed() >= self.reset_timeout)
{
self.state = BreakerState::HalfOpen;
}
self.state
}
pub fn allow(&mut self) -> bool {
!matches!(self.state(), BreakerState::Open)
}
pub fn record_success(&mut self) {
self.failures = 0;
self.opened_at = None;
self.state = BreakerState::Closed;
}
pub fn record_failure(&mut self) {
self.failures = self.failures.saturating_add(1);
if self.failures >= self.failure_threshold {
self.state = BreakerState::Open;
self.opened_at = Some(Instant::now());
}
}
}
#[cfg(test)]
mod tests {
use super::{BreakerState, CircuitBreaker};
use std::time::Duration;
#[test]
fn opens_after_threshold() {
let mut breaker = CircuitBreaker::new(2, Duration::from_secs(1));
assert!(breaker.allow());
breaker.record_failure();
assert_eq!(breaker.state(), BreakerState::Closed);
breaker.record_failure();
assert_eq!(breaker.state(), BreakerState::Open);
assert!(!breaker.allow());
}
#[test]
fn half_opens_after_reset_timeout() {
let mut breaker = CircuitBreaker::new(1, Duration::from_millis(0));
breaker.record_failure();
assert_eq!(breaker.state(), BreakerState::HalfOpen);
breaker.record_success();
assert_eq!(breaker.state(), BreakerState::Closed);
}
}