failsafe/
circuit_breaker.rs

1use super::error::Error;
2use super::failure_policy::FailurePolicy;
3use super::failure_predicate::{self, FailurePredicate};
4use super::instrument::Instrument;
5use super::state_machine::StateMachine;
6
7/// A circuit breaker's public interface.
8pub trait CircuitBreaker {
9    /// Requests permission to call.
10    ///
11    /// It returns `true` if a call is allowed, or `false` if prohibited.
12    fn is_call_permitted(&self) -> bool;
13
14    /// Executes a given function within circuit breaker.
15    ///
16    /// Depending on function result value, the call will be recorded as success or failure.
17    #[inline]
18    fn call<F, E, R>(&self, f: F) -> Result<R, Error<E>>
19    where
20        F: FnOnce() -> Result<R, E>,
21    {
22        self.call_with(failure_predicate::Any, f)
23    }
24
25    /// Executes a given function within circuit breaker.
26    ///
27    /// Depending on function result value, the call will be recorded as success or failure.
28    /// It checks error by the provided predicate. If the predicate returns `true` for the
29    /// error, the call is recorded as failure otherwise considered this error as a success.
30    fn call_with<P, F, E, R>(&self, predicate: P, f: F) -> Result<R, Error<E>>
31    where
32        P: FailurePredicate<E>,
33        F: FnOnce() -> Result<R, E>;
34}
35
36impl<POLICY, INSTRUMENT> CircuitBreaker for StateMachine<POLICY, INSTRUMENT>
37where
38    POLICY: FailurePolicy,
39    INSTRUMENT: Instrument,
40{
41    #[inline]
42    fn is_call_permitted(&self) -> bool {
43        self.is_call_permitted()
44    }
45
46    fn call_with<P, F, E, R>(&self, predicate: P, f: F) -> Result<R, Error<E>>
47    where
48        P: FailurePredicate<E>,
49        F: FnOnce() -> Result<R, E>,
50    {
51        if !self.is_call_permitted() {
52            return Err(Error::Rejected);
53        }
54
55        match f() {
56            Ok(ok) => {
57                self.on_success();
58                Ok(ok)
59            }
60            Err(err) => {
61                if predicate.is_err(&err) {
62                    self.on_error();
63                } else {
64                    self.on_success();
65                }
66                Err(Error::Inner(err))
67            }
68        }
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use std::time::Duration;
75
76    use super::super::backoff;
77    use super::super::config::Config;
78    use super::super::failure_policy::consecutive_failures;
79    use super::*;
80
81    #[test]
82    fn call_with() {
83        let circuit_breaker = new_circuit_breaker();
84        let is_err = |err: &bool| !(*err);
85
86        for _ in 0..2 {
87            match circuit_breaker.call_with(is_err, || Err::<(), _>(true)) {
88                Err(Error::Inner(true)) => {}
89                x => unreachable!("{:?}", x),
90            }
91            assert!(circuit_breaker.is_call_permitted());
92        }
93
94        match circuit_breaker.call_with(is_err, || Err::<(), _>(false)) {
95            Err(Error::Inner(false)) => {}
96            x => unreachable!("{:?}", x),
97        }
98        assert!(!circuit_breaker.is_call_permitted());
99    }
100
101    #[test]
102    fn call_ok() {
103        let circuit_breaker = new_circuit_breaker();
104
105        circuit_breaker.call(|| Ok::<_, ()>(())).unwrap();
106        assert!(circuit_breaker.is_call_permitted());
107    }
108
109    #[test]
110    fn call_err() {
111        let circuit_breaker = new_circuit_breaker();
112
113        match circuit_breaker.call(|| Err::<(), _>(())) {
114            Err(Error::Inner(())) => {}
115            x => unreachable!("{:?}", x),
116        }
117        assert!(!circuit_breaker.is_call_permitted());
118
119        match circuit_breaker.call(|| Err::<(), _>(())) {
120            Err(Error::Rejected) => {}
121            x => unreachable!("{:?}", x),
122        }
123        assert!(!circuit_breaker.is_call_permitted());
124    }
125
126    fn new_circuit_breaker() -> impl CircuitBreaker {
127        let backoff = backoff::constant(Duration::from_secs(5));
128        let policy = consecutive_failures(1, backoff);
129        Config::new().failure_policy(policy).build()
130    }
131}