use std::collections::HashMap;
use std::sync::RwLock;
use std::time::{Duration, Instant};
use crate::error::{ClusterError, Result};
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub cooldown: Duration,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
cooldown: Duration::from_secs(10),
}
}
}
pub struct CircuitBreaker {
peers: RwLock<HashMap<u64, PeerBreaker>>,
config: CircuitBreakerConfig,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
struct PeerBreaker {
state: CircuitState,
consecutive_failures: u32,
last_state_change: Instant,
}
impl PeerBreaker {
fn new() -> Self {
Self {
state: CircuitState::Closed,
consecutive_failures: 0,
last_state_change: Instant::now(),
}
}
}
impl CircuitBreaker {
pub fn new(config: CircuitBreakerConfig) -> Self {
Self {
peers: RwLock::new(HashMap::new()),
config,
}
}
pub fn check(&self, peer: u64) -> Result<()> {
let mut peers = self.peers.write().unwrap_or_else(|p| p.into_inner());
let breaker = peers.entry(peer).or_insert_with(PeerBreaker::new);
match breaker.state {
CircuitState::Closed => Ok(()),
CircuitState::HalfOpen => Ok(()), CircuitState::Open => {
if breaker.last_state_change.elapsed() >= self.config.cooldown {
breaker.state = CircuitState::HalfOpen;
breaker.last_state_change = Instant::now();
Ok(())
} else {
Err(ClusterError::CircuitOpen {
node_id: peer,
failures: breaker.consecutive_failures,
})
}
}
}
}
pub fn record_success(&self, peer: u64) {
let mut peers = self.peers.write().unwrap_or_else(|p| p.into_inner());
let breaker = peers.entry(peer).or_insert_with(PeerBreaker::new);
breaker.consecutive_failures = 0;
if breaker.state != CircuitState::Closed {
breaker.state = CircuitState::Closed;
breaker.last_state_change = Instant::now();
}
}
pub fn record_failure(&self, peer: u64) {
let mut peers = self.peers.write().unwrap_or_else(|p| p.into_inner());
let breaker = peers.entry(peer).or_insert_with(PeerBreaker::new);
breaker.consecutive_failures += 1;
match breaker.state {
CircuitState::Closed => {
if breaker.consecutive_failures >= self.config.failure_threshold {
breaker.state = CircuitState::Open;
breaker.last_state_change = Instant::now();
}
}
CircuitState::HalfOpen => {
breaker.state = CircuitState::Open;
breaker.last_state_change = Instant::now();
}
CircuitState::Open => {
breaker.last_state_change = Instant::now();
}
}
}
pub fn state(&self, peer: u64) -> CircuitState {
let peers = self.peers.read().unwrap_or_else(|p| p.into_inner());
peers
.get(&peer)
.map(|b| b.state)
.unwrap_or(CircuitState::Closed)
}
pub fn failure_count(&self, peer: u64) -> u32 {
let peers = self.peers.read().unwrap_or_else(|p| p.into_inner());
peers
.get(&peer)
.map(|b| b.consecutive_failures)
.unwrap_or(0)
}
}
#[derive(Debug, Clone)]
pub struct RetryPolicy {
pub max_attempts: u32,
pub initial_delay: Duration,
pub max_delay: Duration,
pub backoff_multiplier: f64,
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_attempts: 3,
initial_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(5),
backoff_multiplier: 2.0,
}
}
}
impl RetryPolicy {
pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
if attempt == 0 {
return self.initial_delay;
}
let factor = self.backoff_multiplier.powi(attempt as i32);
let delay_ms = self.initial_delay.as_millis() as f64 * factor;
let capped = delay_ms.min(self.max_delay.as_millis() as f64);
Duration::from_millis(capped as u64)
}
pub fn is_retryable(err: &ClusterError) -> bool {
matches!(err, ClusterError::Transport { .. })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn circuit_starts_closed() {
let cb = CircuitBreaker::new(CircuitBreakerConfig::default());
assert_eq!(cb.state(42), CircuitState::Closed);
cb.check(42).unwrap(); }
#[test]
fn circuit_opens_after_threshold() {
let cb = CircuitBreaker::new(CircuitBreakerConfig {
failure_threshold: 3,
cooldown: Duration::from_secs(60),
});
cb.check(1).unwrap();
cb.record_failure(1);
cb.record_failure(1);
assert_eq!(cb.state(1), CircuitState::Closed);
cb.record_failure(1); assert_eq!(cb.state(1), CircuitState::Open);
assert_eq!(cb.failure_count(1), 3);
let err = cb.check(1).unwrap_err();
assert!(err.to_string().contains("circuit open"), "{err}");
}
#[test]
fn circuit_half_open_after_cooldown() {
let cb = CircuitBreaker::new(CircuitBreakerConfig {
failure_threshold: 1,
cooldown: Duration::from_millis(10),
});
cb.record_failure(1); assert_eq!(cb.state(1), CircuitState::Open);
std::thread::sleep(Duration::from_millis(15));
cb.check(1).unwrap();
assert_eq!(cb.state(1), CircuitState::HalfOpen);
}
#[test]
fn half_open_success_closes_circuit() {
let cb = CircuitBreaker::new(CircuitBreakerConfig {
failure_threshold: 1,
cooldown: Duration::from_millis(5),
});
cb.record_failure(1);
std::thread::sleep(Duration::from_millis(10));
cb.check(1).unwrap();
cb.record_success(1);
assert_eq!(cb.state(1), CircuitState::Closed);
assert_eq!(cb.failure_count(1), 0);
}
#[test]
fn half_open_failure_reopens_circuit() {
let cb = CircuitBreaker::new(CircuitBreakerConfig {
failure_threshold: 1,
cooldown: Duration::from_millis(5),
});
cb.record_failure(1);
std::thread::sleep(Duration::from_millis(10));
cb.check(1).unwrap();
cb.record_failure(1); assert_eq!(cb.state(1), CircuitState::Open);
}
#[test]
fn success_resets_failure_count() {
let cb = CircuitBreaker::new(CircuitBreakerConfig {
failure_threshold: 5,
cooldown: Duration::from_secs(60),
});
cb.record_failure(1);
cb.record_failure(1);
assert_eq!(cb.failure_count(1), 2);
cb.record_success(1);
assert_eq!(cb.failure_count(1), 0);
assert_eq!(cb.state(1), CircuitState::Closed);
}
#[test]
fn retry_delay_exponential_backoff() {
let policy = RetryPolicy {
max_attempts: 5,
initial_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(5),
backoff_multiplier: 2.0,
};
assert_eq!(policy.delay_for_attempt(0), Duration::from_millis(100));
assert_eq!(policy.delay_for_attempt(1), Duration::from_millis(200));
assert_eq!(policy.delay_for_attempt(2), Duration::from_millis(400));
assert_eq!(policy.delay_for_attempt(3), Duration::from_millis(800));
}
#[test]
fn retry_delay_capped_at_max() {
let policy = RetryPolicy {
max_attempts: 10,
initial_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(5),
backoff_multiplier: 10.0,
};
assert_eq!(policy.delay_for_attempt(2), Duration::from_secs(5));
}
#[test]
fn retryable_errors() {
assert!(RetryPolicy::is_retryable(&ClusterError::Transport {
detail: "timeout".into(),
}));
assert!(!RetryPolicy::is_retryable(&ClusterError::Codec {
detail: "bad crc".into(),
}));
assert!(!RetryPolicy::is_retryable(&ClusterError::CircuitOpen {
node_id: 1,
failures: 5,
}));
assert!(!RetryPolicy::is_retryable(&ClusterError::NodeUnreachable {
node_id: 1
}));
}
}