use std::collections::VecDeque;
use std::time::{Duration, Instant};
struct TokenWindow {
events: VecDeque<(Instant, bool)>,
window: Duration,
capacity: usize,
}
impl TokenWindow {
fn new(capacity: usize, window: Duration) -> Self {
Self {
events: VecDeque::new(),
window,
capacity,
}
}
fn record(&mut self, is_retry: bool) {
let now = Instant::now();
self.evict_old(now);
if self.events.len() >= self.capacity {
self.events.pop_front();
}
self.events.push_back((now, is_retry));
}
fn retry_ratio(&self) -> f64 {
let total = self.events.len();
if total == 0 {
return 0.0;
}
let retries = self.events.iter().filter(|(_, r)| *r).count();
retries as f64 / total as f64
}
fn evict_old(&mut self, now: Instant) {
while let Some(&(ts, _)) = self.events.front() {
if now.duration_since(ts) > self.window {
self.events.pop_front();
} else {
break;
}
}
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum CircuitState {
Closed,
HalfOpen,
Open,
}
pub struct CircuitBreaker {
state: CircuitState,
window: TokenWindow,
threshold_pct: u8,
open_duration: Duration,
close_successes: u32,
opened_at: Option<Instant>,
half_open_successes: u32,
}
impl CircuitBreaker {
pub fn new(
capacity: usize,
window: Duration,
threshold_pct: u8,
open_duration: Duration,
close_successes: u32,
) -> Self {
Self {
state: CircuitState::Closed,
window: TokenWindow::new(capacity, window),
threshold_pct,
open_duration,
close_successes,
opened_at: None,
half_open_successes: 0,
}
}
pub fn state(&mut self) -> CircuitState {
if self.state == CircuitState::Open
&& let Some(opened) = self.opened_at
&& Instant::now().duration_since(opened) >= self.open_duration
{
self.state = CircuitState::HalfOpen;
self.half_open_successes = 0;
}
self.state
}
pub fn record_success(&mut self) {
self.window.record(false);
if self.state == CircuitState::HalfOpen {
self.half_open_successes += 1;
if self.half_open_successes >= self.close_successes {
self.state = CircuitState::Closed;
self.opened_at = None;
}
}
}
pub fn record_retry(&mut self) {
self.window.record(true);
if self.state == CircuitState::Closed {
let ratio = self.window.retry_ratio();
let threshold = f64::from(self.threshold_pct) / 100.0;
if ratio >= threshold {
self.state = CircuitState::Open;
self.opened_at = Some(Instant::now());
}
} else if self.state == CircuitState::HalfOpen {
self.state = CircuitState::Open;
self.opened_at = Some(Instant::now());
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_breaker() -> CircuitBreaker {
CircuitBreaker::new(
256,
Duration::from_secs(60),
50, Duration::from_millis(1), 4,
)
}
#[test]
fn circuit_breaker_opens_at_50pct_retry_ratio() {
let mut cb = make_breaker();
for _ in 0..5 {
cb.record_success();
}
for _ in 0..4 {
cb.record_retry();
}
assert_eq!(cb.state(), CircuitState::Closed);
cb.record_retry();
assert_eq!(cb.state(), CircuitState::Open);
}
#[test]
fn circuit_breaker_half_opens_after_window() {
let mut cb = CircuitBreaker::new(
256,
Duration::from_secs(60),
50,
Duration::from_millis(1), 4,
);
for _ in 0..10 {
cb.record_retry();
}
assert_eq!(cb.state(), CircuitState::Open);
std::thread::sleep(Duration::from_millis(5));
assert_eq!(cb.state(), CircuitState::HalfOpen);
}
#[test]
fn circuit_breaker_closes_after_4_successes() {
let mut cb = CircuitBreaker::new(
256,
Duration::from_secs(60),
50,
Duration::from_millis(1),
4,
);
for _ in 0..10 {
cb.record_retry();
}
std::thread::sleep(Duration::from_millis(5));
assert_eq!(cb.state(), CircuitState::HalfOpen);
for _ in 0..3 {
cb.record_success();
}
assert_eq!(cb.state(), CircuitState::HalfOpen);
cb.record_success();
assert_eq!(cb.state(), CircuitState::Closed);
}
}