Skip to main content

nobreak/
lib.rs

1use std::fmt;
2use std::sync::Mutex;
3use std::time::{Duration, Instant};
4
5pub fn breaker() -> CircuitBreaker {
6    builder().build()
7}
8
9pub async fn call_async<T, TFut, TOk, TError>(f: T) -> Result<TOk, CircuitError<TError>>
10where
11    T: FnOnce() -> TFut,
12    TFut: Future<Output = Result<TOk, TError>>,
13{
14    builder().build().call_async(f).await
15}
16
17pub fn call<T, TOk, TError>(f: T) -> Result<TOk, CircuitError<TError>>
18where
19    T: FnOnce() -> Result<TOk, TError>,
20{
21    builder().build().call(f)
22}
23
24pub fn builder() -> CircuitBreakerBuilder {
25    CircuitBreakerBuilder {
26        failure_threshold: 5,
27        success_threshold: 2,
28        open_duration: Duration::from_secs(30),
29    }
30}
31
32pub struct CircuitBreakerBuilder {
33    failure_threshold: usize,
34    success_threshold: usize,
35    open_duration: Duration,
36}
37
38impl CircuitBreakerBuilder {
39    pub fn with_failure_threshold(mut self, threshold: usize) -> Self {
40        self.failure_threshold = threshold;
41        self
42    }
43
44    pub fn with_success_threshold(mut self, threshold: usize) -> Self {
45        self.success_threshold = threshold;
46        self
47    }
48
49    pub fn with_open_duration(mut self, duration: Duration) -> Self {
50        self.open_duration = duration;
51        self
52    }
53
54    pub fn build(self) -> CircuitBreaker {
55        CircuitBreaker {
56            failure_threshold: self.failure_threshold,
57            success_threshold: self.success_threshold,
58            open_duration: self.open_duration,
59            inner: Mutex::new(Inner {
60                state: State::Closed,
61                failure_count: 0,
62                success_count: 0,
63            }),
64        }
65    }
66}
67
68pub struct CircuitBreaker {
69    failure_threshold: usize,
70    success_threshold: usize,
71    open_duration: Duration,
72    inner: Mutex<Inner>,
73}
74
75struct Inner {
76    state: State,
77    failure_count: usize,
78    success_count: usize,
79}
80
81enum State {
82    Closed,
83    Open { since: Instant },
84    HalfOpen,
85}
86
87impl CircuitBreaker {
88    pub fn call<T, TOk, TError>(&self, f: T) -> Result<TOk, CircuitError<TError>>
89    where
90        T: FnOnce() -> Result<TOk, TError>,
91    {
92        {
93            let mut inner = self.inner.lock().unwrap();
94            match inner.state {
95                State::Open { since } => {
96                    if since.elapsed() >= self.open_duration {
97                        inner.state = State::HalfOpen;
98                        inner.success_count = 0;
99                    } else {
100                        return Err(CircuitError::Open);
101                    }
102                }
103                State::Closed | State::HalfOpen => {}
104            }
105        }
106
107        match f() {
108            Ok(val) => {
109                self.record_success();
110                Ok(val)
111            }
112            Err(err) => {
113                self.record_failure();
114                Err(CircuitError::Failed { error: err })
115            }
116        }
117    }
118
119    pub async fn call_async<T, TFut, TOk, TError>(&self, f: T) -> Result<TOk, CircuitError<TError>>
120    where
121        T: FnOnce() -> TFut,
122        TFut: Future<Output = Result<TOk, TError>>,
123    {
124        {
125            let mut inner = self.inner.lock().unwrap();
126            match inner.state {
127                State::Open { since } => {
128                    if since.elapsed() >= self.open_duration {
129                        inner.state = State::HalfOpen;
130                        inner.success_count = 0;
131                    } else {
132                        return Err(CircuitError::Open);
133                    }
134                }
135                State::Closed | State::HalfOpen => {}
136            }
137        }
138
139        match f().await {
140            Ok(val) => {
141                self.record_success();
142                Ok(val)
143            }
144            Err(err) => {
145                self.record_failure();
146                Err(CircuitError::Failed { error: err })
147            }
148        }
149    }
150
151    fn record_success(&self) {
152        let mut inner = self.inner.lock().unwrap();
153        match inner.state {
154            State::HalfOpen => {
155                inner.success_count += 1;
156                if inner.success_count >= self.success_threshold {
157                    inner.state = State::Closed;
158                    inner.failure_count = 0;
159                    inner.success_count = 0;
160                }
161            }
162            _ => {
163                inner.failure_count = 0;
164            }
165        }
166    }
167
168    fn record_failure(&self) {
169        let mut inner = self.inner.lock().unwrap();
170        match inner.state {
171            State::HalfOpen => {
172                inner.state = State::Open {
173                    since: Instant::now(),
174                };
175                inner.failure_count = 0;
176                inner.success_count = 0;
177            }
178            _ => {
179                inner.failure_count += 1;
180                if inner.failure_count >= self.failure_threshold {
181                    inner.state = State::Open {
182                        since: Instant::now(),
183                    };
184                }
185            }
186        }
187    }
188
189    /// Returns the current state of the circuit breaker.
190    pub fn state(&self) -> CircuitState {
191        let inner = self.inner.lock().unwrap();
192        match inner.state {
193            State::Closed => CircuitState::Closed,
194            State::Open { since } => {
195                if since.elapsed() >= self.open_duration {
196                    CircuitState::HalfOpen
197                } else {
198                    CircuitState::Open
199                }
200            }
201            State::HalfOpen => CircuitState::HalfOpen,
202        }
203    }
204}
205
206/// The observable state of the circuit breaker.
207#[derive(Debug, Clone, Copy, PartialEq, Eq)]
208pub enum CircuitState {
209    Closed,
210    Open,
211    HalfOpen,
212}
213
214#[derive(Debug)]
215pub enum CircuitError<E> {
216    /// The operation was executed but returned an error.
217    Failed { error: E },
218    /// The circuit is open; the operation was not executed.
219    Open,
220}
221
222impl<E> CircuitError<E> {
223    /// Extract the underlying error, if one exists.
224    pub fn into_inner(self) -> Option<E> {
225        match self {
226            CircuitError::Failed { error } => Some(error),
227            CircuitError::Open => None,
228        }
229    }
230}
231
232impl<E: fmt::Display> fmt::Display for CircuitError<E> {
233    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
234        match self {
235            CircuitError::Failed { error } => write!(f, "{error}"),
236            CircuitError::Open => write!(f, "circuit breaker is open"),
237        }
238    }
239}
240
241impl<E: std::error::Error + 'static> std::error::Error for CircuitError<E> {
242    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
243        match self {
244            CircuitError::Failed { error } => Some(error),
245            CircuitError::Open => None,
246        }
247    }
248}