#![deny(unsafe_code)]
use crate::types::PeerId;
use dashmap::DashMap;
use parking_lot::RwLock;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Notify;
use tokio::time::interval;
use tracing::{info, warn};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub success_threshold: u32,
pub timeout: Duration,
pub failure_rate_threshold: f64,
pub min_requests: u32,
pub window_duration: Duration,
pub half_open_max_requests: u32,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
success_threshold: 3,
timeout: Duration::from_secs(60),
failure_rate_threshold: 0.5,
min_requests: 10,
window_duration: Duration::from_secs(60),
half_open_max_requests: 1,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct CircuitBreakerStats {
pub total_requests: u64,
pub successful_requests: u64,
pub failed_requests: u64,
pub rejected_requests: u64,
pub failure_rate: f64,
pub state_changes: u64,
pub last_state_change: Option<Instant>,
pub time_in_closed: Duration,
pub time_in_open: Duration,
pub time_in_half_open: Duration,
}
#[derive(Debug)]
struct SlidingWindow {
duration: Duration,
outcomes: Vec<(Instant, bool)>,
success_count: usize,
failure_count: usize,
}
impl SlidingWindow {
fn new(duration: Duration) -> Self {
Self {
duration,
outcomes: Vec::new(),
success_count: 0,
failure_count: 0,
}
}
fn record(&mut self, success: bool) {
let now = Instant::now();
self.outcomes.push((now, success));
if success {
self.success_count += 1;
} else {
self.failure_count += 1;
}
self.cleanup();
}
fn cleanup(&mut self) {
let cutoff = Instant::now() - self.duration;
let mut i = 0;
while i < self.outcomes.len() && self.outcomes[i].0 < cutoff {
if self.outcomes[i].1 {
self.success_count -= 1;
} else {
self.failure_count -= 1;
}
i += 1;
}
self.outcomes.drain(0..i);
}
fn total_requests(&self) -> usize {
self.success_count + self.failure_count
}
fn failure_rate(&self) -> f64 {
let total = self.total_requests();
if total == 0 {
0.0
} else {
self.failure_count as f64 / total as f64
}
}
fn reset(&mut self) {
self.outcomes.clear();
self.success_count = 0;
self.failure_count = 0;
}
}
pub struct CircuitBreaker {
config: CircuitBreakerConfig,
state: Arc<RwLock<CircuitState>>,
state_changed_at: Arc<RwLock<Instant>>,
consecutive_failures: AtomicUsize,
consecutive_successes: AtomicUsize,
half_open_requests: AtomicUsize,
window: Arc<RwLock<SlidingWindow>>,
stats: Arc<RwLock<CircuitBreakerStats>>,
state_change_notify: Arc<Notify>,
}
impl CircuitBreaker {
pub fn new(config: CircuitBreakerConfig) -> Self {
let window_duration = config.window_duration;
Self {
config,
state: Arc::new(RwLock::new(CircuitState::Closed)),
state_changed_at: Arc::new(RwLock::new(Instant::now())),
consecutive_failures: AtomicUsize::new(0),
consecutive_successes: AtomicUsize::new(0),
half_open_requests: AtomicUsize::new(0),
window: Arc::new(RwLock::new(SlidingWindow::new(window_duration))),
stats: Arc::new(RwLock::new(CircuitBreakerStats::default())),
state_change_notify: Arc::new(Notify::new()),
}
}
pub fn allow_request(&self) -> bool {
let current_state = *self.state.read();
match current_state {
CircuitState::Closed => true,
CircuitState::Open => {
let elapsed = self.state_changed_at.read().elapsed();
if elapsed >= self.config.timeout {
self.transition_to_half_open();
true
} else {
self.stats.write().rejected_requests += 1;
false
}
}
CircuitState::HalfOpen => {
let current = self.half_open_requests.load(Ordering::Acquire);
if current < self.config.half_open_max_requests as usize {
self.half_open_requests.fetch_add(1, Ordering::Release);
true
} else {
self.stats.write().rejected_requests += 1;
false
}
}
}
}
pub fn record_outcome(&self, success: bool) {
{
let mut stats = self.stats.write();
stats.total_requests += 1;
if success {
stats.successful_requests += 1;
} else {
stats.failed_requests += 1;
}
}
self.window.write().record(success);
let current_state = *self.state.read();
match current_state {
CircuitState::Closed => {
if success {
self.consecutive_failures.store(0, Ordering::Release);
} else {
let failures = self.consecutive_failures.fetch_add(1, Ordering::AcqRel) + 1;
if failures >= self.config.failure_threshold as usize {
self.check_and_open_circuit();
}
}
}
CircuitState::Open => {
warn!("Outcome recorded while circuit is open");
}
CircuitState::HalfOpen => {
if success {
let successes = self.consecutive_successes.fetch_add(1, Ordering::AcqRel) + 1;
if successes >= self.config.success_threshold as usize {
self.transition_to_closed();
}
} else {
self.transition_to_open();
}
self.half_open_requests.fetch_sub(1, Ordering::Release);
}
}
}
fn check_and_open_circuit(&self) {
let window = self.window.read();
let total_requests = window.total_requests();
let failure_rate = window.failure_rate();
if total_requests >= self.config.min_requests as usize
&& failure_rate >= self.config.failure_rate_threshold
{
drop(window); self.transition_to_open();
}
}
fn transition_to_open(&self) {
let mut state = self.state.write();
let previous_state = *state;
if previous_state != CircuitState::Open {
*state = CircuitState::Open;
*self.state_changed_at.write() = Instant::now();
self.consecutive_failures.store(0, Ordering::Release);
self.consecutive_successes.store(0, Ordering::Release);
self.update_state_stats(previous_state, CircuitState::Open);
info!(
"Circuit breaker opened (failure rate: {:.2}%)",
self.window.read().failure_rate() * 100.0
);
self.state_change_notify.notify_waiters();
}
}
fn transition_to_half_open(&self) {
let mut state = self.state.write();
let previous_state = *state;
if previous_state == CircuitState::Open {
*state = CircuitState::HalfOpen;
*self.state_changed_at.write() = Instant::now();
self.consecutive_successes.store(0, Ordering::Release);
self.half_open_requests.store(0, Ordering::Release);
self.window.write().reset();
self.update_state_stats(previous_state, CircuitState::HalfOpen);
info!("Circuit breaker half-opened for testing");
self.state_change_notify.notify_waiters();
}
}
fn transition_to_closed(&self) {
let mut state = self.state.write();
let previous_state = *state;
if previous_state != CircuitState::Closed {
*state = CircuitState::Closed;
*self.state_changed_at.write() = Instant::now();
self.consecutive_failures.store(0, Ordering::Release);
self.consecutive_successes.store(0, Ordering::Release);
self.update_state_stats(previous_state, CircuitState::Closed);
info!("Circuit breaker closed");
self.state_change_notify.notify_waiters();
}
}
fn update_state_stats(&self, from_state: CircuitState, _to_state: CircuitState) {
let mut stats = self.stats.write();
stats.state_changes += 1;
if let Some(last_change) = stats.last_state_change {
let duration = last_change.elapsed();
match from_state {
CircuitState::Closed => stats.time_in_closed += duration,
CircuitState::Open => stats.time_in_open += duration,
CircuitState::HalfOpen => stats.time_in_half_open += duration,
}
}
stats.last_state_change = Some(Instant::now());
stats.failure_rate = self.window.read().failure_rate();
}
pub fn state(&self) -> CircuitState {
*self.state.read()
}
pub fn stats(&self) -> CircuitBreakerStats {
let mut stats = self.stats.read().clone();
stats.failure_rate = self.window.read().failure_rate();
stats
}
pub async fn wait_for_state_change(&self) {
self.state_change_notify.notified().await;
}
pub fn reset(&self) {
*self.state.write() = CircuitState::Closed;
*self.state_changed_at.write() = Instant::now();
self.consecutive_failures.store(0, Ordering::Release);
self.consecutive_successes.store(0, Ordering::Release);
self.half_open_requests.store(0, Ordering::Release);
self.window.write().reset();
*self.stats.write() = CircuitBreakerStats::default();
self.state_change_notify.notify_waiters();
}
}
pub struct CircuitBreakerManager {
breakers: Arc<DashMap<PeerId, Arc<CircuitBreaker>>>,
default_config: CircuitBreakerConfig,
global_stats: Arc<RwLock<GlobalCircuitStats>>,
maintenance_handle: Option<tokio::task::JoinHandle<()>>,
}
#[derive(Debug, Clone, Default)]
pub struct GlobalCircuitStats {
pub total_breakers: usize,
pub open_circuits: usize,
pub half_open_circuits: usize,
pub total_requests: u64,
pub total_rejected: u64,
pub avg_failure_rate: f64,
}
impl CircuitBreakerManager {
pub fn new(default_config: CircuitBreakerConfig) -> Self {
let manager = Self {
breakers: Arc::new(DashMap::new()),
default_config,
global_stats: Arc::new(RwLock::new(GlobalCircuitStats::default())),
maintenance_handle: None,
};
let maintenance_manager = manager.clone();
let handle = tokio::spawn(async move {
maintenance_manager.run_maintenance().await;
});
Self {
maintenance_handle: Some(handle),
..manager
}
}
pub fn get_breaker(&self, peer_id: PeerId) -> Arc<CircuitBreaker> {
self.breakers
.entry(peer_id)
.or_insert_with(|| Arc::new(CircuitBreaker::new(self.default_config.clone())))
.clone()
}
pub fn allow_request(&self, peer_id: PeerId) -> bool {
self.get_breaker(peer_id).allow_request()
}
pub fn record_outcome(&self, peer_id: PeerId, success: bool) {
self.get_breaker(peer_id).record_outcome(success);
}
pub fn get_state(&self, peer_id: PeerId) -> CircuitState {
self.get_breaker(peer_id).state()
}
pub fn get_stats(&self, peer_id: PeerId) -> CircuitBreakerStats {
self.get_breaker(peer_id).stats()
}
pub fn get_global_stats(&self) -> GlobalCircuitStats {
self.global_stats.read().clone()
}
pub fn reset(&self, peer_id: PeerId) {
if let Some(breaker) = self.breakers.get(&peer_id) {
breaker.reset();
}
}
pub fn remove(&self, peer_id: PeerId) {
self.breakers.remove(&peer_id);
}
async fn run_maintenance(&self) {
let mut interval = interval(Duration::from_secs(10));
loop {
interval.tick().await;
self.update_global_stats();
}
}
fn update_global_stats(&self) {
let mut total_requests = 0u64;
let mut total_rejected = 0u64;
let mut open_circuits = 0;
let mut half_open_circuits = 0;
let mut total_failure_rate = 0.0;
for entry in self.breakers.iter() {
let breaker = entry.value();
let stats = breaker.stats();
total_requests += stats.total_requests;
total_rejected += stats.rejected_requests;
total_failure_rate += stats.failure_rate;
match breaker.state() {
CircuitState::Open => open_circuits += 1,
CircuitState::HalfOpen => half_open_circuits += 1,
_ => {}
}
}
let total_breakers = self.breakers.len();
let avg_failure_rate = if total_breakers > 0 {
total_failure_rate / total_breakers as f64
} else {
0.0
};
let mut global_stats = self.global_stats.write();
global_stats.total_breakers = total_breakers;
global_stats.open_circuits = open_circuits;
global_stats.half_open_circuits = half_open_circuits;
global_stats.total_requests = total_requests;
global_stats.total_rejected = total_rejected;
global_stats.avg_failure_rate = avg_failure_rate;
}
pub fn shutdown(&mut self) {
if let Some(handle) = self.maintenance_handle.take() {
handle.abort();
}
self.breakers.clear();
}
}
impl Clone for CircuitBreakerManager {
fn clone(&self) -> Self {
Self {
breakers: self.breakers.clone(),
default_config: self.default_config.clone(),
global_stats: self.global_stats.clone(),
maintenance_handle: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_circuit_breaker_closed() {
let config = CircuitBreakerConfig::default();
let breaker = CircuitBreaker::new(config);
assert_eq!(breaker.state(), CircuitState::Closed);
assert!(breaker.allow_request());
breaker.record_outcome(true);
assert_eq!(breaker.state(), CircuitState::Closed);
}
#[tokio::test]
async fn test_circuit_breaker_opens_on_failures() {
let config = CircuitBreakerConfig {
failure_threshold: 3,
min_requests: 1,
..Default::default()
};
let breaker = CircuitBreaker::new(config);
for _ in 0..3 {
assert!(breaker.allow_request());
breaker.record_outcome(false);
}
assert_eq!(breaker.state(), CircuitState::Open);
assert!(!breaker.allow_request());
let stats = breaker.stats();
assert_eq!(stats.failed_requests, 3);
assert_eq!(stats.rejected_requests, 1);
}
#[tokio::test]
async fn test_circuit_breaker_half_open() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
timeout: Duration::from_millis(100),
..Default::default()
};
let breaker = CircuitBreaker::new(config);
for _ in 0..2 {
breaker.record_outcome(false);
}
assert_eq!(breaker.state(), CircuitState::Open);
sleep(Duration::from_millis(150)).await;
assert!(breaker.allow_request());
assert_eq!(breaker.state(), CircuitState::HalfOpen);
}
#[tokio::test]
async fn test_circuit_breaker_closes_after_success() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
success_threshold: 2,
timeout: Duration::from_millis(50),
..Default::default()
};
let breaker = CircuitBreaker::new(config);
for _ in 0..2 {
breaker.record_outcome(false);
}
sleep(Duration::from_millis(100)).await;
assert!(breaker.allow_request());
breaker.record_outcome(true);
assert!(breaker.allow_request());
breaker.record_outcome(true);
assert_eq!(breaker.state(), CircuitState::Closed);
}
#[tokio::test]
async fn test_circuit_breaker_manager() {
let config = CircuitBreakerConfig::default();
let manager = CircuitBreakerManager::new(config);
let peer1 = PeerId::random();
let peer2 = PeerId::random();
assert!(manager.allow_request(peer1));
assert!(manager.allow_request(peer2));
manager.record_outcome(peer1, true);
manager.record_outcome(peer2, false);
assert_eq!(manager.get_state(peer1), CircuitState::Closed);
let global_stats = manager.get_global_stats();
assert_eq!(global_stats.total_breakers, 2);
}
#[test]
fn test_sliding_window() {
let mut window = SlidingWindow::new(Duration::from_secs(1));
window.record(true);
window.record(false);
window.record(true);
window.record(false);
assert_eq!(window.total_requests(), 4);
assert_eq!(window.failure_rate(), 0.5);
window.reset();
assert_eq!(window.total_requests(), 0);
assert_eq!(window.failure_rate(), 0.0);
}
}