hyperi_rustlib/tiered_sink/
circuit.rs1use std::sync::Arc;
12use std::sync::atomic::{AtomicU8, AtomicU32, AtomicU64, Ordering};
13use std::time::Duration;
14use tokio::sync::RwLock;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum CircuitState {
19 Closed,
21 Open,
23 HalfOpen,
25}
26
27pub struct CircuitBreaker {
33 state: RwLock<CircuitState>,
34 consecutive_failures: AtomicU32,
35 failure_threshold: u32,
36 reset_timeout: Duration,
37 last_failure_time: AtomicU64, health_state: Arc<AtomicU8>,
41}
42
43impl CircuitBreaker {
44 #[must_use]
49 pub fn new(failure_threshold: u32, reset_timeout: Duration) -> Self {
50 let health_state = Arc::new(AtomicU8::new(0)); #[cfg(feature = "health")]
53 {
54 let hs = Arc::clone(&health_state);
55 crate::health::HealthRegistry::register("circuit_breaker", move || {
56 match hs.load(Ordering::Relaxed) {
57 0 => crate::health::HealthStatus::Healthy, 2 => crate::health::HealthStatus::Degraded, _ => crate::health::HealthStatus::Unhealthy, }
61 });
62 }
63
64 Self {
65 state: RwLock::new(CircuitState::Closed),
66 consecutive_failures: AtomicU32::new(0),
67 failure_threshold,
68 reset_timeout,
69 last_failure_time: AtomicU64::new(0),
70 health_state,
71 }
72 }
73
74 fn sync_health_state(&self, state: CircuitState) {
76 let val = match state {
77 CircuitState::Closed => 0,
78 CircuitState::Open => 1,
79 CircuitState::HalfOpen => 2,
80 };
81 self.health_state.store(val, Ordering::Relaxed);
82 }
83
84 pub async fn state(&self) -> CircuitState {
86 let mut state = self.state.write().await;
87
88 if *state == CircuitState::Open {
90 let last_failure = self.last_failure_time.load(Ordering::SeqCst);
91 let now = current_epoch_millis();
92 let elapsed = Duration::from_millis(now.saturating_sub(last_failure));
93
94 if elapsed >= self.reset_timeout {
95 *state = CircuitState::HalfOpen;
96 self.sync_health_state(*state);
97 }
98 }
99
100 *state
101 }
102
103 pub async fn is_closed(&self) -> bool {
105 self.state().await == CircuitState::Closed
106 }
107
108 pub async fn is_open(&self) -> bool {
110 self.state().await == CircuitState::Open
111 }
112
113 pub async fn record_success(&self) {
115 let mut state = self.state.write().await;
116 self.consecutive_failures.store(0, Ordering::SeqCst);
117 *state = CircuitState::Closed;
118 self.sync_health_state(*state);
119 }
120
121 pub async fn record_failure(&self) {
123 let failures = self.consecutive_failures.fetch_add(1, Ordering::SeqCst) + 1;
124 self.last_failure_time
125 .store(current_epoch_millis(), Ordering::SeqCst);
126
127 if failures >= self.failure_threshold {
128 let mut state = self.state.write().await;
129 *state = CircuitState::Open;
130 self.sync_health_state(*state);
131 }
132 }
133
134 #[must_use]
136 pub fn consecutive_failures(&self) -> u32 {
137 self.consecutive_failures.load(Ordering::SeqCst)
138 }
139
140 pub async fn reset(&self) {
142 self.consecutive_failures.store(0, Ordering::SeqCst);
143 let mut state = self.state.write().await;
144 *state = CircuitState::Closed;
145 self.sync_health_state(*state);
146 }
147}
148
149impl std::fmt::Debug for CircuitBreaker {
150 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151 f.debug_struct("CircuitBreaker")
152 .field("failure_threshold", &self.failure_threshold)
153 .field("reset_timeout", &self.reset_timeout)
154 .field("consecutive_failures", &self.consecutive_failures())
155 .finish_non_exhaustive()
156 }
157}
158
159fn current_epoch_millis() -> u64 {
160 use std::time::SystemTime;
161 SystemTime::now()
162 .duration_since(SystemTime::UNIX_EPOCH)
163 .map_or(0, |d| u64::try_from(d.as_millis()).unwrap_or(u64::MAX))
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169
170 #[tokio::test]
171 async fn test_initial_state_is_closed() {
172 let cb = CircuitBreaker::new(3, Duration::from_secs(30));
173 assert_eq!(cb.state().await, CircuitState::Closed);
174 assert!(cb.is_closed().await);
175 }
176
177 #[tokio::test]
178 async fn test_opens_after_threshold() {
179 let cb = CircuitBreaker::new(3, Duration::from_secs(30));
180
181 cb.record_failure().await;
182 assert!(cb.is_closed().await);
183
184 cb.record_failure().await;
185 assert!(cb.is_closed().await);
186
187 cb.record_failure().await;
188 assert!(cb.is_open().await);
189 assert_eq!(cb.consecutive_failures(), 3);
190 }
191
192 #[tokio::test]
193 async fn test_success_resets_failures() {
194 let cb = CircuitBreaker::new(3, Duration::from_secs(30));
195
196 cb.record_failure().await;
197 cb.record_failure().await;
198 assert_eq!(cb.consecutive_failures(), 2);
199
200 cb.record_success().await;
201 assert_eq!(cb.consecutive_failures(), 0);
202 assert!(cb.is_closed().await);
203 }
204
205 #[tokio::test]
206 async fn test_half_open_after_timeout() {
207 let cb = CircuitBreaker::new(1, Duration::from_millis(50));
208
209 cb.record_failure().await;
210 assert!(cb.is_open().await);
211
212 tokio::time::sleep(Duration::from_millis(100)).await;
214
215 assert_eq!(cb.state().await, CircuitState::HalfOpen);
216 }
217
218 #[tokio::test]
219 async fn test_half_open_success_closes() {
220 let cb = CircuitBreaker::new(1, Duration::from_millis(10));
221
222 cb.record_failure().await;
223 tokio::time::sleep(Duration::from_millis(20)).await;
224
225 assert_eq!(cb.state().await, CircuitState::HalfOpen);
226
227 cb.record_success().await;
228 assert!(cb.is_closed().await);
229 }
230
231 #[tokio::test]
232 async fn test_half_open_failure_reopens() {
233 let cb = CircuitBreaker::new(1, Duration::from_millis(10));
234
235 cb.record_failure().await;
236 tokio::time::sleep(Duration::from_millis(20)).await;
237
238 assert_eq!(cb.state().await, CircuitState::HalfOpen);
239
240 cb.record_failure().await;
241 assert!(cb.is_open().await);
242 }
243
244 #[tokio::test]
245 async fn test_reset() {
246 let cb = CircuitBreaker::new(1, Duration::from_secs(30));
247
248 cb.record_failure().await;
249 assert!(cb.is_open().await);
250
251 cb.reset().await;
252 assert!(cb.is_closed().await);
253 assert_eq!(cb.consecutive_failures(), 0);
254 }
255}