chio_guards/external/
circuit_breaker.rs1use std::sync::Arc;
20use std::sync::Mutex;
21use std::time::Duration;
22
23use tokio::time::Instant;
24
25use super::cache::{Clock, TokioClock};
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub enum CircuitState {
32 Closed,
34 Open,
36 HalfOpen,
39}
40
41#[derive(Debug, Clone)]
43pub struct CircuitBreakerConfig {
44 pub failure_threshold: u32,
46 pub failure_window: Duration,
49 pub success_threshold: u32,
52 pub reset_timeout: Duration,
54}
55
56impl Default for CircuitBreakerConfig {
57 fn default() -> Self {
58 Self {
59 failure_threshold: 5,
60 failure_window: Duration::from_secs(60),
61 success_threshold: 2,
62 reset_timeout: Duration::from_secs(30),
63 }
64 }
65}
66
67pub struct CircuitBreaker {
69 inner: Mutex<CircuitInner>,
70 config: CircuitBreakerConfig,
71 clock: Arc<dyn Clock>,
72}
73
74#[derive(Debug)]
75struct CircuitInner {
76 state: CircuitState,
77 failures: Vec<Instant>,
79 half_open_successes: u32,
81 opened_at: Option<Instant>,
83}
84
85impl CircuitBreaker {
86 pub fn new(config: CircuitBreakerConfig) -> Self {
89 Self::with_clock(config, Arc::new(TokioClock))
90 }
91
92 pub fn with_clock(config: CircuitBreakerConfig, clock: Arc<dyn Clock>) -> Self {
94 Self {
95 inner: Mutex::new(CircuitInner {
96 state: CircuitState::Closed,
97 failures: Vec::new(),
98 half_open_successes: 0,
99 opened_at: None,
100 }),
101 config,
102 clock,
103 }
104 }
105
106 pub fn config(&self) -> &CircuitBreakerConfig {
108 &self.config
109 }
110
111 pub fn current_state(&self) -> CircuitState {
114 let now = self.clock.now();
115 let Ok(mut inner) = self.inner.lock() else {
116 return CircuitState::Open;
117 };
118 self.tick(&mut inner, now);
119 inner.state
120 }
121
122 pub fn allow_call(&self) -> bool {
126 let now = self.clock.now();
127 let Ok(mut inner) = self.inner.lock() else {
128 return false;
129 };
130 self.tick(&mut inner, now);
131 !matches!(inner.state, CircuitState::Open)
132 }
133
134 pub fn record_success(&self) {
136 let now = self.clock.now();
137 let Ok(mut inner) = self.inner.lock() else {
138 return;
139 };
140 self.tick(&mut inner, now);
141 match inner.state {
142 CircuitState::Closed => {
143 inner.failures.clear();
144 }
145 CircuitState::HalfOpen => {
146 inner.half_open_successes = inner.half_open_successes.saturating_add(1);
147 if inner.half_open_successes >= self.config.success_threshold {
148 inner.state = CircuitState::Closed;
149 inner.failures.clear();
150 inner.half_open_successes = 0;
151 inner.opened_at = None;
152 }
153 }
154 CircuitState::Open => {
155 }
158 }
159 }
160
161 pub fn record_failure(&self) {
163 let now = self.clock.now();
164 let Ok(mut inner) = self.inner.lock() else {
165 return;
166 };
167 self.tick(&mut inner, now);
168 match inner.state {
169 CircuitState::Closed => {
170 inner.failures.push(now);
171 self.drop_stale_failures(&mut inner, now);
172 if inner.failures.len() as u32 >= self.config.failure_threshold {
173 inner.state = CircuitState::Open;
174 inner.opened_at = Some(now);
175 inner.failures.clear();
176 }
177 }
178 CircuitState::HalfOpen => {
179 inner.state = CircuitState::Open;
180 inner.opened_at = Some(now);
181 inner.half_open_successes = 0;
182 }
183 CircuitState::Open => {
184 inner.opened_at = Some(now);
186 }
187 }
188 }
189
190 pub fn reset(&self) {
193 if let Ok(mut inner) = self.inner.lock() {
194 inner.state = CircuitState::Closed;
195 inner.failures.clear();
196 inner.half_open_successes = 0;
197 inner.opened_at = None;
198 }
199 }
200
201 fn tick(&self, inner: &mut CircuitInner, now: Instant) {
202 match inner.state {
203 CircuitState::Open => {
204 if let Some(opened) = inner.opened_at {
205 if now.duration_since(opened) >= self.config.reset_timeout {
206 inner.state = CircuitState::HalfOpen;
207 inner.half_open_successes = 0;
208 }
209 }
210 }
211 CircuitState::Closed => {
212 self.drop_stale_failures(inner, now);
213 }
214 CircuitState::HalfOpen => {}
215 }
216 }
217
218 fn drop_stale_failures(&self, inner: &mut CircuitInner, now: Instant) {
219 let window = self.config.failure_window;
220 inner.failures.retain(|ts| now.duration_since(*ts) < window);
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 fn config(failure_threshold: u32) -> CircuitBreakerConfig {
229 CircuitBreakerConfig {
230 failure_threshold,
231 failure_window: Duration::from_secs(60),
232 success_threshold: 2,
233 reset_timeout: Duration::from_secs(10),
234 }
235 }
236
237 #[tokio::test(flavor = "current_thread", start_paused = true)]
238 async fn starts_closed() {
239 let cb = CircuitBreaker::new(config(5));
240 assert_eq!(cb.current_state(), CircuitState::Closed);
241 assert!(cb.allow_call());
242 }
243
244 #[tokio::test(flavor = "current_thread", start_paused = true)]
245 async fn opens_after_threshold_failures() {
246 let cb = CircuitBreaker::new(config(3));
247 cb.record_failure();
248 cb.record_failure();
249 assert_eq!(cb.current_state(), CircuitState::Closed);
250 cb.record_failure();
251 assert_eq!(cb.current_state(), CircuitState::Open);
252 assert!(!cb.allow_call());
253 }
254
255 #[tokio::test(flavor = "current_thread", start_paused = true)]
256 async fn transitions_to_half_open_after_reset_timeout() {
257 let cb = CircuitBreaker::new(config(2));
258 cb.record_failure();
259 cb.record_failure();
260 assert_eq!(cb.current_state(), CircuitState::Open);
261 tokio::time::advance(Duration::from_secs(11)).await;
262 assert_eq!(cb.current_state(), CircuitState::HalfOpen);
263 assert!(cb.allow_call());
264 }
265
266 #[tokio::test(flavor = "current_thread", start_paused = true)]
267 async fn half_open_closes_after_success_threshold() {
268 let cb = CircuitBreaker::new(config(2));
269 cb.record_failure();
270 cb.record_failure();
271 tokio::time::advance(Duration::from_secs(11)).await;
272 assert_eq!(cb.current_state(), CircuitState::HalfOpen);
273 cb.record_success();
274 cb.record_success();
275 assert_eq!(cb.current_state(), CircuitState::Closed);
276 }
277
278 #[tokio::test(flavor = "current_thread", start_paused = true)]
279 async fn half_open_failure_reopens() {
280 let cb = CircuitBreaker::new(config(2));
281 cb.record_failure();
282 cb.record_failure();
283 tokio::time::advance(Duration::from_secs(11)).await;
284 assert_eq!(cb.current_state(), CircuitState::HalfOpen);
285 cb.record_failure();
286 assert_eq!(cb.current_state(), CircuitState::Open);
287 }
288
289 #[tokio::test(flavor = "current_thread", start_paused = true)]
290 async fn stale_failures_are_forgotten() {
291 let cb = CircuitBreaker::new(CircuitBreakerConfig {
292 failure_threshold: 3,
293 failure_window: Duration::from_secs(5),
294 success_threshold: 1,
295 reset_timeout: Duration::from_secs(10),
296 });
297 cb.record_failure();
298 cb.record_failure();
299 tokio::time::advance(Duration::from_secs(6)).await;
300 cb.record_failure();
301 cb.record_failure();
302 assert_eq!(cb.current_state(), CircuitState::Closed);
304 }
305}