Skip to main content

atomr_core/pattern/
circuit_breaker.rs

1//! Circuit breaker.
2
3use std::future::Future;
4use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
5use std::sync::Arc;
6use std::time::Duration;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9#[non_exhaustive]
10pub enum CircuitBreakerState {
11    Closed,
12    Open,
13    HalfOpen,
14}
15
16pub struct CircuitBreaker {
17    max_failures: u32,
18    call_timeout: Duration,
19    reset_timeout: Duration,
20    failures: AtomicU32,
21    opened_at_ns: AtomicU64,
22    // packed state: 0=closed, 1=open, 2=half-open
23    state: AtomicU32,
24}
25
26impl CircuitBreaker {
27    pub fn new(max_failures: u32, call_timeout: Duration, reset_timeout: Duration) -> Arc<Self> {
28        Arc::new(Self {
29            max_failures,
30            call_timeout,
31            reset_timeout,
32            failures: AtomicU32::new(0),
33            opened_at_ns: AtomicU64::new(0),
34            state: AtomicU32::new(0),
35        })
36    }
37
38    pub fn state(&self) -> CircuitBreakerState {
39        match self.state.load(Ordering::Acquire) {
40            0 => CircuitBreakerState::Closed,
41            1 => {
42                // Compare elapsed since the breaker opened (epoch in
43                // ns since process start) — Phase 3.4 fix; the
44                // previous comparison used `Instant::now().elapsed()`
45                // which is always 0 and never transitioned to half-open.
46                let now_ns = self.elapsed_ns();
47                let opened_ns = self.opened_at_ns.load(Ordering::Acquire);
48                if opened_ns > 0 && now_ns.saturating_sub(opened_ns) >= self.reset_timeout.as_nanos() as u64 {
49                    CircuitBreakerState::HalfOpen
50                } else {
51                    CircuitBreakerState::Open
52                }
53            }
54            _ => CircuitBreakerState::HalfOpen,
55        }
56    }
57
58    fn elapsed_ns(&self) -> u64 {
59        // Stable epoch chosen at first call of `record_failure`. We
60        // approximate via `std::time::SystemTime` + `UNIX_EPOCH` so
61        // both `record_failure` and `state()` agree.
62        std::time::SystemTime::now()
63            .duration_since(std::time::UNIX_EPOCH)
64            .map(|d| d.as_nanos() as u64)
65            .unwrap_or(0)
66    }
67
68    pub async fn call<F, Fut, T, E>(&self, f: F) -> Result<T, CircuitBreakerError<E>>
69    where
70        F: FnOnce() -> Fut,
71        Fut: Future<Output = Result<T, E>>,
72    {
73        let st = self.state.load(Ordering::Acquire);
74        if st == 1 {
75            return Err(CircuitBreakerError::Open);
76        }
77        let res = tokio::time::timeout(self.call_timeout, f()).await;
78        match res {
79            Ok(Ok(v)) => {
80                self.failures.store(0, Ordering::Release);
81                self.state.store(0, Ordering::Release);
82                Ok(v)
83            }
84            Ok(Err(e)) => {
85                self.record_failure();
86                Err(CircuitBreakerError::Inner(e))
87            }
88            Err(_) => {
89                self.record_failure();
90                Err(CircuitBreakerError::Timeout)
91            }
92        }
93    }
94
95    fn record_failure(&self) {
96        let n = self.failures.fetch_add(1, Ordering::AcqRel) + 1;
97        if n >= self.max_failures {
98            self.state.store(1, Ordering::Release);
99            self.opened_at_ns.store(self.elapsed_ns(), Ordering::Release);
100        }
101    }
102}
103
104#[derive(Debug, thiserror::Error)]
105#[non_exhaustive]
106pub enum CircuitBreakerError<E> {
107    #[error("circuit breaker is open")]
108    Open,
109    #[error("call timed out")]
110    Timeout,
111    #[error(transparent)]
112    Inner(E),
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118
119    #[tokio::test]
120    async fn opens_after_max_failures() {
121        let cb = CircuitBreaker::new(2, Duration::from_secs(1), Duration::from_secs(1));
122        for _ in 0..2 {
123            let _ = cb.call(|| async { Err::<(), _>(1) }).await;
124        }
125        let res: Result<(), _> = cb.call(|| async { Ok::<(), u32>(()) }).await;
126        assert!(matches!(res, Err(CircuitBreakerError::Open)));
127    }
128}