atomr_core/pattern/
circuit_breaker.rs1use std::future::Future;
4use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
5use std::sync::Arc;
6use std::time::Duration;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9#[non_exhaustive]
10pub enum CircuitBreakerState {
11 Closed,
12 Open,
13 HalfOpen,
14}
15
16pub struct CircuitBreaker {
17 max_failures: u32,
18 call_timeout: Duration,
19 reset_timeout: Duration,
20 failures: AtomicU32,
21 opened_at_ns: AtomicU64,
22 state: AtomicU32,
24}
25
26impl CircuitBreaker {
27 pub fn new(max_failures: u32, call_timeout: Duration, reset_timeout: Duration) -> Arc<Self> {
28 Arc::new(Self {
29 max_failures,
30 call_timeout,
31 reset_timeout,
32 failures: AtomicU32::new(0),
33 opened_at_ns: AtomicU64::new(0),
34 state: AtomicU32::new(0),
35 })
36 }
37
38 pub fn state(&self) -> CircuitBreakerState {
39 match self.state.load(Ordering::Acquire) {
40 0 => CircuitBreakerState::Closed,
41 1 => {
42 let now_ns = self.elapsed_ns();
47 let opened_ns = self.opened_at_ns.load(Ordering::Acquire);
48 if opened_ns > 0 && now_ns.saturating_sub(opened_ns) >= self.reset_timeout.as_nanos() as u64 {
49 CircuitBreakerState::HalfOpen
50 } else {
51 CircuitBreakerState::Open
52 }
53 }
54 _ => CircuitBreakerState::HalfOpen,
55 }
56 }
57
58 fn elapsed_ns(&self) -> u64 {
59 std::time::SystemTime::now()
63 .duration_since(std::time::UNIX_EPOCH)
64 .map(|d| d.as_nanos() as u64)
65 .unwrap_or(0)
66 }
67
68 pub async fn call<F, Fut, T, E>(&self, f: F) -> Result<T, CircuitBreakerError<E>>
69 where
70 F: FnOnce() -> Fut,
71 Fut: Future<Output = Result<T, E>>,
72 {
73 let st = self.state.load(Ordering::Acquire);
74 if st == 1 {
75 return Err(CircuitBreakerError::Open);
76 }
77 let res = tokio::time::timeout(self.call_timeout, f()).await;
78 match res {
79 Ok(Ok(v)) => {
80 self.failures.store(0, Ordering::Release);
81 self.state.store(0, Ordering::Release);
82 Ok(v)
83 }
84 Ok(Err(e)) => {
85 self.record_failure();
86 Err(CircuitBreakerError::Inner(e))
87 }
88 Err(_) => {
89 self.record_failure();
90 Err(CircuitBreakerError::Timeout)
91 }
92 }
93 }
94
95 fn record_failure(&self) {
96 let n = self.failures.fetch_add(1, Ordering::AcqRel) + 1;
97 if n >= self.max_failures {
98 self.state.store(1, Ordering::Release);
99 self.opened_at_ns.store(self.elapsed_ns(), Ordering::Release);
100 }
101 }
102}
103
104#[derive(Debug, thiserror::Error)]
105#[non_exhaustive]
106pub enum CircuitBreakerError<E> {
107 #[error("circuit breaker is open")]
108 Open,
109 #[error("call timed out")]
110 Timeout,
111 #[error(transparent)]
112 Inner(E),
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 #[tokio::test]
120 async fn opens_after_max_failures() {
121 let cb = CircuitBreaker::new(2, Duration::from_secs(1), Duration::from_secs(1));
122 for _ in 0..2 {
123 let _ = cb.call(|| async { Err::<(), _>(1) }).await;
124 }
125 let res: Result<(), _> = cb.call(|| async { Ok::<(), u32>(()) }).await;
126 assert!(matches!(res, Err(CircuitBreakerError::Open)));
127 }
128}