apr_cli/federation/
health.rs

1//! Health Checker - Monitors node health across the federation
2//!
3//! Tracks latency, throughput, error rates, and GPU utilization.
4//! Supports both active probing and passive health updates.
5
6use super::traits::*;
7use std::collections::HashMap;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::RwLock;
10use std::time::{Duration, Instant};
11
12// ============================================================================
13// Health Status (exported type)
14// ============================================================================
15
16/// Health status summary for external consumers
17#[derive(Debug, Clone)]
18pub struct HealthStatus {
19    pub node_id: NodeId,
20    pub state: HealthState,
21    pub latency_p50: Duration,
22    pub latency_p99: Duration,
23    pub queue_depth: u32,
24    pub last_updated: Instant,
25}
26
27impl From<NodeHealth> for HealthStatus {
28    fn from(health: NodeHealth) -> Self {
29        Self {
30            node_id: health.node_id,
31            state: health.status,
32            latency_p50: health.latency_p50,
33            latency_p99: health.latency_p99,
34            queue_depth: health.queue_depth,
35            last_updated: health.last_check,
36        }
37    }
38}
39
40// ============================================================================
41// Health Checker Implementation
42// ============================================================================
43
44/// Configuration for health checking
45#[derive(Debug, Clone)]
46pub struct HealthConfig {
47    /// How often to check health (default: 10s)
48    pub check_interval: Duration,
49    /// Timeout for health probes (default: 5s)
50    pub probe_timeout: Duration,
51    /// Number of failures before marking unhealthy (default: 3)
52    pub failure_threshold: u32,
53    /// Number of successes to recover from unhealthy (default: 2)
54    pub recovery_threshold: u32,
55    /// Latency threshold for degraded state (default: 1s)
56    pub degraded_latency: Duration,
57}
58
59impl Default for HealthConfig {
60    fn default() -> Self {
61        Self {
62            check_interval: Duration::from_secs(10),
63            probe_timeout: Duration::from_secs(5),
64            failure_threshold: 3,
65            recovery_threshold: 2,
66            degraded_latency: Duration::from_secs(1),
67        }
68    }
69}
70
71/// Internal tracking state for a node
72#[derive(Debug, Clone)]
73struct NodeState {
74    health: NodeHealth,
75    consecutive_failures: u32,
76    consecutive_successes: u32,
77}
78
79/// In-memory health checker
80pub struct HealthChecker {
81    config: HealthConfig,
82    states: RwLock<HashMap<NodeId, NodeState>>,
83    monitoring: AtomicBool,
84}
85
86impl HealthChecker {
87    pub fn new(config: HealthConfig) -> Self {
88        Self {
89            config,
90            states: RwLock::new(HashMap::new()),
91            monitoring: AtomicBool::new(false),
92        }
93    }
94
95    /// Register a node for health tracking
96    pub fn register_node(&self, node_id: NodeId) {
97        let mut states = self.states.write().expect("health lock poisoned");
98
99        let health = NodeHealth {
100            node_id: node_id.clone(),
101            status: HealthState::Unknown,
102            latency_p50: Duration::ZERO,
103            latency_p99: Duration::ZERO,
104            throughput: 0,
105            gpu_utilization: None,
106            queue_depth: 0,
107            last_check: Instant::now(),
108        };
109
110        states.insert(
111            node_id,
112            NodeState {
113                health,
114                consecutive_failures: 0,
115                consecutive_successes: 0,
116            },
117        );
118    }
119
120    /// Remove a node from health tracking
121    pub fn deregister_node(&self, node_id: &NodeId) {
122        let mut states = self.states.write().expect("health lock poisoned");
123        states.remove(node_id);
124    }
125
126    /// Report a successful request (passive health update)
127    pub fn report_success(&self, node_id: &NodeId, latency: Duration) {
128        let mut states = self.states.write().expect("health lock poisoned");
129
130        if let Some(state) = states.get_mut(node_id) {
131            state.consecutive_failures = 0;
132            state.consecutive_successes += 1;
133
134            // Update latency (simple moving average)
135            let old_latency = state.health.latency_p50;
136            state.health.latency_p50 = Duration::from_millis(
137                (old_latency.as_millis() as u64 * 9 + latency.as_millis() as u64) / 10,
138            );
139
140            state.health.last_check = Instant::now();
141
142            // Update status based on latency
143            if latency > self.config.degraded_latency {
144                state.health.status = HealthState::Degraded;
145            } else if state.consecutive_successes >= self.config.recovery_threshold {
146                state.health.status = HealthState::Healthy;
147            }
148        }
149    }
150
151    /// Report a failed request (passive health update)
152    pub fn report_failure(&self, node_id: &NodeId) {
153        let mut states = self.states.write().expect("health lock poisoned");
154
155        if let Some(state) = states.get_mut(node_id) {
156            state.consecutive_successes = 0;
157            state.consecutive_failures += 1;
158            state.health.last_check = Instant::now();
159
160            if state.consecutive_failures >= self.config.failure_threshold {
161                state.health.status = HealthState::Unhealthy;
162            } else {
163                state.health.status = HealthState::Degraded;
164            }
165        }
166    }
167
168    /// Get all health statuses
169    pub fn all_statuses(&self) -> Vec<HealthStatus> {
170        let states = self.states.read().expect("health lock poisoned");
171        states
172            .values()
173            .map(|s| HealthStatus::from(s.health.clone()))
174            .collect()
175    }
176
177    /// Check if monitoring is active
178    pub fn is_monitoring(&self) -> bool {
179        self.monitoring.load(Ordering::SeqCst)
180    }
181
182    /// Get count of healthy nodes
183    pub fn healthy_count(&self) -> usize {
184        let states = self.states.read().expect("health lock poisoned");
185        states
186            .values()
187            .filter(|s| s.health.status == HealthState::Healthy)
188            .count()
189    }
190
191    /// Get total node count
192    pub fn total_count(&self) -> usize {
193        let states = self.states.read().expect("health lock poisoned");
194        states.len()
195    }
196}
197
198impl Default for HealthChecker {
199    fn default() -> Self {
200        Self::new(HealthConfig::default())
201    }
202}
203
204impl HealthCheckerTrait for HealthChecker {
205    fn check_node(&self, node_id: &NodeId) -> BoxFuture<'_, FederationResult<NodeHealth>> {
206        let node_id = node_id.clone();
207
208        Box::pin(async move {
209            // In production, this would make an HTTP/gRPC health probe
210            // For now, return cached state or Unknown
211            let states = self.states.read().expect("health lock poisoned");
212
213            states
214                .get(&node_id)
215                .map(|s| s.health.clone())
216                .ok_or(FederationError::NodeUnreachable(node_id))
217        })
218    }
219
220    fn get_cached_health(&self, node_id: &NodeId) -> Option<NodeHealth> {
221        let states = self.states.read().expect("health lock poisoned");
222        states.get(node_id).map(|s| s.health.clone())
223    }
224
225    fn start_monitoring(&self, _interval: Duration) -> BoxFuture<'_, ()> {
226        Box::pin(async move {
227            self.monitoring.store(true, Ordering::SeqCst);
228            // In production, this would spawn a background task
229            // that periodically probes all registered nodes
230        })
231    }
232
233    fn stop_monitoring(&self) -> BoxFuture<'_, ()> {
234        Box::pin(async move {
235            self.monitoring.store(false, Ordering::SeqCst);
236        })
237    }
238}
239
240// ============================================================================
241// Circuit Breaker Implementation
242// ============================================================================
243
244/// Circuit breaker configuration
245#[derive(Debug, Clone)]
246pub struct CircuitBreakerConfig {
247    /// Failures to open circuit (default: 5)
248    pub failure_threshold: u32,
249    /// Time before half-open probe (default: 30s)
250    pub reset_timeout: Duration,
251    /// Successes in half-open to close (default: 3)
252    pub half_open_successes: u32,
253}
254
255impl Default for CircuitBreakerConfig {
256    fn default() -> Self {
257        Self {
258            failure_threshold: 5,
259            reset_timeout: Duration::from_secs(30),
260            half_open_successes: 3,
261        }
262    }
263}
264
265/// Per-node circuit state
266#[derive(Debug, Clone)]
267struct CircuitBreakerState {
268    state: CircuitState,
269    failures: u32,
270    successes_in_half_open: u32,
271    last_failure: Option<Instant>,
272}
273
274/// Circuit breaker implementation
275pub struct CircuitBreaker {
276    config: CircuitBreakerConfig,
277    states: RwLock<HashMap<NodeId, CircuitBreakerState>>,
278}
279
280impl CircuitBreaker {
281    pub fn new(config: CircuitBreakerConfig) -> Self {
282        Self {
283            config,
284            states: RwLock::new(HashMap::new()),
285        }
286    }
287
288    fn get_or_create_state(&self, node_id: &NodeId) -> CircuitBreakerState {
289        let states = self.states.read().expect("circuit breaker lock poisoned");
290        states.get(node_id).cloned().unwrap_or(CircuitBreakerState {
291            state: CircuitState::Closed,
292            failures: 0,
293            successes_in_half_open: 0,
294            last_failure: None,
295        })
296    }
297
298    fn update_state(&self, node_id: &NodeId, state: CircuitBreakerState) {
299        let mut states = self.states.write().expect("circuit breaker lock poisoned");
300        states.insert(node_id.clone(), state);
301    }
302
303    /// Get all circuit breaker states
304    pub fn all_states(&self) -> Vec<(NodeId, CircuitState)> {
305        let states = self.states.read().expect("circuit breaker lock poisoned");
306        states
307            .iter()
308            .map(|(node_id, state)| (node_id.clone(), state.state))
309            .collect()
310    }
311}
312
313impl Default for CircuitBreaker {
314    fn default() -> Self {
315        Self::new(CircuitBreakerConfig::default())
316    }
317}
318
319impl CircuitBreakerTrait for CircuitBreaker {
320    fn is_open(&self, node_id: &NodeId) -> bool {
321        let state = self.get_or_create_state(node_id);
322
323        match state.state {
324            CircuitState::Open => {
325                // Check if enough time has passed to try half-open
326                if let Some(last_failure) = state.last_failure {
327                    if last_failure.elapsed() >= self.config.reset_timeout {
328                        // Transition to half-open
329                        let mut new_state = state;
330                        new_state.state = CircuitState::HalfOpen;
331                        new_state.successes_in_half_open = 0;
332                        self.update_state(node_id, new_state);
333                        return false; // Allow one request through
334                    }
335                }
336                true // Still open
337            }
338            CircuitState::HalfOpen => false, // Allow probe requests
339            CircuitState::Closed => false,
340        }
341    }
342
343    fn record_success(&self, node_id: &NodeId) {
344        let mut state = self.get_or_create_state(node_id);
345
346        match state.state {
347            CircuitState::HalfOpen => {
348                state.successes_in_half_open += 1;
349                if state.successes_in_half_open >= self.config.half_open_successes {
350                    // Transition back to closed
351                    state.state = CircuitState::Closed;
352                    state.failures = 0;
353                    state.successes_in_half_open = 0;
354                }
355            }
356            CircuitState::Closed => {
357                // Reset failure counter on success
358                state.failures = 0;
359            }
360            CircuitState::Open => {
361                // Shouldn't happen, but reset just in case
362                state.state = CircuitState::Closed;
363                state.failures = 0;
364            }
365        }
366
367        self.update_state(node_id, state);
368    }
369
370    fn record_failure(&self, node_id: &NodeId) {
371        let mut state = self.get_or_create_state(node_id);
372        state.failures += 1;
373        state.last_failure = Some(Instant::now());
374
375        match state.state {
376            CircuitState::Closed => {
377                if state.failures >= self.config.failure_threshold {
378                    state.state = CircuitState::Open;
379                }
380            }
381            CircuitState::HalfOpen => {
382                // Any failure in half-open goes back to open
383                state.state = CircuitState::Open;
384                state.successes_in_half_open = 0;
385            }
386            CircuitState::Open => {
387                // Already open, just update last_failure time
388            }
389        }
390
391        self.update_state(node_id, state);
392    }
393
394    fn state(&self, node_id: &NodeId) -> CircuitState {
395        self.get_or_create_state(node_id).state
396    }
397}
398
399// ============================================================================
400// Tests
401// ============================================================================
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406
407    #[test]
408    fn test_health_status_transitions() {
409        let checker = HealthChecker::default();
410        let node = NodeId("test-node".to_string());
411
412        checker.register_node(node.clone());
413
414        // Initially unknown
415        let health = checker.get_cached_health(&node).expect("node should exist");
416        assert_eq!(health.status, HealthState::Unknown);
417
418        // Report successes to become healthy
419        for _ in 0..3 {
420            checker.report_success(&node, Duration::from_millis(50));
421        }
422
423        let health = checker.get_cached_health(&node).expect("node should exist");
424        assert_eq!(health.status, HealthState::Healthy);
425    }
426
427    #[test]
428    fn test_health_degraded_on_high_latency() {
429        let checker = HealthChecker::default();
430        let node = NodeId("slow-node".to_string());
431
432        checker.register_node(node.clone());
433
434        // Report high latency
435        checker.report_success(&node, Duration::from_secs(2));
436
437        let health = checker.get_cached_health(&node).expect("node should exist");
438        assert_eq!(health.status, HealthState::Degraded);
439    }
440
441    #[test]
442    fn test_health_unhealthy_on_failures() {
443        let config = HealthConfig {
444            failure_threshold: 3,
445            ..Default::default()
446        };
447        let checker = HealthChecker::new(config);
448        let node = NodeId("failing-node".to_string());
449
450        checker.register_node(node.clone());
451
452        // Report failures
453        for _ in 0..3 {
454            checker.report_failure(&node);
455        }
456
457        let health = checker.get_cached_health(&node).expect("node should exist");
458        assert_eq!(health.status, HealthState::Unhealthy);
459    }
460
461    #[test]
462    fn test_circuit_breaker_opens_on_failures() {
463        let config = CircuitBreakerConfig {
464            failure_threshold: 3,
465            ..Default::default()
466        };
467        let breaker = CircuitBreaker::new(config);
468        let node = NodeId("failing-node".to_string());
469
470        // Initially closed
471        assert!(!breaker.is_open(&node));
472        assert_eq!(breaker.state(&node), CircuitState::Closed);
473
474        // Record failures
475        for _ in 0..3 {
476            breaker.record_failure(&node);
477        }
478
479        assert!(breaker.is_open(&node));
480        assert_eq!(breaker.state(&node), CircuitState::Open);
481    }
482
483    #[test]
484    fn test_circuit_breaker_success_resets() {
485        let breaker = CircuitBreaker::default();
486        let node = NodeId("flaky-node".to_string());
487
488        // Some failures
489        breaker.record_failure(&node);
490        breaker.record_failure(&node);
491
492        // Success should reset
493        breaker.record_success(&node);
494
495        let state = breaker.get_or_create_state(&node);
496        assert_eq!(state.failures, 0);
497    }
498
499    #[test]
500    fn test_circuit_breaker_half_open_recovery() {
501        let config = CircuitBreakerConfig {
502            failure_threshold: 2,
503            half_open_successes: 2,
504            reset_timeout: Duration::from_millis(10),
505        };
506        let breaker = CircuitBreaker::new(config);
507        let node = NodeId("recovering-node".to_string());
508
509        // Open the circuit
510        breaker.record_failure(&node);
511        breaker.record_failure(&node);
512        assert_eq!(breaker.state(&node), CircuitState::Open);
513
514        // Wait for reset timeout
515        std::thread::sleep(Duration::from_millis(20));
516
517        // Should transition to half-open on next check
518        assert!(!breaker.is_open(&node));
519        assert_eq!(breaker.state(&node), CircuitState::HalfOpen);
520
521        // Successes in half-open should close
522        breaker.record_success(&node);
523        breaker.record_success(&node);
524        assert_eq!(breaker.state(&node), CircuitState::Closed);
525    }
526}