Skip to main content

rs_zero/resil/
breaker.rs

1use std::{future::Future, sync::Arc, time::Duration};
2
3use thiserror::Error;
4use tokio::sync::Mutex;
5
6use crate::resil::{WindowConfig, WindowSnapshot, breaker_state::CircuitBreakerState};
7
8/// Circuit breaker state.
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum BreakerState {
11    /// Calls are allowed and failures are counted.
12    Closed,
13    /// Calls are rejected until the reset timeout elapses.
14    Open,
15    /// A limited number of trial calls may decide whether to close again.
16    HalfOpen,
17}
18
19/// Circuit breaker configuration.
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct BreakerConfig {
22    /// Number of consecutive failures that opens the breaker.
23    pub failure_threshold: u32,
24    /// Time before an open breaker allows a half-open trial call.
25    pub reset_timeout: Duration,
26}
27
28impl Default for BreakerConfig {
29    fn default() -> Self {
30        Self {
31            failure_threshold: 5,
32            reset_timeout: Duration::from_secs(30),
33        }
34    }
35}
36
37/// Advanced breaker policy used by production-oriented breakers.
38#[derive(Debug, Clone, PartialEq, Eq)]
39pub struct BreakerPolicyConfig {
40    /// Rolling window configuration used for aggregate decisions.
41    pub window: WindowConfig,
42    /// Minimum request count before aggregate failure ratio can open or drop.
43    pub min_request_count: u64,
44    /// Failure ratio percentage that opens the breaker when enough samples exist.
45    pub failure_ratio_percent: u8,
46    /// Deterministic drop percentage while the rolling window is unhealthy.
47    pub drop_ratio_percent: u8,
48    /// Maximum concurrent trial calls in half-open state.
49    pub half_open_max_calls: u32,
50    /// Minimum interval between forced trial calls while open.
51    pub force_pass_interval: Duration,
52    /// Enables Google SRE style client-side throttling while closed.
53    pub sre_rejection_enabled: bool,
54    /// SRE throttling multiplier in millis. `1500` means `k = 1.5`.
55    pub sre_k_millis: u32,
56    /// Minimum total samples before SRE throttling can reject requests.
57    pub sre_protection: u64,
58}
59
60impl Default for BreakerPolicyConfig {
61    fn default() -> Self {
62        Self {
63            window: WindowConfig::default(),
64            min_request_count: 20,
65            failure_ratio_percent: 50,
66            drop_ratio_percent: 20,
67            half_open_max_calls: 1,
68            force_pass_interval: Duration::from_secs(5),
69            sre_rejection_enabled: false,
70            sre_k_millis: 1500,
71            sre_protection: 5,
72        }
73    }
74}
75
76impl BreakerPolicyConfig {
77    /// Returns a policy using Google SRE style adaptive rejection.
78    pub fn google_sre() -> Self {
79        Self {
80            sre_rejection_enabled: true,
81            drop_ratio_percent: 0,
82            failure_ratio_percent: 100,
83            ..Self::default()
84        }
85    }
86}
87
88/// Snapshot of breaker state and rolling statistics.
89#[derive(Debug, Clone, PartialEq)]
90pub struct BreakerSnapshot {
91    /// Current breaker state.
92    pub state: BreakerState,
93    /// Consecutive backend failures.
94    pub consecutive_failures: u32,
95    /// Current half-open trial calls.
96    pub half_open_in_flight: u32,
97    /// Rolling window statistics.
98    pub window: WindowSnapshot,
99}
100
101/// Error returned when a breaker rejects a call.
102#[derive(Debug, Error, Clone, Copy, PartialEq, Eq)]
103pub enum BreakerError {
104    /// The breaker is open.
105    #[error("circuit breaker is open")]
106    Open,
107    /// The breaker probabilistically dropped the call while unhealthy.
108    #[error("circuit breaker dropped request")]
109    Dropped,
110}
111
112/// Error returned by protected breaker calls.
113#[derive(Debug, Error, PartialEq, Eq)]
114pub enum BreakerCallError<E> {
115    /// The breaker rejected the call before the operation ran.
116    #[error(transparent)]
117    Rejected(#[from] BreakerError),
118    /// The protected operation returned an error.
119    #[error("protected call failed: {0}")]
120    Inner(E),
121}
122
123/// Small circuit breaker suitable for local protection and tests.
124#[derive(Debug)]
125pub struct CircuitBreaker {
126    state: CircuitBreakerState,
127}
128
129impl CircuitBreaker {
130    /// Creates a breaker that opens after `failure_threshold` failures.
131    pub fn new(failure_threshold: u32, reset_timeout: Duration) -> Self {
132        Self {
133            state: CircuitBreakerState::new(
134                BreakerConfig {
135                    failure_threshold,
136                    reset_timeout,
137                },
138                BreakerPolicyConfig::default(),
139            ),
140        }
141    }
142
143    /// Returns the current state, applying reset timeout transition if needed.
144    pub fn state(&mut self) -> BreakerState {
145        self.state.state()
146    }
147
148    /// Returns whether the next call may proceed.
149    pub fn allow(&mut self) -> bool {
150        self.state.allow().is_ok()
151    }
152
153    /// Records a successful call.
154    pub fn record_success(&mut self) {
155        self.state.record_success();
156    }
157
158    /// Records a failed call.
159    pub fn record_failure(&mut self) {
160        self.state.record_failure();
161    }
162}
163
164/// Thread-safe circuit breaker handle for async services.
165#[derive(Debug, Clone)]
166pub struct SharedCircuitBreaker {
167    state: Arc<Mutex<CircuitBreakerState>>,
168}
169
170impl SharedCircuitBreaker {
171    /// Creates a shared circuit breaker from configuration.
172    pub fn new(config: BreakerConfig) -> Self {
173        Self::with_policy(config, BreakerPolicyConfig::default())
174    }
175
176    /// Creates a shared circuit breaker with advanced rolling-window policy.
177    pub fn with_policy(config: BreakerConfig, policy: BreakerPolicyConfig) -> Self {
178        Self {
179            state: Arc::new(Mutex::new(CircuitBreakerState::new(config, policy))),
180        }
181    }
182
183    /// Attempts to enter the protected section.
184    pub async fn allow(&self) -> Result<BreakerGuard, BreakerError> {
185        self.state.lock().await.allow()?;
186        Ok(BreakerGuard {
187            breaker: self.clone(),
188            completed: false,
189        })
190    }
191
192    /// Runs a protected async operation and records success or failure.
193    pub async fn do_request<F, Fut, T, E>(&self, request: F) -> Result<T, BreakerCallError<E>>
194    where
195        F: FnOnce() -> Fut,
196        Fut: Future<Output = Result<T, E>>,
197    {
198        self.do_with_acceptable(request, |_| false).await
199    }
200
201    /// Runs a protected operation with a fallback used only for breaker rejection.
202    pub async fn do_with_fallback<F, Fut, Fb, FbFut, T, E>(
203        &self,
204        request: F,
205        fallback: Fb,
206    ) -> Result<T, E>
207    where
208        F: FnOnce() -> Fut,
209        Fut: Future<Output = Result<T, E>>,
210        Fb: FnOnce(BreakerError) -> FbFut,
211        FbFut: Future<Output = Result<T, E>>,
212    {
213        let guard = match self.allow().await {
214            Ok(guard) => guard,
215            Err(error) => return fallback(error).await,
216        };
217
218        match request().await {
219            Ok(value) => {
220                guard.record_success().await;
221                Ok(value)
222            }
223            Err(error) => {
224                guard.record_failure().await;
225                Err(error)
226            }
227        }
228    }
229
230    /// Runs a protected operation and lets callers mark some errors as acceptable.
231    pub async fn do_with_acceptable<F, Fut, T, E, A>(
232        &self,
233        request: F,
234        acceptable: A,
235    ) -> Result<T, BreakerCallError<E>>
236    where
237        F: FnOnce() -> Fut,
238        Fut: Future<Output = Result<T, E>>,
239        A: Fn(&E) -> bool,
240    {
241        let guard = self.allow().await?;
242        match request().await {
243            Ok(value) => {
244                guard.record_success().await;
245                Ok(value)
246            }
247            Err(error) if acceptable(&error) => {
248                guard.record_success().await;
249                Err(BreakerCallError::Inner(error))
250            }
251            Err(error) => {
252                guard.record_failure().await;
253                Err(BreakerCallError::Inner(error))
254            }
255        }
256    }
257
258    /// Returns the current state.
259    pub async fn state(&self) -> BreakerState {
260        self.state.lock().await.state()
261    }
262
263    /// Returns a snapshot with state and rolling statistics.
264    pub async fn snapshot(&self) -> BreakerSnapshot {
265        self.state.lock().await.snapshot()
266    }
267
268    async fn record_success(&self) {
269        self.state.lock().await.record_success();
270    }
271
272    async fn record_failure(&self) {
273        self.state.lock().await.record_failure();
274    }
275}
276
277/// Guard returned by [`SharedCircuitBreaker::allow`].
278#[derive(Debug)]
279pub struct BreakerGuard {
280    breaker: SharedCircuitBreaker,
281    completed: bool,
282}
283
284impl BreakerGuard {
285    /// Marks the protected operation as successful.
286    pub async fn record_success(mut self) {
287        self.breaker.record_success().await;
288        self.completed = true;
289    }
290
291    /// Marks the protected operation as failed.
292    pub async fn record_failure(mut self) {
293        self.breaker.record_failure().await;
294        self.completed = true;
295    }
296}
297
298impl Drop for BreakerGuard {
299    fn drop(&mut self) {
300        if !self.completed {
301            let breaker = self.breaker.clone();
302            if let Ok(handle) = tokio::runtime::Handle::try_current() {
303                handle.spawn(async move {
304                    breaker.record_failure().await;
305                });
306            }
307        }
308    }
309}