1use super::traits::*;
7use std::collections::HashMap;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::RwLock;
10use std::time::{Duration, Instant};
11
12#[derive(Debug, Clone)]
18pub struct HealthStatus {
19 pub node_id: NodeId,
20 pub state: HealthState,
21 pub latency_p50: Duration,
22 pub latency_p99: Duration,
23 pub queue_depth: u32,
24 pub last_updated: Instant,
25}
26
27impl From<NodeHealth> for HealthStatus {
28 fn from(health: NodeHealth) -> Self {
29 Self {
30 node_id: health.node_id,
31 state: health.status,
32 latency_p50: health.latency_p50,
33 latency_p99: health.latency_p99,
34 queue_depth: health.queue_depth,
35 last_updated: health.last_check,
36 }
37 }
38}
39
40#[derive(Debug, Clone)]
46pub struct HealthConfig {
47 pub check_interval: Duration,
49 pub probe_timeout: Duration,
51 pub failure_threshold: u32,
53 pub recovery_threshold: u32,
55 pub degraded_latency: Duration,
57}
58
59impl Default for HealthConfig {
60 fn default() -> Self {
61 Self {
62 check_interval: Duration::from_secs(10),
63 probe_timeout: Duration::from_secs(5),
64 failure_threshold: 3,
65 recovery_threshold: 2,
66 degraded_latency: Duration::from_secs(1),
67 }
68 }
69}
70
71#[derive(Debug, Clone)]
73struct NodeState {
74 health: NodeHealth,
75 consecutive_failures: u32,
76 consecutive_successes: u32,
77}
78
79pub struct HealthChecker {
81 config: HealthConfig,
82 states: RwLock<HashMap<NodeId, NodeState>>,
83 monitoring: AtomicBool,
84}
85
86impl HealthChecker {
87 pub fn new(config: HealthConfig) -> Self {
88 Self {
89 config,
90 states: RwLock::new(HashMap::new()),
91 monitoring: AtomicBool::new(false),
92 }
93 }
94
95 pub fn register_node(&self, node_id: NodeId) {
97 let mut states = self.states.write().expect("health lock poisoned");
98
99 let health = NodeHealth {
100 node_id: node_id.clone(),
101 status: HealthState::Unknown,
102 latency_p50: Duration::ZERO,
103 latency_p99: Duration::ZERO,
104 throughput: 0,
105 gpu_utilization: None,
106 queue_depth: 0,
107 last_check: Instant::now(),
108 };
109
110 states.insert(
111 node_id,
112 NodeState {
113 health,
114 consecutive_failures: 0,
115 consecutive_successes: 0,
116 },
117 );
118 }
119
120 pub fn deregister_node(&self, node_id: &NodeId) {
122 let mut states = self.states.write().expect("health lock poisoned");
123 states.remove(node_id);
124 }
125
126 pub fn report_success(&self, node_id: &NodeId, latency: Duration) {
128 let mut states = self.states.write().expect("health lock poisoned");
129
130 if let Some(state) = states.get_mut(node_id) {
131 state.consecutive_failures = 0;
132 state.consecutive_successes += 1;
133
134 let old_latency = state.health.latency_p50;
136 state.health.latency_p50 = Duration::from_millis(
137 (old_latency.as_millis() as u64 * 9 + latency.as_millis() as u64) / 10,
138 );
139
140 state.health.last_check = Instant::now();
141
142 if latency > self.config.degraded_latency {
144 state.health.status = HealthState::Degraded;
145 } else if state.consecutive_successes >= self.config.recovery_threshold {
146 state.health.status = HealthState::Healthy;
147 }
148 }
149 }
150
151 pub fn report_failure(&self, node_id: &NodeId) {
153 let mut states = self.states.write().expect("health lock poisoned");
154
155 if let Some(state) = states.get_mut(node_id) {
156 state.consecutive_successes = 0;
157 state.consecutive_failures += 1;
158 state.health.last_check = Instant::now();
159
160 if state.consecutive_failures >= self.config.failure_threshold {
161 state.health.status = HealthState::Unhealthy;
162 } else {
163 state.health.status = HealthState::Degraded;
164 }
165 }
166 }
167
168 pub fn all_statuses(&self) -> Vec<HealthStatus> {
170 let states = self.states.read().expect("health lock poisoned");
171 states
172 .values()
173 .map(|s| HealthStatus::from(s.health.clone()))
174 .collect()
175 }
176
177 pub fn is_monitoring(&self) -> bool {
179 self.monitoring.load(Ordering::SeqCst)
180 }
181
182 pub fn healthy_count(&self) -> usize {
184 let states = self.states.read().expect("health lock poisoned");
185 states
186 .values()
187 .filter(|s| s.health.status == HealthState::Healthy)
188 .count()
189 }
190
191 pub fn total_count(&self) -> usize {
193 let states = self.states.read().expect("health lock poisoned");
194 states.len()
195 }
196}
197
198impl Default for HealthChecker {
199 fn default() -> Self {
200 Self::new(HealthConfig::default())
201 }
202}
203
204impl HealthCheckerTrait for HealthChecker {
205 fn check_node(&self, node_id: &NodeId) -> BoxFuture<'_, FederationResult<NodeHealth>> {
206 let node_id = node_id.clone();
207
208 Box::pin(async move {
209 let states = self.states.read().expect("health lock poisoned");
212
213 states
214 .get(&node_id)
215 .map(|s| s.health.clone())
216 .ok_or(FederationError::NodeUnreachable(node_id))
217 })
218 }
219
220 fn get_cached_health(&self, node_id: &NodeId) -> Option<NodeHealth> {
221 let states = self.states.read().expect("health lock poisoned");
222 states.get(node_id).map(|s| s.health.clone())
223 }
224
225 fn start_monitoring(&self, _interval: Duration) -> BoxFuture<'_, ()> {
226 Box::pin(async move {
227 self.monitoring.store(true, Ordering::SeqCst);
228 })
231 }
232
233 fn stop_monitoring(&self) -> BoxFuture<'_, ()> {
234 Box::pin(async move {
235 self.monitoring.store(false, Ordering::SeqCst);
236 })
237 }
238}
239
240#[derive(Debug, Clone)]
246pub struct CircuitBreakerConfig {
247 pub failure_threshold: u32,
249 pub reset_timeout: Duration,
251 pub half_open_successes: u32,
253}
254
255impl Default for CircuitBreakerConfig {
256 fn default() -> Self {
257 Self {
258 failure_threshold: 5,
259 reset_timeout: Duration::from_secs(30),
260 half_open_successes: 3,
261 }
262 }
263}
264
265#[derive(Debug, Clone)]
267struct CircuitBreakerState {
268 state: CircuitState,
269 failures: u32,
270 successes_in_half_open: u32,
271 last_failure: Option<Instant>,
272}
273
274pub struct CircuitBreaker {
276 config: CircuitBreakerConfig,
277 states: RwLock<HashMap<NodeId, CircuitBreakerState>>,
278}
279
280impl CircuitBreaker {
281 pub fn new(config: CircuitBreakerConfig) -> Self {
282 Self {
283 config,
284 states: RwLock::new(HashMap::new()),
285 }
286 }
287
288 fn get_or_create_state(&self, node_id: &NodeId) -> CircuitBreakerState {
289 let states = self.states.read().expect("circuit breaker lock poisoned");
290 states.get(node_id).cloned().unwrap_or(CircuitBreakerState {
291 state: CircuitState::Closed,
292 failures: 0,
293 successes_in_half_open: 0,
294 last_failure: None,
295 })
296 }
297
298 fn update_state(&self, node_id: &NodeId, state: CircuitBreakerState) {
299 let mut states = self.states.write().expect("circuit breaker lock poisoned");
300 states.insert(node_id.clone(), state);
301 }
302
303 pub fn all_states(&self) -> Vec<(NodeId, CircuitState)> {
305 let states = self.states.read().expect("circuit breaker lock poisoned");
306 states
307 .iter()
308 .map(|(node_id, state)| (node_id.clone(), state.state))
309 .collect()
310 }
311}
312
313impl Default for CircuitBreaker {
314 fn default() -> Self {
315 Self::new(CircuitBreakerConfig::default())
316 }
317}
318
319impl CircuitBreakerTrait for CircuitBreaker {
320 fn is_open(&self, node_id: &NodeId) -> bool {
321 let state = self.get_or_create_state(node_id);
322
323 match state.state {
324 CircuitState::Open => {
325 if let Some(last_failure) = state.last_failure {
327 if last_failure.elapsed() >= self.config.reset_timeout {
328 let mut new_state = state;
330 new_state.state = CircuitState::HalfOpen;
331 new_state.successes_in_half_open = 0;
332 self.update_state(node_id, new_state);
333 return false; }
335 }
336 true }
338 CircuitState::HalfOpen => false, CircuitState::Closed => false,
340 }
341 }
342
343 fn record_success(&self, node_id: &NodeId) {
344 let mut state = self.get_or_create_state(node_id);
345
346 match state.state {
347 CircuitState::HalfOpen => {
348 state.successes_in_half_open += 1;
349 if state.successes_in_half_open >= self.config.half_open_successes {
350 state.state = CircuitState::Closed;
352 state.failures = 0;
353 state.successes_in_half_open = 0;
354 }
355 }
356 CircuitState::Closed => {
357 state.failures = 0;
359 }
360 CircuitState::Open => {
361 state.state = CircuitState::Closed;
363 state.failures = 0;
364 }
365 }
366
367 self.update_state(node_id, state);
368 }
369
370 fn record_failure(&self, node_id: &NodeId) {
371 let mut state = self.get_or_create_state(node_id);
372 state.failures += 1;
373 state.last_failure = Some(Instant::now());
374
375 match state.state {
376 CircuitState::Closed => {
377 if state.failures >= self.config.failure_threshold {
378 state.state = CircuitState::Open;
379 }
380 }
381 CircuitState::HalfOpen => {
382 state.state = CircuitState::Open;
384 state.successes_in_half_open = 0;
385 }
386 CircuitState::Open => {
387 }
389 }
390
391 self.update_state(node_id, state);
392 }
393
394 fn state(&self, node_id: &NodeId) -> CircuitState {
395 self.get_or_create_state(node_id).state
396 }
397}
398
399#[cfg(test)]
404mod tests {
405 use super::*;
406
407 #[test]
408 fn test_health_status_transitions() {
409 let checker = HealthChecker::default();
410 let node = NodeId("test-node".to_string());
411
412 checker.register_node(node.clone());
413
414 let health = checker.get_cached_health(&node).expect("node should exist");
416 assert_eq!(health.status, HealthState::Unknown);
417
418 for _ in 0..3 {
420 checker.report_success(&node, Duration::from_millis(50));
421 }
422
423 let health = checker.get_cached_health(&node).expect("node should exist");
424 assert_eq!(health.status, HealthState::Healthy);
425 }
426
427 #[test]
428 fn test_health_degraded_on_high_latency() {
429 let checker = HealthChecker::default();
430 let node = NodeId("slow-node".to_string());
431
432 checker.register_node(node.clone());
433
434 checker.report_success(&node, Duration::from_secs(2));
436
437 let health = checker.get_cached_health(&node).expect("node should exist");
438 assert_eq!(health.status, HealthState::Degraded);
439 }
440
441 #[test]
442 fn test_health_unhealthy_on_failures() {
443 let config = HealthConfig {
444 failure_threshold: 3,
445 ..Default::default()
446 };
447 let checker = HealthChecker::new(config);
448 let node = NodeId("failing-node".to_string());
449
450 checker.register_node(node.clone());
451
452 for _ in 0..3 {
454 checker.report_failure(&node);
455 }
456
457 let health = checker.get_cached_health(&node).expect("node should exist");
458 assert_eq!(health.status, HealthState::Unhealthy);
459 }
460
461 #[test]
462 fn test_circuit_breaker_opens_on_failures() {
463 let config = CircuitBreakerConfig {
464 failure_threshold: 3,
465 ..Default::default()
466 };
467 let breaker = CircuitBreaker::new(config);
468 let node = NodeId("failing-node".to_string());
469
470 assert!(!breaker.is_open(&node));
472 assert_eq!(breaker.state(&node), CircuitState::Closed);
473
474 for _ in 0..3 {
476 breaker.record_failure(&node);
477 }
478
479 assert!(breaker.is_open(&node));
480 assert_eq!(breaker.state(&node), CircuitState::Open);
481 }
482
483 #[test]
484 fn test_circuit_breaker_success_resets() {
485 let breaker = CircuitBreaker::default();
486 let node = NodeId("flaky-node".to_string());
487
488 breaker.record_failure(&node);
490 breaker.record_failure(&node);
491
492 breaker.record_success(&node);
494
495 let state = breaker.get_or_create_state(&node);
496 assert_eq!(state.failures, 0);
497 }
498
499 #[test]
500 fn test_circuit_breaker_half_open_recovery() {
501 let config = CircuitBreakerConfig {
502 failure_threshold: 2,
503 half_open_successes: 2,
504 reset_timeout: Duration::from_millis(10),
505 };
506 let breaker = CircuitBreaker::new(config);
507 let node = NodeId("recovering-node".to_string());
508
509 breaker.record_failure(&node);
511 breaker.record_failure(&node);
512 assert_eq!(breaker.state(&node), CircuitState::Open);
513
514 std::thread::sleep(Duration::from_millis(20));
516
517 assert!(!breaker.is_open(&node));
519 assert_eq!(breaker.state(&node), CircuitState::HalfOpen);
520
521 breaker.record_success(&node);
523 breaker.record_success(&node);
524 assert_eq!(breaker.state(&node), CircuitState::Closed);
525 }
526
527 #[test]
532 fn test_health_config_default() {
533 let config = HealthConfig::default();
534 assert_eq!(config.check_interval, Duration::from_secs(10));
535 assert_eq!(config.probe_timeout, Duration::from_secs(5));
536 assert_eq!(config.failure_threshold, 3);
537 assert_eq!(config.recovery_threshold, 2);
538 assert_eq!(config.degraded_latency, Duration::from_secs(1));
539 }
540
541 #[test]
542 fn test_health_config_custom() {
543 let config = HealthConfig {
544 check_interval: Duration::from_secs(30),
545 probe_timeout: Duration::from_secs(10),
546 failure_threshold: 5,
547 recovery_threshold: 3,
548 degraded_latency: Duration::from_millis(500),
549 };
550 assert_eq!(config.check_interval, Duration::from_secs(30));
551 assert_eq!(config.failure_threshold, 5);
552 }
553
554 #[test]
555 fn test_health_config_clone() {
556 let config = HealthConfig::default();
557 let cloned = config.clone();
558 assert_eq!(cloned.failure_threshold, config.failure_threshold);
559 }
560
561 #[test]
566 fn test_health_checker_register_and_deregister() {
567 let checker = HealthChecker::default();
568 let node = NodeId("temp-node".to_string());
569
570 checker.register_node(node.clone());
571 assert_eq!(checker.total_count(), 1);
572
573 checker.deregister_node(&node);
574 assert_eq!(checker.total_count(), 0);
575 }
576
577 #[test]
578 fn test_health_checker_deregister_unknown_node() {
579 let checker = HealthChecker::default();
580 let unknown = NodeId("unknown".to_string());
581
582 checker.deregister_node(&unknown);
584 assert_eq!(checker.total_count(), 0);
585 }
586
587 #[test]
588 fn test_health_checker_all_statuses_empty() {
589 let checker = HealthChecker::default();
590 let statuses = checker.all_statuses();
591 assert!(statuses.is_empty());
592 }
593
594 #[test]
595 fn test_health_checker_all_statuses_multiple() {
596 let checker = HealthChecker::default();
597 checker.register_node(NodeId("n1".to_string()));
598 checker.register_node(NodeId("n2".to_string()));
599 checker.register_node(NodeId("n3".to_string()));
600
601 let statuses = checker.all_statuses();
602 assert_eq!(statuses.len(), 3);
603 }
604
605 #[test]
606 fn test_health_checker_healthy_count_none() {
607 let checker = HealthChecker::default();
608 checker.register_node(NodeId("n1".to_string()));
609 assert_eq!(checker.healthy_count(), 0);
611 }
612
613 #[test]
614 fn test_health_checker_healthy_count_some() {
615 let checker = HealthChecker::default();
616 let n1 = NodeId("n1".to_string());
617 let n2 = NodeId("n2".to_string());
618 checker.register_node(n1.clone());
619 checker.register_node(n2.clone());
620
621 for _ in 0..3 {
623 checker.report_success(&n1, Duration::from_millis(10));
624 }
625 assert_eq!(checker.healthy_count(), 1);
628 assert_eq!(checker.total_count(), 2);
629 }
630
631 #[test]
632 fn test_health_checker_is_monitoring_default() {
633 let checker = HealthChecker::default();
634 assert!(!checker.is_monitoring());
635 }
636
637 #[test]
638 fn test_health_checker_report_success_unknown_node() {
639 let checker = HealthChecker::default();
640 let unknown = NodeId("unknown".to_string());
641
642 checker.report_success(&unknown, Duration::from_millis(50));
644 assert_eq!(checker.total_count(), 0);
645 }
646
647 #[test]
648 fn test_health_checker_report_failure_unknown_node() {
649 let checker = HealthChecker::default();
650 let unknown = NodeId("unknown".to_string());
651
652 checker.report_failure(&unknown);
654 assert_eq!(checker.total_count(), 0);
655 }
656
657 #[test]
658 fn test_health_checker_degraded_then_healthy() {
659 let config = HealthConfig {
660 recovery_threshold: 2,
661 ..Default::default()
662 };
663 let checker = HealthChecker::new(config);
664 let node = NodeId("recovering".to_string());
665
666 checker.register_node(node.clone());
667
668 checker.report_failure(&node);
670 let health = checker.get_cached_health(&node).expect("should exist");
671 assert_eq!(health.status, HealthState::Degraded);
672
673 for _ in 0..2 {
675 checker.report_success(&node, Duration::from_millis(50));
676 }
677 let health = checker.get_cached_health(&node).expect("should exist");
678 assert_eq!(health.status, HealthState::Healthy);
679 }
680
681 #[test]
682 fn test_health_checker_failure_below_threshold_is_degraded() {
683 let config = HealthConfig {
684 failure_threshold: 5,
685 ..Default::default()
686 };
687 let checker = HealthChecker::new(config);
688 let node = NodeId("flaky".to_string());
689 checker.register_node(node.clone());
690
691 for _ in 0..3 {
693 checker.report_failure(&node);
694 }
695 let health = checker.get_cached_health(&node).expect("should exist");
696 assert_eq!(health.status, HealthState::Degraded);
697
698 for _ in 0..2 {
700 checker.report_failure(&node);
701 }
702 let health = checker.get_cached_health(&node).expect("should exist");
703 assert_eq!(health.status, HealthState::Unhealthy);
704 }
705
706 #[test]
707 fn test_health_checker_latency_moving_average() {
708 let checker = HealthChecker::default();
709 let node = NodeId("avg-node".to_string());
710 checker.register_node(node.clone());
711
712 for _ in 0..10 {
714 checker.report_success(&node, Duration::from_millis(100));
715 }
716
717 let health = checker.get_cached_health(&node).expect("should exist");
718 assert!(health.latency_p50.as_millis() > 0);
720 }
721
722 #[test]
723 fn test_health_checker_get_cached_health_none() {
724 let checker = HealthChecker::default();
725 let unknown = NodeId("no-such-node".to_string());
726 assert!(checker.get_cached_health(&unknown).is_none());
727 }
728
729 #[test]
734 fn test_health_status_from_node_health() {
735 let now = Instant::now();
736 let health = NodeHealth {
737 node_id: NodeId("test".to_string()),
738 status: HealthState::Degraded,
739 latency_p50: Duration::from_millis(100),
740 latency_p99: Duration::from_millis(500),
741 throughput: 200,
742 gpu_utilization: Some(0.5),
743 queue_depth: 5,
744 last_check: now,
745 };
746
747 let status = HealthStatus::from(health);
748 assert_eq!(status.node_id, NodeId("test".to_string()));
749 assert_eq!(status.state, HealthState::Degraded);
750 assert_eq!(status.latency_p50, Duration::from_millis(100));
751 assert_eq!(status.latency_p99, Duration::from_millis(500));
752 assert_eq!(status.queue_depth, 5);
753 }
754
755 #[tokio::test]
760 async fn test_health_checker_check_node_registered() {
761 let checker = HealthChecker::default();
762 let node = NodeId("registered".to_string());
763 checker.register_node(node.clone());
764
765 let result = checker.check_node(&node).await;
766 assert!(result.is_ok());
767 let health = result.expect("check_node failed");
768 assert_eq!(health.node_id, node);
769 assert_eq!(health.status, HealthState::Unknown);
770 }
771
772 #[tokio::test]
773 async fn test_health_checker_check_node_unregistered() {
774 let checker = HealthChecker::default();
775 let node = NodeId("missing".to_string());
776
777 let result = checker.check_node(&node).await;
778 assert!(result.is_err());
779 assert!(matches!(
780 result.unwrap_err(),
781 FederationError::NodeUnreachable(_)
782 ));
783 }
784
785 #[tokio::test]
786 async fn test_health_checker_start_stop_monitoring() {
787 let checker = HealthChecker::default();
788
789 assert!(!checker.is_monitoring());
790
791 checker.start_monitoring(Duration::from_secs(10)).await;
792 assert!(checker.is_monitoring());
793
794 checker.stop_monitoring().await;
795 assert!(!checker.is_monitoring());
796 }
797
798 #[test]
803 fn test_circuit_breaker_config_default() {
804 let config = CircuitBreakerConfig::default();
805 assert_eq!(config.failure_threshold, 5);
806 assert_eq!(config.reset_timeout, Duration::from_secs(30));
807 assert_eq!(config.half_open_successes, 3);
808 }
809
810 #[test]
811 fn test_circuit_breaker_config_custom() {
812 let config = CircuitBreakerConfig {
813 failure_threshold: 10,
814 reset_timeout: Duration::from_secs(60),
815 half_open_successes: 5,
816 };
817 assert_eq!(config.failure_threshold, 10);
818 }
819
820 #[test]
821 fn test_circuit_breaker_config_clone() {
822 let config = CircuitBreakerConfig::default();
823 let cloned = config.clone();
824 assert_eq!(cloned.half_open_successes, config.half_open_successes);
825 }
826
827 #[test]
832 fn test_circuit_breaker_all_states_empty() {
833 let breaker = CircuitBreaker::default();
834 let states = breaker.all_states();
835 assert!(states.is_empty());
836 }
837
838 #[test]
839 fn test_circuit_breaker_all_states_multiple() {
840 let config = CircuitBreakerConfig {
841 failure_threshold: 2,
842 ..Default::default()
843 };
844 let breaker = CircuitBreaker::new(config);
845
846 let n1 = NodeId("n1".to_string());
847 let n2 = NodeId("n2".to_string());
848
849 breaker.record_failure(&n1);
851 breaker.record_failure(&n1);
852
853 breaker.record_success(&n2);
855
856 let states = breaker.all_states();
857 assert_eq!(states.len(), 2);
858
859 let n1_state = states.iter().find(|(id, _)| *id == n1).map(|(_, s)| *s);
860 let n2_state = states.iter().find(|(id, _)| *id == n2).map(|(_, s)| *s);
861
862 assert_eq!(n1_state, Some(CircuitState::Open));
863 assert_eq!(n2_state, Some(CircuitState::Closed));
864 }
865
866 #[test]
867 fn test_circuit_breaker_unknown_node_defaults_closed() {
868 let breaker = CircuitBreaker::default();
869 let unknown = NodeId("unknown".to_string());
870
871 assert_eq!(breaker.state(&unknown), CircuitState::Closed);
872 assert!(!breaker.is_open(&unknown));
873 }
874
875 #[test]
876 fn test_circuit_breaker_record_success_in_open_resets() {
877 let config = CircuitBreakerConfig {
878 failure_threshold: 2,
879 ..Default::default()
880 };
881 let breaker = CircuitBreaker::new(config);
882 let node = NodeId("node".to_string());
883
884 breaker.record_failure(&node);
886 breaker.record_failure(&node);
887 assert_eq!(breaker.state(&node), CircuitState::Open);
888
889 breaker.record_success(&node);
891 assert_eq!(breaker.state(&node), CircuitState::Closed);
893 }
894
895 #[test]
896 fn test_circuit_breaker_failure_in_half_open_reopens() {
897 let config = CircuitBreakerConfig {
898 failure_threshold: 2,
899 half_open_successes: 3,
900 reset_timeout: Duration::from_millis(10),
901 };
902 let breaker = CircuitBreaker::new(config);
903 let node = NodeId("node".to_string());
904
905 breaker.record_failure(&node);
907 breaker.record_failure(&node);
908 assert_eq!(breaker.state(&node), CircuitState::Open);
909
910 std::thread::sleep(Duration::from_millis(20));
912
913 assert!(!breaker.is_open(&node));
915 assert_eq!(breaker.state(&node), CircuitState::HalfOpen);
916
917 breaker.record_failure(&node);
919 assert_eq!(breaker.state(&node), CircuitState::Open);
920 }
921
922 #[test]
923 fn test_circuit_breaker_failure_while_already_open() {
924 let config = CircuitBreakerConfig {
925 failure_threshold: 2,
926 ..Default::default()
927 };
928 let breaker = CircuitBreaker::new(config);
929 let node = NodeId("node".to_string());
930
931 breaker.record_failure(&node);
933 breaker.record_failure(&node);
934 assert_eq!(breaker.state(&node), CircuitState::Open);
935
936 breaker.record_failure(&node);
938 assert_eq!(breaker.state(&node), CircuitState::Open);
939 }
940
941 #[test]
942 fn test_circuit_breaker_is_open_before_timeout() {
943 let config = CircuitBreakerConfig {
944 failure_threshold: 2,
945 reset_timeout: Duration::from_secs(60), ..Default::default()
947 };
948 let breaker = CircuitBreaker::new(config);
949 let node = NodeId("node".to_string());
950
951 breaker.record_failure(&node);
952 breaker.record_failure(&node);
953
954 assert!(breaker.is_open(&node));
956 }
957
958 #[test]
959 fn test_circuit_breaker_half_open_partial_success() {
960 let config = CircuitBreakerConfig {
961 failure_threshold: 2,
962 half_open_successes: 3,
963 reset_timeout: Duration::from_millis(10),
964 };
965 let breaker = CircuitBreaker::new(config);
966 let node = NodeId("node".to_string());
967
968 breaker.record_failure(&node);
970 breaker.record_failure(&node);
971 std::thread::sleep(Duration::from_millis(20));
972
973 assert!(!breaker.is_open(&node));
975
976 breaker.record_success(&node);
978 breaker.record_success(&node);
979 assert_eq!(breaker.state(&node), CircuitState::HalfOpen);
980
981 breaker.record_success(&node);
983 assert_eq!(breaker.state(&node), CircuitState::Closed);
984 }
985}