1use std::sync::Arc;
34use std::time::{Duration, Instant};
35use tokio::sync::RwLock;
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum CircuitState {
40 Closed,
42 Open,
44 HalfOpen,
46}
47
48#[derive(Debug, Clone)]
50pub struct CircuitBreakerConfig {
51 pub failure_threshold: u32,
53 pub timeout: Duration,
55 pub success_threshold: u32,
57}
58
59impl Default for CircuitBreakerConfig {
60 #[inline]
61 fn default() -> Self {
62 Self {
63 failure_threshold: 5,
64 timeout: Duration::from_secs(60),
65 success_threshold: 2,
66 }
67 }
68}
69
70pub struct CircuitBreaker {
72 name: String,
73 config: CircuitBreakerConfig,
74 state: Arc<RwLock<CircuitBreakerState>>,
75}
76
77#[derive(Debug)]
78struct CircuitBreakerState {
79 circuit_state: CircuitState,
80 failure_count: u32,
81 success_count: u32,
82 last_failure_time: Option<Instant>,
83}
84
85impl CircuitBreaker {
86 #[must_use]
88 pub fn new(name: impl Into<String>, config: CircuitBreakerConfig) -> Self {
89 Self {
90 name: name.into(),
91 config,
92 state: Arc::new(RwLock::new(CircuitBreakerState {
93 circuit_state: CircuitState::Closed,
94 failure_count: 0,
95 success_count: 0,
96 last_failure_time: None,
97 })),
98 }
99 }
100
101 #[must_use]
103 pub async fn state(&self) -> CircuitState {
104 self.state.read().await.circuit_state
105 }
106
107 #[must_use]
109 #[inline]
110 pub fn name(&self) -> &str {
111 &self.name
112 }
113
114 pub async fn call<F, Fut, T, E>(&self, f: F) -> Result<T, CircuitBreakerError<E>>
116 where
117 F: FnOnce() -> Fut,
118 Fut: std::future::Future<Output = Result<T, E>>,
119 {
120 if let Err(CircuitBreakerError::CircuitOpen) = self.check_state::<E>().await {
122 return Err(CircuitBreakerError::CircuitOpen);
123 }
124
125 match f().await {
127 Ok(result) => {
128 self.on_success().await;
129 Ok(result)
130 }
131 Err(e) => {
132 self.on_failure().await;
133 Err(CircuitBreakerError::CallFailed(e))
134 }
135 }
136 }
137
138 async fn check_state<E>(&self) -> Result<(), CircuitBreakerError<E>> {
140 let mut state = self.state.write().await;
141
142 match state.circuit_state {
143 CircuitState::Closed => Ok(()),
144 CircuitState::Open => {
145 if let Some(last_failure) = state.last_failure_time {
147 if last_failure.elapsed() >= self.config.timeout {
148 state.circuit_state = CircuitState::HalfOpen;
150 state.success_count = 0;
151 tracing::info!(
152 circuit_breaker = %self.name,
153 "Circuit breaker transitioning to half-open"
154 );
155 Ok(())
156 } else {
157 Err(CircuitBreakerError::CircuitOpen)
158 }
159 } else {
160 Err(CircuitBreakerError::CircuitOpen)
161 }
162 }
163 CircuitState::HalfOpen => Ok(()),
164 }
165 }
166
167 async fn on_success(&self) {
169 let mut state = self.state.write().await;
170
171 match state.circuit_state {
172 CircuitState::Closed => {
173 state.failure_count = 0;
175 }
176 CircuitState::HalfOpen => {
177 state.success_count += 1;
178 if state.success_count >= self.config.success_threshold {
179 state.circuit_state = CircuitState::Closed;
181 state.failure_count = 0;
182 state.success_count = 0;
183 tracing::info!(
184 circuit_breaker = %self.name,
185 "Circuit breaker closed after recovery"
186 );
187 }
188 }
189 CircuitState::Open => {}
190 }
191 }
192
193 async fn on_failure(&self) {
195 let mut state = self.state.write().await;
196
197 match state.circuit_state {
198 CircuitState::Closed => {
199 state.failure_count += 1;
200 if state.failure_count >= self.config.failure_threshold {
201 state.circuit_state = CircuitState::Open;
203 state.last_failure_time = Some(Instant::now());
204 tracing::warn!(
205 circuit_breaker = %self.name,
206 failures = state.failure_count,
207 "Circuit breaker opened due to failures"
208 );
209 }
210 }
211 CircuitState::HalfOpen => {
212 state.circuit_state = CircuitState::Open;
214 state.last_failure_time = Some(Instant::now());
215 state.success_count = 0;
216 tracing::warn!(
217 circuit_breaker = %self.name,
218 "Circuit breaker reopened after failed recovery attempt"
219 );
220 }
221 CircuitState::Open => {
222 state.last_failure_time = Some(Instant::now());
224 }
225 }
226 }
227
228 pub async fn reset(&self) {
230 let mut state = self.state.write().await;
231 state.circuit_state = CircuitState::Closed;
232 state.failure_count = 0;
233 state.success_count = 0;
234 state.last_failure_time = None;
235 tracing::info!(circuit_breaker = %self.name, "Circuit breaker manually reset");
236 }
237}
238
239#[derive(Debug, thiserror::Error)]
241pub enum CircuitBreakerError<E> {
242 #[error("Circuit breaker is open")]
244 CircuitOpen,
245 #[error("Call failed: {0}")]
247 CallFailed(E),
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253
254 #[tokio::test]
255 async fn test_circuit_breaker_closed_state() {
256 let config = CircuitBreakerConfig {
257 failure_threshold: 3,
258 timeout: Duration::from_secs(1),
259 success_threshold: 2,
260 };
261 let breaker = CircuitBreaker::new("test", config);
262
263 assert_eq!(breaker.state().await, CircuitState::Closed);
264
265 let result = breaker.call(|| async { Ok::<_, String>("success") }).await;
267 assert!(result.is_ok());
268 assert_eq!(breaker.state().await, CircuitState::Closed);
269 }
270
271 #[tokio::test]
272 async fn test_circuit_breaker_opens_after_failures() {
273 let config = CircuitBreakerConfig {
274 failure_threshold: 3,
275 timeout: Duration::from_secs(1),
276 success_threshold: 2,
277 };
278 let breaker = CircuitBreaker::new("test", config);
279
280 for _ in 0..3 {
282 let _ = breaker.call(|| async { Err::<(), _>("error") }).await;
283 }
284
285 assert_eq!(breaker.state().await, CircuitState::Open);
286 }
287
288 #[tokio::test]
289 async fn test_circuit_breaker_rejects_when_open() {
290 let config = CircuitBreakerConfig {
291 failure_threshold: 2,
292 timeout: Duration::from_secs(10),
293 success_threshold: 2,
294 };
295 let breaker = CircuitBreaker::new("test", config);
296
297 for _ in 0..2 {
299 let _ = breaker.call(|| async { Err::<(), _>("error") }).await;
300 }
301
302 let result = breaker.call(|| async { Ok::<_, String>("success") }).await;
304 assert!(matches!(result, Err(CircuitBreakerError::CircuitOpen)));
305 }
306
307 #[tokio::test]
308 async fn test_circuit_breaker_half_open() {
309 let config = CircuitBreakerConfig {
310 failure_threshold: 2,
311 timeout: Duration::from_millis(100),
312 success_threshold: 2,
313 };
314 let breaker = CircuitBreaker::new("test", config);
315
316 for _ in 0..2 {
318 let _ = breaker.call(|| async { Err::<(), _>("error") }).await;
319 }
320
321 assert_eq!(breaker.state().await, CircuitState::Open);
322
323 tokio::time::sleep(Duration::from_millis(150)).await;
325
326 let _ = breaker.call(|| async { Ok::<_, String>("success") }).await;
328 let state = breaker.state().await;
329 assert!(state == CircuitState::HalfOpen || state == CircuitState::Closed);
330 }
331
332 #[tokio::test]
333 async fn test_circuit_breaker_recovery() {
334 let config = CircuitBreakerConfig {
335 failure_threshold: 2,
336 timeout: Duration::from_millis(100),
337 success_threshold: 2,
338 };
339 let breaker = CircuitBreaker::new("test", config);
340
341 for _ in 0..2 {
343 let _ = breaker.call(|| async { Err::<(), _>("error") }).await;
344 }
345
346 tokio::time::sleep(Duration::from_millis(150)).await;
348
349 for _ in 0..2 {
351 let _ = breaker.call(|| async { Ok::<_, String>("success") }).await;
352 }
353
354 assert_eq!(breaker.state().await, CircuitState::Closed);
355 }
356
357 #[tokio::test]
358 async fn test_circuit_breaker_reset() {
359 let config = CircuitBreakerConfig {
360 failure_threshold: 2,
361 timeout: Duration::from_secs(10),
362 success_threshold: 2,
363 };
364 let breaker = CircuitBreaker::new("test", config);
365
366 for _ in 0..2 {
368 let _ = breaker.call(|| async { Err::<(), _>("error") }).await;
369 }
370
371 assert_eq!(breaker.state().await, CircuitState::Open);
372
373 breaker.reset().await;
375 assert_eq!(breaker.state().await, CircuitState::Closed);
376 }
377}