reliability_toolkit/
circuit_breaker.rs1use std::sync::Arc;
19use std::time::Duration;
20
21use tokio::sync::Mutex;
22use tokio::time::Instant;
23
24use crate::error::ToolkitError;
25
26#[derive(Clone, Copy, Debug, PartialEq, Eq)]
28pub enum CircuitState {
29 Closed,
31 Open,
33 HalfOpen,
35}
36
37#[derive(Clone, Debug)]
39pub struct CircuitBreakerBuilder {
40 failure_threshold: u32,
41 cool_down: Duration,
42 half_open_max_calls: u32,
43}
44
45impl Default for CircuitBreakerBuilder {
46 fn default() -> Self {
47 Self {
48 failure_threshold: 5,
49 cool_down: Duration::from_secs(30),
50 half_open_max_calls: 1,
51 }
52 }
53}
54
55impl CircuitBreakerBuilder {
56 #[must_use]
58 pub fn failure_threshold(mut self, n: u32) -> Self {
59 assert!(n > 0, "failure_threshold must be > 0");
60 self.failure_threshold = n;
61 self
62 }
63
64 #[must_use]
66 pub fn cool_down(mut self, d: Duration) -> Self {
67 self.cool_down = d;
68 self
69 }
70
71 #[must_use]
73 pub fn half_open_max_calls(mut self, n: u32) -> Self {
74 assert!(n > 0, "half_open_max_calls must be > 0");
75 self.half_open_max_calls = n;
76 self
77 }
78
79 #[must_use]
81 pub fn build(self) -> CircuitBreaker {
82 CircuitBreaker {
83 inner: Arc::new(Inner {
84 cfg: self,
85 state: Mutex::new(StateMachine {
86 state: CircuitState::Closed,
87 consecutive_failures: 0,
88 opened_at: None,
89 half_open_inflight: 0,
90 half_open_successes: 0,
91 }),
92 }),
93 }
94 }
95}
96
97#[derive(Clone, Debug)]
99pub struct CircuitBreaker {
100 inner: Arc<Inner>,
101}
102
103#[derive(Debug)]
104struct Inner {
105 cfg: CircuitBreakerBuilder,
106 state: Mutex<StateMachine>,
107}
108
109#[derive(Debug)]
110struct StateMachine {
111 state: CircuitState,
112 consecutive_failures: u32,
113 opened_at: Option<Instant>,
114 half_open_inflight: u32,
115 half_open_successes: u32,
116}
117
118impl CircuitBreaker {
119 pub fn new() -> Self {
121 Self::builder().build()
122 }
123
124 pub fn builder() -> CircuitBreakerBuilder {
126 CircuitBreakerBuilder::default()
127 }
128
129 pub async fn state(&self) -> CircuitState {
131 let mut sm = self.inner.state.lock().await;
132 self.tick(&mut sm);
133 sm.state
134 }
135
136 pub async fn call<F, T, E>(&self, fut: F) -> Result<Result<T, E>, ToolkitError>
146 where
147 F: std::future::Future<Output = Result<T, E>>,
148 {
149 let admitted = {
151 let mut sm = self.inner.state.lock().await;
152 self.tick(&mut sm);
153 match sm.state {
154 CircuitState::Closed => true,
155 CircuitState::HalfOpen => {
156 if sm.half_open_inflight < self.inner.cfg.half_open_max_calls {
157 sm.half_open_inflight += 1;
158 true
159 } else {
160 false
161 }
162 }
163 CircuitState::Open => false,
164 }
165 };
166
167 if !admitted {
168 let retry_after = self.retry_after().await;
169 return Err(ToolkitError::CircuitOpen { retry_after });
170 }
171
172 let result = fut.await;
174
175 {
177 let mut sm = self.inner.state.lock().await;
178 match (&result, sm.state) {
179 (Ok(_), CircuitState::Closed) => {
180 sm.consecutive_failures = 0;
181 }
182 (Ok(_), CircuitState::HalfOpen) => {
183 sm.half_open_inflight = sm.half_open_inflight.saturating_sub(1);
184 sm.half_open_successes += 1;
185 if sm.half_open_successes >= self.inner.cfg.half_open_max_calls {
186 sm.state = CircuitState::Closed;
188 sm.consecutive_failures = 0;
189 sm.opened_at = None;
190 sm.half_open_inflight = 0;
191 sm.half_open_successes = 0;
192 }
193 }
194 (Err(_), CircuitState::Closed) => {
195 sm.consecutive_failures += 1;
196 if sm.consecutive_failures >= self.inner.cfg.failure_threshold {
197 sm.state = CircuitState::Open;
198 sm.opened_at = Some(Instant::now());
199 }
200 }
201 (Err(_), CircuitState::HalfOpen) => {
202 sm.state = CircuitState::Open;
203 sm.opened_at = Some(Instant::now());
204 sm.half_open_inflight = 0;
205 sm.half_open_successes = 0;
206 }
207 (_, CircuitState::Open) => {
208 }
211 }
212 }
213
214 Ok(result)
215 }
216
217 pub async fn trip(&self) {
219 let mut sm = self.inner.state.lock().await;
220 sm.state = CircuitState::Open;
221 sm.opened_at = Some(Instant::now());
222 sm.half_open_inflight = 0;
223 sm.half_open_successes = 0;
224 }
225
226 pub async fn reset(&self) {
228 let mut sm = self.inner.state.lock().await;
229 sm.state = CircuitState::Closed;
230 sm.consecutive_failures = 0;
231 sm.opened_at = None;
232 sm.half_open_inflight = 0;
233 sm.half_open_successes = 0;
234 }
235
236 fn tick(&self, sm: &mut StateMachine) {
237 if sm.state == CircuitState::Open {
238 if let Some(t) = sm.opened_at {
239 if Instant::now().duration_since(t) >= self.inner.cfg.cool_down {
240 sm.state = CircuitState::HalfOpen;
241 sm.half_open_inflight = 0;
242 sm.half_open_successes = 0;
243 }
244 }
245 }
246 }
247
248 async fn retry_after(&self) -> Duration {
249 let sm = self.inner.state.lock().await;
250 match sm.opened_at {
251 Some(t) => self
252 .inner
253 .cfg
254 .cool_down
255 .checked_sub(Instant::now().duration_since(t))
256 .unwrap_or_else(|| Duration::from_secs(0)),
257 None => Duration::from_secs(0),
258 }
259 }
260}
261
262impl Default for CircuitBreaker {
263 fn default() -> Self {
264 Self::new()
265 }
266}