stygian_proxy/
circuit_breaker.rs1use std::sync::atomic::{AtomicU8, AtomicU32, AtomicU64, Ordering};
17use std::time::{SystemTime, UNIX_EPOCH};
18
19pub const STATE_CLOSED: u8 = 0;
20pub const STATE_OPEN: u8 = 1;
21pub const STATE_HALF_OPEN: u8 = 2;
22
23pub struct CircuitBreaker {
28 state: AtomicU8,
29 failure_count: AtomicU32,
30 last_failure: AtomicU64,
32 threshold: u32,
33 half_open_after_ms: u64,
34}
35
36impl CircuitBreaker {
37 pub const fn new(threshold: u32, half_open_after_ms: u64) -> Self {
39 Self {
40 state: AtomicU8::new(STATE_CLOSED),
41 failure_count: AtomicU32::new(0),
42 last_failure: AtomicU64::new(0),
43 threshold,
44 half_open_after_ms,
45 }
46 }
47
48 #[inline]
50 pub fn state(&self) -> u8 {
51 self.state.load(Ordering::Acquire)
52 }
53
54 pub fn is_available(&self) -> bool {
59 match self.state.load(Ordering::Acquire) {
60 STATE_CLOSED | STATE_HALF_OPEN => true,
61 STATE_OPEN => {
62 let elapsed_ms = now_ms().saturating_sub(self.last_failure.load(Ordering::Acquire));
63 if elapsed_ms >= self.half_open_after_ms {
64 let _ = self.state.compare_exchange(
68 STATE_OPEN,
69 STATE_HALF_OPEN,
70 Ordering::AcqRel,
71 Ordering::Acquire,
72 );
73 true
74 } else {
75 false
76 }
77 }
78 _ => false,
79 }
80 }
81
82 pub fn record_success(&self) {
87 if self
88 .state
89 .compare_exchange(
90 STATE_HALF_OPEN,
91 STATE_CLOSED,
92 Ordering::AcqRel,
93 Ordering::Acquire,
94 )
95 .is_ok()
96 {
97 self.failure_count.store(0, Ordering::Release);
98 }
99 }
100
101 pub fn record_failure(&self) {
107 let count = self.failure_count.fetch_add(1, Ordering::AcqRel) + 1;
108 self.last_failure.store(now_ms(), Ordering::Release);
109
110 let current_state = self.state.load(Ordering::Acquire);
111 if current_state == STATE_CLOSED && count >= self.threshold {
112 let _ = self.state.compare_exchange(
114 STATE_CLOSED,
115 STATE_OPEN,
116 Ordering::AcqRel,
117 Ordering::Acquire,
118 );
119 } else if current_state == STATE_HALF_OPEN {
120 let _ = self.state.compare_exchange(
122 STATE_HALF_OPEN,
123 STATE_OPEN,
124 Ordering::AcqRel,
125 Ordering::Acquire,
126 );
127 }
128 }
129}
130
131#[inline]
132fn now_ms() -> u64 {
133 SystemTime::now()
134 .duration_since(UNIX_EPOCH)
135 .unwrap_or_default()
136 .as_millis()
137 .try_into()
138 .unwrap_or(u64::MAX)
139}
140
141#[cfg(test)]
146mod tests {
147 use std::sync::Arc;
148
149 use super::*;
150
151 fn breaker(threshold: u32, half_open_after_ms: u64) -> CircuitBreaker {
152 CircuitBreaker::new(threshold, half_open_after_ms)
153 }
154
155 #[test]
156 fn failures_open_circuit() {
157 let cb = breaker(3, 30_000);
158 assert_eq!(cb.state(), STATE_CLOSED);
159 cb.record_failure();
160 cb.record_failure();
161 assert_eq!(cb.state(), STATE_CLOSED, "not tripped yet");
162 cb.record_failure();
163 assert_eq!(cb.state(), STATE_OPEN, "should be open after threshold");
164 assert!(!cb.is_available());
165 }
166
167 #[test]
168 fn half_open_after_elapsed() {
169 let cb = breaker(1, 0); cb.record_failure();
171 assert_eq!(cb.state(), STATE_OPEN);
172 assert!(cb.is_available(), "should transition to half-open");
175 assert_eq!(cb.state(), STATE_HALF_OPEN);
176 }
177
178 #[test]
179 fn success_in_half_open_closes_circuit() {
180 let cb = breaker(1, 0);
181 cb.record_failure();
182 assert!(cb.is_available()); cb.record_success();
184 assert_eq!(cb.state(), STATE_CLOSED);
185 assert!(cb.is_available());
186 }
187
188 #[test]
189 fn failure_in_half_open_reopens() {
190 let cb = breaker(1, 0);
191 cb.record_failure();
192 assert!(cb.is_available()); cb.record_failure(); assert_eq!(cb.state(), STATE_OPEN);
195 }
196
197 #[test]
198 fn concurrent_failures_open_circuit() {
199 use std::thread;
200 let cb = Arc::new(breaker(5, 30_000));
201 let handles: Vec<_> = (0..100)
202 .map(|_| {
203 let cb = Arc::clone(&cb);
204 thread::spawn(move || cb.record_failure())
205 })
206 .collect();
207 for h in handles {
208 assert!(h.join().is_ok(), "worker thread should not panic");
209 }
210 assert_eq!(cb.state(), STATE_OPEN);
211 assert!(cb.failure_count.load(Ordering::Relaxed) >= 5);
212 }
213}