1use std::{future::Future, sync::Arc, time::Duration};
2
3use thiserror::Error;
4use tokio::sync::Mutex;
5
6use crate::resil::{WindowConfig, WindowSnapshot, breaker_state::CircuitBreakerState};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum BreakerState {
11 Closed,
13 Open,
15 HalfOpen,
17}
18
19#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct BreakerConfig {
22 pub failure_threshold: u32,
24 pub reset_timeout: Duration,
26}
27
28impl Default for BreakerConfig {
29 fn default() -> Self {
30 Self {
31 failure_threshold: 5,
32 reset_timeout: Duration::from_secs(30),
33 }
34 }
35}
36
37#[derive(Debug, Clone, PartialEq, Eq)]
39pub struct BreakerPolicyConfig {
40 pub window: WindowConfig,
42 pub min_request_count: u64,
44 pub failure_ratio_percent: u8,
46 pub drop_ratio_percent: u8,
48 pub half_open_max_calls: u32,
50 pub force_pass_interval: Duration,
52 pub sre_rejection_enabled: bool,
54 pub sre_k_millis: u32,
56 pub sre_protection: u64,
58}
59
60impl Default for BreakerPolicyConfig {
61 fn default() -> Self {
62 Self {
63 window: WindowConfig::default(),
64 min_request_count: 20,
65 failure_ratio_percent: 50,
66 drop_ratio_percent: 20,
67 half_open_max_calls: 1,
68 force_pass_interval: Duration::from_secs(5),
69 sre_rejection_enabled: false,
70 sre_k_millis: 1500,
71 sre_protection: 5,
72 }
73 }
74}
75
76impl BreakerPolicyConfig {
77 pub fn google_sre() -> Self {
79 Self {
80 sre_rejection_enabled: true,
81 drop_ratio_percent: 0,
82 failure_ratio_percent: 100,
83 ..Self::default()
84 }
85 }
86}
87
88#[derive(Debug, Clone, PartialEq)]
90pub struct BreakerSnapshot {
91 pub state: BreakerState,
93 pub consecutive_failures: u32,
95 pub half_open_in_flight: u32,
97 pub window: WindowSnapshot,
99}
100
101#[derive(Debug, Error, Clone, Copy, PartialEq, Eq)]
103pub enum BreakerError {
104 #[error("circuit breaker is open")]
106 Open,
107 #[error("circuit breaker dropped request")]
109 Dropped,
110}
111
112#[derive(Debug, Error, PartialEq, Eq)]
114pub enum BreakerCallError<E> {
115 #[error(transparent)]
117 Rejected(#[from] BreakerError),
118 #[error("protected call failed: {0}")]
120 Inner(E),
121}
122
123#[derive(Debug)]
125pub struct CircuitBreaker {
126 state: CircuitBreakerState,
127}
128
129impl CircuitBreaker {
130 pub fn new(failure_threshold: u32, reset_timeout: Duration) -> Self {
132 Self {
133 state: CircuitBreakerState::new(
134 BreakerConfig {
135 failure_threshold,
136 reset_timeout,
137 },
138 BreakerPolicyConfig::default(),
139 ),
140 }
141 }
142
143 pub fn state(&mut self) -> BreakerState {
145 self.state.state()
146 }
147
148 pub fn allow(&mut self) -> bool {
150 self.state.allow().is_ok()
151 }
152
153 pub fn record_success(&mut self) {
155 self.state.record_success();
156 }
157
158 pub fn record_failure(&mut self) {
160 self.state.record_failure();
161 }
162}
163
164#[derive(Debug, Clone)]
166pub struct SharedCircuitBreaker {
167 state: Arc<Mutex<CircuitBreakerState>>,
168}
169
170impl SharedCircuitBreaker {
171 pub fn new(config: BreakerConfig) -> Self {
173 Self::with_policy(config, BreakerPolicyConfig::default())
174 }
175
176 pub fn with_policy(config: BreakerConfig, policy: BreakerPolicyConfig) -> Self {
178 Self {
179 state: Arc::new(Mutex::new(CircuitBreakerState::new(config, policy))),
180 }
181 }
182
183 pub async fn allow(&self) -> Result<BreakerGuard, BreakerError> {
185 self.state.lock().await.allow()?;
186 Ok(BreakerGuard {
187 breaker: self.clone(),
188 completed: false,
189 })
190 }
191
192 pub async fn do_request<F, Fut, T, E>(&self, request: F) -> Result<T, BreakerCallError<E>>
194 where
195 F: FnOnce() -> Fut,
196 Fut: Future<Output = Result<T, E>>,
197 {
198 self.do_with_acceptable(request, |_| false).await
199 }
200
201 pub async fn do_with_fallback<F, Fut, Fb, FbFut, T, E>(
203 &self,
204 request: F,
205 fallback: Fb,
206 ) -> Result<T, E>
207 where
208 F: FnOnce() -> Fut,
209 Fut: Future<Output = Result<T, E>>,
210 Fb: FnOnce(BreakerError) -> FbFut,
211 FbFut: Future<Output = Result<T, E>>,
212 {
213 let guard = match self.allow().await {
214 Ok(guard) => guard,
215 Err(error) => return fallback(error).await,
216 };
217
218 match request().await {
219 Ok(value) => {
220 guard.record_success().await;
221 Ok(value)
222 }
223 Err(error) => {
224 guard.record_failure().await;
225 Err(error)
226 }
227 }
228 }
229
230 pub async fn do_with_acceptable<F, Fut, T, E, A>(
232 &self,
233 request: F,
234 acceptable: A,
235 ) -> Result<T, BreakerCallError<E>>
236 where
237 F: FnOnce() -> Fut,
238 Fut: Future<Output = Result<T, E>>,
239 A: Fn(&E) -> bool,
240 {
241 let guard = self.allow().await?;
242 match request().await {
243 Ok(value) => {
244 guard.record_success().await;
245 Ok(value)
246 }
247 Err(error) if acceptable(&error) => {
248 guard.record_success().await;
249 Err(BreakerCallError::Inner(error))
250 }
251 Err(error) => {
252 guard.record_failure().await;
253 Err(BreakerCallError::Inner(error))
254 }
255 }
256 }
257
258 pub async fn state(&self) -> BreakerState {
260 self.state.lock().await.state()
261 }
262
263 pub async fn snapshot(&self) -> BreakerSnapshot {
265 self.state.lock().await.snapshot()
266 }
267
268 async fn record_success(&self) {
269 self.state.lock().await.record_success();
270 }
271
272 async fn record_failure(&self) {
273 self.state.lock().await.record_failure();
274 }
275}
276
277#[derive(Debug)]
279pub struct BreakerGuard {
280 breaker: SharedCircuitBreaker,
281 completed: bool,
282}
283
284impl BreakerGuard {
285 pub async fn record_success(mut self) {
287 self.breaker.record_success().await;
288 self.completed = true;
289 }
290
291 pub async fn record_failure(mut self) {
293 self.breaker.record_failure().await;
294 self.completed = true;
295 }
296}
297
298impl Drop for BreakerGuard {
299 fn drop(&mut self) {
300 if !self.completed {
301 let breaker = self.breaker.clone();
302 if let Ok(handle) = tokio::runtime::Handle::try_current() {
303 handle.spawn(async move {
304 breaker.record_failure().await;
305 });
306 }
307 }
308 }
309}