use std::collections::HashMap;
use std::net::SocketAddr;
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct CircuitBreaker {
state: BreakerState,
probe_started_at: Option<Instant>,
failures: Vec<Instant>,
cooldown_until: Option<Instant>,
current_cooldown: Duration,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BreakerState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug, Clone, Copy)]
pub struct BreakerConfig {
pub window: Duration,
pub failure_threshold: usize,
pub initial_cooldown: Duration,
pub max_cooldown: Duration,
pub probe_timeout: Duration,
}
impl Default for BreakerConfig {
fn default() -> Self {
Self {
window: Duration::from_secs(60),
failure_threshold: 5,
initial_cooldown: Duration::from_secs(60),
max_cooldown: Duration::from_secs(600),
probe_timeout: Duration::from_secs(30),
}
}
}
impl CircuitBreaker {
fn new(initial_cooldown: Duration) -> Self {
Self {
state: BreakerState::Closed,
failures: Vec::new(),
cooldown_until: None,
current_cooldown: initial_cooldown,
probe_started_at: None,
}
}
#[allow(dead_code)]
pub fn state(&self) -> BreakerState {
self.state
}
}
#[derive(Debug, Default)]
pub struct CircuitBreakerRegistry {
config: BreakerConfig,
breakers: HashMap<SocketAddr, CircuitBreaker>,
}
impl CircuitBreakerRegistry {
pub fn new() -> Self {
Self::with_config(BreakerConfig::default())
}
pub fn with_config(config: BreakerConfig) -> Self {
Self {
config,
breakers: HashMap::new(),
}
}
const MAX_BREAKERS: usize = 4096;
pub fn allow(&mut self, server: SocketAddr) -> bool {
if self.breakers.len() >= Self::MAX_BREAKERS && !self.breakers.contains_key(&server) {
self.evict_idle_closed();
}
let now = Instant::now();
let breaker = self
.breakers
.entry(server)
.or_insert_with(|| CircuitBreaker::new(self.config.initial_cooldown));
match breaker.state {
BreakerState::Closed => true,
BreakerState::Open => {
if let Some(until) = breaker.cooldown_until {
if now >= until {
breaker.state = BreakerState::HalfOpen;
breaker.cooldown_until = None;
breaker.probe_started_at = Some(now);
true
} else {
false
}
} else {
breaker.state = BreakerState::HalfOpen;
breaker.probe_started_at = Some(now);
true
}
}
BreakerState::HalfOpen => {
if let Some(started) = breaker.probe_started_at
&& now.duration_since(started) >= self.config.probe_timeout
{
breaker.probe_started_at = Some(now);
return true;
}
false
}
}
}
pub fn record_success(&mut self, server: SocketAddr) {
if let Some(breaker) = self.breakers.get_mut(&server) {
breaker.state = BreakerState::Closed;
breaker.failures.clear();
breaker.cooldown_until = None;
breaker.current_cooldown = self.config.initial_cooldown;
breaker.probe_started_at = None;
}
}
pub fn record_failure(&mut self, server: SocketAddr) {
if self.breakers.len() >= Self::MAX_BREAKERS && !self.breakers.contains_key(&server) {
self.evict_idle_closed();
}
let now = Instant::now();
let breaker = self
.breakers
.entry(server)
.or_insert_with(|| CircuitBreaker::new(self.config.initial_cooldown));
match breaker.state {
BreakerState::HalfOpen => {
breaker.current_cooldown =
(breaker.current_cooldown * 2).min(self.config.max_cooldown);
breaker.cooldown_until = Some(now + breaker.current_cooldown);
breaker.state = BreakerState::Open;
breaker.failures.clear();
}
BreakerState::Open => {
}
BreakerState::Closed => {
breaker
.failures
.retain(|t| now.saturating_duration_since(*t) <= self.config.window);
breaker.failures.push(now);
if breaker.failures.len() >= self.config.failure_threshold {
breaker.cooldown_until = Some(now + breaker.current_cooldown);
breaker.state = BreakerState::Open;
breaker.failures.clear();
}
}
}
}
#[allow(dead_code)]
pub fn states(&self) -> impl Iterator<Item = (SocketAddr, BreakerState)> + '_ {
self.breakers.iter().map(|(addr, b)| (*addr, b.state))
}
pub fn is_open(&self, server: SocketAddr) -> bool {
self.breakers
.get(&server)
.map(|b| matches!(b.state, BreakerState::Open))
.unwrap_or(false)
}
pub fn is_blocking(&self, server: SocketAddr) -> bool {
self.breakers
.get(&server)
.map(|b| {
matches!(b.state, BreakerState::Open)
&& b.cooldown_until
.map(|until| Instant::now() < until)
.unwrap_or(false)
})
.unwrap_or(false)
}
fn evict_idle_closed(&mut self) {
self.breakers.retain(|_, b| {
!(matches!(b.state, BreakerState::Closed)
&& b.failures.is_empty()
&& b.probe_started_at.is_none())
});
}
}
#[cfg(test)]
mod tests {
use super::*;
fn fast_config() -> BreakerConfig {
BreakerConfig {
window: Duration::from_secs(1),
failure_threshold: 3,
initial_cooldown: Duration::from_millis(50),
max_cooldown: Duration::from_millis(400),
probe_timeout: Duration::from_millis(500),
}
}
fn addr() -> SocketAddr {
"127.0.0.1:5064".parse().unwrap()
}
#[test]
fn closed_allows_traffic_by_default() {
let mut reg = CircuitBreakerRegistry::with_config(fast_config());
assert!(reg.allow(addr()));
}
#[test]
fn trips_after_threshold_failures() {
let mut reg = CircuitBreakerRegistry::with_config(fast_config());
for _ in 0..3 {
reg.record_failure(addr());
}
assert!(!reg.allow(addr()));
assert!(reg.is_open(addr()));
}
#[test]
fn half_open_after_cooldown() {
let mut reg = CircuitBreakerRegistry::with_config(fast_config());
for _ in 0..3 {
reg.record_failure(addr());
}
std::thread::sleep(Duration::from_millis(60));
assert!(reg.allow(addr())); assert!(!reg.allow(addr())); }
#[test]
fn success_in_half_open_returns_to_closed() {
let mut reg = CircuitBreakerRegistry::with_config(fast_config());
for _ in 0..3 {
reg.record_failure(addr());
}
std::thread::sleep(Duration::from_millis(60));
let _ = reg.allow(addr());
reg.record_success(addr());
assert!(reg.allow(addr()));
}
#[test]
fn failure_in_half_open_doubles_cooldown() {
let mut reg = CircuitBreakerRegistry::with_config(fast_config());
for _ in 0..3 {
reg.record_failure(addr());
}
assert_eq!(
reg.breakers[&addr()].current_cooldown,
Duration::from_millis(50)
);
std::thread::sleep(Duration::from_millis(60)); assert!(reg.allow(addr())); reg.record_failure(addr()); assert_eq!(
reg.breakers[&addr()].current_cooldown,
Duration::from_millis(100)
);
assert!(reg.is_open(addr()));
assert!(!reg.allow(addr())); }
#[test]
fn is_blocking_distinguishes_probe_ready_from_hard_blocked() {
let mut reg = CircuitBreakerRegistry::with_config(fast_config());
for _ in 0..3 {
reg.record_failure(addr());
}
assert!(reg.is_open(addr()));
assert!(reg.is_blocking(addr()));
std::thread::sleep(Duration::from_millis(60));
assert!(reg.is_open(addr()));
assert!(!reg.is_blocking(addr()));
}
#[test]
fn open_breaker_recovers_via_is_blocking_then_allow() {
let mut reg = CircuitBreakerRegistry::with_config(fast_config());
for _ in 0..3 {
reg.record_failure(addr());
}
std::thread::sleep(Duration::from_millis(60));
assert!(
!reg.is_blocking(addr()),
"probe-ready breaker must not block"
);
assert!(reg.allow(addr()), "allow() must admit the probe");
reg.record_success(addr());
assert!(reg.allow(addr()), "breaker recovered to CLOSED");
assert!(!reg.is_open(addr()));
}
#[test]
fn stale_half_open_probe_self_heals_after_probe_timeout() {
let mut reg = CircuitBreakerRegistry::with_config(fast_config());
for _ in 0..3 {
reg.record_failure(addr());
}
std::thread::sleep(Duration::from_millis(60)); assert!(reg.allow(addr()), "first probe admitted (now HALF_OPEN)");
assert!(
!reg.allow(addr()),
"probe in flight — further traffic denied"
);
std::thread::sleep(Duration::from_millis(550));
assert!(
reg.allow(addr()),
"stale probe treated as failed — a fresh probe is admitted"
);
}
#[test]
fn old_failures_drop_out_of_window() {
let mut reg = CircuitBreakerRegistry::with_config(BreakerConfig {
window: Duration::from_millis(100),
failure_threshold: 3,
..fast_config()
});
reg.record_failure(addr());
reg.record_failure(addr());
std::thread::sleep(Duration::from_millis(150));
reg.record_failure(addr()); assert!(reg.allow(addr()));
}
}