use std::sync::Arc;
use std::sync::Mutex;
use std::time::Duration;
use tokio::time::Instant;
use super::cache::{Clock, TokioClock};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub failure_window: Duration,
pub success_threshold: u32,
pub reset_timeout: Duration,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
failure_window: Duration::from_secs(60),
success_threshold: 2,
reset_timeout: Duration::from_secs(30),
}
}
}
pub struct CircuitBreaker {
inner: Mutex<CircuitInner>,
config: CircuitBreakerConfig,
clock: Arc<dyn Clock>,
}
#[derive(Debug)]
struct CircuitInner {
state: CircuitState,
failures: Vec<Instant>,
half_open_successes: u32,
opened_at: Option<Instant>,
}
impl CircuitBreaker {
pub fn new(config: CircuitBreakerConfig) -> Self {
Self::with_clock(config, Arc::new(TokioClock))
}
pub fn with_clock(config: CircuitBreakerConfig, clock: Arc<dyn Clock>) -> Self {
Self {
inner: Mutex::new(CircuitInner {
state: CircuitState::Closed,
failures: Vec::new(),
half_open_successes: 0,
opened_at: None,
}),
config,
clock,
}
}
pub fn config(&self) -> &CircuitBreakerConfig {
&self.config
}
pub fn current_state(&self) -> CircuitState {
let now = self.clock.now();
let Ok(mut inner) = self.inner.lock() else {
return CircuitState::Open;
};
self.tick(&mut inner, now);
inner.state
}
pub fn allow_call(&self) -> bool {
let now = self.clock.now();
let Ok(mut inner) = self.inner.lock() else {
return false;
};
self.tick(&mut inner, now);
!matches!(inner.state, CircuitState::Open)
}
pub fn record_success(&self) {
let now = self.clock.now();
let Ok(mut inner) = self.inner.lock() else {
return;
};
self.tick(&mut inner, now);
match inner.state {
CircuitState::Closed => {
inner.failures.clear();
}
CircuitState::HalfOpen => {
inner.half_open_successes = inner.half_open_successes.saturating_add(1);
if inner.half_open_successes >= self.config.success_threshold {
inner.state = CircuitState::Closed;
inner.failures.clear();
inner.half_open_successes = 0;
inner.opened_at = None;
}
}
CircuitState::Open => {
}
}
}
pub fn record_failure(&self) {
let now = self.clock.now();
let Ok(mut inner) = self.inner.lock() else {
return;
};
self.tick(&mut inner, now);
match inner.state {
CircuitState::Closed => {
inner.failures.push(now);
self.drop_stale_failures(&mut inner, now);
if inner.failures.len() as u32 >= self.config.failure_threshold {
inner.state = CircuitState::Open;
inner.opened_at = Some(now);
inner.failures.clear();
}
}
CircuitState::HalfOpen => {
inner.state = CircuitState::Open;
inner.opened_at = Some(now);
inner.half_open_successes = 0;
}
CircuitState::Open => {
inner.opened_at = Some(now);
}
}
}
pub fn reset(&self) {
if let Ok(mut inner) = self.inner.lock() {
inner.state = CircuitState::Closed;
inner.failures.clear();
inner.half_open_successes = 0;
inner.opened_at = None;
}
}
fn tick(&self, inner: &mut CircuitInner, now: Instant) {
match inner.state {
CircuitState::Open => {
if let Some(opened) = inner.opened_at {
if now.duration_since(opened) >= self.config.reset_timeout {
inner.state = CircuitState::HalfOpen;
inner.half_open_successes = 0;
}
}
}
CircuitState::Closed => {
self.drop_stale_failures(inner, now);
}
CircuitState::HalfOpen => {}
}
}
fn drop_stale_failures(&self, inner: &mut CircuitInner, now: Instant) {
let window = self.config.failure_window;
inner.failures.retain(|ts| now.duration_since(*ts) < window);
}
}
#[cfg(test)]
mod tests {
use super::*;
fn config(failure_threshold: u32) -> CircuitBreakerConfig {
CircuitBreakerConfig {
failure_threshold,
failure_window: Duration::from_secs(60),
success_threshold: 2,
reset_timeout: Duration::from_secs(10),
}
}
#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn starts_closed() {
let cb = CircuitBreaker::new(config(5));
assert_eq!(cb.current_state(), CircuitState::Closed);
assert!(cb.allow_call());
}
#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn opens_after_threshold_failures() {
let cb = CircuitBreaker::new(config(3));
cb.record_failure();
cb.record_failure();
assert_eq!(cb.current_state(), CircuitState::Closed);
cb.record_failure();
assert_eq!(cb.current_state(), CircuitState::Open);
assert!(!cb.allow_call());
}
#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn transitions_to_half_open_after_reset_timeout() {
let cb = CircuitBreaker::new(config(2));
cb.record_failure();
cb.record_failure();
assert_eq!(cb.current_state(), CircuitState::Open);
tokio::time::advance(Duration::from_secs(11)).await;
assert_eq!(cb.current_state(), CircuitState::HalfOpen);
assert!(cb.allow_call());
}
#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn half_open_closes_after_success_threshold() {
let cb = CircuitBreaker::new(config(2));
cb.record_failure();
cb.record_failure();
tokio::time::advance(Duration::from_secs(11)).await;
assert_eq!(cb.current_state(), CircuitState::HalfOpen);
cb.record_success();
cb.record_success();
assert_eq!(cb.current_state(), CircuitState::Closed);
}
#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn half_open_failure_reopens() {
let cb = CircuitBreaker::new(config(2));
cb.record_failure();
cb.record_failure();
tokio::time::advance(Duration::from_secs(11)).await;
assert_eq!(cb.current_state(), CircuitState::HalfOpen);
cb.record_failure();
assert_eq!(cb.current_state(), CircuitState::Open);
}
#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn stale_failures_are_forgotten() {
let cb = CircuitBreaker::new(CircuitBreakerConfig {
failure_threshold: 3,
failure_window: Duration::from_secs(5),
success_threshold: 1,
reset_timeout: Duration::from_secs(10),
});
cb.record_failure();
cb.record_failure();
tokio::time::advance(Duration::from_secs(6)).await;
cb.record_failure();
cb.record_failure();
assert_eq!(cb.current_state(), CircuitState::Closed);
}
}