mx_core/resilience/
circuit_breaker.rs1use std::time::Duration;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7#[repr(u8)]
8pub enum CircuitState {
9 Closed = 0,
11 Open = 1,
13 HalfOpen = 2,
15}
16
17impl From<u8> for CircuitState {
18 fn from(value: u8) -> Self {
19 match value {
20 0 => Self::Closed,
21 1 => Self::Open,
22 2 => Self::HalfOpen,
23 _ => Self::Closed,
24 }
25 }
26}
27
28#[derive(Debug, Clone)]
30pub struct CircuitBreakerConfig {
31 pub failure_threshold: u32,
33 pub recovery_timeout: Duration,
35 pub success_threshold: u32,
37}
38
39impl Default for CircuitBreakerConfig {
40 fn default() -> Self {
41 Self {
42 failure_threshold: 5,
43 recovery_timeout: Duration::from_secs(30),
44 success_threshold: 1,
45 }
46 }
47}
48
49impl CircuitBreakerConfig {
50 pub fn new(failure_threshold: u32, recovery_timeout: Duration) -> Self {
52 Self {
53 failure_threshold,
54 recovery_timeout,
55 success_threshold: 1,
56 }
57 }
58
59 pub fn with_success_threshold(mut self, threshold: u32) -> Self {
61 self.success_threshold = threshold;
62 self
63 }
64}
65
66pub trait CircuitBreaker: Send + Sync {
74 fn state(&self) -> CircuitState;
76
77 fn is_open(&self) -> bool {
79 self.state() == CircuitState::Open
80 }
81
82 fn allows_request(&self) -> bool {
84 self.state() != CircuitState::Open
85 }
86
87 fn record_success(&self);
89
90 fn record_failure(&self);
92
93 fn trip(&self);
95
96 fn reset(&self);
98
99 fn time_until_half_open(&self) -> Option<Duration>;
102
103 fn failure_count(&self) -> u32;
105}
106
107pub struct AtomicCircuitBreaker {
109 pub state_val: std::sync::atomic::AtomicU8,
110 failure_count_val: std::sync::atomic::AtomicU32,
111 success_count_val: std::sync::atomic::AtomicU32,
112 last_failure: parking_lot::Mutex<Option<std::time::Instant>>,
113 config: CircuitBreakerConfig,
114}
115
116impl AtomicCircuitBreaker {
117 pub fn new(config: CircuitBreakerConfig) -> Self {
119 Self {
120 state_val: std::sync::atomic::AtomicU8::new(CircuitState::Closed as u8),
121 failure_count_val: std::sync::atomic::AtomicU32::new(0),
122 success_count_val: std::sync::atomic::AtomicU32::new(0),
123 last_failure: parking_lot::Mutex::new(None),
124 config,
125 }
126 }
127
128 pub fn with_defaults() -> Self {
130 Self::new(CircuitBreakerConfig::default())
131 }
132}
133
134impl CircuitBreaker for AtomicCircuitBreaker {
135 fn state(&self) -> CircuitState {
136 use std::sync::atomic::Ordering;
137
138 let state = CircuitState::from(self.state_val.load(Ordering::Acquire));
139
140 if state == CircuitState::Open
142 && let Some(last) = *self.last_failure.lock()
143 && last.elapsed() >= self.config.recovery_timeout
144 && self
145 .state_val
146 .compare_exchange(
147 CircuitState::Open as u8,
148 CircuitState::HalfOpen as u8,
149 Ordering::AcqRel,
150 Ordering::Acquire,
151 )
152 .is_ok()
153 {
154 self.success_count_val.store(0, Ordering::Release);
155 return CircuitState::HalfOpen;
156 }
157
158 state
159 }
160
161 fn record_success(&self) {
162 use std::sync::atomic::Ordering;
163
164 let state = CircuitState::from(self.state_val.load(Ordering::Acquire));
165
166 match state {
167 CircuitState::Closed => {
168 self.failure_count_val.store(0, Ordering::Release);
169 }
170 CircuitState::HalfOpen => {
171 let successes = self.success_count_val.fetch_add(1, Ordering::AcqRel) + 1;
172 if successes >= self.config.success_threshold {
173 self.state_val
174 .store(CircuitState::Closed as u8, Ordering::Release);
175 self.failure_count_val.store(0, Ordering::Release);
176 self.success_count_val.store(0, Ordering::Release);
177 *self.last_failure.lock() = None;
178 }
179 }
180 CircuitState::Open => {}
181 }
182 }
183
184 fn record_failure(&self) {
185 use std::sync::atomic::Ordering;
186
187 *self.last_failure.lock() = Some(std::time::Instant::now());
188
189 let state = CircuitState::from(self.state_val.load(Ordering::Acquire));
190
191 match state {
192 CircuitState::Closed => {
193 let count = self.failure_count_val.fetch_add(1, Ordering::AcqRel) + 1;
194 if count >= self.config.failure_threshold {
195 self.state_val
196 .store(CircuitState::Open as u8, Ordering::Release);
197 }
198 }
199 CircuitState::HalfOpen => {
200 self.state_val
201 .store(CircuitState::Open as u8, Ordering::Release);
202 self.success_count_val.store(0, Ordering::Release);
203 }
204 CircuitState::Open => {}
205 }
206 }
207
208 fn trip(&self) {
209 use std::sync::atomic::Ordering;
210 self.state_val
211 .store(CircuitState::Open as u8, Ordering::Release);
212 *self.last_failure.lock() = Some(std::time::Instant::now());
213 }
214
215 fn reset(&self) {
216 use std::sync::atomic::Ordering;
217 self.state_val
218 .store(CircuitState::Closed as u8, Ordering::Release);
219 self.failure_count_val.store(0, Ordering::Release);
220 self.success_count_val.store(0, Ordering::Release);
221 *self.last_failure.lock() = None;
222 }
223
224 fn time_until_half_open(&self) -> Option<Duration> {
225 use std::sync::atomic::Ordering;
226
227 let state = CircuitState::from(self.state_val.load(Ordering::Acquire));
228 if state != CircuitState::Open {
229 return None;
230 }
231
232 let last = (*self.last_failure.lock())?;
233 let elapsed = last.elapsed();
234
235 if elapsed >= self.config.recovery_timeout {
236 Some(Duration::ZERO)
237 } else {
238 Some(self.config.recovery_timeout - elapsed)
239 }
240 }
241
242 fn failure_count(&self) -> u32 {
243 self.failure_count_val
244 .load(std::sync::atomic::Ordering::Acquire)
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251 use std::thread;
252
253 #[test]
254 fn test_starts_closed() {
255 let cb = AtomicCircuitBreaker::with_defaults();
256 assert_eq!(cb.state(), CircuitState::Closed);
257 assert!(cb.allows_request());
258 }
259
260 #[test]
261 fn test_opens_after_threshold() {
262 let config = CircuitBreakerConfig::new(3, Duration::from_secs(30));
263 let cb = AtomicCircuitBreaker::new(config);
264
265 cb.record_failure();
266 cb.record_failure();
267 assert_eq!(cb.state(), CircuitState::Closed);
268
269 cb.record_failure();
270 assert_eq!(cb.state(), CircuitState::Open);
271 assert!(!cb.allows_request());
272 }
273
274 #[test]
275 fn test_transitions_to_half_open() {
276 let config = CircuitBreakerConfig::new(1, Duration::from_millis(10));
277 let cb = AtomicCircuitBreaker::new(config);
278
279 cb.record_failure();
280 assert_eq!(cb.state(), CircuitState::Open);
281
282 thread::sleep(Duration::from_millis(15));
283 assert_eq!(cb.state(), CircuitState::HalfOpen);
284 }
285
286 #[test]
287 fn test_closes_on_success_in_half_open() {
288 let config = CircuitBreakerConfig::new(1, Duration::from_millis(10));
289 let cb = AtomicCircuitBreaker::new(config);
290
291 cb.record_failure();
292 thread::sleep(Duration::from_millis(15));
293 assert_eq!(cb.state(), CircuitState::HalfOpen);
294
295 cb.record_success();
296 assert_eq!(cb.state(), CircuitState::Closed);
297 }
298
299 #[test]
300 fn test_success_threshold() {
301 let config =
302 CircuitBreakerConfig::new(1, Duration::from_millis(10)).with_success_threshold(3);
303 let cb = AtomicCircuitBreaker::new(config);
304
305 cb.record_failure();
306 thread::sleep(Duration::from_millis(15));
307 assert_eq!(cb.state(), CircuitState::HalfOpen);
308
309 cb.record_success();
310 assert_eq!(cb.state(), CircuitState::HalfOpen); cb.record_success();
313 assert_eq!(cb.state(), CircuitState::HalfOpen); cb.record_success();
316 assert_eq!(cb.state(), CircuitState::Closed); }
318
319 #[test]
320 fn test_trip_and_reset() {
321 let cb = AtomicCircuitBreaker::with_defaults();
322 cb.trip();
323 assert_eq!(cb.state(), CircuitState::Open);
324
325 cb.reset();
326 assert_eq!(cb.state(), CircuitState::Closed);
327 }
328
329 #[test]
330 fn test_thread_safety() {
331 use std::sync::Arc;
332
333 let config = CircuitBreakerConfig::new(100, Duration::from_secs(30));
334 let cb = Arc::new(AtomicCircuitBreaker::new(config));
335
336 let handles: Vec<_> = (0..10)
337 .map(|_| {
338 let cb = Arc::clone(&cb);
339 thread::spawn(move || {
340 for _ in 0..10 {
341 cb.record_failure();
342 }
343 })
344 })
345 .collect();
346
347 for h in handles {
348 h.join().unwrap();
349 }
350
351 assert_eq!(cb.state(), CircuitState::Open);
352 }
353}