Skip to main content

mx_core/resilience/
circuit_breaker.rs

1//! Circuit breaker trait and state types for fault tolerance.
2
3use std::time::Duration;
4
5/// Circuit breaker states.
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7#[repr(u8)]
8pub enum CircuitState {
9    /// Circuit is closed, requests flow normally.
10    Closed = 0,
11    /// Circuit is open, requests are rejected immediately.
12    Open = 1,
13    /// Circuit is testing, allowing a single request through.
14    HalfOpen = 2,
15}
16
17impl From<u8> for CircuitState {
18    fn from(value: u8) -> Self {
19        match value {
20            0 => Self::Closed,
21            1 => Self::Open,
22            2 => Self::HalfOpen,
23            _ => Self::Closed,
24        }
25    }
26}
27
28/// Configuration for the circuit breaker.
29#[derive(Debug, Clone)]
30pub struct CircuitBreakerConfig {
31    /// Number of consecutive failures before opening the circuit.
32    pub failure_threshold: u32,
33    /// Duration to wait before transitioning from Open to HalfOpen.
34    pub recovery_timeout: Duration,
35    /// Number of successes required in half-open state to close (default: 1).
36    pub success_threshold: u32,
37}
38
39impl Default for CircuitBreakerConfig {
40    fn default() -> Self {
41        Self {
42            failure_threshold: 5,
43            recovery_timeout: Duration::from_secs(30),
44            success_threshold: 1,
45        }
46    }
47}
48
49impl CircuitBreakerConfig {
50    /// Creates a new config with the specified threshold and timeout.
51    pub fn new(failure_threshold: u32, recovery_timeout: Duration) -> Self {
52        Self {
53            failure_threshold,
54            recovery_timeout,
55            success_threshold: 1,
56        }
57    }
58
59    /// Sets the number of successes needed in half-open state to close.
60    pub fn with_success_threshold(mut self, threshold: u32) -> Self {
61        self.success_threshold = threshold;
62        self
63    }
64}
65
66/// Trait for circuit breaker pattern implementation.
67///
68/// The circuit breaker prevents cascading failures by:
69/// 1. Tracking failure rates
70/// 2. Opening the circuit when failures exceed threshold
71/// 3. Periodically testing if the service has recovered
72/// 4. Closing the circuit when recovery is confirmed
73pub trait CircuitBreaker: Send + Sync {
74    /// Returns the current circuit state.
75    fn state(&self) -> CircuitState;
76
77    /// Returns whether the circuit is currently open (rejecting requests).
78    fn is_open(&self) -> bool {
79        self.state() == CircuitState::Open
80    }
81
82    /// Returns whether the circuit allows requests through.
83    fn allows_request(&self) -> bool {
84        self.state() != CircuitState::Open
85    }
86
87    /// Records a successful operation.
88    fn record_success(&self);
89
90    /// Records a failed operation.
91    fn record_failure(&self);
92
93    /// Manually opens the circuit.
94    fn trip(&self);
95
96    /// Manually closes the circuit.
97    fn reset(&self);
98
99    /// Returns the time until the circuit will transition to half-open.
100    /// Returns `None` if the circuit is not open.
101    fn time_until_half_open(&self) -> Option<Duration>;
102
103    /// Returns the current failure count.
104    fn failure_count(&self) -> u32;
105}
106
107/// Thread-safe circuit breaker using atomic operations.
108pub struct AtomicCircuitBreaker {
109    pub state_val: std::sync::atomic::AtomicU8,
110    failure_count_val: std::sync::atomic::AtomicU32,
111    success_count_val: std::sync::atomic::AtomicU32,
112    last_failure: parking_lot::Mutex<Option<std::time::Instant>>,
113    config: CircuitBreakerConfig,
114}
115
116impl AtomicCircuitBreaker {
117    /// Creates a new circuit breaker with the given configuration.
118    pub fn new(config: CircuitBreakerConfig) -> Self {
119        Self {
120            state_val: std::sync::atomic::AtomicU8::new(CircuitState::Closed as u8),
121            failure_count_val: std::sync::atomic::AtomicU32::new(0),
122            success_count_val: std::sync::atomic::AtomicU32::new(0),
123            last_failure: parking_lot::Mutex::new(None),
124            config,
125        }
126    }
127
128    /// Creates a circuit breaker with default configuration.
129    pub fn with_defaults() -> Self {
130        Self::new(CircuitBreakerConfig::default())
131    }
132}
133
134impl CircuitBreaker for AtomicCircuitBreaker {
135    fn state(&self) -> CircuitState {
136        use std::sync::atomic::Ordering;
137
138        let state = CircuitState::from(self.state_val.load(Ordering::Acquire));
139
140        // Check if we should transition from Open to HalfOpen
141        if state == CircuitState::Open
142            && let Some(last) = *self.last_failure.lock()
143            && last.elapsed() >= self.config.recovery_timeout
144            && self
145                .state_val
146                .compare_exchange(
147                    CircuitState::Open as u8,
148                    CircuitState::HalfOpen as u8,
149                    Ordering::AcqRel,
150                    Ordering::Acquire,
151                )
152                .is_ok()
153        {
154            self.success_count_val.store(0, Ordering::Release);
155            return CircuitState::HalfOpen;
156        }
157
158        state
159    }
160
161    fn record_success(&self) {
162        use std::sync::atomic::Ordering;
163
164        let state = CircuitState::from(self.state_val.load(Ordering::Acquire));
165
166        match state {
167            CircuitState::Closed => {
168                self.failure_count_val.store(0, Ordering::Release);
169            }
170            CircuitState::HalfOpen => {
171                let successes = self.success_count_val.fetch_add(1, Ordering::AcqRel) + 1;
172                if successes >= self.config.success_threshold {
173                    self.state_val
174                        .store(CircuitState::Closed as u8, Ordering::Release);
175                    self.failure_count_val.store(0, Ordering::Release);
176                    self.success_count_val.store(0, Ordering::Release);
177                    *self.last_failure.lock() = None;
178                }
179            }
180            CircuitState::Open => {}
181        }
182    }
183
184    fn record_failure(&self) {
185        use std::sync::atomic::Ordering;
186
187        *self.last_failure.lock() = Some(std::time::Instant::now());
188
189        let state = CircuitState::from(self.state_val.load(Ordering::Acquire));
190
191        match state {
192            CircuitState::Closed => {
193                let count = self.failure_count_val.fetch_add(1, Ordering::AcqRel) + 1;
194                if count >= self.config.failure_threshold {
195                    self.state_val
196                        .store(CircuitState::Open as u8, Ordering::Release);
197                }
198            }
199            CircuitState::HalfOpen => {
200                self.state_val
201                    .store(CircuitState::Open as u8, Ordering::Release);
202                self.success_count_val.store(0, Ordering::Release);
203            }
204            CircuitState::Open => {}
205        }
206    }
207
208    fn trip(&self) {
209        use std::sync::atomic::Ordering;
210        self.state_val
211            .store(CircuitState::Open as u8, Ordering::Release);
212        *self.last_failure.lock() = Some(std::time::Instant::now());
213    }
214
215    fn reset(&self) {
216        use std::sync::atomic::Ordering;
217        self.state_val
218            .store(CircuitState::Closed as u8, Ordering::Release);
219        self.failure_count_val.store(0, Ordering::Release);
220        self.success_count_val.store(0, Ordering::Release);
221        *self.last_failure.lock() = None;
222    }
223
224    fn time_until_half_open(&self) -> Option<Duration> {
225        use std::sync::atomic::Ordering;
226
227        let state = CircuitState::from(self.state_val.load(Ordering::Acquire));
228        if state != CircuitState::Open {
229            return None;
230        }
231
232        let last = (*self.last_failure.lock())?;
233        let elapsed = last.elapsed();
234
235        if elapsed >= self.config.recovery_timeout {
236            Some(Duration::ZERO)
237        } else {
238            Some(self.config.recovery_timeout - elapsed)
239        }
240    }
241
242    fn failure_count(&self) -> u32 {
243        self.failure_count_val
244            .load(std::sync::atomic::Ordering::Acquire)
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251    use std::thread;
252
253    #[test]
254    fn test_starts_closed() {
255        let cb = AtomicCircuitBreaker::with_defaults();
256        assert_eq!(cb.state(), CircuitState::Closed);
257        assert!(cb.allows_request());
258    }
259
260    #[test]
261    fn test_opens_after_threshold() {
262        let config = CircuitBreakerConfig::new(3, Duration::from_secs(30));
263        let cb = AtomicCircuitBreaker::new(config);
264
265        cb.record_failure();
266        cb.record_failure();
267        assert_eq!(cb.state(), CircuitState::Closed);
268
269        cb.record_failure();
270        assert_eq!(cb.state(), CircuitState::Open);
271        assert!(!cb.allows_request());
272    }
273
274    #[test]
275    fn test_transitions_to_half_open() {
276        let config = CircuitBreakerConfig::new(1, Duration::from_millis(10));
277        let cb = AtomicCircuitBreaker::new(config);
278
279        cb.record_failure();
280        assert_eq!(cb.state(), CircuitState::Open);
281
282        thread::sleep(Duration::from_millis(15));
283        assert_eq!(cb.state(), CircuitState::HalfOpen);
284    }
285
286    #[test]
287    fn test_closes_on_success_in_half_open() {
288        let config = CircuitBreakerConfig::new(1, Duration::from_millis(10));
289        let cb = AtomicCircuitBreaker::new(config);
290
291        cb.record_failure();
292        thread::sleep(Duration::from_millis(15));
293        assert_eq!(cb.state(), CircuitState::HalfOpen);
294
295        cb.record_success();
296        assert_eq!(cb.state(), CircuitState::Closed);
297    }
298
299    #[test]
300    fn test_success_threshold() {
301        let config =
302            CircuitBreakerConfig::new(1, Duration::from_millis(10)).with_success_threshold(3);
303        let cb = AtomicCircuitBreaker::new(config);
304
305        cb.record_failure();
306        thread::sleep(Duration::from_millis(15));
307        assert_eq!(cb.state(), CircuitState::HalfOpen);
308
309        cb.record_success();
310        assert_eq!(cb.state(), CircuitState::HalfOpen); // Need 3 successes
311
312        cb.record_success();
313        assert_eq!(cb.state(), CircuitState::HalfOpen); // Need 1 more
314
315        cb.record_success();
316        assert_eq!(cb.state(), CircuitState::Closed); // Done
317    }
318
319    #[test]
320    fn test_trip_and_reset() {
321        let cb = AtomicCircuitBreaker::with_defaults();
322        cb.trip();
323        assert_eq!(cb.state(), CircuitState::Open);
324
325        cb.reset();
326        assert_eq!(cb.state(), CircuitState::Closed);
327    }
328
329    #[test]
330    fn test_thread_safety() {
331        use std::sync::Arc;
332
333        let config = CircuitBreakerConfig::new(100, Duration::from_secs(30));
334        let cb = Arc::new(AtomicCircuitBreaker::new(config));
335
336        let handles: Vec<_> = (0..10)
337            .map(|_| {
338                let cb = Arc::clone(&cb);
339                thread::spawn(move || {
340                    for _ in 0..10 {
341                        cb.record_failure();
342                    }
343                })
344            })
345            .collect();
346
347        for h in handles {
348            h.join().unwrap();
349        }
350
351        assert_eq!(cb.state(), CircuitState::Open);
352    }
353}