tower_circuitbreaker/
lib.rs1use crate::circuit::Circuit;
98use crate::config::CircuitBreakerConfig;
99use crate::layer::CircuitBreakerLayerBuilder;
100use futures::future::BoxFuture;
101#[cfg(feature = "metrics")]
102use metrics::{counter, describe_counter, describe_gauge};
103use std::sync::Arc;
104#[cfg(feature = "metrics")]
105use std::sync::Once;
106use std::task::{Context, Poll};
107use tokio::sync::Mutex;
108use tower::Service;
109#[cfg(feature = "tracing")]
110use tracing::debug;
111
112pub use circuit::CircuitState;
113pub use error::CircuitBreakerError;
114
115mod circuit;
116mod config;
117mod error;
118mod layer;
119
120pub(crate) type FailureClassifier<Res, Err> = dyn Fn(&Result<Res, Err>) -> bool + Send + Sync;
121pub(crate) type SharedFailureClassifier<Res, Err> = Arc<FailureClassifier<Res, Err>>;
122
123#[cfg(feature = "tracing")]
124pub(crate) static DEFAULT_CIRCUIT_BREAKER_NAME: &str = "<unnamed>";
125
126#[cfg(feature = "metrics")]
127static METRICS_INIT: Once = Once::new();
128
129pub fn circuit_breaker_builder<Res, Err>() -> CircuitBreakerLayerBuilder<Res, Err> {
131 #[cfg(feature = "metrics")]
132 {
133 METRICS_INIT.call_once(|| {
134 describe_counter!(
135 "circuitbreaker_calls_total",
136 "Total number of calls through the circuit breaker"
137 );
138 describe_counter!(
139 "circuitbreaker_transitions_total",
140 "Total number of circuit breaker state transitions"
141 );
142 describe_gauge!(
143 "circuitbreaker_state",
144 "Current state of the circuit breaker"
145 );
146 });
147 }
148 CircuitBreakerLayerBuilder::default()
149}
150
151pub struct CircuitBreaker<S, Res, Err> {
155 inner: S,
156 circuit: Arc<Mutex<Circuit>>,
157 config: Arc<CircuitBreakerConfig<Res, Err>>,
158}
159
160impl<S, Res, Err> CircuitBreaker<S, Res, Err> {
161 pub(crate) fn new(inner: S, config: Arc<CircuitBreakerConfig<Res, Err>>) -> Self {
163 Self {
164 inner,
165 circuit: Arc::new(Mutex::new(Circuit::new())),
166 config,
167 }
168 }
169
170 pub async fn force_open(&self) {
172 let mut circuit = self.circuit.lock().await;
173 circuit.force_open();
174 }
175
176 pub async fn force_closed(&self) {
178 let mut circuit = self.circuit.lock().await;
179 circuit.force_closed();
180 }
181
182 pub async fn reset(&self) {
184 let mut circuit = self.circuit.lock().await;
185 circuit.reset();
186 }
187
188 pub async fn state(&self) -> CircuitState {
190 let circuit = self.circuit.lock().await;
191 circuit.state()
192 }
193}
194
195impl<S, Req, Res, Err> Service<Req> for CircuitBreaker<S, Res, Err>
196where
197 S: Service<Req, Response = Res, Error = Err> + Clone + Send + 'static,
198 S::Future: Send + 'static,
199 Res: Send + 'static,
200 Err: Send + 'static,
201 Req: Send + 'static,
202{
203 type Response = Res;
204 type Error = CircuitBreakerError<Err>;
205 type Future = BoxFuture<'static, Result<Res, Self::Error>>;
206
207 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
208 self.inner
209 .poll_ready(cx)
210 .map_err(CircuitBreakerError::Inner)
211 }
212
213 fn call(&mut self, req: Req) -> Self::Future {
214 let config = Arc::clone(&self.config);
215 let circuit = Arc::clone(&self.circuit);
216 let mut inner = self.inner.clone();
217
218 Box::pin(async move {
219 #[cfg(feature = "tracing")]
220 {
221 let cb_name = config
222 .name
223 .as_deref()
224 .unwrap_or(DEFAULT_CIRCUIT_BREAKER_NAME);
225 debug!(
226 breaker = cb_name,
227 "Checking if call is permitted by circuit breaker"
228 );
229 }
230
231 #[cfg(feature = "tracing")]
232 let circuit_check_span = {
233 use tracing::{Level, span};
234 let state = {
235 let circuit = circuit.lock().await;
237 circuit.state()
238 };
239 let cb_name = config
240 .name
241 .as_deref()
242 .unwrap_or(DEFAULT_CIRCUIT_BREAKER_NAME);
243 span!(Level::DEBUG, "circuit_check", breaker = cb_name, state = ?state)
244 };
245 #[cfg(feature = "tracing")]
246 let _enter = circuit_check_span.enter();
247
248 let permitted = {
249 let mut circuit = circuit.lock().await;
250 circuit.try_acquire(&config)
251 };
252
253 #[cfg(feature = "tracing")]
254 {
255 let cb_name = config
256 .name
257 .as_deref()
258 .unwrap_or(DEFAULT_CIRCUIT_BREAKER_NAME);
259 if permitted {
260 tracing::trace!(breaker = cb_name, "circuit breaker permitted call");
261 } else {
262 tracing::trace!(
263 breaker = cb_name,
264 "circuit breaker rejected call (circuit open)"
265 );
266 }
267 }
268
269 if !permitted {
270 #[cfg(feature = "metrics")]
271 {
272 let counter = counter!("circuitbreaker_calls_total", "outcome" => "rejected");
273 counter.increment(1);
274 }
275 return Err(CircuitBreakerError::OpenCircuit);
276 }
277
278 let result = inner.call(req).await;
279
280 let mut circuit = circuit.lock().await;
281 if (config.failure_classifier)(&result) {
282 circuit.record_failure(&config);
283 } else {
284 circuit.record_success(&config);
285 }
286
287 result.map_err(CircuitBreakerError::Inner)
288 })
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295 use std::time::Duration;
296
297 fn dummy_config() -> CircuitBreakerConfig<(), ()> {
298 CircuitBreakerConfig {
299 failure_rate_threshold: 0.5,
300 sliding_window_size: 10,
301 wait_duration_in_open: Duration::from_secs(1),
302 permitted_calls_in_half_open: 1,
303 failure_classifier: Arc::new(|r| r.is_err()),
304 minimum_number_of_calls: 10,
305 #[cfg(feature = "tracing")]
306 name: Some("test".into()),
307 }
308 }
309
310 #[test]
311 fn transitions_to_open_on_high_failure_rate() {
312 let mut circuit = Circuit::new();
313 let config = dummy_config();
314
315 for _ in 0..6 {
316 circuit.record_failure(&config);
317 }
318 for _ in 0..4 {
319 circuit.record_success(&config);
320 }
321
322 assert_eq!(circuit.state(), CircuitState::Open);
323 }
324
325 #[test]
326 fn stays_closed_on_low_failure_rate() {
327 let mut circuit = Circuit::new();
328 let config = dummy_config();
329
330 for _ in 0..2 {
331 circuit.record_failure(&config);
332 }
333 for _ in 0..8 {
334 circuit.record_success(&config);
335 }
336
337 assert_eq!(circuit.state(), CircuitState::Closed);
338 }
339
340 #[tokio::test]
341 async fn manual_override_controls_work() {
342 let config = Arc::new(dummy_config());
343 let breaker = CircuitBreaker::new((), config);
344
345 breaker.force_open().await;
346 assert_eq!(breaker.state().await, CircuitState::Open);
347
348 breaker.force_closed().await;
349 assert_eq!(breaker.state().await, CircuitState::Closed);
350 }
351
352 #[test]
353 fn test_error_helpers() {
354 let err: CircuitBreakerError<&str> = CircuitBreakerError::OpenCircuit;
355 assert!(err.is_circuit_open());
356 assert_eq!(err.into_inner(), None);
357
358 let err2 = CircuitBreakerError::Inner("fail");
359 assert!(!err2.is_circuit_open());
360 assert_eq!(err2.into_inner(), Some("fail"));
361 }
362}