Skip to main content

do_over/
circuit_breaker.rs

1//! Circuit breaker policy for preventing cascading failures.
2//!
3//! The circuit breaker monitors failures and temporarily stops calling a failing
4//! service, giving it time to recover.
5//!
6//! # States
7//!
8//! - **Closed**: Normal operation, requests flow through. Failures are counted.
9//! - **Open**: Circuit is tripped. Requests fail immediately without calling the service.
10//! - **Half-Open**: After reset timeout, one test request is allowed through.
11//!
12//! # Examples
13//!
14//! ```rust
15//! use do_over::{policy::Policy, circuit_breaker::CircuitBreaker, error::DoOverError};
16//! use std::time::Duration;
17//!
18//! # async fn example() -> Result<(), DoOverError<std::io::Error>> {
19//! // Open circuit after 5 failures, reset after 30 seconds
20//! let breaker = CircuitBreaker::new(5, Duration::from_secs(30));
21//!
22//! match breaker.execute(|| async {
23//!     Ok::<_, DoOverError<std::io::Error>>("success")
24//! }).await {
25//!     Ok(result) => println!("Success: {}", result),
26//!     Err(DoOverError::CircuitOpen) => println!("Circuit is open"),
27//!     Err(e) => println!("Error: {:?}", e),
28//! }
29//! # Ok(())
30//! # }
31//! ```
32
33use std::sync::atomic::{AtomicUsize, Ordering};
34use std::time::{Duration, Instant};
35use tokio::sync::RwLock;
36use crate::policy::Policy;
37use crate::error::DoOverError;
38
39/// Internal circuit breaker state.
40#[derive(Clone, Copy)]
41enum State {
42    /// Normal operation - requests pass through.
43    Closed,
44    /// Circuit is open - requests fail immediately.
45    Open,
46    /// Testing if service recovered - one request allowed.
47    HalfOpen,
48}
49
50/// A circuit breaker that prevents cascading failures.
51///
52/// The circuit breaker tracks consecutive failures and "trips" when the failure
53/// threshold is reached, causing subsequent requests to fail immediately without
54/// calling the underlying service.
55///
56/// # State Transitions
57///
58/// ```text
59/// Closed --[failures >= threshold]--> Open
60/// Open --[reset_timeout elapsed]--> HalfOpen
61/// HalfOpen --[success]--> Closed
62/// HalfOpen --[failure]--> Open
63/// ```
64///
65/// # Examples
66///
67/// ```rust
68/// use do_over::{policy::Policy, circuit_breaker::CircuitBreaker, error::DoOverError};
69/// use std::time::Duration;
70///
71/// # async fn example() {
72/// let breaker = CircuitBreaker::new(3, Duration::from_secs(60));
73///
74/// // Use the breaker to protect calls to an external service
75/// let result: Result<String, DoOverError<String>> = breaker.execute(|| async {
76///     Ok("response".to_string())
77/// }).await;
78/// # }
79/// ```
80pub struct CircuitBreaker {
81    failure_threshold: usize,
82    reset_timeout: Duration,
83    failures: AtomicUsize,
84    opened_at: RwLock<Option<Instant>>,
85    state: RwLock<State>,
86}
87
88impl Clone for CircuitBreaker {
89    fn clone(&self) -> Self {
90        Self {
91            failure_threshold: self.failure_threshold,
92            reset_timeout: self.reset_timeout,
93            failures: AtomicUsize::new(self.failures.load(Ordering::Relaxed)),
94            opened_at: RwLock::new(*self.opened_at.blocking_read()),
95            state: RwLock::new(*self.state.blocking_read()),
96        }
97    }
98}
99
100impl CircuitBreaker {
101    /// Create a new circuit breaker.
102    ///
103    /// # Arguments
104    ///
105    /// * `failure_threshold` - Number of consecutive failures before the circuit opens
106    /// * `reset_timeout` - How long to wait before transitioning from Open to Half-Open
107    ///
108    /// # Examples
109    ///
110    /// ```rust
111    /// use do_over::circuit_breaker::CircuitBreaker;
112    /// use std::time::Duration;
113    ///
114    /// // Open after 5 failures, wait 60 seconds before testing recovery
115    /// let breaker = CircuitBreaker::new(5, Duration::from_secs(60));
116    /// ```
117    pub fn new(failure_threshold: usize, reset_timeout: Duration) -> Self {
118        Self {
119            failure_threshold,
120            reset_timeout,
121            failures: AtomicUsize::new(0),
122            opened_at: RwLock::new(None),
123            state: RwLock::new(State::Closed),
124        }
125    }
126}
127
128#[async_trait::async_trait]
129impl<E> Policy<DoOverError<E>> for CircuitBreaker
130where
131    E: Send + Sync,
132{
133    async fn execute<F, Fut, T>(&self, f: F) -> Result<T, DoOverError<E>>
134    where
135        F: Fn() -> Fut + Send + Sync,
136        Fut: std::future::Future<Output = Result<T, DoOverError<E>>> + Send,
137        T: Send,
138    {
139        {
140            let state = self.state.read().await;
141            if matches!(*state, State::Open) {
142                let opened = self.opened_at.read().await;
143                if let Some(t) = *opened {
144                    if t.elapsed() >= self.reset_timeout {
145                        drop(opened);
146                        *self.state.write().await = State::HalfOpen;
147                    } else {
148                        return Err(DoOverError::CircuitOpen);
149                    }
150                }
151            }
152        }
153
154        match f().await {
155            Ok(v) => {
156                self.failures.store(0, Ordering::Relaxed);
157                *self.state.write().await = State::Closed;
158                Ok(v)
159            }
160            Err(e) => {
161                let count = self.failures.fetch_add(1, Ordering::Relaxed) + 1;
162                if count >= self.failure_threshold {
163                    *self.state.write().await = State::Open;
164                    *self.opened_at.write().await = Some(Instant::now());
165                }
166                Err(e)
167            }
168        }
169    }
170}