1#![deny(missing_docs)]
18
19use std::sync::Mutex;
20use std::time::{Duration, Instant};
21
22#[derive(Debug, Clone)]
30pub struct CircuitBreakerConfig {
31 pub failure_threshold: u32,
33 pub success_threshold: u32,
35 pub initial_backoff: Duration,
37 pub max_backoff: Duration,
40}
41
42impl Default for CircuitBreakerConfig {
43 fn default() -> Self {
44 Self {
45 failure_threshold: 5,
46 success_threshold: 2,
47 initial_backoff: Duration::from_secs(10),
48 max_backoff: Duration::from_secs(120),
49 }
50 }
51}
52
53#[derive(Debug)]
54enum State {
55 Closed {
56 consecutive_failures: u32,
57 },
58 Open {
59 until: Instant,
60 backoff: Duration,
61 },
62 HalfOpen {
63 consecutive_successes: u32,
64 backoff: Duration,
65 },
66}
67
68#[derive(Debug, thiserror::Error)]
79pub enum CircuitError<E> {
80 #[error("circuit breaker `{0}` is open")]
83 Open(String),
84
85 #[error(transparent)]
87 Inner(E),
88}
89
90pub struct CircuitBreaker {
97 name: String,
98 config: CircuitBreakerConfig,
99 state: Mutex<State>,
100}
101
102fn lock_state(m: &Mutex<State>) -> std::sync::MutexGuard<'_, State> {
110 m.lock().unwrap_or_else(|poisoned| {
111 tracing::warn!("circuit breaker mutex was poisoned — recovering inner state");
112 poisoned.into_inner()
113 })
114}
115
116impl CircuitBreaker {
117 pub fn new(name: impl Into<String>, config: CircuitBreakerConfig) -> Self {
119 Self {
120 name: name.into(),
121 config,
122 state: Mutex::new(State::Closed {
123 consecutive_failures: 0,
124 }),
125 }
126 }
127
128 pub fn name(&self) -> &str {
130 &self.name
131 }
132
133 pub fn is_open(&self) -> bool {
135 let state = lock_state(&self.state);
136 matches!(&*state, State::Open { until, .. } if Instant::now() < *until)
137 }
138
139 pub fn allow(&self) -> bool {
143 let mut state = lock_state(&self.state);
144 match &*state {
145 State::Closed { .. } | State::HalfOpen { .. } => true,
146 State::Open { until, backoff } => {
147 if Instant::now() >= *until {
148 let backoff = *backoff;
149 *state = State::HalfOpen {
150 consecutive_successes: 0,
151 backoff,
152 };
153 tracing::info!(name = %self.name, "circuit breaker: open → half-open");
154 true
155 } else {
156 false
157 }
158 }
159 }
160 }
161
162 pub fn on_success(&self) {
165 let mut state = lock_state(&self.state);
166 match &mut *state {
167 State::Closed {
168 consecutive_failures,
169 } => {
170 *consecutive_failures = 0;
171 }
172 State::HalfOpen {
173 consecutive_successes,
174 ..
175 } => {
176 *consecutive_successes += 1;
177 if *consecutive_successes >= self.config.success_threshold {
178 tracing::info!(name = %self.name, "circuit breaker: half-open → closed");
179 *state = State::Closed {
180 consecutive_failures: 0,
181 };
182 }
183 }
184 State::Open { .. } => {}
185 }
186 }
187
188 pub fn on_failure(&self) {
194 let mut state = lock_state(&self.state);
195 match &*state {
196 State::Closed {
197 consecutive_failures,
198 } => {
199 let next = consecutive_failures + 1;
200 if next >= self.config.failure_threshold {
201 let backoff = self.config.initial_backoff;
202 tracing::warn!(name = %self.name, failures = next, "circuit breaker: closed → open");
203 *state = State::Open {
204 until: Instant::now() + backoff,
205 backoff,
206 };
207 } else {
208 *state = State::Closed {
209 consecutive_failures: next,
210 };
211 }
212 }
213 State::HalfOpen { backoff, .. } => {
214 let new_backoff = (*backoff * 2).min(self.config.max_backoff);
215 tracing::warn!(name = %self.name, backoff_secs = new_backoff.as_secs(), "circuit breaker: half-open → open (probe failed)");
216 *state = State::Open {
217 until: Instant::now() + new_backoff,
218 backoff: new_backoff,
219 };
220 }
221 State::Open { backoff, .. } => {
222 let new_backoff = (*backoff * 2).min(self.config.max_backoff);
223 *state = State::Open {
224 until: Instant::now() + new_backoff,
225 backoff: new_backoff,
226 };
227 }
228 }
229 }
230
231 pub fn trip(&self) {
235 let mut state = lock_state(&self.state);
236 let backoff = match &*state {
237 State::Open { backoff, .. } | State::HalfOpen { backoff, .. } => {
238 (*backoff * 2).min(self.config.max_backoff)
239 }
240 State::Closed { .. } => self.config.initial_backoff,
241 };
242 tracing::warn!(name = %self.name, backoff_secs = backoff.as_secs(), "circuit breaker tripped");
243 *state = State::Open {
244 until: Instant::now() + backoff,
245 backoff,
246 };
247 }
248
249 pub fn reset(&self) {
252 let mut state = lock_state(&self.state);
253 tracing::info!(name = %self.name, "circuit breaker reset");
254 *state = State::Closed {
255 consecutive_failures: 0,
256 };
257 }
258
259 pub async fn call<F, Fut, T, E>(&self, f: F) -> Result<T, CircuitError<E>>
262 where
263 F: FnOnce() -> Fut,
264 Fut: std::future::Future<Output = Result<T, E>>,
265 {
266 if !self.allow() {
267 return Err(CircuitError::Open(self.name.clone()));
268 }
269 match f().await {
270 Ok(v) => {
271 self.on_success();
272 Ok(v)
273 }
274 Err(e) => {
275 self.on_failure();
276 Err(CircuitError::Inner(e))
277 }
278 }
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285 use std::time::Duration;
286
287 fn fast_config() -> CircuitBreakerConfig {
288 CircuitBreakerConfig {
289 failure_threshold: 2,
290 success_threshold: 2,
291 initial_backoff: Duration::from_millis(20),
292 max_backoff: Duration::from_millis(200),
293 }
294 }
295
296 #[test]
297 fn opens_after_threshold() {
298 let cb = CircuitBreaker::new("t", fast_config());
299 assert!(cb.allow());
300 cb.on_failure();
301 assert!(cb.allow());
302 cb.on_failure();
303 assert!(!cb.allow());
304 assert!(cb.is_open());
305 }
306
307 #[test]
308 fn success_resets_failure_count() {
309 let cb = CircuitBreaker::new("t", fast_config());
310 cb.on_failure();
311 cb.on_success();
312 cb.on_failure(); assert!(cb.allow());
314 }
315
316 #[tokio::test]
317 async fn transitions_to_half_open_after_backoff() {
318 let cb = CircuitBreaker::new("t", fast_config());
319 cb.on_failure();
320 cb.on_failure();
321 assert!(cb.is_open());
322
323 tokio::time::sleep(Duration::from_millis(30)).await;
324 assert!(cb.allow());
326 assert!(!cb.is_open());
327 }
328
329 #[tokio::test]
330 async fn half_open_success_closes() {
331 let cb = CircuitBreaker::new("t", fast_config());
332 cb.on_failure();
333 cb.on_failure();
334 tokio::time::sleep(Duration::from_millis(30)).await;
335
336 assert!(cb.allow()); cb.on_success();
338 cb.on_success();
339 cb.on_failure();
341 assert!(cb.allow());
342 }
343
344 #[tokio::test]
345 async fn half_open_failure_reopens_with_bigger_backoff() {
346 let cb = CircuitBreaker::new("t", fast_config());
347 cb.on_failure();
348 cb.on_failure();
349 tokio::time::sleep(Duration::from_millis(30)).await;
350 assert!(cb.allow()); cb.on_failure(); tokio::time::sleep(Duration::from_millis(30)).await;
354 assert!(!cb.allow()); }
356
357 #[tokio::test]
358 async fn call_short_circuits_when_open() {
359 let cb = CircuitBreaker::new("t", fast_config());
360 cb.trip();
361 let result: Result<(), CircuitError<&str>> = cb.call(|| async { Ok::<(), &str>(()) }).await;
362 assert!(matches!(result, Err(CircuitError::Open(_))));
363 }
364
365 #[tokio::test]
366 async fn call_records_success_and_failure() {
367 let cb = CircuitBreaker::new("t", fast_config());
368 let _: Result<(), CircuitError<&str>> = cb.call(|| async { Err("boom") }).await;
369 let _: Result<(), CircuitError<&str>> = cb.call(|| async { Err("boom") }).await;
370 assert!(cb.is_open());
371 }
372
373 #[test]
374 fn trip_and_reset_are_idempotent() {
375 let cb = CircuitBreaker::new("t", fast_config());
376 cb.trip();
377 assert!(cb.is_open());
378 cb.reset();
379 assert!(!cb.is_open());
380 cb.reset();
381 assert!(!cb.is_open());
382 }
383
384 #[tokio::test]
385 async fn trip_while_open_doubles_backoff() {
386 let cb = CircuitBreaker::new("t", fast_config());
387 cb.trip(); tokio::time::sleep(Duration::from_millis(15)).await;
389 cb.trip(); assert!(cb.is_open());
392 tokio::time::sleep(Duration::from_millis(15)).await;
393 assert!(cb.is_open());
396 }
397
398 #[test]
399 fn on_failure_while_open_extends_backoff() {
400 let cb = CircuitBreaker::new("t", fast_config());
401 cb.trip();
402 assert!(cb.is_open());
403 cb.on_failure();
407 cb.on_failure();
408 assert!(cb.is_open());
409 }
410
411 #[test]
412 fn recovers_from_poisoned_mutex() {
413 use std::sync::Arc;
419 let cb = Arc::new(CircuitBreaker::new("t", fast_config()));
420 let cb2 = cb.clone();
421 let h = std::thread::spawn(move || {
422 let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
423 let _guard = cb2.state.lock().unwrap();
424 panic!("simulated panic while holding lock");
425 }));
426 });
427 h.join().unwrap();
428 assert!(!cb.is_open()); cb.on_failure();
431 cb.on_failure();
432 assert!(
433 cb.is_open(),
434 "breaker should have tripped despite poisoned mutex"
435 );
436 }
437}