1use std::fmt;
2use std::sync::Mutex;
3use std::time::{Duration, Instant};
4
5pub fn breaker() -> CircuitBreaker {
6 builder().build()
7}
8
9pub async fn call_async<T, TFut, TOk, TError>(f: T) -> Result<TOk, CircuitError<TError>>
10where
11 T: FnOnce() -> TFut,
12 TFut: Future<Output = Result<TOk, TError>>,
13{
14 builder().build().call_async(f).await
15}
16
17pub fn call<T, TOk, TError>(f: T) -> Result<TOk, CircuitError<TError>>
18where
19 T: FnOnce() -> Result<TOk, TError>,
20{
21 builder().build().call(f)
22}
23
24pub fn builder() -> CircuitBreakerBuilder {
25 CircuitBreakerBuilder {
26 failure_threshold: 5,
27 success_threshold: 2,
28 open_duration: Duration::from_secs(30),
29 }
30}
31
32pub struct CircuitBreakerBuilder {
33 failure_threshold: usize,
34 success_threshold: usize,
35 open_duration: Duration,
36}
37
38impl CircuitBreakerBuilder {
39 pub fn with_failure_threshold(mut self, threshold: usize) -> Self {
40 self.failure_threshold = threshold;
41 self
42 }
43
44 pub fn with_success_threshold(mut self, threshold: usize) -> Self {
45 self.success_threshold = threshold;
46 self
47 }
48
49 pub fn with_open_duration(mut self, duration: Duration) -> Self {
50 self.open_duration = duration;
51 self
52 }
53
54 pub fn build(self) -> CircuitBreaker {
55 CircuitBreaker {
56 failure_threshold: self.failure_threshold,
57 success_threshold: self.success_threshold,
58 open_duration: self.open_duration,
59 inner: Mutex::new(Inner {
60 state: State::Closed,
61 failure_count: 0,
62 success_count: 0,
63 }),
64 }
65 }
66}
67
68pub struct CircuitBreaker {
69 failure_threshold: usize,
70 success_threshold: usize,
71 open_duration: Duration,
72 inner: Mutex<Inner>,
73}
74
75struct Inner {
76 state: State,
77 failure_count: usize,
78 success_count: usize,
79}
80
81enum State {
82 Closed,
83 Open { since: Instant },
84 HalfOpen,
85}
86
87impl CircuitBreaker {
88 pub fn call<T, TOk, TError>(&self, f: T) -> Result<TOk, CircuitError<TError>>
89 where
90 T: FnOnce() -> Result<TOk, TError>,
91 {
92 {
93 let mut inner = self.inner.lock().unwrap();
94 match inner.state {
95 State::Open { since } => {
96 if since.elapsed() >= self.open_duration {
97 inner.state = State::HalfOpen;
98 inner.success_count = 0;
99 } else {
100 return Err(CircuitError::Open);
101 }
102 }
103 State::Closed | State::HalfOpen => {}
104 }
105 }
106
107 match f() {
108 Ok(val) => {
109 self.record_success();
110 Ok(val)
111 }
112 Err(err) => {
113 self.record_failure();
114 Err(CircuitError::Failed { error: err })
115 }
116 }
117 }
118
119 pub async fn call_async<T, TFut, TOk, TError>(&self, f: T) -> Result<TOk, CircuitError<TError>>
120 where
121 T: FnOnce() -> TFut,
122 TFut: Future<Output = Result<TOk, TError>>,
123 {
124 {
125 let mut inner = self.inner.lock().unwrap();
126 match inner.state {
127 State::Open { since } => {
128 if since.elapsed() >= self.open_duration {
129 inner.state = State::HalfOpen;
130 inner.success_count = 0;
131 } else {
132 return Err(CircuitError::Open);
133 }
134 }
135 State::Closed | State::HalfOpen => {}
136 }
137 }
138
139 match f().await {
140 Ok(val) => {
141 self.record_success();
142 Ok(val)
143 }
144 Err(err) => {
145 self.record_failure();
146 Err(CircuitError::Failed { error: err })
147 }
148 }
149 }
150
151 fn record_success(&self) {
152 let mut inner = self.inner.lock().unwrap();
153 match inner.state {
154 State::HalfOpen => {
155 inner.success_count += 1;
156 if inner.success_count >= self.success_threshold {
157 inner.state = State::Closed;
158 inner.failure_count = 0;
159 inner.success_count = 0;
160 }
161 }
162 _ => {
163 inner.failure_count = 0;
164 }
165 }
166 }
167
168 fn record_failure(&self) {
169 let mut inner = self.inner.lock().unwrap();
170 match inner.state {
171 State::HalfOpen => {
172 inner.state = State::Open {
173 since: Instant::now(),
174 };
175 inner.failure_count = 0;
176 inner.success_count = 0;
177 }
178 _ => {
179 inner.failure_count += 1;
180 if inner.failure_count >= self.failure_threshold {
181 inner.state = State::Open {
182 since: Instant::now(),
183 };
184 }
185 }
186 }
187 }
188
189 pub fn state(&self) -> CircuitState {
191 let inner = self.inner.lock().unwrap();
192 match inner.state {
193 State::Closed => CircuitState::Closed,
194 State::Open { since } => {
195 if since.elapsed() >= self.open_duration {
196 CircuitState::HalfOpen
197 } else {
198 CircuitState::Open
199 }
200 }
201 State::HalfOpen => CircuitState::HalfOpen,
202 }
203 }
204}
205
206#[derive(Debug, Clone, Copy, PartialEq, Eq)]
208pub enum CircuitState {
209 Closed,
210 Open,
211 HalfOpen,
212}
213
214#[derive(Debug)]
215pub enum CircuitError<E> {
216 Failed { error: E },
218 Open,
220}
221
222impl<E> CircuitError<E> {
223 pub fn into_inner(self) -> Option<E> {
225 match self {
226 CircuitError::Failed { error } => Some(error),
227 CircuitError::Open => None,
228 }
229 }
230}
231
232impl<E: fmt::Display> fmt::Display for CircuitError<E> {
233 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
234 match self {
235 CircuitError::Failed { error } => write!(f, "{error}"),
236 CircuitError::Open => write!(f, "circuit breaker is open"),
237 }
238 }
239}
240
241impl<E: std::error::Error + 'static> std::error::Error for CircuitError<E> {
242 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
243 match self {
244 CircuitError::Failed { error } => Some(error),
245 CircuitError::Open => None,
246 }
247 }
248}