use std::collections::HashMap;
use std::sync::RwLock;
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub success_threshold: u32,
pub open_duration: Duration,
pub failure_window: Duration,
pub timeout_as_failure: bool,
pub server_error_as_failure: bool,
pub rate_limit_as_failure: bool,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
success_threshold: 2,
open_duration: Duration::from_secs(30),
failure_window: Duration::from_secs(60),
timeout_as_failure: true,
server_error_as_failure: true,
rate_limit_as_failure: false, }
}
}
impl CircuitBreakerConfig {
pub fn production() -> Self {
Self {
failure_threshold: 10,
success_threshold: 3,
open_duration: Duration::from_secs(60),
failure_window: Duration::from_secs(120),
timeout_as_failure: true,
server_error_as_failure: true,
rate_limit_as_failure: false,
}
}
pub fn aggressive() -> Self {
Self {
failure_threshold: 3,
success_threshold: 1,
open_duration: Duration::from_secs(15),
failure_window: Duration::from_secs(30),
timeout_as_failure: true,
server_error_as_failure: true,
rate_limit_as_failure: true,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug)]
struct DomainCircuit {
state: CircuitState,
failures: Vec<Instant>,
successes_in_half_open: u32,
opened_at: Option<Instant>,
last_failure: Option<Instant>,
}
impl DomainCircuit {
fn new() -> Self {
Self {
state: CircuitState::Closed,
failures: Vec::new(),
successes_in_half_open: 0,
opened_at: None,
last_failure: None,
}
}
fn recent_failures(&self, window: Duration) -> u32 {
let cutoff = Instant::now() - window;
self.failures.iter().filter(|&&t| t > cutoff).count() as u32
}
fn cleanup_old_failures(&mut self, window: Duration) {
let cutoff = Instant::now() - window;
self.failures.retain(|&t| t > cutoff);
}
}
pub struct CircuitBreaker {
config: CircuitBreakerConfig,
circuits: RwLock<HashMap<String, DomainCircuit>>,
}
impl CircuitBreaker {
pub fn new(config: CircuitBreakerConfig) -> Self {
Self {
config,
circuits: RwLock::new(HashMap::new()),
}
}
pub fn default_config() -> Self {
Self::new(CircuitBreakerConfig::default())
}
pub fn allow_request(&self, domain: &str) -> bool {
let mut circuits = self.circuits.write().unwrap();
let circuit = circuits.entry(domain.to_string()).or_insert_with(DomainCircuit::new);
match circuit.state {
CircuitState::Closed => true,
CircuitState::Open => {
if let Some(opened_at) = circuit.opened_at {
if opened_at.elapsed() >= self.config.open_duration {
circuit.state = CircuitState::HalfOpen;
circuit.successes_in_half_open = 0;
true
} else {
false
}
} else {
false
}
}
CircuitState::HalfOpen => true,
}
}
pub fn record_success(&self, domain: &str) {
let mut circuits = self.circuits.write().unwrap();
if let Some(circuit) = circuits.get_mut(domain) {
match circuit.state {
CircuitState::HalfOpen => {
circuit.successes_in_half_open += 1;
if circuit.successes_in_half_open >= self.config.success_threshold {
circuit.state = CircuitState::Closed;
circuit.failures.clear();
circuit.opened_at = None;
circuit.successes_in_half_open = 0;
}
}
CircuitState::Closed => {
circuit.cleanup_old_failures(self.config.failure_window);
}
CircuitState::Open => {
}
}
}
}
pub fn record_failure(&self, domain: &str) {
let mut circuits = self.circuits.write().unwrap();
let circuit = circuits.entry(domain.to_string()).or_insert_with(DomainCircuit::new);
circuit.failures.push(Instant::now());
circuit.last_failure = Some(Instant::now());
circuit.cleanup_old_failures(self.config.failure_window);
match circuit.state {
CircuitState::Closed => {
if circuit.recent_failures(self.config.failure_window) >= self.config.failure_threshold {
circuit.state = CircuitState::Open;
circuit.opened_at = Some(Instant::now());
}
}
CircuitState::HalfOpen => {
circuit.state = CircuitState::Open;
circuit.opened_at = Some(Instant::now());
circuit.successes_in_half_open = 0;
}
CircuitState::Open => {
circuit.opened_at = Some(Instant::now());
}
}
}
pub fn record_timeout(&self, domain: &str) {
if self.config.timeout_as_failure {
self.record_failure(domain);
}
}
pub fn record_server_error(&self, domain: &str) {
if self.config.server_error_as_failure {
self.record_failure(domain);
}
}
pub fn record_rate_limit(&self, domain: &str) {
if self.config.rate_limit_as_failure {
self.record_failure(domain);
}
}
pub fn get_state(&self, domain: &str) -> CircuitState {
let circuits = self.circuits.read().unwrap();
circuits.get(domain).map(|c| c.state).unwrap_or(CircuitState::Closed)
}
pub fn get_open_circuits(&self) -> Vec<String> {
let circuits = self.circuits.read().unwrap();
circuits
.iter()
.filter(|(_, c)| c.state == CircuitState::Open)
.map(|(domain, _)| domain.clone())
.collect()
}
pub fn reset(&self, domain: &str) {
let mut circuits = self.circuits.write().unwrap();
circuits.remove(domain);
}
pub fn reset_all(&self) {
let mut circuits = self.circuits.write().unwrap();
circuits.clear();
}
pub fn stats(&self) -> CircuitBreakerStats {
let circuits = self.circuits.read().unwrap();
let total = circuits.len();
let open = circuits.values().filter(|c| c.state == CircuitState::Open).count();
let half_open = circuits.values().filter(|c| c.state == CircuitState::HalfOpen).count();
let closed = circuits.values().filter(|c| c.state == CircuitState::Closed).count();
CircuitBreakerStats {
total_domains: total,
open_circuits: open,
half_open_circuits: half_open,
closed_circuits: closed,
}
}
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerStats {
pub total_domains: usize,
pub open_circuits: usize,
pub half_open_circuits: usize,
pub closed_circuits: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_circuit_starts_closed() {
let breaker = CircuitBreaker::default_config();
assert!(breaker.allow_request("example.com"));
assert_eq!(breaker.get_state("example.com"), CircuitState::Closed);
}
#[test]
fn test_circuit_opens_after_failures() {
let config = CircuitBreakerConfig {
failure_threshold: 3,
..Default::default()
};
let breaker = CircuitBreaker::new(config);
for _ in 0..3 {
breaker.record_failure("example.com");
}
assert_eq!(breaker.get_state("example.com"), CircuitState::Open);
assert!(!breaker.allow_request("example.com"));
}
#[test]
fn test_circuit_transitions_to_half_open() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
open_duration: Duration::from_millis(10),
..Default::default()
};
let breaker = CircuitBreaker::new(config);
breaker.record_failure("example.com");
breaker.record_failure("example.com");
assert_eq!(breaker.get_state("example.com"), CircuitState::Open);
std::thread::sleep(Duration::from_millis(15));
assert!(breaker.allow_request("example.com"));
assert_eq!(breaker.get_state("example.com"), CircuitState::HalfOpen);
}
#[test]
fn test_circuit_closes_after_successes() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
success_threshold: 2,
open_duration: Duration::from_millis(10),
..Default::default()
};
let breaker = CircuitBreaker::new(config);
breaker.record_failure("example.com");
breaker.record_failure("example.com");
std::thread::sleep(Duration::from_millis(15));
breaker.allow_request("example.com");
breaker.record_success("example.com");
breaker.record_success("example.com");
assert_eq!(breaker.get_state("example.com"), CircuitState::Closed);
}
#[test]
fn test_stats() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
..Default::default()
};
let breaker = CircuitBreaker::new(config);
breaker.allow_request("good.com");
breaker.record_failure("bad.com");
breaker.record_failure("bad.com");
let stats = breaker.stats();
assert_eq!(stats.total_domains, 2);
assert_eq!(stats.open_circuits, 1);
assert_eq!(stats.closed_circuits, 1);
}
}