Skip to main content

chio_guards/external/
circuit_breaker.rs

1//! Three-state circuit breaker for external service calls.
2//!
3//! The breaker transitions between three states:
4//!
5//! * [`CircuitState::Closed`] -- normal operation. Failures are counted
6//!   inside a sliding window; once the count reaches
7//!   [`CircuitBreakerConfig::failure_threshold`], the breaker opens.
8//! * [`CircuitState::Open`] -- fail-fast. Calls are short-circuited for at
9//!   least [`CircuitBreakerConfig::reset_timeout`].
10//! * [`CircuitState::HalfOpen`] -- probing recovery. A limited number of
11//!   trial calls are admitted; after
12//!   [`CircuitBreakerConfig::success_threshold`] consecutive successes the
13//!   breaker closes. Any failure reopens it.
14//!
15//! The breaker uses a [`Clock`] abstraction for monotonic time so tests can
16//! drive transitions via [`tokio::time::pause`] + `advance` without
17//! wall-clock sleep.
18
19use std::sync::Arc;
20use std::sync::Mutex;
21use std::time::Duration;
22
23use tokio::time::Instant;
24
25use super::cache::{Clock, TokioClock};
26
27/// State of the circuit breaker.
28///
29/// See the module docs for the transition rules.
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub enum CircuitState {
32    /// Normal operation: calls flow through.
33    Closed,
34    /// Tripped: calls are short-circuited until the reset timeout elapses.
35    Open,
36    /// Probing: a bounded number of trial calls are admitted to test
37    /// whether the external dependency has recovered.
38    HalfOpen,
39}
40
41/// Static configuration for a [`CircuitBreaker`].
42#[derive(Debug, Clone)]
43pub struct CircuitBreakerConfig {
44    /// Failures observed within `failure_window` before the breaker opens.
45    pub failure_threshold: u32,
46    /// Rolling window for failure counting in the Closed state. Failures
47    /// older than this are discarded from the sliding count.
48    pub failure_window: Duration,
49    /// Consecutive successes required in the HalfOpen state before the
50    /// breaker transitions back to Closed.
51    pub success_threshold: u32,
52    /// Time the breaker remains Open before a HalfOpen probe is allowed.
53    pub reset_timeout: Duration,
54}
55
56impl Default for CircuitBreakerConfig {
57    fn default() -> Self {
58        Self {
59            failure_threshold: 5,
60            failure_window: Duration::from_secs(60),
61            success_threshold: 2,
62            reset_timeout: Duration::from_secs(30),
63        }
64    }
65}
66
67/// Thread-safe three-state circuit breaker.
68pub struct CircuitBreaker {
69    inner: Mutex<CircuitInner>,
70    config: CircuitBreakerConfig,
71    clock: Arc<dyn Clock>,
72}
73
74#[derive(Debug)]
75struct CircuitInner {
76    state: CircuitState,
77    /// Timestamps of recent failures in the Closed state.
78    failures: Vec<Instant>,
79    /// Count of consecutive successes in the HalfOpen state.
80    half_open_successes: u32,
81    /// When the breaker was last opened, used to schedule the HalfOpen probe.
82    opened_at: Option<Instant>,
83}
84
85impl CircuitBreaker {
86    /// Create a new breaker with the given configuration and the default
87    /// ([`TokioClock`]) clock.
88    pub fn new(config: CircuitBreakerConfig) -> Self {
89        Self::with_clock(config, Arc::new(TokioClock))
90    }
91
92    /// Create a breaker with a custom clock (primarily for tests).
93    pub fn with_clock(config: CircuitBreakerConfig, clock: Arc<dyn Clock>) -> Self {
94        Self {
95            inner: Mutex::new(CircuitInner {
96                state: CircuitState::Closed,
97                failures: Vec::new(),
98                half_open_successes: 0,
99                opened_at: None,
100            }),
101            config,
102            clock,
103        }
104    }
105
106    /// Current configuration.
107    pub fn config(&self) -> &CircuitBreakerConfig {
108        &self.config
109    }
110
111    /// Current state. Transitions from Open to HalfOpen happen lazily on
112    /// observation, so this call also advances state when appropriate.
113    pub fn current_state(&self) -> CircuitState {
114        let now = self.clock.now();
115        let Ok(mut inner) = self.inner.lock() else {
116            return CircuitState::Open;
117        };
118        self.tick(&mut inner, now);
119        inner.state
120    }
121
122    /// Ask whether a call is currently allowed through. `true` means the
123    /// caller should invoke the downstream service; `false` means the
124    /// breaker is Open and the caller must fail fast.
125    pub fn allow_call(&self) -> bool {
126        let now = self.clock.now();
127        let Ok(mut inner) = self.inner.lock() else {
128            return false;
129        };
130        self.tick(&mut inner, now);
131        !matches!(inner.state, CircuitState::Open)
132    }
133
134    /// Record a successful downstream call.
135    pub fn record_success(&self) {
136        let now = self.clock.now();
137        let Ok(mut inner) = self.inner.lock() else {
138            return;
139        };
140        self.tick(&mut inner, now);
141        match inner.state {
142            CircuitState::Closed => {
143                inner.failures.clear();
144            }
145            CircuitState::HalfOpen => {
146                inner.half_open_successes = inner.half_open_successes.saturating_add(1);
147                if inner.half_open_successes >= self.config.success_threshold {
148                    inner.state = CircuitState::Closed;
149                    inner.failures.clear();
150                    inner.half_open_successes = 0;
151                    inner.opened_at = None;
152                }
153            }
154            CircuitState::Open => {
155                // Shouldn't happen -- Open calls are rejected before reaching
156                // the downstream. Treat as a no-op.
157            }
158        }
159    }
160
161    /// Record a failed downstream call.
162    pub fn record_failure(&self) {
163        let now = self.clock.now();
164        let Ok(mut inner) = self.inner.lock() else {
165            return;
166        };
167        self.tick(&mut inner, now);
168        match inner.state {
169            CircuitState::Closed => {
170                inner.failures.push(now);
171                self.drop_stale_failures(&mut inner, now);
172                if inner.failures.len() as u32 >= self.config.failure_threshold {
173                    inner.state = CircuitState::Open;
174                    inner.opened_at = Some(now);
175                    inner.failures.clear();
176                }
177            }
178            CircuitState::HalfOpen => {
179                inner.state = CircuitState::Open;
180                inner.opened_at = Some(now);
181                inner.half_open_successes = 0;
182            }
183            CircuitState::Open => {
184                // Re-arm the open timer.
185                inner.opened_at = Some(now);
186            }
187        }
188    }
189
190    /// Reset to Closed with no recorded failures. Useful for explicit
191    /// operator intervention or tests.
192    pub fn reset(&self) {
193        if let Ok(mut inner) = self.inner.lock() {
194            inner.state = CircuitState::Closed;
195            inner.failures.clear();
196            inner.half_open_successes = 0;
197            inner.opened_at = None;
198        }
199    }
200
201    fn tick(&self, inner: &mut CircuitInner, now: Instant) {
202        match inner.state {
203            CircuitState::Open => {
204                if let Some(opened) = inner.opened_at {
205                    if now.duration_since(opened) >= self.config.reset_timeout {
206                        inner.state = CircuitState::HalfOpen;
207                        inner.half_open_successes = 0;
208                    }
209                }
210            }
211            CircuitState::Closed => {
212                self.drop_stale_failures(inner, now);
213            }
214            CircuitState::HalfOpen => {}
215        }
216    }
217
218    fn drop_stale_failures(&self, inner: &mut CircuitInner, now: Instant) {
219        let window = self.config.failure_window;
220        inner.failures.retain(|ts| now.duration_since(*ts) < window);
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    fn config(failure_threshold: u32) -> CircuitBreakerConfig {
229        CircuitBreakerConfig {
230            failure_threshold,
231            failure_window: Duration::from_secs(60),
232            success_threshold: 2,
233            reset_timeout: Duration::from_secs(10),
234        }
235    }
236
237    #[tokio::test(flavor = "current_thread", start_paused = true)]
238    async fn starts_closed() {
239        let cb = CircuitBreaker::new(config(5));
240        assert_eq!(cb.current_state(), CircuitState::Closed);
241        assert!(cb.allow_call());
242    }
243
244    #[tokio::test(flavor = "current_thread", start_paused = true)]
245    async fn opens_after_threshold_failures() {
246        let cb = CircuitBreaker::new(config(3));
247        cb.record_failure();
248        cb.record_failure();
249        assert_eq!(cb.current_state(), CircuitState::Closed);
250        cb.record_failure();
251        assert_eq!(cb.current_state(), CircuitState::Open);
252        assert!(!cb.allow_call());
253    }
254
255    #[tokio::test(flavor = "current_thread", start_paused = true)]
256    async fn transitions_to_half_open_after_reset_timeout() {
257        let cb = CircuitBreaker::new(config(2));
258        cb.record_failure();
259        cb.record_failure();
260        assert_eq!(cb.current_state(), CircuitState::Open);
261        tokio::time::advance(Duration::from_secs(11)).await;
262        assert_eq!(cb.current_state(), CircuitState::HalfOpen);
263        assert!(cb.allow_call());
264    }
265
266    #[tokio::test(flavor = "current_thread", start_paused = true)]
267    async fn half_open_closes_after_success_threshold() {
268        let cb = CircuitBreaker::new(config(2));
269        cb.record_failure();
270        cb.record_failure();
271        tokio::time::advance(Duration::from_secs(11)).await;
272        assert_eq!(cb.current_state(), CircuitState::HalfOpen);
273        cb.record_success();
274        cb.record_success();
275        assert_eq!(cb.current_state(), CircuitState::Closed);
276    }
277
278    #[tokio::test(flavor = "current_thread", start_paused = true)]
279    async fn half_open_failure_reopens() {
280        let cb = CircuitBreaker::new(config(2));
281        cb.record_failure();
282        cb.record_failure();
283        tokio::time::advance(Duration::from_secs(11)).await;
284        assert_eq!(cb.current_state(), CircuitState::HalfOpen);
285        cb.record_failure();
286        assert_eq!(cb.current_state(), CircuitState::Open);
287    }
288
289    #[tokio::test(flavor = "current_thread", start_paused = true)]
290    async fn stale_failures_are_forgotten() {
291        let cb = CircuitBreaker::new(CircuitBreakerConfig {
292            failure_threshold: 3,
293            failure_window: Duration::from_secs(5),
294            success_threshold: 1,
295            reset_timeout: Duration::from_secs(10),
296        });
297        cb.record_failure();
298        cb.record_failure();
299        tokio::time::advance(Duration::from_secs(6)).await;
300        cb.record_failure();
301        cb.record_failure();
302        // Only 2 failures inside the window -> still closed.
303        assert_eq!(cb.current_state(), CircuitState::Closed);
304    }
305}