use std::time::Duration;
use crate::ErrorKind;
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
failure_threshold: u32,
success_threshold: u32,
timeout: Duration,
failure_rate_threshold: f64,
minimum_requests: u32,
failure_predicate: FailurePredicate,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
success_threshold: 2,
timeout: Duration::from_secs(30),
failure_rate_threshold: 0.5,
minimum_requests: 10,
failure_predicate: FailurePredicate::default(),
}
}
}
impl CircuitBreakerConfig {
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn failure_threshold(mut self, threshold: u32) -> Self {
self.failure_threshold = threshold;
self
}
#[must_use]
pub fn success_threshold(mut self, threshold: u32) -> Self {
self.success_threshold = threshold;
self
}
#[must_use]
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
#[must_use]
pub fn failure_rate_threshold(mut self, threshold: f64) -> Self {
self.failure_rate_threshold = threshold.clamp(0.0, 1.0);
self
}
#[must_use]
pub fn minimum_requests(mut self, count: u32) -> Self {
self.minimum_requests = count;
self
}
#[must_use]
pub fn failure_predicate(mut self, predicate: FailurePredicate) -> Self {
self.failure_predicate = predicate;
self
}
pub fn get_failure_threshold(&self) -> u32 {
self.failure_threshold
}
pub fn get_success_threshold(&self) -> u32 {
self.success_threshold
}
pub fn get_timeout(&self) -> Duration {
self.timeout
}
pub fn get_failure_rate_threshold(&self) -> f64 {
self.failure_rate_threshold
}
pub fn get_minimum_requests(&self) -> u32 {
self.minimum_requests
}
pub fn get_failure_predicate(&self) -> &FailurePredicate {
&self.failure_predicate
}
pub fn is_failure(&self, kind: ErrorKind) -> bool {
self.failure_predicate.is_failure(kind)
}
}
#[derive(Debug, Clone)]
pub struct FailurePredicate {
include: Vec<ErrorKind>,
exclude: Vec<ErrorKind>,
}
impl Default for FailurePredicate {
fn default() -> Self {
Self {
include: vec![
ErrorKind::Timeout,
ErrorKind::Connection,
ErrorKind::Unavailable,
ErrorKind::Internal,
],
exclude: vec![],
}
}
}
impl FailurePredicate {
pub fn only(kinds: impl IntoIterator<Item = ErrorKind>) -> Self {
Self {
include: kinds.into_iter().collect(),
exclude: vec![],
}
}
#[must_use]
pub fn exclude(mut self, kind: ErrorKind) -> Self {
self.exclude.push(kind);
self
}
#[must_use]
pub fn include(mut self, kind: ErrorKind) -> Self {
self.include.push(kind);
self
}
pub fn is_failure(&self, kind: ErrorKind) -> bool {
self.include.contains(&kind) && !self.exclude.contains(&kind)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
impl CircuitState {
pub fn is_closed(&self) -> bool {
matches!(self, CircuitState::Closed)
}
pub fn is_open(&self) -> bool {
matches!(self, CircuitState::Open)
}
pub fn is_half_open(&self) -> bool {
matches!(self, CircuitState::HalfOpen)
}
}
impl std::fmt::Display for CircuitState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CircuitState::Closed => write!(f, "closed"),
CircuitState::Open => write!(f, "open"),
CircuitState::HalfOpen => write!(f, "half-open"),
}
}
}
#[derive(Debug, Clone)]
pub struct CircuitStats {
pub state: CircuitState,
pub failure_count: u32,
pub success_count: u32,
pub total_requests: u64,
pub failed_requests: u64,
pub last_open_time: Option<std::time::Instant>,
pub last_close_time: Option<std::time::Instant>,
}
impl CircuitStats {
pub fn new() -> Self {
Self {
state: CircuitState::Closed,
failure_count: 0,
success_count: 0,
total_requests: 0,
failed_requests: 0,
last_open_time: None,
last_close_time: None,
}
}
pub fn current_state(&self) -> CircuitState {
self.state
}
pub fn failure_count(&self) -> u32 {
self.failure_count
}
pub fn success_count(&self) -> u32 {
self.success_count
}
pub fn failure_rate(&self) -> f64 {
if self.total_requests == 0 {
0.0
} else {
self.failed_requests as f64 / self.total_requests as f64
}
}
}
impl Default for CircuitStats {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub enum CircuitEvent {
Opened {
failure_count: u32,
last_error: String,
},
HalfOpened,
Closed {
success_count: u32,
},
}
impl std::fmt::Display for CircuitEvent {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CircuitEvent::Opened {
failure_count,
last_error,
} => {
write!(
f,
"circuit opened after {} failures: {}",
failure_count, last_error
)
}
CircuitEvent::HalfOpened => write!(f, "circuit half-opened (testing recovery)"),
CircuitEvent::Closed { success_count } => {
write!(f, "circuit closed after {} successes", success_count)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = CircuitBreakerConfig::default();
assert_eq!(config.failure_threshold, 5);
assert_eq!(config.success_threshold, 2);
assert_eq!(config.timeout, Duration::from_secs(30));
assert_eq!(config.failure_rate_threshold, 0.5);
assert_eq!(config.minimum_requests, 10);
}
#[test]
fn test_config_builder() {
let config = CircuitBreakerConfig::new()
.failure_threshold(10)
.success_threshold(3)
.timeout(Duration::from_secs(60))
.failure_rate_threshold(0.8)
.minimum_requests(20);
assert_eq!(config.get_failure_threshold(), 10);
assert_eq!(config.get_success_threshold(), 3);
assert_eq!(config.get_timeout(), Duration::from_secs(60));
assert_eq!(config.get_failure_rate_threshold(), 0.8);
assert_eq!(config.get_minimum_requests(), 20);
}
#[test]
fn test_failure_rate_threshold_clamped() {
let config = CircuitBreakerConfig::new().failure_rate_threshold(1.5);
assert_eq!(config.get_failure_rate_threshold(), 1.0);
let config = CircuitBreakerConfig::new().failure_rate_threshold(-0.5);
assert_eq!(config.get_failure_rate_threshold(), 0.0);
}
#[test]
fn test_default_failure_predicate() {
let predicate = FailurePredicate::default();
assert!(predicate.is_failure(ErrorKind::Timeout));
assert!(predicate.is_failure(ErrorKind::Connection));
assert!(predicate.is_failure(ErrorKind::Unavailable));
assert!(predicate.is_failure(ErrorKind::Internal));
assert!(!predicate.is_failure(ErrorKind::Forbidden));
assert!(!predicate.is_failure(ErrorKind::NotFound));
}
#[test]
fn test_failure_predicate_only() {
let predicate = FailurePredicate::only([ErrorKind::Timeout]);
assert!(predicate.is_failure(ErrorKind::Timeout));
assert!(!predicate.is_failure(ErrorKind::Connection));
}
#[test]
fn test_failure_predicate_exclude() {
let predicate = FailurePredicate::default().exclude(ErrorKind::Timeout);
assert!(!predicate.is_failure(ErrorKind::Timeout));
assert!(predicate.is_failure(ErrorKind::Connection));
}
#[test]
fn test_circuit_state() {
assert!(CircuitState::Closed.is_closed());
assert!(!CircuitState::Closed.is_open());
assert!(!CircuitState::Closed.is_half_open());
assert!(!CircuitState::Open.is_closed());
assert!(CircuitState::Open.is_open());
assert!(!CircuitState::Open.is_half_open());
assert!(!CircuitState::HalfOpen.is_closed());
assert!(!CircuitState::HalfOpen.is_open());
assert!(CircuitState::HalfOpen.is_half_open());
}
#[test]
fn test_circuit_stats() {
let mut stats = CircuitStats::new();
assert_eq!(stats.current_state(), CircuitState::Closed);
assert_eq!(stats.failure_count(), 0);
assert_eq!(stats.success_count(), 0);
assert_eq!(stats.failure_rate(), 0.0);
stats.total_requests = 10;
stats.failed_requests = 3;
assert!((stats.failure_rate() - 0.3).abs() < f64::EPSILON);
}
#[test]
fn test_circuit_event_display() {
let event = CircuitEvent::Opened {
failure_count: 5,
last_error: "connection refused".to_string(),
};
let display = event.to_string();
assert!(display.contains("5 failures"));
assert!(display.contains("connection refused"));
let event = CircuitEvent::HalfOpened;
assert!(event.to_string().contains("half-opened"));
let event = CircuitEvent::Closed { success_count: 2 };
assert!(event.to_string().contains("2 successes"));
}
#[test]
fn test_circuit_state_display() {
assert_eq!(format!("{}", CircuitState::Closed), "closed");
assert_eq!(format!("{}", CircuitState::Open), "open");
assert_eq!(format!("{}", CircuitState::HalfOpen), "half-open");
}
#[test]
fn test_circuit_stats_default() {
let stats = CircuitStats::default();
assert_eq!(stats.state, CircuitState::Closed);
assert_eq!(stats.failure_count, 0);
}
#[test]
fn test_failure_predicate_include() {
let predicate = FailurePredicate::only([ErrorKind::Timeout]).include(ErrorKind::Connection);
assert!(predicate.is_failure(ErrorKind::Timeout));
assert!(predicate.is_failure(ErrorKind::Connection));
assert!(!predicate.is_failure(ErrorKind::NotFound));
}
#[test]
fn test_config_is_failure() {
let config = CircuitBreakerConfig::default();
assert!(config.is_failure(ErrorKind::Timeout));
assert!(!config.is_failure(ErrorKind::NotFound));
}
#[test]
fn test_config_custom_predicate() {
let predicate = FailurePredicate::only([ErrorKind::NotFound]);
let config = CircuitBreakerConfig::new().failure_predicate(predicate);
assert!(config.is_failure(ErrorKind::NotFound));
assert!(!config.is_failure(ErrorKind::Timeout));
}
#[test]
fn test_config_get_failure_predicate() {
let config = CircuitBreakerConfig::default();
let predicate = config.get_failure_predicate();
assert!(predicate.is_failure(ErrorKind::Timeout));
}
#[test]
fn test_circuit_event_clone() {
let event = CircuitEvent::Opened {
failure_count: 5,
last_error: "error".to_string(),
};
let cloned = event.clone();
match cloned {
CircuitEvent::Opened { failure_count, .. } => assert_eq!(failure_count, 5),
_ => panic!("Wrong variant"),
}
}
#[test]
fn test_circuit_state_copy() {
let state = CircuitState::Open;
let copied: CircuitState = state;
assert_eq!(state, copied);
}
#[test]
fn test_circuit_stats_with_times() {
let mut stats = CircuitStats::new();
stats.last_open_time = Some(std::time::Instant::now());
stats.last_close_time = Some(std::time::Instant::now());
assert!(stats.last_open_time.is_some());
assert!(stats.last_close_time.is_some());
}
}