1use super::traits::*;
7use std::collections::HashMap;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::RwLock;
10use std::time::{Duration, Instant};
11
12#[derive(Debug, Clone)]
18pub struct HealthStatus {
19 pub node_id: NodeId,
20 pub state: HealthState,
21 pub latency_p50: Duration,
22 pub latency_p99: Duration,
23 pub queue_depth: u32,
24 pub last_updated: Instant,
25}
26
27impl From<NodeHealth> for HealthStatus {
28 fn from(health: NodeHealth) -> Self {
29 Self {
30 node_id: health.node_id,
31 state: health.status,
32 latency_p50: health.latency_p50,
33 latency_p99: health.latency_p99,
34 queue_depth: health.queue_depth,
35 last_updated: health.last_check,
36 }
37 }
38}
39
40#[derive(Debug, Clone)]
46pub struct HealthConfig {
47 pub check_interval: Duration,
49 pub probe_timeout: Duration,
51 pub failure_threshold: u32,
53 pub recovery_threshold: u32,
55 pub degraded_latency: Duration,
57}
58
59impl Default for HealthConfig {
60 fn default() -> Self {
61 Self {
62 check_interval: Duration::from_secs(10),
63 probe_timeout: Duration::from_secs(5),
64 failure_threshold: 3,
65 recovery_threshold: 2,
66 degraded_latency: Duration::from_secs(1),
67 }
68 }
69}
70
71#[derive(Debug, Clone)]
73struct NodeState {
74 health: NodeHealth,
75 consecutive_failures: u32,
76 consecutive_successes: u32,
77}
78
79pub struct HealthChecker {
81 config: HealthConfig,
82 states: RwLock<HashMap<NodeId, NodeState>>,
83 monitoring: AtomicBool,
84}
85
86impl HealthChecker {
87 pub fn new(config: HealthConfig) -> Self {
88 Self {
89 config,
90 states: RwLock::new(HashMap::new()),
91 monitoring: AtomicBool::new(false),
92 }
93 }
94
95 pub fn register_node(&self, node_id: NodeId) {
97 let mut states = self.states.write().expect("health lock poisoned");
98
99 let health = NodeHealth {
100 node_id: node_id.clone(),
101 status: HealthState::Unknown,
102 latency_p50: Duration::ZERO,
103 latency_p99: Duration::ZERO,
104 throughput: 0,
105 gpu_utilization: None,
106 queue_depth: 0,
107 last_check: Instant::now(),
108 };
109
110 states.insert(
111 node_id,
112 NodeState {
113 health,
114 consecutive_failures: 0,
115 consecutive_successes: 0,
116 },
117 );
118 }
119
120 pub fn deregister_node(&self, node_id: &NodeId) {
122 let mut states = self.states.write().expect("health lock poisoned");
123 states.remove(node_id);
124 }
125
126 pub fn report_success(&self, node_id: &NodeId, latency: Duration) {
128 let mut states = self.states.write().expect("health lock poisoned");
129
130 if let Some(state) = states.get_mut(node_id) {
131 state.consecutive_failures = 0;
132 state.consecutive_successes += 1;
133
134 let old_latency = state.health.latency_p50;
136 state.health.latency_p50 = Duration::from_millis(
137 (old_latency.as_millis() as u64 * 9 + latency.as_millis() as u64) / 10,
138 );
139
140 state.health.last_check = Instant::now();
141
142 if latency > self.config.degraded_latency {
144 state.health.status = HealthState::Degraded;
145 } else if state.consecutive_successes >= self.config.recovery_threshold {
146 state.health.status = HealthState::Healthy;
147 }
148 }
149 }
150
151 pub fn report_failure(&self, node_id: &NodeId) {
153 let mut states = self.states.write().expect("health lock poisoned");
154
155 if let Some(state) = states.get_mut(node_id) {
156 state.consecutive_successes = 0;
157 state.consecutive_failures += 1;
158 state.health.last_check = Instant::now();
159
160 if state.consecutive_failures >= self.config.failure_threshold {
161 state.health.status = HealthState::Unhealthy;
162 } else {
163 state.health.status = HealthState::Degraded;
164 }
165 }
166 }
167
168 pub fn all_statuses(&self) -> Vec<HealthStatus> {
170 let states = self.states.read().expect("health lock poisoned");
171 states
172 .values()
173 .map(|s| HealthStatus::from(s.health.clone()))
174 .collect()
175 }
176
177 pub fn is_monitoring(&self) -> bool {
179 self.monitoring.load(Ordering::SeqCst)
180 }
181
182 pub fn healthy_count(&self) -> usize {
184 let states = self.states.read().expect("health lock poisoned");
185 states
186 .values()
187 .filter(|s| s.health.status == HealthState::Healthy)
188 .count()
189 }
190
191 pub fn total_count(&self) -> usize {
193 let states = self.states.read().expect("health lock poisoned");
194 states.len()
195 }
196}
197
198impl Default for HealthChecker {
199 fn default() -> Self {
200 Self::new(HealthConfig::default())
201 }
202}
203
204impl HealthCheckerTrait for HealthChecker {
205 fn check_node(&self, node_id: &NodeId) -> BoxFuture<'_, FederationResult<NodeHealth>> {
206 let node_id = node_id.clone();
207
208 Box::pin(async move {
209 let states = self.states.read().expect("health lock poisoned");
212
213 states
214 .get(&node_id)
215 .map(|s| s.health.clone())
216 .ok_or(FederationError::NodeUnreachable(node_id))
217 })
218 }
219
220 fn get_cached_health(&self, node_id: &NodeId) -> Option<NodeHealth> {
221 let states = self.states.read().expect("health lock poisoned");
222 states.get(node_id).map(|s| s.health.clone())
223 }
224
225 fn start_monitoring(&self, _interval: Duration) -> BoxFuture<'_, ()> {
226 Box::pin(async move {
227 self.monitoring.store(true, Ordering::SeqCst);
228 })
231 }
232
233 fn stop_monitoring(&self) -> BoxFuture<'_, ()> {
234 Box::pin(async move {
235 self.monitoring.store(false, Ordering::SeqCst);
236 })
237 }
238}
239
240#[derive(Debug, Clone)]
246pub struct CircuitBreakerConfig {
247 pub failure_threshold: u32,
249 pub reset_timeout: Duration,
251 pub half_open_successes: u32,
253}
254
255impl Default for CircuitBreakerConfig {
256 fn default() -> Self {
257 Self {
258 failure_threshold: 5,
259 reset_timeout: Duration::from_secs(30),
260 half_open_successes: 3,
261 }
262 }
263}
264
265#[derive(Debug, Clone)]
267struct CircuitBreakerState {
268 state: CircuitState,
269 failures: u32,
270 successes_in_half_open: u32,
271 last_failure: Option<Instant>,
272}
273
274pub struct CircuitBreaker {
276 config: CircuitBreakerConfig,
277 states: RwLock<HashMap<NodeId, CircuitBreakerState>>,
278}
279
280impl CircuitBreaker {
281 pub fn new(config: CircuitBreakerConfig) -> Self {
282 Self {
283 config,
284 states: RwLock::new(HashMap::new()),
285 }
286 }
287
288 fn get_or_create_state(&self, node_id: &NodeId) -> CircuitBreakerState {
289 let states = self.states.read().expect("circuit breaker lock poisoned");
290 states.get(node_id).cloned().unwrap_or(CircuitBreakerState {
291 state: CircuitState::Closed,
292 failures: 0,
293 successes_in_half_open: 0,
294 last_failure: None,
295 })
296 }
297
298 fn update_state(&self, node_id: &NodeId, state: CircuitBreakerState) {
299 let mut states = self.states.write().expect("circuit breaker lock poisoned");
300 states.insert(node_id.clone(), state);
301 }
302
303 pub fn all_states(&self) -> Vec<(NodeId, CircuitState)> {
305 let states = self.states.read().expect("circuit breaker lock poisoned");
306 states
307 .iter()
308 .map(|(node_id, state)| (node_id.clone(), state.state))
309 .collect()
310 }
311}
312
313impl Default for CircuitBreaker {
314 fn default() -> Self {
315 Self::new(CircuitBreakerConfig::default())
316 }
317}
318
319impl CircuitBreakerTrait for CircuitBreaker {
320 fn is_open(&self, node_id: &NodeId) -> bool {
321 let state = self.get_or_create_state(node_id);
322
323 match state.state {
324 CircuitState::Open => {
325 if let Some(last_failure) = state.last_failure {
327 if last_failure.elapsed() >= self.config.reset_timeout {
328 let mut new_state = state;
330 new_state.state = CircuitState::HalfOpen;
331 new_state.successes_in_half_open = 0;
332 self.update_state(node_id, new_state);
333 return false; }
335 }
336 true }
338 CircuitState::HalfOpen => false, CircuitState::Closed => false,
340 }
341 }
342
343 fn record_success(&self, node_id: &NodeId) {
344 let mut state = self.get_or_create_state(node_id);
345
346 match state.state {
347 CircuitState::HalfOpen => {
348 state.successes_in_half_open += 1;
349 if state.successes_in_half_open >= self.config.half_open_successes {
350 state.state = CircuitState::Closed;
352 state.failures = 0;
353 state.successes_in_half_open = 0;
354 }
355 }
356 CircuitState::Closed => {
357 state.failures = 0;
359 }
360 CircuitState::Open => {
361 state.state = CircuitState::Closed;
363 state.failures = 0;
364 }
365 }
366
367 self.update_state(node_id, state);
368 }
369
370 fn record_failure(&self, node_id: &NodeId) {
371 let mut state = self.get_or_create_state(node_id);
372 state.failures += 1;
373 state.last_failure = Some(Instant::now());
374
375 match state.state {
376 CircuitState::Closed => {
377 if state.failures >= self.config.failure_threshold {
378 state.state = CircuitState::Open;
379 }
380 }
381 CircuitState::HalfOpen => {
382 state.state = CircuitState::Open;
384 state.successes_in_half_open = 0;
385 }
386 CircuitState::Open => {
387 }
389 }
390
391 self.update_state(node_id, state);
392 }
393
394 fn state(&self, node_id: &NodeId) -> CircuitState {
395 self.get_or_create_state(node_id).state
396 }
397}
398
399#[cfg(test)]
404mod tests {
405 use super::*;
406
407 #[test]
408 fn test_health_status_transitions() {
409 let checker = HealthChecker::default();
410 let node = NodeId("test-node".to_string());
411
412 checker.register_node(node.clone());
413
414 let health = checker.get_cached_health(&node).expect("node should exist");
416 assert_eq!(health.status, HealthState::Unknown);
417
418 for _ in 0..3 {
420 checker.report_success(&node, Duration::from_millis(50));
421 }
422
423 let health = checker.get_cached_health(&node).expect("node should exist");
424 assert_eq!(health.status, HealthState::Healthy);
425 }
426
427 #[test]
428 fn test_health_degraded_on_high_latency() {
429 let checker = HealthChecker::default();
430 let node = NodeId("slow-node".to_string());
431
432 checker.register_node(node.clone());
433
434 checker.report_success(&node, Duration::from_secs(2));
436
437 let health = checker.get_cached_health(&node).expect("node should exist");
438 assert_eq!(health.status, HealthState::Degraded);
439 }
440
441 #[test]
442 fn test_health_unhealthy_on_failures() {
443 let config = HealthConfig {
444 failure_threshold: 3,
445 ..Default::default()
446 };
447 let checker = HealthChecker::new(config);
448 let node = NodeId("failing-node".to_string());
449
450 checker.register_node(node.clone());
451
452 for _ in 0..3 {
454 checker.report_failure(&node);
455 }
456
457 let health = checker.get_cached_health(&node).expect("node should exist");
458 assert_eq!(health.status, HealthState::Unhealthy);
459 }
460
461 #[test]
462 fn test_circuit_breaker_opens_on_failures() {
463 let config = CircuitBreakerConfig {
464 failure_threshold: 3,
465 ..Default::default()
466 };
467 let breaker = CircuitBreaker::new(config);
468 let node = NodeId("failing-node".to_string());
469
470 assert!(!breaker.is_open(&node));
472 assert_eq!(breaker.state(&node), CircuitState::Closed);
473
474 for _ in 0..3 {
476 breaker.record_failure(&node);
477 }
478
479 assert!(breaker.is_open(&node));
480 assert_eq!(breaker.state(&node), CircuitState::Open);
481 }
482
483 #[test]
484 fn test_circuit_breaker_success_resets() {
485 let breaker = CircuitBreaker::default();
486 let node = NodeId("flaky-node".to_string());
487
488 breaker.record_failure(&node);
490 breaker.record_failure(&node);
491
492 breaker.record_success(&node);
494
495 let state = breaker.get_or_create_state(&node);
496 assert_eq!(state.failures, 0);
497 }
498
499 #[test]
500 fn test_circuit_breaker_half_open_recovery() {
501 let config = CircuitBreakerConfig {
502 failure_threshold: 2,
503 half_open_successes: 2,
504 reset_timeout: Duration::from_millis(10),
505 };
506 let breaker = CircuitBreaker::new(config);
507 let node = NodeId("recovering-node".to_string());
508
509 breaker.record_failure(&node);
511 breaker.record_failure(&node);
512 assert_eq!(breaker.state(&node), CircuitState::Open);
513
514 std::thread::sleep(Duration::from_millis(20));
516
517 assert!(!breaker.is_open(&node));
519 assert_eq!(breaker.state(&node), CircuitState::HalfOpen);
520
521 breaker.record_success(&node);
523 breaker.record_success(&node);
524 assert_eq!(breaker.state(&node), CircuitState::Closed);
525 }
526}