use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub open_duration: Duration,
pub fallback_cache_ttl: Duration,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 3,
open_duration: Duration::from_secs(60),
fallback_cache_ttl: Duration::from_secs(86400), }
}
}
#[derive(Debug)]
pub struct CircuitBreaker {
state: CircuitState,
failure_count: u32,
opened_at: Option<Instant>,
config: CircuitBreakerConfig,
service_name: &'static str,
}
impl CircuitBreaker {
pub fn new(service_name: &'static str) -> Self {
Self::with_config(service_name, CircuitBreakerConfig::default())
}
pub fn with_config(service_name: &'static str, config: CircuitBreakerConfig) -> Self {
Self {
state: CircuitState::Closed,
failure_count: 0,
opened_at: None,
config,
service_name,
}
}
pub fn should_allow_request(&mut self) -> bool {
match self.state {
CircuitState::Closed => true,
CircuitState::Open => {
if let Some(opened) = self.opened_at {
if opened.elapsed() >= self.config.open_duration {
tracing::info!(
service = self.service_name,
"Circuit breaker transitioning to half-open"
);
self.state = CircuitState::HalfOpen;
return true;
}
}
false
}
CircuitState::HalfOpen => true,
}
}
pub fn record_success(&mut self) {
if self.state != CircuitState::Closed {
tracing::info!(
service = self.service_name,
previous_state = ?self.state,
"Circuit breaker closing after successful request"
);
}
self.state = CircuitState::Closed;
self.failure_count = 0;
self.opened_at = None;
}
pub fn record_failure(&mut self) {
self.failure_count += 1;
match self.state {
CircuitState::Closed => {
if self.failure_count >= self.config.failure_threshold {
tracing::warn!(
service = self.service_name,
failure_count = self.failure_count,
"Circuit breaker opening after {} failures",
self.failure_count
);
self.state = CircuitState::Open;
self.opened_at = Some(Instant::now());
}
}
CircuitState::HalfOpen => {
tracing::warn!(
service = self.service_name,
"Circuit breaker reopening after half-open failure"
);
self.state = CircuitState::Open;
self.opened_at = Some(Instant::now());
}
CircuitState::Open => {
}
}
}
#[cfg(test)]
pub fn state(&self) -> CircuitState {
self.state
}
pub fn is_fallback_valid(&self, fetched_at: Instant) -> bool {
fetched_at.elapsed() < self.config.fallback_cache_ttl
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_circuit_starts_closed() {
let cb = CircuitBreaker::new("test");
assert_eq!(cb.state(), CircuitState::Closed);
}
#[test]
fn test_allows_requests_when_closed() {
let mut cb = CircuitBreaker::new("test");
assert!(cb.should_allow_request());
}
#[test]
fn test_opens_after_threshold_failures() {
let config = CircuitBreakerConfig {
failure_threshold: 3,
..Default::default()
};
let mut cb = CircuitBreaker::with_config("test", config);
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Closed);
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Closed);
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
}
#[test]
fn test_blocks_requests_when_open() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
open_duration: Duration::from_secs(60),
..Default::default()
};
let mut cb = CircuitBreaker::with_config("test", config);
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
assert!(!cb.should_allow_request());
}
#[test]
fn test_transitions_to_half_open_after_cooldown() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
open_duration: Duration::from_millis(1),
..Default::default()
};
let mut cb = CircuitBreaker::with_config("test", config);
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
std::thread::sleep(Duration::from_millis(5));
assert!(cb.should_allow_request());
assert_eq!(cb.state(), CircuitState::HalfOpen);
}
#[test]
fn test_closes_on_success() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
open_duration: Duration::from_millis(1),
..Default::default()
};
let mut cb = CircuitBreaker::with_config("test", config);
cb.record_failure();
std::thread::sleep(Duration::from_millis(5));
cb.should_allow_request();
cb.record_success();
assert_eq!(cb.state(), CircuitState::Closed);
assert_eq!(cb.failure_count, 0);
}
#[test]
fn test_reopens_on_half_open_failure() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
open_duration: Duration::from_millis(1),
..Default::default()
};
let mut cb = CircuitBreaker::with_config("test", config);
cb.record_failure();
std::thread::sleep(Duration::from_millis(5));
cb.should_allow_request();
cb.record_failure();
assert_eq!(cb.state(), CircuitState::Open);
}
#[test]
fn test_success_resets_failure_count() {
let mut cb = CircuitBreaker::new("test");
cb.record_failure();
cb.record_failure();
assert_eq!(cb.failure_count, 2);
cb.record_success();
assert_eq!(cb.failure_count, 0);
}
#[test]
fn test_fallback_validity() {
let config = CircuitBreakerConfig {
fallback_cache_ttl: Duration::from_millis(50),
..Default::default()
};
let cb = CircuitBreaker::with_config("test", config);
let fetched = Instant::now();
assert!(cb.is_fallback_valid(fetched));
std::thread::sleep(Duration::from_millis(60));
assert!(!cb.is_fallback_valid(fetched));
}
}