1use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum CircuitState {
13 Closed,
15 Open,
17 HalfOpen,
19}
20
21#[derive(Debug, Clone)]
23pub struct CircuitConfig {
24 pub failure_threshold: u32,
26 pub success_threshold: u32,
28 pub open_timeout: Duration,
30 pub failure_window: Duration,
32 pub min_calls: u32,
34}
35
36impl Default for CircuitConfig {
37 fn default() -> Self {
38 Self {
39 failure_threshold: 5,
40 success_threshold: 3,
41 open_timeout: Duration::from_secs(30),
42 failure_window: Duration::from_secs(60),
43 min_calls: 10,
44 }
45 }
46}
47
48#[derive(Debug, Clone)]
50pub struct CircuitStats {
51 pub state: CircuitState,
52 pub consecutive_failures: u32,
53 pub consecutive_successes: u32,
54 pub total_calls: u64,
55 pub total_failures: u64,
56 pub total_successes: u64,
57 pub total_rejected: u64,
58 pub last_failure: Option<Instant>,
59 pub last_state_change: Instant,
60}
61
62#[derive(Debug)]
64pub struct CircuitBreaker {
65 module_name: String,
67 state: std::sync::Mutex<CircuitState>,
69 consecutive_failures: AtomicU32,
71 consecutive_successes: AtomicU32,
73 total_calls: AtomicU64,
75 total_failures: AtomicU64,
77 total_successes: AtomicU64,
79 total_rejected: AtomicU64,
81 last_failure: std::sync::Mutex<Option<Instant>>,
83 last_state_change: std::sync::Mutex<Instant>,
85 config: CircuitConfig,
87}
88
89impl CircuitBreaker {
90 pub fn new(module_name: String, config: CircuitConfig) -> Self {
92 Self {
93 module_name,
94 state: std::sync::Mutex::new(CircuitState::Closed),
95 consecutive_failures: AtomicU32::new(0),
96 consecutive_successes: AtomicU32::new(0),
97 total_calls: AtomicU64::new(0),
98 total_failures: AtomicU64::new(0),
99 total_successes: AtomicU64::new(0),
100 total_rejected: AtomicU64::new(0),
101 last_failure: std::sync::Mutex::new(None),
102 last_state_change: std::sync::Mutex::new(Instant::now()),
103 config,
104 }
105 }
106
107 pub fn with_defaults(module_name: String) -> Self {
109 Self::new(module_name, CircuitConfig::default())
110 }
111
112 pub fn can_execute(&self) -> bool {
114 let mut state = self.state.lock().unwrap();
115
116 match *state {
117 CircuitState::Closed => true,
118 CircuitState::Open => {
119 let elapsed = self.last_state_change.lock().unwrap().elapsed();
121 if elapsed >= self.config.open_timeout {
122 *state = CircuitState::HalfOpen;
124 *self.last_state_change.lock().unwrap() = Instant::now();
125 self.consecutive_successes.store(0, Ordering::Relaxed);
126 true
127 } else {
128 self.total_rejected.fetch_add(1, Ordering::Relaxed);
129 false
130 }
131 }
132 CircuitState::HalfOpen => true,
133 }
134 }
135
136 pub fn record_success(&self) {
138 self.total_calls.fetch_add(1, Ordering::Relaxed);
139 self.total_successes.fetch_add(1, Ordering::Relaxed);
140 self.consecutive_failures.store(0, Ordering::Relaxed);
141
142 let mut state = self.state.lock().unwrap();
143
144 if *state == CircuitState::HalfOpen {
145 let successes = self.consecutive_successes.fetch_add(1, Ordering::Relaxed) + 1;
146 if successes >= self.config.success_threshold {
147 *state = CircuitState::Closed;
149 *self.last_state_change.lock().unwrap() = Instant::now();
150 self.consecutive_failures.store(0, Ordering::Relaxed);
151 }
152 }
153 }
154
155 pub fn record_failure(&self) {
157 self.total_calls.fetch_add(1, Ordering::Relaxed);
158 self.total_failures.fetch_add(1, Ordering::Relaxed);
159
160 let failures = self.consecutive_failures.fetch_add(1, Ordering::Relaxed) + 1;
161 *self.last_failure.lock().unwrap() = Some(Instant::now());
162
163 let mut state = self.state.lock().unwrap();
164
165 if *state == CircuitState::HalfOpen {
166 *state = CircuitState::Open;
168 *self.last_state_change.lock().unwrap() = Instant::now();
169 } else if *state == CircuitState::Closed {
170 if failures >= self.config.failure_threshold {
172 let total = self.total_calls.load(Ordering::Relaxed);
173 if total >= self.config.min_calls as u64 {
174 *state = CircuitState::Open;
175 *self.last_state_change.lock().unwrap() = Instant::now();
176 }
177 }
178 }
179 }
180
181 pub fn state(&self) -> CircuitState {
183 *self.state.lock().unwrap()
184 }
185
186 pub fn stats(&self) -> CircuitStats {
188 CircuitStats {
189 state: self.state(),
190 consecutive_failures: self.consecutive_failures.load(Ordering::Relaxed),
191 consecutive_successes: self.consecutive_successes.load(Ordering::Relaxed),
192 total_calls: self.total_calls.load(Ordering::Relaxed),
193 total_failures: self.total_failures.load(Ordering::Relaxed),
194 total_successes: self.total_successes.load(Ordering::Relaxed),
195 total_rejected: self.total_rejected.load(Ordering::Relaxed),
196 last_failure: *self.last_failure.lock().unwrap(),
197 last_state_change: *self.last_state_change.lock().unwrap(),
198 }
199 }
200
201 pub fn module_name(&self) -> &str {
203 &self.module_name
204 }
205
206 pub fn force_open(&self) {
208 *self.state.lock().unwrap() = CircuitState::Open;
209 *self.last_state_change.lock().unwrap() = Instant::now();
210 }
211
212 pub fn force_close(&self) {
214 *self.state.lock().unwrap() = CircuitState::Closed;
215 *self.last_state_change.lock().unwrap() = Instant::now();
216 self.consecutive_failures.store(0, Ordering::Relaxed);
217 }
218
219 pub fn reset(&self) {
221 self.consecutive_failures.store(0, Ordering::Relaxed);
222 self.consecutive_successes.store(0, Ordering::Relaxed);
223 self.total_calls.store(0, Ordering::Relaxed);
224 self.total_failures.store(0, Ordering::Relaxed);
225 self.total_successes.store(0, Ordering::Relaxed);
226 self.total_rejected.store(0, Ordering::Relaxed);
227 *self.last_failure.lock().unwrap() = None;
228 self.force_close();
229 }
230}
231
232#[derive(Debug)]
234pub struct CircuitRegistry {
235 circuits: DashMap<String, Arc<CircuitBreaker>>,
237 default_config: CircuitConfig,
239}
240
241impl CircuitRegistry {
242 pub fn new() -> Self {
244 Self {
245 circuits: DashMap::new(),
246 default_config: CircuitConfig::default(),
247 }
248 }
249
250 pub fn with_config(config: CircuitConfig) -> Self {
252 Self {
253 circuits: DashMap::new(),
254 default_config: config,
255 }
256 }
257
258 pub fn get_or_create(&self, module_name: &str) -> Arc<CircuitBreaker> {
260 self.circuits
261 .entry(module_name.to_string())
262 .or_insert_with(|| {
263 Arc::new(CircuitBreaker::new(
264 module_name.to_string(),
265 self.default_config.clone(),
266 ))
267 })
268 .clone()
269 }
270
271 pub fn register(&self, module_name: &str, config: CircuitConfig) {
273 self.circuits.insert(
274 module_name.to_string(),
275 Arc::new(CircuitBreaker::new(module_name.to_string(), config)),
276 );
277 }
278
279 pub fn can_execute(&self, module_name: &str) -> bool {
281 self.get_or_create(module_name).can_execute()
282 }
283
284 pub fn record_success(&self, module_name: &str) {
286 self.get_or_create(module_name).record_success();
287 }
288
289 pub fn record_failure(&self, module_name: &str) {
291 self.get_or_create(module_name).record_failure();
292 }
293
294 pub fn stats(&self, module_name: &str) -> Option<CircuitStats> {
296 self.circuits.get(module_name).map(|c| c.stats())
297 }
298
299 pub fn all_stats(&self) -> Vec<(String, CircuitStats)> {
301 self.circuits
302 .iter()
303 .map(|e| (e.key().clone(), e.value().stats()))
304 .collect()
305 }
306
307 pub fn circuit_count(&self) -> usize {
309 self.circuits.len()
310 }
311
312 pub fn open_circuit_count(&self) -> usize {
314 self.circuits
315 .iter()
316 .filter(|e| *e.value().state.lock().unwrap() == CircuitState::Open)
317 .count()
318 }
319}
320
321impl Default for CircuitRegistry {
322 fn default() -> Self {
323 Self::new()
324 }
325}
326
327use dashmap::DashMap;
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332
333 #[test]
334 fn test_circuit_closed_initially() {
335 let cb = CircuitBreaker::with_defaults("test".to_string());
336 assert_eq!(cb.state(), CircuitState::Closed);
337 assert!(cb.can_execute());
338 }
339
340 #[test]
341 fn test_circuit_opens_after_failures() {
342 let config = CircuitConfig {
343 failure_threshold: 3,
344 min_calls: 0,
345 ..CircuitConfig::default()
346 };
347 let cb = CircuitBreaker::new("test".to_string(), config);
348
349 for _ in 0..3 {
351 cb.record_failure();
352 }
353
354 assert_eq!(cb.state(), CircuitState::Open);
355 assert!(!cb.can_execute());
356 }
357
358 #[test]
359 fn test_circuit_half_open_after_timeout() {
360 let config = CircuitConfig {
361 failure_threshold: 1,
362 min_calls: 0,
363 open_timeout: Duration::from_millis(100),
364 ..CircuitConfig::default()
365 };
366 let cb = CircuitBreaker::new("test".to_string(), config);
367
368 cb.record_failure();
370 assert_eq!(cb.state(), CircuitState::Open);
371
372 std::thread::sleep(Duration::from_millis(150));
374
375 assert!(cb.can_execute());
377 assert_eq!(cb.state(), CircuitState::HalfOpen);
378 }
379
380 #[test]
381 fn test_circuit_closes_after_successes() {
382 let config = CircuitConfig {
383 failure_threshold: 1,
384 success_threshold: 2,
385 min_calls: 0,
386 open_timeout: Duration::from_millis(10),
387 ..CircuitConfig::default()
388 };
389 let cb = CircuitBreaker::new("test".to_string(), config);
390
391 cb.record_failure();
393
394 std::thread::sleep(Duration::from_millis(50));
396 cb.can_execute(); cb.record_success();
400 cb.record_success();
401
402 assert_eq!(cb.state(), CircuitState::Closed);
403 }
404
405 #[test]
406 fn test_circuit_reopens_on_failure_in_half_open() {
407 let config = CircuitConfig {
408 failure_threshold: 1,
409 success_threshold: 2,
410 min_calls: 0,
411 open_timeout: Duration::from_millis(10),
412 ..CircuitConfig::default()
413 };
414 let cb = CircuitBreaker::new("test".to_string(), config);
415
416 cb.record_failure();
418
419 std::thread::sleep(Duration::from_millis(50));
421 cb.can_execute();
422
423 cb.record_success();
425
426 cb.record_failure();
428
429 assert_eq!(cb.state(), CircuitState::Open);
430 }
431
432 #[test]
433 fn test_circuit_registry() {
434 let registry = CircuitRegistry::new();
435
436 let cb1 = registry.get_or_create("module1");
437 let cb2 = registry.get_or_create("module2");
438
439 assert_eq!(registry.circuit_count(), 2);
440
441 cb1.record_failure();
442 cb2.record_success();
443
444 let stats1 = registry.stats("module1").unwrap();
445 let stats2 = registry.stats("module2").unwrap();
446
447 assert_eq!(stats1.total_failures, 1);
448 assert_eq!(stats2.total_successes, 1);
449 }
450}