use super::traits::*;
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::RwLock;
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct HealthStatus {
pub node_id: NodeId,
pub state: HealthState,
pub latency_p50: Duration,
pub latency_p99: Duration,
pub queue_depth: u32,
pub last_updated: Instant,
}
impl From<NodeHealth> for HealthStatus {
fn from(health: NodeHealth) -> Self {
Self {
node_id: health.node_id,
state: health.status,
latency_p50: health.latency_p50,
latency_p99: health.latency_p99,
queue_depth: health.queue_depth,
last_updated: health.last_check,
}
}
}
#[derive(Debug, Clone)]
pub struct HealthConfig {
pub check_interval: Duration,
pub probe_timeout: Duration,
pub failure_threshold: u32,
pub recovery_threshold: u32,
pub degraded_latency: Duration,
}
impl Default for HealthConfig {
fn default() -> Self {
Self {
check_interval: Duration::from_secs(10),
probe_timeout: Duration::from_secs(5),
failure_threshold: 3,
recovery_threshold: 2,
degraded_latency: Duration::from_secs(1),
}
}
}
#[derive(Debug, Clone)]
struct NodeState {
health: NodeHealth,
consecutive_failures: u32,
consecutive_successes: u32,
}
pub struct HealthChecker {
config: HealthConfig,
states: RwLock<HashMap<NodeId, NodeState>>,
monitoring: AtomicBool,
}
impl HealthChecker {
pub fn new(config: HealthConfig) -> Self {
Self {
config,
states: RwLock::new(HashMap::new()),
monitoring: AtomicBool::new(false),
}
}
pub fn register_node(&self, node_id: NodeId) {
let mut states = self.states.write().expect("health lock poisoned");
let health = NodeHealth {
node_id: node_id.clone(),
status: HealthState::Unknown,
latency_p50: Duration::ZERO,
latency_p99: Duration::ZERO,
throughput: 0,
gpu_utilization: None,
queue_depth: 0,
last_check: Instant::now(),
};
states.insert(
node_id,
NodeState {
health,
consecutive_failures: 0,
consecutive_successes: 0,
},
);
}
pub fn deregister_node(&self, node_id: &NodeId) {
let mut states = self.states.write().expect("health lock poisoned");
states.remove(node_id);
}
pub fn report_success(&self, node_id: &NodeId, latency: Duration) {
let mut states = self.states.write().expect("health lock poisoned");
if let Some(state) = states.get_mut(node_id) {
state.consecutive_failures = 0;
state.consecutive_successes += 1;
let old_latency = state.health.latency_p50;
state.health.latency_p50 = Duration::from_millis(
(old_latency.as_millis() as u64 * 9 + latency.as_millis() as u64) / 10,
);
state.health.last_check = Instant::now();
if latency > self.config.degraded_latency {
state.health.status = HealthState::Degraded;
} else if state.consecutive_successes >= self.config.recovery_threshold {
state.health.status = HealthState::Healthy;
}
}
}
pub fn report_failure(&self, node_id: &NodeId) {
let mut states = self.states.write().expect("health lock poisoned");
if let Some(state) = states.get_mut(node_id) {
state.consecutive_successes = 0;
state.consecutive_failures += 1;
state.health.last_check = Instant::now();
if state.consecutive_failures >= self.config.failure_threshold {
state.health.status = HealthState::Unhealthy;
} else {
state.health.status = HealthState::Degraded;
}
}
}
pub fn all_statuses(&self) -> Vec<HealthStatus> {
let states = self.states.read().expect("health lock poisoned");
states
.values()
.map(|s| HealthStatus::from(s.health.clone()))
.collect()
}
pub fn is_monitoring(&self) -> bool {
self.monitoring.load(Ordering::SeqCst)
}
pub fn healthy_count(&self) -> usize {
let states = self.states.read().expect("health lock poisoned");
states
.values()
.filter(|s| s.health.status == HealthState::Healthy)
.count()
}
pub fn total_count(&self) -> usize {
let states = self.states.read().expect("health lock poisoned");
states.len()
}
}
impl Default for HealthChecker {
fn default() -> Self {
Self::new(HealthConfig::default())
}
}
impl HealthCheckerTrait for HealthChecker {
fn check_node(&self, node_id: &NodeId) -> BoxFuture<'_, FederationResult<NodeHealth>> {
let node_id = node_id.clone();
Box::pin(async move {
let states = self.states.read().expect("health lock poisoned");
states
.get(&node_id)
.map(|s| s.health.clone())
.ok_or(FederationError::NodeUnreachable(node_id))
})
}
fn get_cached_health(&self, node_id: &NodeId) -> Option<NodeHealth> {
let states = self.states.read().expect("health lock poisoned");
states.get(node_id).map(|s| s.health.clone())
}
fn start_monitoring(&self, _interval: Duration) -> BoxFuture<'_, ()> {
Box::pin(async move {
self.monitoring.store(true, Ordering::SeqCst);
})
}
fn stop_monitoring(&self) -> BoxFuture<'_, ()> {
Box::pin(async move {
self.monitoring.store(false, Ordering::SeqCst);
})
}
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub reset_timeout: Duration,
pub half_open_successes: u32,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
reset_timeout: Duration::from_secs(30),
half_open_successes: 3,
}
}
}
#[derive(Debug, Clone)]
struct CircuitBreakerState {
state: CircuitState,
failures: u32,
successes_in_half_open: u32,
last_failure: Option<Instant>,
}
pub struct CircuitBreaker {
config: CircuitBreakerConfig,
states: RwLock<HashMap<NodeId, CircuitBreakerState>>,
}
impl CircuitBreaker {
pub fn new(config: CircuitBreakerConfig) -> Self {
Self {
config,
states: RwLock::new(HashMap::new()),
}
}
fn get_or_create_state(&self, node_id: &NodeId) -> CircuitBreakerState {
let states = self.states.read().expect("circuit breaker lock poisoned");
states.get(node_id).cloned().unwrap_or(CircuitBreakerState {
state: CircuitState::Closed,
failures: 0,
successes_in_half_open: 0,
last_failure: None,
})
}
fn update_state(&self, node_id: &NodeId, state: CircuitBreakerState) {
let mut states = self.states.write().expect("circuit breaker lock poisoned");
states.insert(node_id.clone(), state);
}
pub fn all_states(&self) -> Vec<(NodeId, CircuitState)> {
let states = self.states.read().expect("circuit breaker lock poisoned");
states
.iter()
.map(|(node_id, state)| (node_id.clone(), state.state))
.collect()
}
}
impl Default for CircuitBreaker {
fn default() -> Self {
Self::new(CircuitBreakerConfig::default())
}
}
impl CircuitBreakerTrait for CircuitBreaker {
fn is_open(&self, node_id: &NodeId) -> bool {
let state = self.get_or_create_state(node_id);
match state.state {
CircuitState::Open => {
if let Some(last_failure) = state.last_failure {
if last_failure.elapsed() >= self.config.reset_timeout {
let mut new_state = state;
new_state.state = CircuitState::HalfOpen;
new_state.successes_in_half_open = 0;
self.update_state(node_id, new_state);
return false; }
}
true }
CircuitState::HalfOpen => false, CircuitState::Closed => false,
}
}
fn record_success(&self, node_id: &NodeId) {
let mut state = self.get_or_create_state(node_id);
match state.state {
CircuitState::HalfOpen => {
state.successes_in_half_open += 1;
if state.successes_in_half_open >= self.config.half_open_successes {
state.state = CircuitState::Closed;
state.failures = 0;
state.successes_in_half_open = 0;
}
}
CircuitState::Closed => {
state.failures = 0;
}
CircuitState::Open => {
state.state = CircuitState::Closed;
state.failures = 0;
}
}
self.update_state(node_id, state);
}
fn record_failure(&self, node_id: &NodeId) {
let mut state = self.get_or_create_state(node_id);
state.failures += 1;
state.last_failure = Some(Instant::now());
match state.state {
CircuitState::Closed => {
if state.failures >= self.config.failure_threshold {
state.state = CircuitState::Open;
}
}
CircuitState::HalfOpen => {
state.state = CircuitState::Open;
state.successes_in_half_open = 0;
}
CircuitState::Open => {
}
}
self.update_state(node_id, state);
}
fn state(&self, node_id: &NodeId) -> CircuitState {
self.get_or_create_state(node_id).state
}
}
#[cfg(test)]
#[path = "health_tests.rs"]
mod tests;