use aspect_core::{Aspect, AspectError, ProceedingJoinPoint};
use parking_lot::Mutex;
use std::any::Any;
use std::sync::Arc;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, PartialEq)]
pub enum CircuitState {
Closed,
Open { until: Instant },
HalfOpen,
}
#[derive(Clone)]
pub struct CircuitBreakerAspect {
state: Arc<Mutex<CircuitBreakerState>>,
}
struct CircuitBreakerState {
circuit_state: CircuitState,
failure_count: usize,
success_count: usize,
failure_threshold: usize,
timeout: Duration,
half_open_max_requests: usize,
}
impl CircuitBreakerAspect {
pub fn new(failure_threshold: usize, timeout: Duration) -> Self {
Self {
state: Arc::new(Mutex::new(CircuitBreakerState {
circuit_state: CircuitState::Closed,
failure_count: 0,
success_count: 0,
failure_threshold,
timeout,
half_open_max_requests: 1,
})),
}
}
pub fn with_half_open_requests(self, max_requests: usize) -> Self {
self.state.lock().half_open_max_requests = max_requests;
self
}
pub fn state(&self) -> CircuitState {
self.state.lock().circuit_state.clone()
}
pub fn reset(&self) {
let mut state = self.state.lock();
state.circuit_state = CircuitState::Closed;
state.failure_count = 0;
state.success_count = 0;
}
fn record_success(&self) {
let mut state = self.state.lock();
match state.circuit_state {
CircuitState::HalfOpen => {
state.success_count += 1;
if state.success_count >= state.half_open_max_requests {
state.circuit_state = CircuitState::Closed;
state.failure_count = 0;
state.success_count = 0;
}
}
CircuitState::Closed => {
state.failure_count = 0;
}
CircuitState::Open { .. } => {
state.failure_count = 0;
state.success_count = 0;
}
}
}
fn record_failure(&self) {
let mut state = self.state.lock();
match state.circuit_state {
CircuitState::HalfOpen => {
state.circuit_state = CircuitState::Open {
until: Instant::now() + state.timeout,
};
state.success_count = 0;
}
CircuitState::Closed => {
state.failure_count += 1;
if state.failure_count >= state.failure_threshold {
state.circuit_state = CircuitState::Open {
until: Instant::now() + state.timeout,
};
}
}
CircuitState::Open { .. } => {
}
}
}
fn should_allow_request(&self) -> Result<(), AspectError> {
let mut state = self.state.lock();
match state.circuit_state {
CircuitState::Closed => Ok(()),
CircuitState::HalfOpen => Ok(()),
CircuitState::Open { until } => {
if Instant::now() >= until {
state.circuit_state = CircuitState::HalfOpen;
state.success_count = 0;
Ok(())
} else {
Err(AspectError::execution(
"Circuit breaker is OPEN - failing fast",
))
}
}
}
}
}
impl Aspect for CircuitBreakerAspect {
fn around(&self, pjp: ProceedingJoinPoint) -> Result<Box<dyn Any>, AspectError> {
self.should_allow_request()?;
match pjp.proceed() {
Ok(result) => {
self.record_success();
Ok(result)
}
Err(e) => {
self.record_failure();
Err(e)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_circuit_breaker_closed_initially() {
let breaker = CircuitBreakerAspect::new(3, Duration::from_secs(1));
assert_eq!(breaker.state(), CircuitState::Closed);
}
#[test]
fn test_circuit_opens_after_threshold() {
let breaker = CircuitBreakerAspect::new(3, Duration::from_secs(60));
for _ in 0..3 {
breaker.record_failure();
}
match breaker.state() {
CircuitState::Open { .. } => (),
_ => panic!("Circuit should be open"),
}
}
#[test]
fn test_circuit_rejects_when_open() {
let breaker = CircuitBreakerAspect::new(1, Duration::from_secs(60));
breaker.record_failure();
assert!(breaker.should_allow_request().is_err());
}
#[test]
fn test_circuit_transitions_to_half_open() {
let breaker = CircuitBreakerAspect::new(1, Duration::from_millis(100));
breaker.record_failure();
assert!(matches!(breaker.state(), CircuitState::Open { .. }));
std::thread::sleep(Duration::from_millis(150));
assert!(breaker.should_allow_request().is_ok());
assert_eq!(breaker.state(), CircuitState::HalfOpen);
}
#[test]
fn test_circuit_closes_after_success() {
let breaker = CircuitBreakerAspect::new(1, Duration::from_millis(50));
breaker.record_failure();
std::thread::sleep(Duration::from_millis(60));
breaker.should_allow_request().unwrap();
breaker.record_success();
assert_eq!(breaker.state(), CircuitState::Closed);
}
#[test]
fn test_reset() {
let breaker = CircuitBreakerAspect::new(2, Duration::from_secs(60));
breaker.record_failure();
breaker.record_failure();
assert!(matches!(breaker.state(), CircuitState::Open { .. }));
breaker.reset();
assert_eq!(breaker.state(), CircuitState::Closed);
}
}