use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Mutex;
use std::time::{Duration, Instant};
use crate::error::ResilienceError;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitBreakerState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerPolicy {
failure_threshold: u32,
open_duration: Duration,
}
impl CircuitBreakerPolicy {
#[must_use]
pub fn new() -> Self {
Self {
failure_threshold: 5,
open_duration: Duration::from_secs(30),
}
}
#[must_use]
pub fn with_failure_threshold(mut self, threshold: u32) -> Self {
self.failure_threshold = threshold;
self
}
#[must_use]
pub fn with_open_duration(mut self, duration: Duration) -> Self {
self.open_duration = duration;
self
}
}
impl Default for CircuitBreakerPolicy {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub enum CircuitBreakerError<E> {
CircuitOpen,
ProbeInFlight,
DownstreamFailed(E),
Internal(ResilienceError),
}
impl<E: std::fmt::Display> std::fmt::Display for CircuitBreakerError<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::CircuitOpen => f.write_str("circuit breaker is open"),
Self::ProbeInFlight => {
f.write_str("circuit breaker is half-open and a probe is in flight")
}
Self::DownstreamFailed(e) => write!(f, "downstream call failed: {e}"),
Self::Internal(e) => write!(f, "internal error: {e}"),
}
}
}
impl<E: std::fmt::Debug + std::fmt::Display> std::error::Error for CircuitBreakerError<E> {}
#[derive(Debug)]
struct Internal {
state: CircuitBreakerState,
failure_count: u32,
opened_at: Option<Instant>,
}
#[derive(Debug)]
pub struct CircuitBreaker {
policy: CircuitBreakerPolicy,
inner: Mutex<Internal>,
probe_inflight: AtomicBool,
}
impl CircuitBreaker {
#[must_use]
pub fn new(policy: CircuitBreakerPolicy) -> Self {
Self {
policy,
inner: Mutex::new(Internal {
state: CircuitBreakerState::Closed,
failure_count: 0,
opened_at: None,
}),
probe_inflight: AtomicBool::new(false),
}
}
pub fn state(&self) -> Result<CircuitBreakerState, ResilienceError> {
let guard = self
.inner
.lock()
.map_err(|_| ResilienceError::Internal("circuit-breaker mutex poisoned".into()))?;
Ok(guard.state)
}
pub fn call<T, E, F>(&self, f: F) -> Result<T, CircuitBreakerError<E>>
where
F: FnOnce() -> Result<T, E>,
{
let allow_probe = {
let mut guard = self.inner.lock().map_err(|_| {
CircuitBreakerError::Internal(ResilienceError::Internal(
"circuit-breaker mutex poisoned".into(),
))
})?;
if guard.state == CircuitBreakerState::Open {
if let Some(opened_at) = guard.opened_at {
if opened_at.elapsed() >= self.policy.open_duration {
guard.state = CircuitBreakerState::HalfOpen;
}
}
}
match guard.state {
CircuitBreakerState::Closed => false,
CircuitBreakerState::Open => return Err(CircuitBreakerError::CircuitOpen),
CircuitBreakerState::HalfOpen => true,
}
};
if allow_probe {
if self
.probe_inflight
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
.is_err()
{
return Err(CircuitBreakerError::ProbeInFlight);
}
}
let outcome = f();
let mut guard = self.inner.lock().map_err(|_| {
self.probe_inflight.store(false, Ordering::SeqCst);
CircuitBreakerError::Internal(ResilienceError::Internal(
"circuit-breaker mutex poisoned".into(),
))
})?;
match (&outcome, guard.state) {
(Ok(_), CircuitBreakerState::HalfOpen) => {
guard.state = CircuitBreakerState::Closed;
guard.failure_count = 0;
guard.opened_at = None;
self.probe_inflight.store(false, Ordering::SeqCst);
}
(Err(_), CircuitBreakerState::HalfOpen) => {
guard.state = CircuitBreakerState::Open;
guard.opened_at = Some(Instant::now());
self.probe_inflight.store(false, Ordering::SeqCst);
}
(Ok(_), CircuitBreakerState::Closed) => {
guard.failure_count = 0;
}
(Err(_), CircuitBreakerState::Closed) => {
guard.failure_count += 1;
if guard.failure_count >= self.policy.failure_threshold {
guard.state = CircuitBreakerState::Open;
guard.opened_at = Some(Instant::now());
}
}
(_, CircuitBreakerState::Open) => {}
}
drop(guard);
outcome.map_err(CircuitBreakerError::DownstreamFailed)
}
}