Skip to main content

apr_cli/federation/
health.rs

1//! Health Checker - Monitors node health across the federation
2//!
3//! Tracks latency, throughput, error rates, and GPU utilization.
4//! Supports both active probing and passive health updates.
5
6use super::traits::*;
7use std::collections::HashMap;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::RwLock;
10use std::time::{Duration, Instant};
11
12// ============================================================================
13// Health Status (exported type)
14// ============================================================================
15
16/// Health status summary for external consumers
17#[derive(Debug, Clone)]
18pub struct HealthStatus {
19    pub node_id: NodeId,
20    pub state: HealthState,
21    pub latency_p50: Duration,
22    pub latency_p99: Duration,
23    pub queue_depth: u32,
24    pub last_updated: Instant,
25}
26
27impl From<NodeHealth> for HealthStatus {
28    fn from(health: NodeHealth) -> Self {
29        Self {
30            node_id: health.node_id,
31            state: health.status,
32            latency_p50: health.latency_p50,
33            latency_p99: health.latency_p99,
34            queue_depth: health.queue_depth,
35            last_updated: health.last_check,
36        }
37    }
38}
39
40// ============================================================================
41// Health Checker Implementation
42// ============================================================================
43
44/// Configuration for health checking
45#[derive(Debug, Clone)]
46pub struct HealthConfig {
47    /// How often to check health (default: 10s)
48    pub check_interval: Duration,
49    /// Timeout for health probes (default: 5s)
50    pub probe_timeout: Duration,
51    /// Number of failures before marking unhealthy (default: 3)
52    pub failure_threshold: u32,
53    /// Number of successes to recover from unhealthy (default: 2)
54    pub recovery_threshold: u32,
55    /// Latency threshold for degraded state (default: 1s)
56    pub degraded_latency: Duration,
57}
58
59impl Default for HealthConfig {
60    fn default() -> Self {
61        Self {
62            check_interval: Duration::from_secs(10),
63            probe_timeout: Duration::from_secs(5),
64            failure_threshold: 3,
65            recovery_threshold: 2,
66            degraded_latency: Duration::from_secs(1),
67        }
68    }
69}
70
71/// Internal tracking state for a node
72#[derive(Debug, Clone)]
73struct NodeState {
74    health: NodeHealth,
75    consecutive_failures: u32,
76    consecutive_successes: u32,
77}
78
79/// In-memory health checker
80pub struct HealthChecker {
81    config: HealthConfig,
82    states: RwLock<HashMap<NodeId, NodeState>>,
83    monitoring: AtomicBool,
84}
85
86impl HealthChecker {
87    pub fn new(config: HealthConfig) -> Self {
88        Self {
89            config,
90            states: RwLock::new(HashMap::new()),
91            monitoring: AtomicBool::new(false),
92        }
93    }
94
95    /// Register a node for health tracking
96    pub fn register_node(&self, node_id: NodeId) {
97        let mut states = self.states.write().expect("health lock poisoned");
98
99        let health = NodeHealth {
100            node_id: node_id.clone(),
101            status: HealthState::Unknown,
102            latency_p50: Duration::ZERO,
103            latency_p99: Duration::ZERO,
104            throughput: 0,
105            gpu_utilization: None,
106            queue_depth: 0,
107            last_check: Instant::now(),
108        };
109
110        states.insert(
111            node_id,
112            NodeState {
113                health,
114                consecutive_failures: 0,
115                consecutive_successes: 0,
116            },
117        );
118    }
119
120    /// Remove a node from health tracking
121    pub fn deregister_node(&self, node_id: &NodeId) {
122        let mut states = self.states.write().expect("health lock poisoned");
123        states.remove(node_id);
124    }
125
126    /// Report a successful request (passive health update)
127    pub fn report_success(&self, node_id: &NodeId, latency: Duration) {
128        let mut states = self.states.write().expect("health lock poisoned");
129
130        if let Some(state) = states.get_mut(node_id) {
131            state.consecutive_failures = 0;
132            state.consecutive_successes += 1;
133
134            // Update latency (simple moving average)
135            let old_latency = state.health.latency_p50;
136            state.health.latency_p50 = Duration::from_millis(
137                (old_latency.as_millis() as u64 * 9 + latency.as_millis() as u64) / 10,
138            );
139
140            state.health.last_check = Instant::now();
141
142            // Update status based on latency
143            if latency > self.config.degraded_latency {
144                state.health.status = HealthState::Degraded;
145            } else if state.consecutive_successes >= self.config.recovery_threshold {
146                state.health.status = HealthState::Healthy;
147            }
148        }
149    }
150
151    /// Report a failed request (passive health update)
152    pub fn report_failure(&self, node_id: &NodeId) {
153        let mut states = self.states.write().expect("health lock poisoned");
154
155        if let Some(state) = states.get_mut(node_id) {
156            state.consecutive_successes = 0;
157            state.consecutive_failures += 1;
158            state.health.last_check = Instant::now();
159
160            if state.consecutive_failures >= self.config.failure_threshold {
161                state.health.status = HealthState::Unhealthy;
162            } else {
163                state.health.status = HealthState::Degraded;
164            }
165        }
166    }
167
168    /// Get all health statuses
169    pub fn all_statuses(&self) -> Vec<HealthStatus> {
170        let states = self.states.read().expect("health lock poisoned");
171        states
172            .values()
173            .map(|s| HealthStatus::from(s.health.clone()))
174            .collect()
175    }
176
177    /// Check if monitoring is active
178    pub fn is_monitoring(&self) -> bool {
179        self.monitoring.load(Ordering::SeqCst)
180    }
181
182    /// Get count of healthy nodes
183    pub fn healthy_count(&self) -> usize {
184        let states = self.states.read().expect("health lock poisoned");
185        states
186            .values()
187            .filter(|s| s.health.status == HealthState::Healthy)
188            .count()
189    }
190
191    /// Get total node count
192    pub fn total_count(&self) -> usize {
193        let states = self.states.read().expect("health lock poisoned");
194        states.len()
195    }
196}
197
198impl Default for HealthChecker {
199    fn default() -> Self {
200        Self::new(HealthConfig::default())
201    }
202}
203
204impl HealthCheckerTrait for HealthChecker {
205    fn check_node(&self, node_id: &NodeId) -> BoxFuture<'_, FederationResult<NodeHealth>> {
206        let node_id = node_id.clone();
207
208        Box::pin(async move {
209            // In production, this would make an HTTP/gRPC health probe
210            // For now, return cached state or Unknown
211            let states = self.states.read().expect("health lock poisoned");
212
213            states
214                .get(&node_id)
215                .map(|s| s.health.clone())
216                .ok_or(FederationError::NodeUnreachable(node_id))
217        })
218    }
219
220    fn get_cached_health(&self, node_id: &NodeId) -> Option<NodeHealth> {
221        let states = self.states.read().expect("health lock poisoned");
222        states.get(node_id).map(|s| s.health.clone())
223    }
224
225    fn start_monitoring(&self, _interval: Duration) -> BoxFuture<'_, ()> {
226        Box::pin(async move {
227            self.monitoring.store(true, Ordering::SeqCst);
228            // In production, this would spawn a background task
229            // that periodically probes all registered nodes
230        })
231    }
232
233    fn stop_monitoring(&self) -> BoxFuture<'_, ()> {
234        Box::pin(async move {
235            self.monitoring.store(false, Ordering::SeqCst);
236        })
237    }
238}
239
240// ============================================================================
241// Circuit Breaker Implementation
242// ============================================================================
243
244/// Circuit breaker configuration
245#[derive(Debug, Clone)]
246pub struct CircuitBreakerConfig {
247    /// Failures to open circuit (default: 5)
248    pub failure_threshold: u32,
249    /// Time before half-open probe (default: 30s)
250    pub reset_timeout: Duration,
251    /// Successes in half-open to close (default: 3)
252    pub half_open_successes: u32,
253}
254
255impl Default for CircuitBreakerConfig {
256    fn default() -> Self {
257        Self {
258            failure_threshold: 5,
259            reset_timeout: Duration::from_secs(30),
260            half_open_successes: 3,
261        }
262    }
263}
264
265/// Per-node circuit state
266#[derive(Debug, Clone)]
267struct CircuitBreakerState {
268    state: CircuitState,
269    failures: u32,
270    successes_in_half_open: u32,
271    last_failure: Option<Instant>,
272}
273
274/// Circuit breaker implementation
275pub struct CircuitBreaker {
276    config: CircuitBreakerConfig,
277    states: RwLock<HashMap<NodeId, CircuitBreakerState>>,
278}
279
280impl CircuitBreaker {
281    pub fn new(config: CircuitBreakerConfig) -> Self {
282        Self {
283            config,
284            states: RwLock::new(HashMap::new()),
285        }
286    }
287
288    fn get_or_create_state(&self, node_id: &NodeId) -> CircuitBreakerState {
289        let states = self.states.read().expect("circuit breaker lock poisoned");
290        states.get(node_id).cloned().unwrap_or(CircuitBreakerState {
291            state: CircuitState::Closed,
292            failures: 0,
293            successes_in_half_open: 0,
294            last_failure: None,
295        })
296    }
297
298    fn update_state(&self, node_id: &NodeId, state: CircuitBreakerState) {
299        let mut states = self.states.write().expect("circuit breaker lock poisoned");
300        states.insert(node_id.clone(), state);
301    }
302
303    /// Get all circuit breaker states
304    pub fn all_states(&self) -> Vec<(NodeId, CircuitState)> {
305        let states = self.states.read().expect("circuit breaker lock poisoned");
306        states
307            .iter()
308            .map(|(node_id, state)| (node_id.clone(), state.state))
309            .collect()
310    }
311}
312
313impl Default for CircuitBreaker {
314    fn default() -> Self {
315        Self::new(CircuitBreakerConfig::default())
316    }
317}
318
319impl CircuitBreakerTrait for CircuitBreaker {
320    fn is_open(&self, node_id: &NodeId) -> bool {
321        let state = self.get_or_create_state(node_id);
322
323        match state.state {
324            CircuitState::Open => {
325                // Check if enough time has passed to try half-open
326                if let Some(last_failure) = state.last_failure {
327                    if last_failure.elapsed() >= self.config.reset_timeout {
328                        // Transition to half-open
329                        let mut new_state = state;
330                        new_state.state = CircuitState::HalfOpen;
331                        new_state.successes_in_half_open = 0;
332                        self.update_state(node_id, new_state);
333                        return false; // Allow one request through
334                    }
335                }
336                true // Still open
337            }
338            CircuitState::HalfOpen => false, // Allow probe requests
339            CircuitState::Closed => false,
340        }
341    }
342
343    fn record_success(&self, node_id: &NodeId) {
344        let mut state = self.get_or_create_state(node_id);
345
346        match state.state {
347            CircuitState::HalfOpen => {
348                state.successes_in_half_open += 1;
349                if state.successes_in_half_open >= self.config.half_open_successes {
350                    // Transition back to closed
351                    state.state = CircuitState::Closed;
352                    state.failures = 0;
353                    state.successes_in_half_open = 0;
354                }
355            }
356            CircuitState::Closed => {
357                // Reset failure counter on success
358                state.failures = 0;
359            }
360            CircuitState::Open => {
361                // Shouldn't happen, but reset just in case
362                state.state = CircuitState::Closed;
363                state.failures = 0;
364            }
365        }
366
367        self.update_state(node_id, state);
368    }
369
370    fn record_failure(&self, node_id: &NodeId) {
371        let mut state = self.get_or_create_state(node_id);
372        state.failures += 1;
373        state.last_failure = Some(Instant::now());
374
375        match state.state {
376            CircuitState::Closed => {
377                if state.failures >= self.config.failure_threshold {
378                    state.state = CircuitState::Open;
379                }
380            }
381            CircuitState::HalfOpen => {
382                // Any failure in half-open goes back to open
383                state.state = CircuitState::Open;
384                state.successes_in_half_open = 0;
385            }
386            CircuitState::Open => {
387                // Already open, just update last_failure time
388            }
389        }
390
391        self.update_state(node_id, state);
392    }
393
394    fn state(&self, node_id: &NodeId) -> CircuitState {
395        self.get_or_create_state(node_id).state
396    }
397}
398
399// ============================================================================
400// Tests
401// ============================================================================
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406
407    #[test]
408    fn test_health_status_transitions() {
409        let checker = HealthChecker::default();
410        let node = NodeId("test-node".to_string());
411
412        checker.register_node(node.clone());
413
414        // Initially unknown
415        let health = checker.get_cached_health(&node).expect("node should exist");
416        assert_eq!(health.status, HealthState::Unknown);
417
418        // Report successes to become healthy
419        for _ in 0..3 {
420            checker.report_success(&node, Duration::from_millis(50));
421        }
422
423        let health = checker.get_cached_health(&node).expect("node should exist");
424        assert_eq!(health.status, HealthState::Healthy);
425    }
426
427    #[test]
428    fn test_health_degraded_on_high_latency() {
429        let checker = HealthChecker::default();
430        let node = NodeId("slow-node".to_string());
431
432        checker.register_node(node.clone());
433
434        // Report high latency
435        checker.report_success(&node, Duration::from_secs(2));
436
437        let health = checker.get_cached_health(&node).expect("node should exist");
438        assert_eq!(health.status, HealthState::Degraded);
439    }
440
441    #[test]
442    fn test_health_unhealthy_on_failures() {
443        let config = HealthConfig {
444            failure_threshold: 3,
445            ..Default::default()
446        };
447        let checker = HealthChecker::new(config);
448        let node = NodeId("failing-node".to_string());
449
450        checker.register_node(node.clone());
451
452        // Report failures
453        for _ in 0..3 {
454            checker.report_failure(&node);
455        }
456
457        let health = checker.get_cached_health(&node).expect("node should exist");
458        assert_eq!(health.status, HealthState::Unhealthy);
459    }
460
461    #[test]
462    fn test_circuit_breaker_opens_on_failures() {
463        let config = CircuitBreakerConfig {
464            failure_threshold: 3,
465            ..Default::default()
466        };
467        let breaker = CircuitBreaker::new(config);
468        let node = NodeId("failing-node".to_string());
469
470        // Initially closed
471        assert!(!breaker.is_open(&node));
472        assert_eq!(breaker.state(&node), CircuitState::Closed);
473
474        // Record failures
475        for _ in 0..3 {
476            breaker.record_failure(&node);
477        }
478
479        assert!(breaker.is_open(&node));
480        assert_eq!(breaker.state(&node), CircuitState::Open);
481    }
482
483    #[test]
484    fn test_circuit_breaker_success_resets() {
485        let breaker = CircuitBreaker::default();
486        let node = NodeId("flaky-node".to_string());
487
488        // Some failures
489        breaker.record_failure(&node);
490        breaker.record_failure(&node);
491
492        // Success should reset
493        breaker.record_success(&node);
494
495        let state = breaker.get_or_create_state(&node);
496        assert_eq!(state.failures, 0);
497    }
498
499    #[test]
500    fn test_circuit_breaker_half_open_recovery() {
501        let config = CircuitBreakerConfig {
502            failure_threshold: 2,
503            half_open_successes: 2,
504            reset_timeout: Duration::from_millis(10),
505        };
506        let breaker = CircuitBreaker::new(config);
507        let node = NodeId("recovering-node".to_string());
508
509        // Open the circuit
510        breaker.record_failure(&node);
511        breaker.record_failure(&node);
512        assert_eq!(breaker.state(&node), CircuitState::Open);
513
514        // Wait for reset timeout
515        std::thread::sleep(Duration::from_millis(20));
516
517        // Should transition to half-open on next check
518        assert!(!breaker.is_open(&node));
519        assert_eq!(breaker.state(&node), CircuitState::HalfOpen);
520
521        // Successes in half-open should close
522        breaker.record_success(&node);
523        breaker.record_success(&node);
524        assert_eq!(breaker.state(&node), CircuitState::Closed);
525    }
526
527    // =========================================================================
528    // HealthConfig tests
529    // =========================================================================
530
531    #[test]
532    fn test_health_config_default() {
533        let config = HealthConfig::default();
534        assert_eq!(config.check_interval, Duration::from_secs(10));
535        assert_eq!(config.probe_timeout, Duration::from_secs(5));
536        assert_eq!(config.failure_threshold, 3);
537        assert_eq!(config.recovery_threshold, 2);
538        assert_eq!(config.degraded_latency, Duration::from_secs(1));
539    }
540
541    #[test]
542    fn test_health_config_custom() {
543        let config = HealthConfig {
544            check_interval: Duration::from_secs(30),
545            probe_timeout: Duration::from_secs(10),
546            failure_threshold: 5,
547            recovery_threshold: 3,
548            degraded_latency: Duration::from_millis(500),
549        };
550        assert_eq!(config.check_interval, Duration::from_secs(30));
551        assert_eq!(config.failure_threshold, 5);
552    }
553
554    #[test]
555    fn test_health_config_clone() {
556        let config = HealthConfig::default();
557        let cloned = config.clone();
558        assert_eq!(cloned.failure_threshold, config.failure_threshold);
559    }
560
561    // =========================================================================
562    // HealthChecker extended tests
563    // =========================================================================
564
565    #[test]
566    fn test_health_checker_register_and_deregister() {
567        let checker = HealthChecker::default();
568        let node = NodeId("temp-node".to_string());
569
570        checker.register_node(node.clone());
571        assert_eq!(checker.total_count(), 1);
572
573        checker.deregister_node(&node);
574        assert_eq!(checker.total_count(), 0);
575    }
576
577    #[test]
578    fn test_health_checker_deregister_unknown_node() {
579        let checker = HealthChecker::default();
580        let unknown = NodeId("unknown".to_string());
581
582        // Deregistering a non-existent node should not panic
583        checker.deregister_node(&unknown);
584        assert_eq!(checker.total_count(), 0);
585    }
586
587    #[test]
588    fn test_health_checker_all_statuses_empty() {
589        let checker = HealthChecker::default();
590        let statuses = checker.all_statuses();
591        assert!(statuses.is_empty());
592    }
593
594    #[test]
595    fn test_health_checker_all_statuses_multiple() {
596        let checker = HealthChecker::default();
597        checker.register_node(NodeId("n1".to_string()));
598        checker.register_node(NodeId("n2".to_string()));
599        checker.register_node(NodeId("n3".to_string()));
600
601        let statuses = checker.all_statuses();
602        assert_eq!(statuses.len(), 3);
603    }
604
605    #[test]
606    fn test_health_checker_healthy_count_none() {
607        let checker = HealthChecker::default();
608        checker.register_node(NodeId("n1".to_string()));
609        // All nodes start Unknown, not Healthy
610        assert_eq!(checker.healthy_count(), 0);
611    }
612
613    #[test]
614    fn test_health_checker_healthy_count_some() {
615        let checker = HealthChecker::default();
616        let n1 = NodeId("n1".to_string());
617        let n2 = NodeId("n2".to_string());
618        checker.register_node(n1.clone());
619        checker.register_node(n2.clone());
620
621        // Make n1 healthy
622        for _ in 0..3 {
623            checker.report_success(&n1, Duration::from_millis(10));
624        }
625        // Leave n2 as Unknown
626
627        assert_eq!(checker.healthy_count(), 1);
628        assert_eq!(checker.total_count(), 2);
629    }
630
631    #[test]
632    fn test_health_checker_is_monitoring_default() {
633        let checker = HealthChecker::default();
634        assert!(!checker.is_monitoring());
635    }
636
637    #[test]
638    fn test_health_checker_report_success_unknown_node() {
639        let checker = HealthChecker::default();
640        let unknown = NodeId("unknown".to_string());
641
642        // Reporting on unknown node should not panic
643        checker.report_success(&unknown, Duration::from_millis(50));
644        assert_eq!(checker.total_count(), 0);
645    }
646
647    #[test]
648    fn test_health_checker_report_failure_unknown_node() {
649        let checker = HealthChecker::default();
650        let unknown = NodeId("unknown".to_string());
651
652        // Reporting on unknown node should not panic
653        checker.report_failure(&unknown);
654        assert_eq!(checker.total_count(), 0);
655    }
656
657    #[test]
658    fn test_health_checker_degraded_then_healthy() {
659        let config = HealthConfig {
660            recovery_threshold: 2,
661            ..Default::default()
662        };
663        let checker = HealthChecker::new(config);
664        let node = NodeId("recovering".to_string());
665
666        checker.register_node(node.clone());
667
668        // First failure -> degraded
669        checker.report_failure(&node);
670        let health = checker.get_cached_health(&node).expect("should exist");
671        assert_eq!(health.status, HealthState::Degraded);
672
673        // Successful recoveries
674        for _ in 0..2 {
675            checker.report_success(&node, Duration::from_millis(50));
676        }
677        let health = checker.get_cached_health(&node).expect("should exist");
678        assert_eq!(health.status, HealthState::Healthy);
679    }
680
681    #[test]
682    fn test_health_checker_failure_below_threshold_is_degraded() {
683        let config = HealthConfig {
684            failure_threshold: 5,
685            ..Default::default()
686        };
687        let checker = HealthChecker::new(config);
688        let node = NodeId("flaky".to_string());
689        checker.register_node(node.clone());
690
691        // 3 failures (below threshold of 5)
692        for _ in 0..3 {
693            checker.report_failure(&node);
694        }
695        let health = checker.get_cached_health(&node).expect("should exist");
696        assert_eq!(health.status, HealthState::Degraded);
697
698        // 2 more failures (reach threshold of 5)
699        for _ in 0..2 {
700            checker.report_failure(&node);
701        }
702        let health = checker.get_cached_health(&node).expect("should exist");
703        assert_eq!(health.status, HealthState::Unhealthy);
704    }
705
706    #[test]
707    fn test_health_checker_latency_moving_average() {
708        let checker = HealthChecker::default();
709        let node = NodeId("avg-node".to_string());
710        checker.register_node(node.clone());
711
712        // Report several successes with known latency
713        for _ in 0..10 {
714            checker.report_success(&node, Duration::from_millis(100));
715        }
716
717        let health = checker.get_cached_health(&node).expect("should exist");
718        // After enough iterations, p50 should approach 100ms
719        assert!(health.latency_p50.as_millis() > 0);
720    }
721
722    #[test]
723    fn test_health_checker_get_cached_health_none() {
724        let checker = HealthChecker::default();
725        let unknown = NodeId("no-such-node".to_string());
726        assert!(checker.get_cached_health(&unknown).is_none());
727    }
728
729    // =========================================================================
730    // HealthStatus From<NodeHealth> tests
731    // =========================================================================
732
733    #[test]
734    fn test_health_status_from_node_health() {
735        let now = Instant::now();
736        let health = NodeHealth {
737            node_id: NodeId("test".to_string()),
738            status: HealthState::Degraded,
739            latency_p50: Duration::from_millis(100),
740            latency_p99: Duration::from_millis(500),
741            throughput: 200,
742            gpu_utilization: Some(0.5),
743            queue_depth: 5,
744            last_check: now,
745        };
746
747        let status = HealthStatus::from(health);
748        assert_eq!(status.node_id, NodeId("test".to_string()));
749        assert_eq!(status.state, HealthState::Degraded);
750        assert_eq!(status.latency_p50, Duration::from_millis(100));
751        assert_eq!(status.latency_p99, Duration::from_millis(500));
752        assert_eq!(status.queue_depth, 5);
753    }
754
755    // =========================================================================
756    // HealthCheckerTrait async tests
757    // =========================================================================
758
759    #[tokio::test]
760    async fn test_health_checker_check_node_registered() {
761        let checker = HealthChecker::default();
762        let node = NodeId("registered".to_string());
763        checker.register_node(node.clone());
764
765        let result = checker.check_node(&node).await;
766        assert!(result.is_ok());
767        let health = result.expect("check_node failed");
768        assert_eq!(health.node_id, node);
769        assert_eq!(health.status, HealthState::Unknown);
770    }
771
772    #[tokio::test]
773    async fn test_health_checker_check_node_unregistered() {
774        let checker = HealthChecker::default();
775        let node = NodeId("missing".to_string());
776
777        let result = checker.check_node(&node).await;
778        assert!(result.is_err());
779        assert!(matches!(
780            result.unwrap_err(),
781            FederationError::NodeUnreachable(_)
782        ));
783    }
784
785    #[tokio::test]
786    async fn test_health_checker_start_stop_monitoring() {
787        let checker = HealthChecker::default();
788
789        assert!(!checker.is_monitoring());
790
791        checker.start_monitoring(Duration::from_secs(10)).await;
792        assert!(checker.is_monitoring());
793
794        checker.stop_monitoring().await;
795        assert!(!checker.is_monitoring());
796    }
797
798    // =========================================================================
799    // CircuitBreakerConfig tests
800    // =========================================================================
801
802    #[test]
803    fn test_circuit_breaker_config_default() {
804        let config = CircuitBreakerConfig::default();
805        assert_eq!(config.failure_threshold, 5);
806        assert_eq!(config.reset_timeout, Duration::from_secs(30));
807        assert_eq!(config.half_open_successes, 3);
808    }
809
810    #[test]
811    fn test_circuit_breaker_config_custom() {
812        let config = CircuitBreakerConfig {
813            failure_threshold: 10,
814            reset_timeout: Duration::from_secs(60),
815            half_open_successes: 5,
816        };
817        assert_eq!(config.failure_threshold, 10);
818    }
819
820    #[test]
821    fn test_circuit_breaker_config_clone() {
822        let config = CircuitBreakerConfig::default();
823        let cloned = config.clone();
824        assert_eq!(cloned.half_open_successes, config.half_open_successes);
825    }
826
827    // =========================================================================
828    // CircuitBreaker extended tests
829    // =========================================================================
830
831    #[test]
832    fn test_circuit_breaker_all_states_empty() {
833        let breaker = CircuitBreaker::default();
834        let states = breaker.all_states();
835        assert!(states.is_empty());
836    }
837
838    #[test]
839    fn test_circuit_breaker_all_states_multiple() {
840        let config = CircuitBreakerConfig {
841            failure_threshold: 2,
842            ..Default::default()
843        };
844        let breaker = CircuitBreaker::new(config);
845
846        let n1 = NodeId("n1".to_string());
847        let n2 = NodeId("n2".to_string());
848
849        // n1: make it open
850        breaker.record_failure(&n1);
851        breaker.record_failure(&n1);
852
853        // n2: keep it closed
854        breaker.record_success(&n2);
855
856        let states = breaker.all_states();
857        assert_eq!(states.len(), 2);
858
859        let n1_state = states.iter().find(|(id, _)| *id == n1).map(|(_, s)| *s);
860        let n2_state = states.iter().find(|(id, _)| *id == n2).map(|(_, s)| *s);
861
862        assert_eq!(n1_state, Some(CircuitState::Open));
863        assert_eq!(n2_state, Some(CircuitState::Closed));
864    }
865
866    #[test]
867    fn test_circuit_breaker_unknown_node_defaults_closed() {
868        let breaker = CircuitBreaker::default();
869        let unknown = NodeId("unknown".to_string());
870
871        assert_eq!(breaker.state(&unknown), CircuitState::Closed);
872        assert!(!breaker.is_open(&unknown));
873    }
874
875    #[test]
876    fn test_circuit_breaker_record_success_in_open_resets() {
877        let config = CircuitBreakerConfig {
878            failure_threshold: 2,
879            ..Default::default()
880        };
881        let breaker = CircuitBreaker::new(config);
882        let node = NodeId("node".to_string());
883
884        // Open circuit
885        breaker.record_failure(&node);
886        breaker.record_failure(&node);
887        assert_eq!(breaker.state(&node), CircuitState::Open);
888
889        // Record success while open (edge case)
890        breaker.record_success(&node);
891        // Should reset to closed
892        assert_eq!(breaker.state(&node), CircuitState::Closed);
893    }
894
895    #[test]
896    fn test_circuit_breaker_failure_in_half_open_reopens() {
897        let config = CircuitBreakerConfig {
898            failure_threshold: 2,
899            half_open_successes: 3,
900            reset_timeout: Duration::from_millis(10),
901        };
902        let breaker = CircuitBreaker::new(config);
903        let node = NodeId("node".to_string());
904
905        // Open circuit
906        breaker.record_failure(&node);
907        breaker.record_failure(&node);
908        assert_eq!(breaker.state(&node), CircuitState::Open);
909
910        // Wait for reset
911        std::thread::sleep(Duration::from_millis(20));
912
913        // Transition to half-open
914        assert!(!breaker.is_open(&node));
915        assert_eq!(breaker.state(&node), CircuitState::HalfOpen);
916
917        // Failure in half-open -> back to open
918        breaker.record_failure(&node);
919        assert_eq!(breaker.state(&node), CircuitState::Open);
920    }
921
922    #[test]
923    fn test_circuit_breaker_failure_while_already_open() {
924        let config = CircuitBreakerConfig {
925            failure_threshold: 2,
926            ..Default::default()
927        };
928        let breaker = CircuitBreaker::new(config);
929        let node = NodeId("node".to_string());
930
931        // Open circuit
932        breaker.record_failure(&node);
933        breaker.record_failure(&node);
934        assert_eq!(breaker.state(&node), CircuitState::Open);
935
936        // Additional failure while open should not panic
937        breaker.record_failure(&node);
938        assert_eq!(breaker.state(&node), CircuitState::Open);
939    }
940
941    #[test]
942    fn test_circuit_breaker_is_open_before_timeout() {
943        let config = CircuitBreakerConfig {
944            failure_threshold: 2,
945            reset_timeout: Duration::from_secs(60), // Long timeout
946            ..Default::default()
947        };
948        let breaker = CircuitBreaker::new(config);
949        let node = NodeId("node".to_string());
950
951        breaker.record_failure(&node);
952        breaker.record_failure(&node);
953
954        // Still open (timeout hasn't passed)
955        assert!(breaker.is_open(&node));
956    }
957
958    #[test]
959    fn test_circuit_breaker_half_open_partial_success() {
960        let config = CircuitBreakerConfig {
961            failure_threshold: 2,
962            half_open_successes: 3,
963            reset_timeout: Duration::from_millis(10),
964        };
965        let breaker = CircuitBreaker::new(config);
966        let node = NodeId("node".to_string());
967
968        // Open circuit
969        breaker.record_failure(&node);
970        breaker.record_failure(&node);
971        std::thread::sleep(Duration::from_millis(20));
972
973        // Transition to half-open
974        assert!(!breaker.is_open(&node));
975
976        // 2 successes (need 3)
977        breaker.record_success(&node);
978        breaker.record_success(&node);
979        assert_eq!(breaker.state(&node), CircuitState::HalfOpen);
980
981        // 3rd success -> closed
982        breaker.record_success(&node);
983        assert_eq!(breaker.state(&node), CircuitState::Closed);
984    }
985}