1use std::{
6 future::Future,
7 sync::{
8 atomic::{AtomicU32, AtomicU64, Ordering},
9 Arc,
10 },
11 time::{Duration, Instant},
12};
13
14use dashmap::DashMap;
15use parking_lot::RwLock;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum CircuitState {
20 Closed,
22 Open,
24 HalfOpen,
26}
27
28impl std::fmt::Display for CircuitState {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 match self {
31 CircuitState::Closed => write!(f, "closed"),
32 CircuitState::Open => write!(f, "open"),
33 CircuitState::HalfOpen => write!(f, "half-open"),
34 }
35 }
36}
37
38#[derive(Debug, Clone)]
40pub struct CircuitOpenError {
41 pub circuit_name: String,
43 pub retry_after: Duration,
45}
46
47impl std::fmt::Display for CircuitOpenError {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 write!(
50 f,
51 "circuit '{}' is open, retry after {:?}",
52 self.circuit_name, self.retry_after
53 )
54 }
55}
56
57impl std::error::Error for CircuitOpenError {}
58
59#[derive(Debug, Clone)]
61pub struct CircuitBreakerConfig {
62 pub failure_threshold: u32,
64 pub success_threshold: u32,
66 pub timeout: Duration,
68 pub failure_window: Duration,
70 pub half_open_requests: u32,
72}
73
74impl Default for CircuitBreakerConfig {
75 fn default() -> Self {
76 Self {
77 failure_threshold: 5,
78 success_threshold: 3,
79 timeout: Duration::from_secs(30),
80 failure_window: Duration::from_secs(60),
81 half_open_requests: 1,
82 }
83 }
84}
85
86impl CircuitBreakerConfig {
87 pub fn new(failure_threshold: u32) -> Self {
89 Self {
90 failure_threshold,
91 ..Default::default()
92 }
93 }
94
95 pub fn with_success_threshold(mut self, threshold: u32) -> Self {
97 self.success_threshold = threshold;
98 self
99 }
100
101 pub fn with_timeout(mut self, timeout: Duration) -> Self {
103 self.timeout = timeout;
104 self
105 }
106
107 pub fn with_failure_window(mut self, window: Duration) -> Self {
109 self.failure_window = window;
110 self
111 }
112
113 pub fn with_half_open_requests(mut self, requests: u32) -> Self {
115 self.half_open_requests = requests.max(1);
116 self
117 }
118}
119
120#[derive(Debug, Clone)]
122pub struct CircuitBreakerStats {
123 pub state: CircuitState,
125 pub success_count: u64,
127 pub failure_count: u64,
129 pub rejected_count: u64,
131 pub failures_in_window: u32,
133 pub half_open_successes: u32,
135 pub time_in_state: Duration,
137}
138
139pub struct CircuitBreaker {
141 name: String,
142 config: CircuitBreakerConfig,
143 state: RwLock<CircuitState>,
144 state_changed_at: RwLock<Instant>,
146 failures: RwLock<Vec<Instant>>,
148 half_open_successes: AtomicU32,
150 half_open_in_flight: AtomicU32,
152 success_count: AtomicU64,
154 failure_count: AtomicU64,
155 rejected_count: AtomicU64,
156}
157
158impl CircuitBreaker {
159 pub fn new(name: impl Into<String>, config: CircuitBreakerConfig) -> Self {
161 Self {
162 name: name.into(),
163 config,
164 state: RwLock::new(CircuitState::Closed),
165 state_changed_at: RwLock::new(Instant::now()),
166 failures: RwLock::new(Vec::new()),
167 half_open_successes: AtomicU32::new(0),
168 half_open_in_flight: AtomicU32::new(0),
169 success_count: AtomicU64::new(0),
170 failure_count: AtomicU64::new(0),
171 rejected_count: AtomicU64::new(0),
172 }
173 }
174
175 pub fn name(&self) -> &str {
177 &self.name
178 }
179
180 pub fn check(&self) -> Result<(), CircuitOpenError> {
185 self.maybe_transition_to_half_open();
186
187 let state = *self.state.read();
188
189 match state {
190 CircuitState::Closed => Ok(()),
191 CircuitState::Open => {
192 self.rejected_count.fetch_add(1, Ordering::Relaxed);
193 let elapsed = self.state_changed_at.read().elapsed();
194 let retry_after = self.config.timeout.saturating_sub(elapsed);
195 Err(CircuitOpenError {
196 circuit_name: self.name.clone(),
197 retry_after,
198 })
199 }
200 CircuitState::HalfOpen => {
201 let in_flight = self.half_open_in_flight.load(Ordering::Acquire);
203 if in_flight < self.config.half_open_requests {
204 self.half_open_in_flight.fetch_add(1, Ordering::AcqRel);
205 Ok(())
206 } else {
207 self.rejected_count.fetch_add(1, Ordering::Relaxed);
208 Err(CircuitOpenError {
209 circuit_name: self.name.clone(),
210 retry_after: Duration::from_millis(100),
211 })
212 }
213 }
214 }
215 }
216
217 pub fn record_success(&self) {
219 self.success_count.fetch_add(1, Ordering::Relaxed);
220
221 let state = *self.state.read();
222
223 if state == CircuitState::HalfOpen {
224 self.half_open_in_flight.fetch_sub(1, Ordering::AcqRel);
225 let successes = self.half_open_successes.fetch_add(1, Ordering::AcqRel) + 1;
226
227 if successes >= self.config.success_threshold {
228 self.transition_to(CircuitState::Closed);
229 }
230 }
231 }
232
233 pub fn record_failure(&self) {
235 self.failure_count.fetch_add(1, Ordering::Relaxed);
236
237 let state = *self.state.read();
238
239 match state {
240 CircuitState::Closed => {
241 let now = Instant::now();
242 let mut failures = self.failures.write();
243
244 failures.push(now);
246
247 let cutoff = now - self.config.failure_window;
249 failures.retain(|&t| t > cutoff);
250
251 if failures.len() as u32 >= self.config.failure_threshold {
253 drop(failures); self.transition_to(CircuitState::Open);
255 }
256 }
257 CircuitState::HalfOpen => {
258 self.half_open_in_flight.fetch_sub(1, Ordering::AcqRel);
259 self.transition_to(CircuitState::Open);
261 }
262 CircuitState::Open => {
263 }
265 }
266 }
267
268 pub async fn call<F, Fut, T, E>(&self, f: F) -> Result<T, CircuitBreakerError<E>>
270 where
271 F: FnOnce() -> Fut,
272 Fut: Future<Output = Result<T, E>>,
273 {
274 self.check().map_err(CircuitBreakerError::CircuitOpen)?;
275
276 match f().await {
277 Ok(result) => {
278 self.record_success();
279 Ok(result)
280 }
281 Err(e) => {
282 self.record_failure();
283 Err(CircuitBreakerError::Inner(e))
284 }
285 }
286 }
287
288 pub fn get_state(&self) -> CircuitState {
290 self.maybe_transition_to_half_open();
291 *self.state.read()
292 }
293
294 pub fn get_stats(&self) -> CircuitBreakerStats {
296 self.maybe_transition_to_half_open();
297
298 let state = *self.state.read();
299 let failures = self.failures.read();
300 let now = Instant::now();
301 let cutoff = now - self.config.failure_window;
302 let failures_in_window = failures.iter().filter(|&&t| t > cutoff).count() as u32;
303
304 CircuitBreakerStats {
305 state,
306 success_count: self.success_count.load(Ordering::Relaxed),
307 failure_count: self.failure_count.load(Ordering::Relaxed),
308 rejected_count: self.rejected_count.load(Ordering::Relaxed),
309 failures_in_window,
310 half_open_successes: self.half_open_successes.load(Ordering::Relaxed),
311 time_in_state: self.state_changed_at.read().elapsed(),
312 }
313 }
314
315 pub fn reset(&self) {
317 self.transition_to(CircuitState::Closed);
318 self.failures.write().clear();
319 }
320
321 fn transition_to(&self, new_state: CircuitState) {
322 let mut state = self.state.write();
323 let old_state = *state;
324
325 if old_state != new_state {
326 *state = new_state;
327 *self.state_changed_at.write() = Instant::now();
328
329 if new_state == CircuitState::HalfOpen || new_state == CircuitState::Closed {
331 self.half_open_successes.store(0, Ordering::Relaxed);
332 self.half_open_in_flight.store(0, Ordering::Relaxed);
333 }
334
335 if new_state == CircuitState::Closed {
337 self.failures.write().clear();
338 }
339
340 #[cfg(feature = "otel")]
341 tracing::info!(
342 circuit = %self.name,
343 old_state = %old_state,
344 new_state = %new_state,
345 "circuit breaker state changed"
346 );
347 }
348 }
349
350 fn maybe_transition_to_half_open(&self) {
351 let state = *self.state.read();
352 if state == CircuitState::Open {
353 let elapsed = self.state_changed_at.read().elapsed();
354 if elapsed >= self.config.timeout {
355 self.transition_to(CircuitState::HalfOpen);
356 }
357 }
358 }
359}
360
361#[derive(Debug)]
363pub enum CircuitBreakerError<E> {
364 CircuitOpen(CircuitOpenError),
366 Inner(E),
368}
369
370impl<E: std::fmt::Display> std::fmt::Display for CircuitBreakerError<E> {
371 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
372 match self {
373 CircuitBreakerError::CircuitOpen(e) => write!(f, "{}", e),
374 CircuitBreakerError::Inner(e) => write!(f, "{}", e),
375 }
376 }
377}
378
379impl<E: std::error::Error + 'static> std::error::Error for CircuitBreakerError<E> {
380 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
381 match self {
382 CircuitBreakerError::CircuitOpen(e) => Some(e),
383 CircuitBreakerError::Inner(e) => Some(e),
384 }
385 }
386}
387
388pub struct CircuitBreakerManager {
390 breakers: DashMap<String, Arc<CircuitBreaker>>,
391 default_config: CircuitBreakerConfig,
392}
393
394impl CircuitBreakerManager {
395 pub fn new(default_config: CircuitBreakerConfig) -> Self {
397 Self {
398 breakers: DashMap::new(),
399 default_config,
400 }
401 }
402
403 pub fn get_or_create(&self, name: &str) -> Arc<CircuitBreaker> {
405 self.breakers
406 .entry(name.to_string())
407 .or_insert_with(|| Arc::new(CircuitBreaker::new(name, self.default_config.clone())))
408 .clone()
409 }
410
411 pub fn get(&self, name: &str) -> Option<Arc<CircuitBreaker>> {
413 self.breakers.get(name).map(|r| r.clone())
414 }
415
416 pub fn create_with_config(
418 &self,
419 name: &str,
420 config: CircuitBreakerConfig,
421 ) -> Arc<CircuitBreaker> {
422 let breaker = Arc::new(CircuitBreaker::new(name, config));
423 self.breakers.insert(name.to_string(), breaker.clone());
424 breaker
425 }
426
427 pub fn get_all_stats(&self) -> Vec<(String, CircuitBreakerStats)> {
429 self.breakers
430 .iter()
431 .map(|entry| (entry.key().clone(), entry.value().get_stats()))
432 .collect()
433 }
434
435 pub fn reset_all(&self) {
437 for entry in self.breakers.iter() {
438 entry.value().reset();
439 }
440 }
441
442 pub fn remove(&self, name: &str) {
444 self.breakers.remove(name);
445 }
446
447 pub fn clear(&self) {
449 self.breakers.clear();
450 }
451
452 pub fn len(&self) -> usize {
454 self.breakers.len()
455 }
456
457 pub fn is_empty(&self) -> bool {
459 self.breakers.is_empty()
460 }
461}
462
463impl Default for CircuitBreakerManager {
464 fn default() -> Self {
465 Self::new(CircuitBreakerConfig::default())
466 }
467}
468
469pub struct KeyedCircuitBreaker<K: std::hash::Hash + Eq + Clone + Send + Sync + 'static> {
495 breakers: DashMap<K, Arc<CircuitBreaker>>,
496 config: CircuitBreakerConfig,
497 counter: std::sync::atomic::AtomicU64,
499}
500
501impl<K: std::hash::Hash + Eq + Clone + Send + Sync + 'static> KeyedCircuitBreaker<K> {
502 pub fn new(config: CircuitBreakerConfig) -> Self {
506 Self {
507 breakers: DashMap::new(),
508 config,
509 counter: std::sync::atomic::AtomicU64::new(0),
510 }
511 }
512
513 pub fn check(&self, key: &K) -> Result<(), CircuitOpenError> {
518 self.get_or_create(key).check()
519 }
520
521 pub fn record_success(&self, key: &K) {
523 self.get_or_create(key).record_success()
524 }
525
526 pub fn record_failure(&self, key: &K) {
528 self.get_or_create(key).record_failure()
529 }
530
531 pub async fn call<F, Fut, T, E>(&self, key: &K, f: F) -> Result<T, CircuitBreakerError<E>>
534 where
535 F: FnOnce() -> Fut,
536 Fut: Future<Output = Result<T, E>>,
537 {
538 self.get_or_create(key).call(f).await
539 }
540
541 pub fn get_state(&self, key: &K) -> Option<CircuitState> {
543 self.breakers.get(key).map(|cb| cb.get_state())
544 }
545
546 pub fn get_stats(&self, key: &K) -> Option<CircuitBreakerStats> {
548 self.breakers.get(key).map(|cb| cb.get_stats())
549 }
550
551 pub fn get_all_stats(&self) -> Vec<(K, CircuitBreakerStats)> {
553 self.breakers
554 .iter()
555 .map(|entry| (entry.key().clone(), entry.value().get_stats()))
556 .collect()
557 }
558
559 pub fn reset(&self, key: &K) {
561 if let Some(cb) = self.breakers.get(key) {
562 cb.reset();
563 }
564 }
565
566 pub fn reset_all(&self) {
568 for entry in self.breakers.iter() {
569 entry.value().reset();
570 }
571 }
572
573 pub fn remove(&self, key: &K) {
575 self.breakers.remove(key);
576 }
577
578 pub fn clear(&self) {
580 self.breakers.clear();
581 }
582
583 pub fn len(&self) -> usize {
585 self.breakers.len()
586 }
587
588 pub fn is_empty(&self) -> bool {
590 self.breakers.is_empty()
591 }
592
593 pub fn get(&self, key: &K) -> Option<Arc<CircuitBreaker>> {
595 self.breakers.get(key).map(|r| r.clone())
596 }
597
598 fn get_or_create(&self, key: &K) -> Arc<CircuitBreaker> {
599 self.breakers
600 .entry(key.clone())
601 .or_insert_with(|| {
602 let id = self.counter.fetch_add(1, Ordering::Relaxed);
603 Arc::new(CircuitBreaker::new(
604 format!("keyed-{}", id),
605 self.config.clone(),
606 ))
607 })
608 .clone()
609 }
610}
611
612impl<K: std::hash::Hash + Eq + Clone + Send + Sync + 'static> Default for KeyedCircuitBreaker<K> {
613 fn default() -> Self {
614 Self::new(CircuitBreakerConfig::default())
615 }
616}
617
618#[cfg(test)]
619mod tests {
620 use super::*;
621
622 #[test]
623 fn test_circuit_breaker_config_default() {
624 let config = CircuitBreakerConfig::default();
625 assert_eq!(config.failure_threshold, 5);
626 assert_eq!(config.success_threshold, 3);
627 assert_eq!(config.timeout, Duration::from_secs(30));
628 }
629
630 #[test]
631 fn test_circuit_breaker_config_builder() {
632 let config = CircuitBreakerConfig::new(10)
633 .with_success_threshold(5)
634 .with_timeout(Duration::from_secs(60))
635 .with_half_open_requests(3);
636
637 assert_eq!(config.failure_threshold, 10);
638 assert_eq!(config.success_threshold, 5);
639 assert_eq!(config.timeout, Duration::from_secs(60));
640 assert_eq!(config.half_open_requests, 3);
641 }
642
643 #[test]
644 fn test_circuit_breaker_initial_state() {
645 let cb = CircuitBreaker::new("test", CircuitBreakerConfig::default());
646 assert_eq!(cb.get_state(), CircuitState::Closed);
647 }
648
649 #[test]
650 fn test_circuit_breaker_opens_on_failures() {
651 let config = CircuitBreakerConfig::new(3);
652 let cb = CircuitBreaker::new("test", config);
653
654 cb.record_failure();
656 assert_eq!(cb.get_state(), CircuitState::Closed);
657 cb.record_failure();
658 assert_eq!(cb.get_state(), CircuitState::Closed);
659 cb.record_failure();
660 assert_eq!(cb.get_state(), CircuitState::Open);
661 }
662
663 #[test]
664 fn test_circuit_breaker_check_when_open() {
665 let config = CircuitBreakerConfig::new(1);
666 let cb = CircuitBreaker::new("test", config);
667
668 cb.record_failure();
669 assert_eq!(cb.get_state(), CircuitState::Open);
670
671 let result = cb.check();
672 assert!(result.is_err());
673
674 let err = result.unwrap_err();
675 assert_eq!(err.circuit_name, "test");
676 }
677
678 #[test]
679 fn test_circuit_breaker_transitions_to_half_open() {
680 let config = CircuitBreakerConfig::new(1).with_timeout(Duration::from_millis(10));
681 let cb = CircuitBreaker::new("test", config);
682
683 cb.record_failure();
684 assert_eq!(cb.get_state(), CircuitState::Open);
685
686 std::thread::sleep(Duration::from_millis(20));
687
688 assert_eq!(cb.get_state(), CircuitState::HalfOpen);
689 }
690
691 #[test]
692 fn test_circuit_breaker_closes_on_success() {
693 let config = CircuitBreakerConfig::new(1)
694 .with_timeout(Duration::from_millis(10))
695 .with_success_threshold(2);
696 let cb = CircuitBreaker::new("test", config);
697
698 cb.record_failure();
699 std::thread::sleep(Duration::from_millis(20));
700
701 assert_eq!(cb.get_state(), CircuitState::HalfOpen);
702
703 cb.check().unwrap();
705 cb.record_success();
706 assert_eq!(cb.get_state(), CircuitState::HalfOpen);
707
708 cb.check().unwrap();
709 cb.record_success();
710 assert_eq!(cb.get_state(), CircuitState::Closed);
711 }
712
713 #[test]
714 fn test_circuit_breaker_reopens_on_half_open_failure() {
715 let config = CircuitBreakerConfig::new(1).with_timeout(Duration::from_millis(10));
716 let cb = CircuitBreaker::new("test", config);
717
718 cb.record_failure();
719 std::thread::sleep(Duration::from_millis(20));
720
721 assert_eq!(cb.get_state(), CircuitState::HalfOpen);
722
723 cb.check().unwrap();
724 cb.record_failure();
725 assert_eq!(cb.get_state(), CircuitState::Open);
726 }
727
728 #[test]
729 fn test_circuit_breaker_reset() {
730 let config = CircuitBreakerConfig::new(1);
731 let cb = CircuitBreaker::new("test", config);
732
733 cb.record_failure();
734 assert_eq!(cb.get_state(), CircuitState::Open);
735
736 cb.reset();
737 assert_eq!(cb.get_state(), CircuitState::Closed);
738 }
739
740 #[test]
741 fn test_circuit_breaker_stats() {
742 let config = CircuitBreakerConfig::new(5);
743 let cb = CircuitBreaker::new("test", config);
744
745 cb.record_success();
746 cb.record_success();
747 cb.record_failure();
748
749 let stats = cb.get_stats();
750 assert_eq!(stats.state, CircuitState::Closed);
751 assert_eq!(stats.success_count, 2);
752 assert_eq!(stats.failure_count, 1);
753 assert_eq!(stats.failures_in_window, 1);
754 }
755
756 #[tokio::test]
757 async fn test_circuit_breaker_call_success() {
758 let cb = CircuitBreaker::new("test", CircuitBreakerConfig::default());
759
760 let result = cb
761 .call(|| async { Ok::<_, std::io::Error>("success") })
762 .await;
763
764 assert!(result.is_ok());
765 assert_eq!(result.unwrap(), "success");
766 assert_eq!(cb.get_stats().success_count, 1);
767 }
768
769 #[tokio::test]
770 async fn test_circuit_breaker_call_failure() {
771 let cb = CircuitBreaker::new("test", CircuitBreakerConfig::default());
772
773 let result: Result<(), CircuitBreakerError<std::io::Error>> = cb
774 .call(|| async { Err(std::io::Error::new(std::io::ErrorKind::Other, "failed")) })
775 .await;
776
777 assert!(result.is_err());
778 assert_eq!(cb.get_stats().failure_count, 1);
779 }
780
781 #[test]
782 fn test_circuit_breaker_manager() {
783 let manager = CircuitBreakerManager::default();
784
785 let cb1 = manager.get_or_create("service1");
786 let cb2 = manager.get_or_create("service2");
787 let cb1_again = manager.get_or_create("service1");
788
789 assert_eq!(cb1.name(), "service1");
790 assert_eq!(cb2.name(), "service2");
791 assert!(Arc::ptr_eq(&cb1, &cb1_again));
792 }
793
794 #[test]
795 fn test_circuit_breaker_manager_custom_config() {
796 let manager = CircuitBreakerManager::default();
797
798 let config = CircuitBreakerConfig::new(10);
799 let cb = manager.create_with_config("custom", config);
800
801 assert_eq!(cb.name(), "custom");
803 assert_eq!(cb.get_state(), CircuitState::Closed);
804 }
805
806 #[test]
807 fn test_circuit_breaker_manager_get_all_stats() {
808 let manager = CircuitBreakerManager::default();
809
810 manager.get_or_create("a").record_success();
811 manager.get_or_create("b").record_failure();
812
813 let stats = manager.get_all_stats();
814 assert_eq!(stats.len(), 2);
815 }
816
817 #[test]
818 fn test_circuit_breaker_manager_reset_all() {
819 let manager = CircuitBreakerManager::new(CircuitBreakerConfig::new(1));
820
821 let cb = manager.get_or_create("test");
822 cb.record_failure();
823 assert_eq!(cb.get_state(), CircuitState::Open);
824
825 manager.reset_all();
826 assert_eq!(cb.get_state(), CircuitState::Closed);
827 }
828
829 #[test]
830 fn test_circuit_state_display() {
831 assert_eq!(format!("{}", CircuitState::Closed), "closed");
832 assert_eq!(format!("{}", CircuitState::Open), "open");
833 assert_eq!(format!("{}", CircuitState::HalfOpen), "half-open");
834 }
835
836 #[test]
839 fn test_keyed_circuit_breaker_basic() {
840 let cb = KeyedCircuitBreaker::<String>::new(CircuitBreakerConfig::default());
841
842 assert!(cb.check(&"key1".to_string()).is_ok());
844 assert!(cb.check(&"key2".to_string()).is_ok());
845
846 assert_eq!(cb.len(), 2);
848 }
849
850 #[test]
851 fn test_keyed_circuit_breaker_isolation() {
852 let config = CircuitBreakerConfig::new(1); let cb = KeyedCircuitBreaker::<String>::new(config);
854
855 cb.record_failure(&"key1".to_string());
857
858 assert!(cb.check(&"key1".to_string()).is_err());
860 assert!(cb.check(&"key2".to_string()).is_ok());
861 }
862
863 #[test]
864 fn test_keyed_circuit_breaker_stats() {
865 let cb = KeyedCircuitBreaker::<String>::new(CircuitBreakerConfig::default());
866
867 cb.record_success(&"a".to_string());
868 cb.record_success(&"a".to_string());
869 cb.record_failure(&"b".to_string());
870
871 let stats_a = cb.get_stats(&"a".to_string()).unwrap();
872 assert_eq!(stats_a.success_count, 2);
873
874 let stats_b = cb.get_stats(&"b".to_string()).unwrap();
875 assert_eq!(stats_b.failure_count, 1);
876
877 let all_stats = cb.get_all_stats();
878 assert_eq!(all_stats.len(), 2);
879 }
880
881 #[test]
882 fn test_keyed_circuit_breaker_reset() {
883 let config = CircuitBreakerConfig::new(1);
884 let cb = KeyedCircuitBreaker::<String>::new(config);
885
886 cb.record_failure(&"key".to_string());
887 assert!(cb.check(&"key".to_string()).is_err());
888
889 cb.reset(&"key".to_string());
890 assert!(cb.check(&"key".to_string()).is_ok());
891 }
892
893 #[test]
894 fn test_keyed_circuit_breaker_reset_all() {
895 let config = CircuitBreakerConfig::new(1);
896 let cb = KeyedCircuitBreaker::<String>::new(config);
897
898 cb.record_failure(&"a".to_string());
899 cb.record_failure(&"b".to_string());
900
901 assert!(cb.check(&"a".to_string()).is_err());
902 assert!(cb.check(&"b".to_string()).is_err());
903
904 cb.reset_all();
905
906 assert!(cb.check(&"a".to_string()).is_ok());
907 assert!(cb.check(&"b".to_string()).is_ok());
908 }
909
910 #[tokio::test]
911 async fn test_keyed_circuit_breaker_call() {
912 let cb = KeyedCircuitBreaker::<String>::new(CircuitBreakerConfig::default());
913
914 let result = cb
915 .call(&"test".to_string(), || async {
916 Ok::<_, std::io::Error>("success")
917 })
918 .await;
919
920 assert!(result.is_ok());
921 assert_eq!(result.unwrap(), "success");
922 }
923
924 #[test]
925 fn test_keyed_circuit_breaker_remove() {
926 let cb = KeyedCircuitBreaker::<String>::new(CircuitBreakerConfig::default());
927
928 cb.check(&"key".to_string()).ok();
929 assert_eq!(cb.len(), 1);
930
931 cb.remove(&"key".to_string());
932 assert_eq!(cb.len(), 0);
933 }
934
935 #[test]
936 fn test_keyed_circuit_breaker_clear() {
937 let cb = KeyedCircuitBreaker::<String>::new(CircuitBreakerConfig::default());
938
939 cb.check(&"a".to_string()).ok();
940 cb.check(&"b".to_string()).ok();
941 cb.check(&"c".to_string()).ok();
942
943 assert_eq!(cb.len(), 3);
944
945 cb.clear();
946 assert!(cb.is_empty());
947 }
948
949 #[test]
950 fn test_keyed_circuit_breaker_default() {
951 let cb = KeyedCircuitBreaker::<u64>::default();
952 assert!(cb.check(&1).is_ok());
953 assert!(cb.check(&2).is_ok());
954 }
955}