1use std::{
6 future::Future,
7 sync::{
8 atomic::{AtomicU32, AtomicU64, Ordering},
9 Arc,
10 },
11 time::{Duration, Instant},
12};
13
14use dashmap::DashMap;
15use parking_lot::RwLock;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum CircuitState {
20 Closed,
22 Open,
24 HalfOpen,
26}
27
28impl std::fmt::Display for CircuitState {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 match self {
31 CircuitState::Closed => write!(f, "closed"),
32 CircuitState::Open => write!(f, "open"),
33 CircuitState::HalfOpen => write!(f, "half-open"),
34 }
35 }
36}
37
38#[derive(Debug, Clone)]
40pub struct CircuitOpenError {
41 pub circuit_name: String,
43 pub retry_after: Duration,
45}
46
47impl std::fmt::Display for CircuitOpenError {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 write!(
50 f,
51 "circuit '{}' is open, retry after {:?}",
52 self.circuit_name, self.retry_after
53 )
54 }
55}
56
57impl std::error::Error for CircuitOpenError {}
58
59#[derive(Debug, Clone)]
61pub struct CircuitBreakerConfig {
62 pub failure_threshold: u32,
64 pub success_threshold: u32,
66 pub timeout: Duration,
68 pub failure_window: Duration,
70 pub half_open_requests: u32,
72}
73
74impl Default for CircuitBreakerConfig {
75 fn default() -> Self {
76 Self {
77 failure_threshold: 5,
78 success_threshold: 3,
79 timeout: Duration::from_secs(30),
80 failure_window: Duration::from_secs(60),
81 half_open_requests: 1,
82 }
83 }
84}
85
86impl CircuitBreakerConfig {
87 pub fn new(failure_threshold: u32) -> Self {
89 Self {
90 failure_threshold,
91 ..Default::default()
92 }
93 }
94
95 pub fn with_success_threshold(mut self, threshold: u32) -> Self {
97 self.success_threshold = threshold;
98 self
99 }
100
101 pub fn with_timeout(mut self, timeout: Duration) -> Self {
103 self.timeout = timeout;
104 self
105 }
106
107 pub fn with_failure_window(mut self, window: Duration) -> Self {
109 self.failure_window = window;
110 self
111 }
112
113 pub fn with_half_open_requests(mut self, requests: u32) -> Self {
115 self.half_open_requests = requests.max(1);
116 self
117 }
118}
119
120#[derive(Debug, Clone)]
122pub struct CircuitBreakerStats {
123 pub state: CircuitState,
125 pub success_count: u64,
127 pub failure_count: u64,
129 pub rejected_count: u64,
131 pub failures_in_window: u32,
133 pub half_open_successes: u32,
135 pub time_in_state: Duration,
137}
138
139pub struct CircuitBreaker {
141 name: String,
142 config: CircuitBreakerConfig,
143 state: RwLock<CircuitState>,
144 state_changed_at: RwLock<Instant>,
146 failures: RwLock<Vec<Instant>>,
148 half_open_successes: AtomicU32,
150 half_open_in_flight: AtomicU32,
152 success_count: AtomicU64,
154 failure_count: AtomicU64,
155 rejected_count: AtomicU64,
156}
157
158impl CircuitBreaker {
159 pub fn new(name: impl Into<String>, config: CircuitBreakerConfig) -> Self {
161 Self {
162 name: name.into(),
163 config,
164 state: RwLock::new(CircuitState::Closed),
165 state_changed_at: RwLock::new(Instant::now()),
166 failures: RwLock::new(Vec::new()),
167 half_open_successes: AtomicU32::new(0),
168 half_open_in_flight: AtomicU32::new(0),
169 success_count: AtomicU64::new(0),
170 failure_count: AtomicU64::new(0),
171 rejected_count: AtomicU64::new(0),
172 }
173 }
174
175 pub fn name(&self) -> &str {
177 &self.name
178 }
179
180 pub fn check(&self) -> Result<(), CircuitOpenError> {
185 self.maybe_transition_to_half_open();
186
187 let state = *self.state.read();
188
189 match state {
190 CircuitState::Closed => Ok(()),
191 CircuitState::Open => {
192 self.rejected_count.fetch_add(1, Ordering::Relaxed);
193 let elapsed = self.state_changed_at.read().elapsed();
194 let retry_after = self.config.timeout.saturating_sub(elapsed);
195 Err(CircuitOpenError {
196 circuit_name: self.name.clone(),
197 retry_after,
198 })
199 }
200 CircuitState::HalfOpen => {
201 let in_flight = self.half_open_in_flight.load(Ordering::Acquire);
203 if in_flight < self.config.half_open_requests {
204 self.half_open_in_flight.fetch_add(1, Ordering::AcqRel);
205 Ok(())
206 } else {
207 self.rejected_count.fetch_add(1, Ordering::Relaxed);
208 Err(CircuitOpenError {
209 circuit_name: self.name.clone(),
210 retry_after: Duration::from_millis(100),
211 })
212 }
213 }
214 }
215 }
216
217 pub fn record_success(&self) {
219 self.success_count.fetch_add(1, Ordering::Relaxed);
220
221 let state = *self.state.read();
222
223 if state == CircuitState::HalfOpen {
224 self.half_open_in_flight.fetch_sub(1, Ordering::AcqRel);
225 let successes = self.half_open_successes.fetch_add(1, Ordering::AcqRel) + 1;
226
227 if successes >= self.config.success_threshold {
228 self.transition_to(CircuitState::Closed);
229 }
230 }
231 }
232
233 pub fn record_failure(&self) {
235 self.failure_count.fetch_add(1, Ordering::Relaxed);
236
237 let state = *self.state.read();
238
239 match state {
240 CircuitState::Closed => {
241 let now = Instant::now();
242 let mut failures = self.failures.write();
243
244 failures.push(now);
246
247 let cutoff = now - self.config.failure_window;
249 failures.retain(|&t| t > cutoff);
250
251 if failures.len() as u32 >= self.config.failure_threshold {
253 drop(failures); self.transition_to(CircuitState::Open);
255 }
256 }
257 CircuitState::HalfOpen => {
258 self.half_open_in_flight.fetch_sub(1, Ordering::AcqRel);
259 self.transition_to(CircuitState::Open);
261 }
262 CircuitState::Open => {
263 }
265 }
266 }
267
268 pub async fn call<F, Fut, T, E>(&self, f: F) -> Result<T, CircuitBreakerError<E>>
270 where
271 F: FnOnce() -> Fut,
272 Fut: Future<Output = Result<T, E>>,
273 {
274 self.check().map_err(CircuitBreakerError::CircuitOpen)?;
275
276 match f().await {
277 Ok(result) => {
278 self.record_success();
279 Ok(result)
280 }
281 Err(e) => {
282 self.record_failure();
283 Err(CircuitBreakerError::Inner(e))
284 }
285 }
286 }
287
288 pub fn get_state(&self) -> CircuitState {
290 self.maybe_transition_to_half_open();
291 *self.state.read()
292 }
293
294 pub fn get_stats(&self) -> CircuitBreakerStats {
296 self.maybe_transition_to_half_open();
297
298 let state = *self.state.read();
299 let failures = self.failures.read();
300 let now = Instant::now();
301 let cutoff = now - self.config.failure_window;
302 let failures_in_window = failures.iter().filter(|&&t| t > cutoff).count() as u32;
303
304 CircuitBreakerStats {
305 state,
306 success_count: self.success_count.load(Ordering::Relaxed),
307 failure_count: self.failure_count.load(Ordering::Relaxed),
308 rejected_count: self.rejected_count.load(Ordering::Relaxed),
309 failures_in_window,
310 half_open_successes: self.half_open_successes.load(Ordering::Relaxed),
311 time_in_state: self.state_changed_at.read().elapsed(),
312 }
313 }
314
315 pub fn reset(&self) {
317 self.transition_to(CircuitState::Closed);
318 self.failures.write().clear();
319 }
320
321 fn transition_to(&self, new_state: CircuitState) {
322 let mut state = self.state.write();
323 let old_state = *state;
324
325 if old_state != new_state {
326 *state = new_state;
327 *self.state_changed_at.write() = Instant::now();
328
329 if new_state == CircuitState::HalfOpen || new_state == CircuitState::Closed {
331 self.half_open_successes.store(0, Ordering::Relaxed);
332 self.half_open_in_flight.store(0, Ordering::Relaxed);
333 }
334
335 if new_state == CircuitState::Closed {
337 self.failures.write().clear();
338 }
339
340 #[cfg(feature = "otel")]
341 tracing::info!(
342 circuit = %self.name,
343 old_state = %old_state,
344 new_state = %new_state,
345 "circuit breaker state changed"
346 );
347 }
348 }
349
350 fn maybe_transition_to_half_open(&self) {
351 let state = *self.state.read();
352 if state == CircuitState::Open {
353 let elapsed = self.state_changed_at.read().elapsed();
354 if elapsed >= self.config.timeout {
355 self.transition_to(CircuitState::HalfOpen);
356 }
357 }
358 }
359}
360
361#[derive(Debug)]
363pub enum CircuitBreakerError<E> {
364 CircuitOpen(CircuitOpenError),
366 Inner(E),
368}
369
370impl<E: std::fmt::Display> std::fmt::Display for CircuitBreakerError<E> {
371 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
372 match self {
373 CircuitBreakerError::CircuitOpen(e) => write!(f, "{}", e),
374 CircuitBreakerError::Inner(e) => write!(f, "{}", e),
375 }
376 }
377}
378
379impl<E: std::error::Error + 'static> std::error::Error for CircuitBreakerError<E> {
380 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
381 match self {
382 CircuitBreakerError::CircuitOpen(e) => Some(e),
383 CircuitBreakerError::Inner(e) => Some(e),
384 }
385 }
386}
387
388pub struct CircuitBreakerManager {
390 breakers: DashMap<String, Arc<CircuitBreaker>>,
391 default_config: CircuitBreakerConfig,
392}
393
394impl CircuitBreakerManager {
395 pub fn new(default_config: CircuitBreakerConfig) -> Self {
397 Self {
398 breakers: DashMap::new(),
399 default_config,
400 }
401 }
402
403 pub fn get_or_create(&self, name: &str) -> Arc<CircuitBreaker> {
405 self.breakers
406 .entry(name.to_string())
407 .or_insert_with(|| Arc::new(CircuitBreaker::new(name, self.default_config.clone())))
408 .clone()
409 }
410
411 pub fn get(&self, name: &str) -> Option<Arc<CircuitBreaker>> {
413 self.breakers.get(name).map(|r| r.clone())
414 }
415
416 pub fn create_with_config(
418 &self,
419 name: &str,
420 config: CircuitBreakerConfig,
421 ) -> Arc<CircuitBreaker> {
422 let breaker = Arc::new(CircuitBreaker::new(name, config));
423 self.breakers.insert(name.to_string(), breaker.clone());
424 breaker
425 }
426
427 pub fn get_all_stats(&self) -> Vec<(String, CircuitBreakerStats)> {
429 self.breakers
430 .iter()
431 .map(|entry| (entry.key().clone(), entry.value().get_stats()))
432 .collect()
433 }
434
435 pub fn reset_all(&self) {
437 for entry in self.breakers.iter() {
438 entry.value().reset();
439 }
440 }
441
442 pub fn remove(&self, name: &str) {
444 self.breakers.remove(name);
445 }
446
447 pub fn clear(&self) {
449 self.breakers.clear();
450 }
451
452 pub fn len(&self) -> usize {
454 self.breakers.len()
455 }
456
457 pub fn is_empty(&self) -> bool {
459 self.breakers.is_empty()
460 }
461}
462
463impl Default for CircuitBreakerManager {
464 fn default() -> Self {
465 Self::new(CircuitBreakerConfig::default())
466 }
467}
468
469pub struct KeyedCircuitBreaker<K: std::hash::Hash + Eq + Clone + Send + Sync + 'static> {
495 breakers: DashMap<K, Arc<CircuitBreaker>>,
496 config: CircuitBreakerConfig,
497 counter: std::sync::atomic::AtomicU64,
499}
500
501impl<K: std::hash::Hash + Eq + Clone + Send + Sync + 'static> KeyedCircuitBreaker<K> {
502 pub fn new(config: CircuitBreakerConfig) -> Self {
506 Self {
507 breakers: DashMap::new(),
508 config,
509 counter: std::sync::atomic::AtomicU64::new(0),
510 }
511 }
512
513 pub fn check(&self, key: &K) -> Result<(), CircuitOpenError> {
517 self.get_or_create(key).check()
518 }
519
520 pub fn record_success(&self, key: &K) {
522 self.get_or_create(key).record_success()
523 }
524
525 pub fn record_failure(&self, key: &K) {
527 self.get_or_create(key).record_failure()
528 }
529
530 pub async fn call<F, Fut, T, E>(&self, key: &K, f: F) -> Result<T, CircuitBreakerError<E>>
532 where
533 F: FnOnce() -> Fut,
534 Fut: Future<Output = Result<T, E>>,
535 {
536 self.get_or_create(key).call(f).await
537 }
538
539 pub fn get_state(&self, key: &K) -> Option<CircuitState> {
541 self.breakers.get(key).map(|cb| cb.get_state())
542 }
543
544 pub fn get_stats(&self, key: &K) -> Option<CircuitBreakerStats> {
546 self.breakers.get(key).map(|cb| cb.get_stats())
547 }
548
549 pub fn get_all_stats(&self) -> Vec<(K, CircuitBreakerStats)> {
551 self.breakers
552 .iter()
553 .map(|entry| (entry.key().clone(), entry.value().get_stats()))
554 .collect()
555 }
556
557 pub fn reset(&self, key: &K) {
559 if let Some(cb) = self.breakers.get(key) {
560 cb.reset();
561 }
562 }
563
564 pub fn reset_all(&self) {
566 for entry in self.breakers.iter() {
567 entry.value().reset();
568 }
569 }
570
571 pub fn remove(&self, key: &K) {
573 self.breakers.remove(key);
574 }
575
576 pub fn clear(&self) {
578 self.breakers.clear();
579 }
580
581 pub fn len(&self) -> usize {
583 self.breakers.len()
584 }
585
586 pub fn is_empty(&self) -> bool {
588 self.breakers.is_empty()
589 }
590
591 pub fn get(&self, key: &K) -> Option<Arc<CircuitBreaker>> {
593 self.breakers.get(key).map(|r| r.clone())
594 }
595
596 fn get_or_create(&self, key: &K) -> Arc<CircuitBreaker> {
597 self.breakers
598 .entry(key.clone())
599 .or_insert_with(|| {
600 let id = self.counter.fetch_add(1, Ordering::Relaxed);
601 Arc::new(CircuitBreaker::new(
602 format!("keyed-{}", id),
603 self.config.clone(),
604 ))
605 })
606 .clone()
607 }
608}
609
610impl<K: std::hash::Hash + Eq + Clone + Send + Sync + 'static> Default for KeyedCircuitBreaker<K> {
611 fn default() -> Self {
612 Self::new(CircuitBreakerConfig::default())
613 }
614}
615
616#[cfg(test)]
617mod tests {
618 use super::*;
619
620 #[test]
621 fn test_circuit_breaker_config_default() {
622 let config = CircuitBreakerConfig::default();
623 assert_eq!(config.failure_threshold, 5);
624 assert_eq!(config.success_threshold, 3);
625 assert_eq!(config.timeout, Duration::from_secs(30));
626 }
627
628 #[test]
629 fn test_circuit_breaker_config_builder() {
630 let config = CircuitBreakerConfig::new(10)
631 .with_success_threshold(5)
632 .with_timeout(Duration::from_secs(60))
633 .with_half_open_requests(3);
634
635 assert_eq!(config.failure_threshold, 10);
636 assert_eq!(config.success_threshold, 5);
637 assert_eq!(config.timeout, Duration::from_secs(60));
638 assert_eq!(config.half_open_requests, 3);
639 }
640
641 #[test]
642 fn test_circuit_breaker_initial_state() {
643 let cb = CircuitBreaker::new("test", CircuitBreakerConfig::default());
644 assert_eq!(cb.get_state(), CircuitState::Closed);
645 }
646
647 #[test]
648 fn test_circuit_breaker_opens_on_failures() {
649 let config = CircuitBreakerConfig::new(3);
650 let cb = CircuitBreaker::new("test", config);
651
652 cb.record_failure();
654 assert_eq!(cb.get_state(), CircuitState::Closed);
655 cb.record_failure();
656 assert_eq!(cb.get_state(), CircuitState::Closed);
657 cb.record_failure();
658 assert_eq!(cb.get_state(), CircuitState::Open);
659 }
660
661 #[test]
662 fn test_circuit_breaker_check_when_open() {
663 let config = CircuitBreakerConfig::new(1);
664 let cb = CircuitBreaker::new("test", config);
665
666 cb.record_failure();
667 assert_eq!(cb.get_state(), CircuitState::Open);
668
669 let result = cb.check();
670 assert!(result.is_err());
671
672 let err = result.unwrap_err();
673 assert_eq!(err.circuit_name, "test");
674 }
675
676 #[test]
677 fn test_circuit_breaker_transitions_to_half_open() {
678 let config = CircuitBreakerConfig::new(1).with_timeout(Duration::from_millis(10));
679 let cb = CircuitBreaker::new("test", config);
680
681 cb.record_failure();
682 assert_eq!(cb.get_state(), CircuitState::Open);
683
684 std::thread::sleep(Duration::from_millis(20));
685
686 assert_eq!(cb.get_state(), CircuitState::HalfOpen);
687 }
688
689 #[test]
690 fn test_circuit_breaker_closes_on_success() {
691 let config = CircuitBreakerConfig::new(1)
692 .with_timeout(Duration::from_millis(10))
693 .with_success_threshold(2);
694 let cb = CircuitBreaker::new("test", config);
695
696 cb.record_failure();
697 std::thread::sleep(Duration::from_millis(20));
698
699 assert_eq!(cb.get_state(), CircuitState::HalfOpen);
700
701 cb.check().unwrap();
703 cb.record_success();
704 assert_eq!(cb.get_state(), CircuitState::HalfOpen);
705
706 cb.check().unwrap();
707 cb.record_success();
708 assert_eq!(cb.get_state(), CircuitState::Closed);
709 }
710
711 #[test]
712 fn test_circuit_breaker_reopens_on_half_open_failure() {
713 let config = CircuitBreakerConfig::new(1).with_timeout(Duration::from_millis(10));
714 let cb = CircuitBreaker::new("test", config);
715
716 cb.record_failure();
717 std::thread::sleep(Duration::from_millis(20));
718
719 assert_eq!(cb.get_state(), CircuitState::HalfOpen);
720
721 cb.check().unwrap();
722 cb.record_failure();
723 assert_eq!(cb.get_state(), CircuitState::Open);
724 }
725
726 #[test]
727 fn test_circuit_breaker_reset() {
728 let config = CircuitBreakerConfig::new(1);
729 let cb = CircuitBreaker::new("test", config);
730
731 cb.record_failure();
732 assert_eq!(cb.get_state(), CircuitState::Open);
733
734 cb.reset();
735 assert_eq!(cb.get_state(), CircuitState::Closed);
736 }
737
738 #[test]
739 fn test_circuit_breaker_stats() {
740 let config = CircuitBreakerConfig::new(5);
741 let cb = CircuitBreaker::new("test", config);
742
743 cb.record_success();
744 cb.record_success();
745 cb.record_failure();
746
747 let stats = cb.get_stats();
748 assert_eq!(stats.state, CircuitState::Closed);
749 assert_eq!(stats.success_count, 2);
750 assert_eq!(stats.failure_count, 1);
751 assert_eq!(stats.failures_in_window, 1);
752 }
753
754 #[tokio::test]
755 async fn test_circuit_breaker_call_success() {
756 let cb = CircuitBreaker::new("test", CircuitBreakerConfig::default());
757
758 let result = cb
759 .call(|| async { Ok::<_, std::io::Error>("success") })
760 .await;
761
762 assert!(result.is_ok());
763 assert_eq!(result.unwrap(), "success");
764 assert_eq!(cb.get_stats().success_count, 1);
765 }
766
767 #[tokio::test]
768 async fn test_circuit_breaker_call_failure() {
769 let cb = CircuitBreaker::new("test", CircuitBreakerConfig::default());
770
771 let result: Result<(), CircuitBreakerError<std::io::Error>> = cb
772 .call(|| async { Err(std::io::Error::new(std::io::ErrorKind::Other, "failed")) })
773 .await;
774
775 assert!(result.is_err());
776 assert_eq!(cb.get_stats().failure_count, 1);
777 }
778
779 #[test]
780 fn test_circuit_breaker_manager() {
781 let manager = CircuitBreakerManager::default();
782
783 let cb1 = manager.get_or_create("service1");
784 let cb2 = manager.get_or_create("service2");
785 let cb1_again = manager.get_or_create("service1");
786
787 assert_eq!(cb1.name(), "service1");
788 assert_eq!(cb2.name(), "service2");
789 assert!(Arc::ptr_eq(&cb1, &cb1_again));
790 }
791
792 #[test]
793 fn test_circuit_breaker_manager_custom_config() {
794 let manager = CircuitBreakerManager::default();
795
796 let config = CircuitBreakerConfig::new(10);
797 let cb = manager.create_with_config("custom", config);
798
799 assert_eq!(cb.name(), "custom");
801 assert_eq!(cb.get_state(), CircuitState::Closed);
802 }
803
804 #[test]
805 fn test_circuit_breaker_manager_get_all_stats() {
806 let manager = CircuitBreakerManager::default();
807
808 manager.get_or_create("a").record_success();
809 manager.get_or_create("b").record_failure();
810
811 let stats = manager.get_all_stats();
812 assert_eq!(stats.len(), 2);
813 }
814
815 #[test]
816 fn test_circuit_breaker_manager_reset_all() {
817 let manager = CircuitBreakerManager::new(CircuitBreakerConfig::new(1));
818
819 let cb = manager.get_or_create("test");
820 cb.record_failure();
821 assert_eq!(cb.get_state(), CircuitState::Open);
822
823 manager.reset_all();
824 assert_eq!(cb.get_state(), CircuitState::Closed);
825 }
826
827 #[test]
828 fn test_circuit_state_display() {
829 assert_eq!(format!("{}", CircuitState::Closed), "closed");
830 assert_eq!(format!("{}", CircuitState::Open), "open");
831 assert_eq!(format!("{}", CircuitState::HalfOpen), "half-open");
832 }
833
834 #[test]
837 fn test_keyed_circuit_breaker_basic() {
838 let cb = KeyedCircuitBreaker::<String>::new(CircuitBreakerConfig::default());
839
840 assert!(cb.check(&"key1".to_string()).is_ok());
842 assert!(cb.check(&"key2".to_string()).is_ok());
843
844 assert_eq!(cb.len(), 2);
846 }
847
848 #[test]
849 fn test_keyed_circuit_breaker_isolation() {
850 let config = CircuitBreakerConfig::new(1); let cb = KeyedCircuitBreaker::<String>::new(config);
852
853 cb.record_failure(&"key1".to_string());
855
856 assert!(cb.check(&"key1".to_string()).is_err());
858 assert!(cb.check(&"key2".to_string()).is_ok());
859 }
860
861 #[test]
862 fn test_keyed_circuit_breaker_stats() {
863 let cb = KeyedCircuitBreaker::<String>::new(CircuitBreakerConfig::default());
864
865 cb.record_success(&"a".to_string());
866 cb.record_success(&"a".to_string());
867 cb.record_failure(&"b".to_string());
868
869 let stats_a = cb.get_stats(&"a".to_string()).unwrap();
870 assert_eq!(stats_a.success_count, 2);
871
872 let stats_b = cb.get_stats(&"b".to_string()).unwrap();
873 assert_eq!(stats_b.failure_count, 1);
874
875 let all_stats = cb.get_all_stats();
876 assert_eq!(all_stats.len(), 2);
877 }
878
879 #[test]
880 fn test_keyed_circuit_breaker_reset() {
881 let config = CircuitBreakerConfig::new(1);
882 let cb = KeyedCircuitBreaker::<String>::new(config);
883
884 cb.record_failure(&"key".to_string());
885 assert!(cb.check(&"key".to_string()).is_err());
886
887 cb.reset(&"key".to_string());
888 assert!(cb.check(&"key".to_string()).is_ok());
889 }
890
891 #[test]
892 fn test_keyed_circuit_breaker_reset_all() {
893 let config = CircuitBreakerConfig::new(1);
894 let cb = KeyedCircuitBreaker::<String>::new(config);
895
896 cb.record_failure(&"a".to_string());
897 cb.record_failure(&"b".to_string());
898
899 assert!(cb.check(&"a".to_string()).is_err());
900 assert!(cb.check(&"b".to_string()).is_err());
901
902 cb.reset_all();
903
904 assert!(cb.check(&"a".to_string()).is_ok());
905 assert!(cb.check(&"b".to_string()).is_ok());
906 }
907
908 #[tokio::test]
909 async fn test_keyed_circuit_breaker_call() {
910 let cb = KeyedCircuitBreaker::<String>::new(CircuitBreakerConfig::default());
911
912 let result = cb
913 .call(&"test".to_string(), || async { Ok::<_, std::io::Error>("success") })
914 .await;
915
916 assert!(result.is_ok());
917 assert_eq!(result.unwrap(), "success");
918 }
919
920 #[test]
921 fn test_keyed_circuit_breaker_remove() {
922 let cb = KeyedCircuitBreaker::<String>::new(CircuitBreakerConfig::default());
923
924 cb.check(&"key".to_string()).ok();
925 assert_eq!(cb.len(), 1);
926
927 cb.remove(&"key".to_string());
928 assert_eq!(cb.len(), 0);
929 }
930
931 #[test]
932 fn test_keyed_circuit_breaker_clear() {
933 let cb = KeyedCircuitBreaker::<String>::new(CircuitBreakerConfig::default());
934
935 cb.check(&"a".to_string()).ok();
936 cb.check(&"b".to_string()).ok();
937 cb.check(&"c".to_string()).ok();
938
939 assert_eq!(cb.len(), 3);
940
941 cb.clear();
942 assert!(cb.is_empty());
943 }
944
945 #[test]
946 fn test_keyed_circuit_breaker_default() {
947 let cb = KeyedCircuitBreaker::<u64>::default();
948 assert!(cb.check(&1).is_ok());
949 assert!(cb.check(&2).is_ok());
950 }
951}