1use crate::circuit::Circuit;
214use futures::future::BoxFuture;
215#[cfg(feature = "metrics")]
216use metrics::{counter, describe_counter, describe_gauge};
217use std::sync::Arc;
218#[cfg(feature = "metrics")]
219use std::sync::Once;
220use std::task::{Context, Poll};
221use tokio::sync::Mutex;
222use tower::Service;
223#[cfg(feature = "tracing")]
224use tracing::debug;
225
226pub use circuit::CircuitState;
227pub use config::{CircuitBreakerConfig, CircuitBreakerConfigBuilder, SlidingWindowType};
228pub use error::CircuitBreakerError;
229pub use events::CircuitBreakerEvent;
230pub use layer::CircuitBreakerLayer;
231
232mod circuit;
233mod config;
234mod error;
235mod events;
236mod layer;
237
238pub(crate) type FailureClassifier<Res, Err> = dyn Fn(&Result<Res, Err>) -> bool + Send + Sync;
239pub(crate) type SharedFailureClassifier<Res, Err> = Arc<FailureClassifier<Res, Err>>;
240
241pub(crate) type FallbackFn<Req, Res, Err> =
242 dyn Fn(Req) -> BoxFuture<'static, Result<Res, Err>> + Send + Sync;
243pub(crate) type SharedFallback<Req, Res, Err> = Arc<FallbackFn<Req, Res, Err>>;
244
245#[cfg(feature = "metrics")]
246static METRICS_INIT: Once = Once::new();
247
248pub fn circuit_breaker_builder<Res, Err>() -> CircuitBreakerConfigBuilder<Res, Err> {
253 #[cfg(feature = "metrics")]
254 {
255 METRICS_INIT.call_once(|| {
256 describe_counter!(
257 "circuitbreaker_calls_total",
258 "Total number of calls through the circuit breaker"
259 );
260 describe_counter!(
261 "circuitbreaker_transitions_total",
262 "Total number of circuit breaker state transitions"
263 );
264 describe_gauge!(
265 "circuitbreaker_state",
266 "Current state of the circuit breaker"
267 );
268 });
269 }
270 CircuitBreakerConfigBuilder::default()
271}
272
273pub struct CircuitBreaker<S, Req, Res, Err> {
277 inner: S,
278 circuit: Arc<Mutex<Circuit>>,
279 state_atomic: Arc<std::sync::atomic::AtomicU8>,
280 config: Arc<CircuitBreakerConfig<Res, Err>>,
281 fallback: Option<SharedFallback<Req, Res, Err>>,
282 _phantom: std::marker::PhantomData<Req>,
283}
284
285impl<S, Req, Res, Err> CircuitBreaker<S, Req, Res, Err> {
286 pub(crate) fn new(inner: S, config: Arc<CircuitBreakerConfig<Res, Err>>) -> Self {
288 let state_atomic = Arc::new(std::sync::atomic::AtomicU8::new(CircuitState::Closed as u8));
289 Self {
290 inner,
291 circuit: Arc::new(Mutex::new(Circuit::new_with_atomic(Arc::clone(
292 &state_atomic,
293 )))),
294 state_atomic,
295 config,
296 fallback: None,
297 _phantom: std::marker::PhantomData,
298 }
299 }
300
301 pub fn with_fallback<F>(mut self, fallback: F) -> Self
303 where
304 F: Fn(Req) -> BoxFuture<'static, Result<Res, Err>> + Send + Sync + 'static,
305 {
306 self.fallback = Some(Arc::new(fallback));
307 self
308 }
309
310 pub async fn force_open(&self) {
312 let mut circuit = self.circuit.lock().await;
313 circuit.force_open(&self.config);
314 }
315
316 pub async fn force_closed(&self) {
318 let mut circuit = self.circuit.lock().await;
319 circuit.force_closed(&self.config);
320 }
321
322 pub async fn reset(&self) {
324 let mut circuit = self.circuit.lock().await;
325 circuit.reset(&self.config);
326 }
327
328 pub async fn state(&self) -> CircuitState {
330 let circuit = self.circuit.lock().await;
331 circuit.state()
332 }
333
334 pub fn state_sync(&self) -> CircuitState {
339 CircuitState::from_u8(self.state_atomic.load(std::sync::atomic::Ordering::Acquire))
340 }
341}
342
343impl<S, Req, Res, Err> Service<Req> for CircuitBreaker<S, Req, Res, Err>
344where
345 S: Service<Req, Response = Res, Error = Err> + Clone + Send + 'static,
346 S::Future: Send + 'static,
347 Res: Send + 'static,
348 Err: Send + 'static,
349 Req: Send + 'static,
350{
351 type Response = Res;
352 type Error = CircuitBreakerError<Err>;
353 type Future = BoxFuture<'static, Result<Res, Self::Error>>;
354
355 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
356 self.inner
357 .poll_ready(cx)
358 .map_err(CircuitBreakerError::Inner)
359 }
360
361 fn call(&mut self, req: Req) -> Self::Future {
362 let config = Arc::clone(&self.config);
363 let circuit = Arc::clone(&self.circuit);
364 let mut inner = self.inner.clone();
365 let fallback = self.fallback.clone();
366
367 Box::pin(async move {
368 #[cfg(feature = "tracing")]
369 {
370 let cb_name = &config.name;
371 debug!(
372 breaker = cb_name,
373 "Checking if call is permitted by circuit breaker"
374 );
375 }
376
377 #[cfg(feature = "tracing")]
378 let circuit_check_span = {
379 use tracing::{Level, span};
380 let state = {
381 let circuit = circuit.lock().await;
383 circuit.state()
384 };
385 let cb_name = &config.name;
386 span!(Level::DEBUG, "circuit_check", breaker = cb_name, state = ?state)
387 };
388 #[cfg(feature = "tracing")]
389 let _enter = circuit_check_span.enter();
390
391 let permitted = {
392 let mut circuit = circuit.lock().await;
393 circuit.try_acquire(&config)
394 };
395
396 #[cfg(feature = "tracing")]
397 {
398 let cb_name = &config.name;
399 if permitted {
400 tracing::trace!(breaker = cb_name, "circuit breaker permitted call");
401 } else {
402 tracing::trace!(
403 breaker = cb_name,
404 "circuit breaker rejected call (circuit open)"
405 );
406 }
407 }
408
409 if !permitted {
410 #[cfg(feature = "metrics")]
411 {
412 let counter = counter!("circuitbreaker_calls_total", "outcome" => "rejected");
413 counter.increment(1);
414 }
415
416 if let Some(fallback_fn) = fallback {
418 #[cfg(feature = "tracing")]
419 {
420 let cb_name = &config.name;
421 tracing::debug!(breaker = cb_name, "Calling fallback handler");
422 }
423
424 return fallback_fn(req).await.map_err(CircuitBreakerError::Inner);
425 }
426
427 return Err(CircuitBreakerError::OpenCircuit);
428 }
429
430 let start = std::time::Instant::now();
431 let result = inner.call(req).await;
432 let duration = start.elapsed();
433
434 let mut circuit = circuit.lock().await;
435 if (config.failure_classifier)(&result) {
436 circuit.record_failure(&config, duration);
437 } else {
438 circuit.record_success(&config, duration);
439 }
440
441 result.map_err(CircuitBreakerError::Inner)
442 })
443 }
444}
445
446#[cfg(test)]
447mod tests {
448 use super::*;
449 use std::time::Duration;
450
451 fn dummy_config() -> CircuitBreakerConfig<(), ()> {
452 use tower_resilience_core::EventListeners;
453 CircuitBreakerConfig {
454 failure_rate_threshold: 0.5,
455 sliding_window_type: crate::config::SlidingWindowType::CountBased,
456 sliding_window_size: 10,
457 sliding_window_duration: None,
458 wait_duration_in_open: Duration::from_secs(1),
459 permitted_calls_in_half_open: 1,
460 failure_classifier: Arc::new(|r| r.is_err()),
461 minimum_number_of_calls: 10,
462 slow_call_duration_threshold: None,
463 slow_call_rate_threshold: 1.0,
464 event_listeners: EventListeners::new(),
465 name: "test".into(),
466 }
467 }
468
469 #[test]
470 fn transitions_to_open_on_high_failure_rate() {
471 let mut circuit = Circuit::new();
472 let config = dummy_config();
473
474 for _ in 0..6 {
475 circuit.record_failure(&config, Duration::from_millis(10));
476 }
477 for _ in 0..4 {
478 circuit.record_success(&config, Duration::from_millis(10));
479 }
480
481 assert_eq!(circuit.state(), CircuitState::Open);
482 }
483
484 #[test]
485 fn stays_closed_on_low_failure_rate() {
486 let mut circuit = Circuit::new();
487 let config = dummy_config();
488
489 for _ in 0..2 {
490 circuit.record_failure(&config, Duration::from_millis(10));
491 }
492 for _ in 0..8 {
493 circuit.record_success(&config, Duration::from_millis(10));
494 }
495
496 assert_eq!(circuit.state(), CircuitState::Closed);
497 }
498
499 #[tokio::test]
500 async fn manual_override_controls_work() {
501 let config = Arc::new(dummy_config());
502 let breaker: CircuitBreaker<(), (), (), ()> = CircuitBreaker::new((), config);
503
504 breaker.force_open().await;
505 assert_eq!(breaker.state().await, CircuitState::Open);
506
507 breaker.force_closed().await;
508 assert_eq!(breaker.state().await, CircuitState::Closed);
509 }
510
511 #[test]
512 fn test_error_helpers() {
513 let err: CircuitBreakerError<&str> = CircuitBreakerError::OpenCircuit;
514 assert!(err.is_circuit_open());
515 assert_eq!(err.into_inner(), None);
516
517 let err2 = CircuitBreakerError::Inner("fail");
518 assert!(!err2.is_circuit_open());
519 assert_eq!(err2.into_inner(), Some("fail"));
520 }
521
522 #[test]
523 fn test_event_listeners() {
524 use std::sync::atomic::{AtomicUsize, Ordering};
525 use tower_resilience_core::EventListeners;
526
527 let state_transitions = Arc::new(AtomicUsize::new(0));
528 let call_permitted = Arc::new(AtomicUsize::new(0));
529 let call_rejected = Arc::new(AtomicUsize::new(0));
530 let successes = Arc::new(AtomicUsize::new(0));
531 let failures = Arc::new(AtomicUsize::new(0));
532
533 let st_clone = Arc::clone(&state_transitions);
534 let cp_clone = Arc::clone(&call_permitted);
535 let cr_clone = Arc::clone(&call_rejected);
536 let s_clone = Arc::clone(&successes);
537 let f_clone = Arc::clone(&failures);
538
539 let config: CircuitBreakerConfig<(), ()> = CircuitBreakerConfig {
540 failure_rate_threshold: 0.5,
541 sliding_window_type: crate::config::SlidingWindowType::CountBased,
542 sliding_window_size: 10,
543 sliding_window_duration: None,
544 wait_duration_in_open: Duration::from_secs(1),
545 permitted_calls_in_half_open: 1,
546 failure_classifier: Arc::new(|r| r.is_err()),
547 minimum_number_of_calls: 10,
548 slow_call_duration_threshold: None,
549 slow_call_rate_threshold: 1.0,
550 event_listeners: {
551 let mut listeners = EventListeners::new();
552 listeners.add(tower_resilience_core::FnListener::new(
553 move |event| match event {
554 CircuitBreakerEvent::StateTransition { .. } => {
555 st_clone.fetch_add(1, Ordering::SeqCst);
556 }
557 CircuitBreakerEvent::CallPermitted { .. } => {
558 cp_clone.fetch_add(1, Ordering::SeqCst);
559 }
560 CircuitBreakerEvent::CallRejected { .. } => {
561 cr_clone.fetch_add(1, Ordering::SeqCst);
562 }
563 CircuitBreakerEvent::SuccessRecorded { .. } => {
564 s_clone.fetch_add(1, Ordering::SeqCst);
565 }
566 CircuitBreakerEvent::FailureRecorded { .. } => {
567 f_clone.fetch_add(1, Ordering::SeqCst);
568 }
569 CircuitBreakerEvent::SlowCallDetected { .. } => {}
570 },
571 ));
572 listeners
573 },
574 name: "test".into(),
575 };
576
577 let mut circuit = Circuit::new();
578
579 for _ in 0..6 {
581 circuit.record_failure(&config, Duration::from_millis(10));
582 }
583 for _ in 0..4 {
584 circuit.record_success(&config, Duration::from_millis(10));
585 }
586
587 assert_eq!(circuit.state(), CircuitState::Open);
589 assert_eq!(state_transitions.load(Ordering::SeqCst), 1);
590 assert_eq!(failures.load(Ordering::SeqCst), 6);
591 assert_eq!(successes.load(Ordering::SeqCst), 4);
592
593 let permitted = circuit.try_acquire(&config);
595 assert!(!permitted);
596 assert_eq!(call_rejected.load(Ordering::SeqCst), 1);
597 }
598
599 #[test]
600 fn test_slow_call_detection() {
601 use std::sync::atomic::{AtomicUsize, Ordering};
602 use tower_resilience_core::EventListeners;
603
604 let slow_calls = Arc::new(AtomicUsize::new(0));
605 let slow_clone = Arc::clone(&slow_calls);
606
607 let config: CircuitBreakerConfig<(), ()> = CircuitBreakerConfig {
608 failure_rate_threshold: 0.5,
609 sliding_window_type: crate::config::SlidingWindowType::CountBased,
610 sliding_window_size: 10,
611 sliding_window_duration: None,
612 wait_duration_in_open: Duration::from_secs(1),
613 permitted_calls_in_half_open: 1,
614 failure_classifier: Arc::new(|r| r.is_err()),
615 minimum_number_of_calls: 10,
616 slow_call_duration_threshold: Some(Duration::from_millis(100)),
617 slow_call_rate_threshold: 0.5,
618 event_listeners: {
619 let mut listeners = EventListeners::new();
620 listeners.add(tower_resilience_core::FnListener::new(move |event| {
621 if matches!(event, CircuitBreakerEvent::SlowCallDetected { .. }) {
622 slow_clone.fetch_add(1, Ordering::SeqCst);
623 }
624 }));
625 listeners
626 },
627 name: "test".into(),
628 };
629
630 let mut circuit = Circuit::new();
631
632 for _ in 0..6 {
634 circuit.record_success(&config, Duration::from_millis(150));
635 }
636 for _ in 0..4 {
638 circuit.record_success(&config, Duration::from_millis(50));
639 }
640
641 assert_eq!(slow_calls.load(Ordering::SeqCst), 6);
643
644 assert_eq!(circuit.state(), CircuitState::Open);
646 }
647
648 #[test]
649 fn test_slow_call_with_failures() {
650 use tower_resilience_core::EventListeners;
651
652 let config: CircuitBreakerConfig<(), ()> = CircuitBreakerConfig {
653 failure_rate_threshold: 1.0, sliding_window_type: crate::config::SlidingWindowType::CountBased,
655 sliding_window_size: 10,
656 sliding_window_duration: None,
657 wait_duration_in_open: Duration::from_secs(1),
658 permitted_calls_in_half_open: 1,
659 failure_classifier: Arc::new(|r| r.is_err()),
660 minimum_number_of_calls: 10,
661 slow_call_duration_threshold: Some(Duration::from_millis(100)),
662 slow_call_rate_threshold: 0.5,
663 event_listeners: EventListeners::new(),
664 name: "test".into(),
665 };
666
667 let mut circuit = Circuit::new();
668
669 for _ in 0..6 {
671 circuit.record_failure(&config, Duration::from_millis(150));
672 }
673 for _ in 0..4 {
675 circuit.record_success(&config, Duration::from_millis(50));
676 }
677
678 assert_eq!(circuit.state(), CircuitState::Open);
680 }
681
682 #[tokio::test]
683 async fn test_circuit_breaker_sync_state() {
684 let config = Arc::new(dummy_config());
685 let breaker: CircuitBreaker<(), (), (), ()> = CircuitBreaker::new((), config.clone());
686
687 let sync_state = breaker.state_sync();
689 assert_eq!(sync_state, CircuitState::Closed);
690
691 breaker.force_open().await;
693 assert_eq!(breaker.state_sync(), CircuitState::Open);
694 assert_eq!(breaker.state().await, CircuitState::Open);
695 }
696}