use dashmap::DashMap;
use std::collections::HashMap;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum BreakerState {
Closed,
Open,
HalfOpen,
}
struct CircuitBreaker {
state: BreakerState,
target_error_count: usize,
opened_at: Option<Instant>,
probe_interval: Duration,
threshold: usize,
reason: String,
}
impl CircuitBreaker {
fn new(threshold: usize, probe_interval: Duration) -> Self {
Self {
state: BreakerState::Closed,
target_error_count: 0,
opened_at: None,
probe_interval,
threshold,
reason: String::new(),
}
}
}
pub struct CircuitBreakerManager {
breakers: DashMap<String, CircuitBreaker>,
threshold: usize,
probe_interval: Duration,
}
impl CircuitBreakerManager {
pub fn new(threshold: usize, probe_interval: Duration) -> Self {
Self {
breakers: DashMap::new(),
threshold,
probe_interval,
}
}
#[allow(dead_code)]
pub fn is_open(&self, host: &str) -> bool {
match self.breakers.get(host) {
Some(breaker) => matches!(breaker.state, BreakerState::Open | BreakerState::HalfOpen),
None => false,
}
}
pub fn should_probe(&self, host: &str) -> bool {
let mut breaker = match self.breakers.get_mut(host) {
Some(b) => b,
None => return false,
};
if breaker.state != BreakerState::Open {
return false;
}
if let Some(opened_at) = breaker.opened_at {
if opened_at.elapsed() >= breaker.probe_interval {
breaker.state = BreakerState::HalfOpen;
return true;
}
}
false
}
pub fn trip(&self, host: &str, reason: &str) {
let mut breaker = self
.breakers
.entry(host.to_string())
.or_insert_with(|| CircuitBreaker::new(self.threshold, self.probe_interval));
breaker.state = BreakerState::Open;
breaker.opened_at = Some(Instant::now());
breaker.reason = reason.to_string();
}
pub fn record_target_error(&self, host: &str) {
let mut breaker = self
.breakers
.entry(host.to_string())
.or_insert_with(|| CircuitBreaker::new(self.threshold, self.probe_interval));
breaker.target_error_count += 1;
if breaker.target_error_count >= breaker.threshold {
breaker.state = BreakerState::Open;
breaker.opened_at = Some(Instant::now());
breaker.reason = format!(
"target_error_count ({}) reached threshold ({})",
breaker.target_error_count, breaker.threshold
);
}
}
pub fn record_success(&self, host: &str) {
if let Some(mut breaker) = self.breakers.get_mut(host) {
if breaker.state == BreakerState::HalfOpen {
breaker.state = BreakerState::Closed;
}
breaker.target_error_count = 0;
}
}
pub fn get_all(&self) -> HashMap<String, bool> {
self.breakers
.iter()
.map(|entry| {
let is_open = entry.value().state != BreakerState::Closed;
(entry.key().clone(), is_open)
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
const HOST_A: &str = "yunhq.sse.com.cn";
const HOST_B: &str = "www.szse.cn";
fn manager() -> CircuitBreakerManager {
CircuitBreakerManager::new(3, Duration::from_millis(50))
}
#[test]
fn new_manager_has_no_breakers() {
let mgr = manager();
assert!(mgr.get_all().is_empty());
}
#[test]
fn unknown_host_is_not_open() {
let mgr = manager();
assert!(!mgr.is_open("unknown.host"));
}
#[test]
fn unknown_host_should_not_probe() {
let mgr = manager();
assert!(!mgr.should_probe("unknown.host"));
}
#[test]
fn not_open_before_threshold() {
let mgr = manager();
mgr.record_target_error(HOST_A);
mgr.record_target_error(HOST_A);
assert!(!mgr.is_open(HOST_A));
}
#[test]
fn should_not_probe_when_closed() {
let mgr = manager();
mgr.record_target_error(HOST_A);
assert!(!mgr.should_probe(HOST_A));
}
#[test]
fn trips_at_threshold() {
let mgr = manager(); mgr.record_target_error(HOST_A);
mgr.record_target_error(HOST_A);
mgr.record_target_error(HOST_A);
assert!(mgr.is_open(HOST_A));
}
#[test]
fn trips_above_threshold() {
let mgr = manager();
for _ in 0..5 {
mgr.record_target_error(HOST_A);
}
assert!(mgr.is_open(HOST_A));
}
#[test]
fn manual_trip_opens_breaker() {
let mgr = manager();
mgr.trip(HOST_A, "manual intervention");
assert!(mgr.is_open(HOST_A));
}
#[test]
fn manual_trip_stores_reason() {
let mgr = manager();
mgr.trip(HOST_A, "test reason");
let breaker = mgr.breakers.get(HOST_A).unwrap();
assert_eq!(breaker.reason, "test reason");
}
#[test]
fn should_not_probe_before_interval() {
let mgr = CircuitBreakerManager::new(1, Duration::from_secs(60));
mgr.record_target_error(HOST_A);
assert!(mgr.is_open(HOST_A));
assert!(!mgr.should_probe(HOST_A));
}
#[test]
fn should_probe_after_interval() {
let mgr = CircuitBreakerManager::new(1, Duration::from_millis(20));
mgr.record_target_error(HOST_A);
assert!(mgr.is_open(HOST_A));
thread::sleep(Duration::from_millis(30));
assert!(mgr.should_probe(HOST_A));
}
#[test]
fn should_probe_transitions_to_half_open() {
let mgr = CircuitBreakerManager::new(1, Duration::from_millis(20));
mgr.record_target_error(HOST_A);
thread::sleep(Duration::from_millis(30));
assert!(mgr.should_probe(HOST_A));
let breaker = mgr.breakers.get(HOST_A).unwrap();
assert_eq!(breaker.state, BreakerState::HalfOpen);
}
#[test]
fn half_open_is_still_considered_open() {
let mgr = CircuitBreakerManager::new(1, Duration::from_millis(20));
mgr.record_target_error(HOST_A);
thread::sleep(Duration::from_millis(30));
mgr.should_probe(HOST_A); assert!(mgr.is_open(HOST_A));
}
#[test]
fn should_probe_returns_false_when_half_open() {
let mgr = CircuitBreakerManager::new(1, Duration::from_millis(20));
mgr.record_target_error(HOST_A);
thread::sleep(Duration::from_millis(30));
assert!(mgr.should_probe(HOST_A)); assert!(!mgr.should_probe(HOST_A)); }
#[test]
fn success_in_half_open_closes_breaker() {
let mgr = CircuitBreakerManager::new(1, Duration::from_millis(20));
mgr.record_target_error(HOST_A);
thread::sleep(Duration::from_millis(30));
mgr.should_probe(HOST_A);
mgr.record_success(HOST_A);
assert!(!mgr.is_open(HOST_A));
let breaker = mgr.breakers.get(HOST_A).unwrap();
assert_eq!(breaker.state, BreakerState::Closed);
assert_eq!(breaker.target_error_count, 0);
}
#[test]
fn error_after_half_open_reopens_breaker() {
let mgr = CircuitBreakerManager::new(1, Duration::from_millis(20));
mgr.record_target_error(HOST_A);
thread::sleep(Duration::from_millis(30));
mgr.should_probe(HOST_A);
mgr.record_target_error(HOST_A);
assert!(mgr.is_open(HOST_A));
let breaker = mgr.breakers.get(HOST_A).unwrap();
assert_eq!(breaker.state, BreakerState::Open);
}
#[test]
fn success_resets_error_count_in_closed_state() {
let mgr = manager(); mgr.record_target_error(HOST_A);
mgr.record_target_error(HOST_A);
mgr.record_success(HOST_A);
mgr.record_target_error(HOST_A);
mgr.record_target_error(HOST_A);
assert!(!mgr.is_open(HOST_A));
mgr.record_target_error(HOST_A);
assert!(mgr.is_open(HOST_A));
}
#[test]
fn success_for_unknown_host_is_noop() {
let mgr = manager();
mgr.record_success("nonexistent.host"); assert!(mgr.get_all().is_empty());
}
#[test]
fn hosts_are_independent() {
let mgr = manager();
mgr.trip(HOST_A, "down");
assert!(mgr.is_open(HOST_A));
assert!(!mgr.is_open(HOST_B));
}
#[test]
fn errors_on_one_host_dont_affect_another() {
let mgr = manager();
for _ in 0..3 {
mgr.record_target_error(HOST_A);
}
assert!(mgr.is_open(HOST_A));
assert!(!mgr.is_open(HOST_B));
}
#[test]
fn get_all_reflects_current_state() {
let mgr = manager();
mgr.trip(HOST_A, "test");
mgr.record_target_error(HOST_B);
let all = mgr.get_all();
assert_eq!(all.len(), 2);
assert_eq!(all.get(HOST_A), Some(&true));
assert_eq!(all.get(HOST_B), Some(&false));
}
#[test]
fn get_all_empty_when_no_activity() {
let mgr = manager();
assert!(mgr.get_all().is_empty());
}
#[test]
fn get_all_after_recovery() {
let mgr = CircuitBreakerManager::new(1, Duration::from_millis(20));
mgr.record_target_error(HOST_A);
thread::sleep(Duration::from_millis(30));
mgr.should_probe(HOST_A); mgr.record_success(HOST_A);
let all = mgr.get_all();
assert_eq!(all.get(HOST_A), Some(&false));
}
#[test]
fn threshold_one_trips_on_first_error() {
let mgr = CircuitBreakerManager::new(1, Duration::from_millis(50));
mgr.record_target_error(HOST_A);
assert!(mgr.is_open(HOST_A));
}
#[test]
fn large_threshold_requires_many_errors() {
let mgr = CircuitBreakerManager::new(100, Duration::from_millis(50));
for _ in 0..99 {
mgr.record_target_error(HOST_A);
}
assert!(!mgr.is_open(HOST_A));
mgr.record_target_error(HOST_A);
assert!(mgr.is_open(HOST_A));
}
#[test]
fn full_lifecycle_closed_open_halfopen_closed() {
let mgr = CircuitBreakerManager::new(2, Duration::from_millis(20));
assert!(!mgr.is_open(HOST_A));
mgr.record_target_error(HOST_A);
mgr.record_target_error(HOST_A);
assert!(mgr.is_open(HOST_A));
thread::sleep(Duration::from_millis(30));
assert!(mgr.should_probe(HOST_A));
assert!(mgr.is_open(HOST_A));
mgr.record_success(HOST_A);
assert!(!mgr.is_open(HOST_A));
mgr.record_target_error(HOST_A);
mgr.record_target_error(HOST_A);
assert!(mgr.is_open(HOST_A));
}
#[test]
fn full_lifecycle_with_failed_probe() {
let mgr = CircuitBreakerManager::new(1, Duration::from_millis(20));
mgr.record_target_error(HOST_A);
assert!(mgr.is_open(HOST_A));
thread::sleep(Duration::from_millis(30));
assert!(mgr.should_probe(HOST_A));
mgr.record_target_error(HOST_A);
assert!(mgr.is_open(HOST_A));
let breaker = mgr.breakers.get(HOST_A).unwrap();
assert_eq!(breaker.state, BreakerState::Open);
assert!(breaker.opened_at.unwrap().elapsed() < Duration::from_millis(100));
}
#[test]
fn breaker_state_debug() {
assert_eq!(format!("{:?}", BreakerState::Closed), "Closed");
assert_eq!(format!("{:?}", BreakerState::Open), "Open");
assert_eq!(format!("{:?}", BreakerState::HalfOpen), "HalfOpen");
}
#[test]
fn breaker_state_clone_and_eq() {
let s = BreakerState::Open;
let s2 = s;
let s3 = s;
assert_eq!(s, s2);
assert_eq!(s2, s3);
assert_ne!(BreakerState::Closed, BreakerState::Open);
}
#[test]
fn trip_overwrites_closed_breaker() {
let mgr = manager();
mgr.record_target_error(HOST_A); assert!(!mgr.is_open(HOST_A));
mgr.trip(HOST_A, "forced");
assert!(mgr.is_open(HOST_A));
}
#[test]
fn trip_resets_timer_on_already_open_breaker() {
let mgr = CircuitBreakerManager::new(1, Duration::from_millis(100));
mgr.record_target_error(HOST_A); let first_opened = mgr.breakers.get(HOST_A).unwrap().opened_at.unwrap();
thread::sleep(Duration::from_millis(10));
mgr.trip(HOST_A, "re-trip");
let second_opened = mgr.breakers.get(HOST_A).unwrap().opened_at.unwrap();
assert!(second_opened > first_opened);
}
#[test]
fn many_hosts_tracked_independently() {
let mgr = CircuitBreakerManager::new(2, Duration::from_millis(50));
for i in 0..20 {
let host = format!("host-{i}.example.com");
mgr.record_target_error(&host);
mgr.record_target_error(&host);
assert!(mgr.is_open(&host));
}
let all = mgr.get_all();
assert_eq!(all.len(), 20);
assert!(all.values().all(|&open| open));
}
}