use parking_lot::RwLock;
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use crate::error::{Result, RingKernelError};
use crate::runtime::KernelId;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HealthStatus {
Healthy,
Degraded,
Unhealthy,
Unknown,
}
impl HealthStatus {
pub fn is_healthy(&self) -> bool {
matches!(self, HealthStatus::Healthy | HealthStatus::Degraded)
}
pub fn is_unhealthy(&self) -> bool {
matches!(self, HealthStatus::Unhealthy)
}
}
#[derive(Debug, Clone)]
pub struct HealthCheckResult {
pub name: String,
pub status: HealthStatus,
pub message: Option<String>,
pub duration: Duration,
pub checked_at: Instant,
}
pub type HealthCheckFn =
Arc<dyn Fn() -> Pin<Box<dyn Future<Output = HealthStatus> + Send>> + Send + Sync>;
pub struct HealthCheck {
pub name: String,
check_fn: HealthCheckFn,
pub is_liveness: bool,
pub is_readiness: bool,
pub timeout: Duration,
last_result: RwLock<Option<HealthCheckResult>>,
}
impl HealthCheck {
pub fn new(name: impl Into<String>, check_fn: HealthCheckFn) -> Self {
Self {
name: name.into(),
check_fn,
is_liveness: false,
is_readiness: false,
timeout: Duration::from_secs(5),
last_result: RwLock::new(None),
}
}
pub fn liveness(mut self) -> Self {
self.is_liveness = true;
self
}
pub fn readiness(mut self) -> Self {
self.is_readiness = true;
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub async fn check(&self) -> HealthCheckResult {
let start = Instant::now();
let status = (self.check_fn)().await;
let duration = start.elapsed();
let result = HealthCheckResult {
name: self.name.clone(),
status,
message: None,
duration,
checked_at: Instant::now(),
};
*self.last_result.write() = Some(result.clone());
result
}
pub fn last_result(&self) -> Option<HealthCheckResult> {
self.last_result.read().clone()
}
}
pub struct HealthChecker {
checks: RwLock<Vec<Arc<HealthCheck>>>,
#[allow(dead_code)]
check_interval: Duration,
#[allow(dead_code)]
running: std::sync::atomic::AtomicBool,
}
impl HealthChecker {
pub fn new() -> Arc<Self> {
Arc::new(Self {
checks: RwLock::new(Vec::new()),
check_interval: Duration::from_secs(10),
running: std::sync::atomic::AtomicBool::new(false),
})
}
pub fn with_interval(self: Arc<Self>, interval: Duration) -> Arc<Self> {
let _ = interval;
self
}
pub fn register(&self, check: HealthCheck) {
self.checks.write().push(Arc::new(check));
}
pub fn register_liveness<F, Fut>(&self, name: impl Into<String>, check_fn: F)
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = bool> + Send + 'static,
{
let name = name.into();
let check = HealthCheck::new(
name,
Arc::new(move || {
let fut = check_fn();
Box::pin(async move {
if fut.await {
HealthStatus::Healthy
} else {
HealthStatus::Unhealthy
}
}) as Pin<Box<dyn Future<Output = HealthStatus> + Send>>
}),
)
.liveness();
self.register(check);
}
pub fn register_readiness<F, Fut>(&self, name: impl Into<String>, check_fn: F)
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = bool> + Send + 'static,
{
let name = name.into();
let check = HealthCheck::new(
name,
Arc::new(move || {
let fut = check_fn();
Box::pin(async move {
if fut.await {
HealthStatus::Healthy
} else {
HealthStatus::Unhealthy
}
}) as Pin<Box<dyn Future<Output = HealthStatus> + Send>>
}),
)
.readiness();
self.register(check);
}
pub async fn check_all(&self) -> Vec<HealthCheckResult> {
let checks = self.checks.read().clone();
let mut results = Vec::with_capacity(checks.len());
for check in checks {
results.push(check.check().await);
}
results
}
pub async fn check_liveness(&self) -> Vec<HealthCheckResult> {
let checks = self.checks.read().clone();
let mut results = Vec::new();
for check in checks.iter().filter(|c| c.is_liveness) {
results.push(check.check().await);
}
results
}
pub async fn check_readiness(&self) -> Vec<HealthCheckResult> {
let checks = self.checks.read().clone();
let mut results = Vec::new();
for check in checks.iter().filter(|c| c.is_readiness) {
results.push(check.check().await);
}
results
}
pub async fn is_alive(&self) -> bool {
let results = self.check_liveness().await;
results.iter().all(|r| r.status.is_healthy())
}
pub async fn is_ready(&self) -> bool {
let results = self.check_readiness().await;
results.iter().all(|r| r.status.is_healthy())
}
pub async fn aggregate_status(&self) -> HealthStatus {
let results = self.check_all().await;
if results.is_empty() {
return HealthStatus::Unknown;
}
let all_healthy = results.iter().all(|r| r.status == HealthStatus::Healthy);
let any_unhealthy = results.iter().any(|r| r.status == HealthStatus::Unhealthy);
if all_healthy {
HealthStatus::Healthy
} else if any_unhealthy {
HealthStatus::Unhealthy
} else {
HealthStatus::Degraded
}
}
pub fn check_count(&self) -> usize {
self.checks.read().len()
}
}
impl Default for HealthChecker {
fn default() -> Self {
Self {
checks: RwLock::new(Vec::new()),
check_interval: Duration::from_secs(10),
running: std::sync::atomic::AtomicBool::new(false),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub success_threshold: u32,
pub recovery_timeout: Duration,
pub window_duration: Duration,
pub half_open_max_requests: u32,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5,
success_threshold: 3,
recovery_timeout: Duration::from_secs(30),
window_duration: Duration::from_secs(60),
half_open_max_requests: 3,
}
}
}
pub struct CircuitBreaker {
config: CircuitBreakerConfig,
state: RwLock<CircuitState>,
failure_count: AtomicU32,
success_count: AtomicU32,
opened_at: RwLock<Option<Instant>>,
half_open_requests: AtomicU32,
total_requests: AtomicU64,
total_failures: AtomicU64,
total_rejections: AtomicU64,
}
impl CircuitBreaker {
pub fn new() -> Arc<Self> {
Self::with_config(CircuitBreakerConfig::default())
}
pub fn with_config(config: CircuitBreakerConfig) -> Arc<Self> {
Arc::new(Self {
config,
state: RwLock::new(CircuitState::Closed),
failure_count: AtomicU32::new(0),
success_count: AtomicU32::new(0),
opened_at: RwLock::new(None),
half_open_requests: AtomicU32::new(0),
total_requests: AtomicU64::new(0),
total_failures: AtomicU64::new(0),
total_rejections: AtomicU64::new(0),
})
}
pub fn state(&self) -> CircuitState {
let current_state = *self.state.read();
if current_state == CircuitState::Open {
if let Some(opened_at) = *self.opened_at.read() {
if opened_at.elapsed() >= self.config.recovery_timeout {
*self.state.write() = CircuitState::HalfOpen;
self.half_open_requests.store(0, Ordering::SeqCst);
self.success_count.store(0, Ordering::SeqCst);
return CircuitState::HalfOpen;
}
}
}
current_state
}
pub fn is_allowed(&self) -> bool {
match self.state() {
CircuitState::Closed => true,
CircuitState::Open => false,
CircuitState::HalfOpen => {
self.half_open_requests.load(Ordering::SeqCst) < self.config.half_open_max_requests
}
}
}
pub fn record_success(&self) {
self.total_requests.fetch_add(1, Ordering::Relaxed);
let state = self.state();
if state == CircuitState::HalfOpen {
let success_count = self.success_count.fetch_add(1, Ordering::SeqCst) + 1;
self.half_open_requests.fetch_sub(1, Ordering::SeqCst);
if success_count >= self.config.success_threshold {
self.close();
}
}
}
pub fn record_failure(&self) {
self.total_requests.fetch_add(1, Ordering::Relaxed);
self.total_failures.fetch_add(1, Ordering::Relaxed);
let state = self.state();
match state {
CircuitState::Closed => {
let failure_count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
if failure_count >= self.config.failure_threshold {
self.open();
}
}
CircuitState::HalfOpen => {
self.half_open_requests.fetch_sub(1, Ordering::SeqCst);
self.open();
}
CircuitState::Open => {}
}
}
pub fn record_rejection(&self) {
self.total_rejections.fetch_add(1, Ordering::Relaxed);
}
fn open(&self) {
*self.state.write() = CircuitState::Open;
*self.opened_at.write() = Some(Instant::now());
}
fn close(&self) {
*self.state.write() = CircuitState::Closed;
*self.opened_at.write() = None;
self.failure_count.store(0, Ordering::SeqCst);
self.success_count.store(0, Ordering::SeqCst);
}
pub fn reset(&self) {
self.close();
}
fn acquire_half_open(&self) -> bool {
if self.state() != CircuitState::HalfOpen {
return true;
}
let current = self.half_open_requests.load(Ordering::SeqCst);
if current >= self.config.half_open_max_requests {
return false;
}
self.half_open_requests.fetch_add(1, Ordering::SeqCst);
true
}
pub async fn execute<F, Fut, T, E>(&self, operation: F) -> Result<T>
where
F: FnOnce() -> Fut,
Fut: Future<Output = std::result::Result<T, E>>,
E: std::fmt::Display,
{
if !self.is_allowed() {
self.record_rejection();
return Err(RingKernelError::BackendError(
"Circuit breaker is open".to_string(),
));
}
if !self.acquire_half_open() {
self.record_rejection();
return Err(RingKernelError::BackendError(
"Circuit breaker half-open limit reached".to_string(),
));
}
match operation().await {
Ok(result) => {
self.record_success();
Ok(result)
}
Err(e) => {
self.record_failure();
Err(RingKernelError::BackendError(format!(
"Operation failed: {}",
e
)))
}
}
}
pub fn stats(&self) -> CircuitBreakerStats {
CircuitBreakerStats {
state: self.state(),
total_requests: self.total_requests.load(Ordering::Relaxed),
total_failures: self.total_failures.load(Ordering::Relaxed),
total_rejections: self.total_rejections.load(Ordering::Relaxed),
failure_count: self.failure_count.load(Ordering::Relaxed),
success_count: self.success_count.load(Ordering::Relaxed),
}
}
}
impl Default for CircuitBreaker {
fn default() -> Self {
Self {
config: CircuitBreakerConfig::default(),
state: RwLock::new(CircuitState::Closed),
failure_count: AtomicU32::new(0),
success_count: AtomicU32::new(0),
opened_at: RwLock::new(None),
half_open_requests: AtomicU32::new(0),
total_requests: AtomicU64::new(0),
total_failures: AtomicU64::new(0),
total_rejections: AtomicU64::new(0),
}
}
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerStats {
pub state: CircuitState,
pub total_requests: u64,
pub total_failures: u64,
pub total_rejections: u64,
pub failure_count: u32,
pub success_count: u32,
}
#[derive(Debug, Clone)]
pub enum BackoffStrategy {
Fixed(Duration),
Linear {
initial: Duration,
max: Duration,
},
Exponential {
initial: Duration,
max: Duration,
multiplier: f64,
},
None,
}
impl BackoffStrategy {
pub fn delay(&self, attempt: u32) -> Duration {
match self {
BackoffStrategy::Fixed(d) => *d,
BackoffStrategy::Linear { initial, max } => {
let delay = initial.mul_f64((attempt + 1) as f64);
delay.min(*max)
}
BackoffStrategy::Exponential {
initial,
max,
multiplier,
} => {
let factor = multiplier.powi(attempt as i32);
let delay = initial.mul_f64(factor);
delay.min(*max)
}
BackoffStrategy::None => Duration::ZERO,
}
}
}
#[derive(Clone)]
pub struct RetryPolicy {
pub max_attempts: u32,
pub backoff: BackoffStrategy,
pub jitter: bool,
#[allow(clippy::type_complexity)]
retryable: Option<Arc<dyn Fn(&str) -> bool + Send + Sync>>,
}
impl std::fmt::Debug for RetryPolicy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RetryPolicy")
.field("max_attempts", &self.max_attempts)
.field("backoff", &self.backoff)
.field("jitter", &self.jitter)
.field("retryable", &self.retryable.is_some())
.finish()
}
}
impl RetryPolicy {
pub fn new(max_attempts: u32) -> Self {
Self {
max_attempts,
backoff: BackoffStrategy::Exponential {
initial: Duration::from_millis(100),
max: Duration::from_secs(30),
multiplier: 2.0,
},
jitter: true,
retryable: None,
}
}
pub fn with_backoff(mut self, backoff: BackoffStrategy) -> Self {
self.backoff = backoff;
self
}
pub fn without_jitter(mut self) -> Self {
self.jitter = false;
self
}
pub fn with_retryable<F>(mut self, predicate: F) -> Self
where
F: Fn(&str) -> bool + Send + Sync + 'static,
{
self.retryable = Some(Arc::new(predicate));
self
}
pub fn is_retryable(&self, error: &str) -> bool {
self.retryable.as_ref().map(|p| p(error)).unwrap_or(true)
}
pub fn get_delay(&self, attempt: u32) -> Duration {
let base_delay = self.backoff.delay(attempt);
if self.jitter && base_delay > Duration::ZERO {
let jitter_factor = 0.75 + (rand_u64() % 50) as f64 / 200.0;
base_delay.mul_f64(jitter_factor)
} else {
base_delay
}
}
pub async fn execute<F, Fut, T, E>(&self, mut operation: F) -> Result<T>
where
F: FnMut() -> Fut,
Fut: Future<Output = std::result::Result<T, E>>,
E: std::fmt::Display,
{
let mut last_error = String::new();
for attempt in 0..self.max_attempts {
match operation().await {
Ok(result) => return Ok(result),
Err(e) => {
last_error = format!("{}", e);
if !self.is_retryable(&last_error) {
return Err(RingKernelError::BackendError(format!(
"Non-retryable error: {}",
last_error
)));
}
if attempt + 1 >= self.max_attempts {
break;
}
let delay = self.get_delay(attempt);
tokio::time::sleep(delay).await;
}
}
}
Err(RingKernelError::BackendError(format!(
"Operation failed after {} attempts: {}",
self.max_attempts, last_error
)))
}
}
impl Default for RetryPolicy {
fn default() -> Self {
Self::new(3)
}
}
fn rand_u64() -> u64 {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
std::time::SystemTime::now().hash(&mut hasher);
std::thread::current().id().hash(&mut hasher);
hasher.finish()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum DegradationLevel {
Normal = 0,
Light = 1,
Moderate = 2,
Severe = 3,
Critical = 4,
}
impl DegradationLevel {
pub fn next_worse(self) -> Self {
match self {
DegradationLevel::Normal => DegradationLevel::Light,
DegradationLevel::Light => DegradationLevel::Moderate,
DegradationLevel::Moderate => DegradationLevel::Severe,
DegradationLevel::Severe => DegradationLevel::Critical,
DegradationLevel::Critical => DegradationLevel::Critical,
}
}
pub fn next_better(self) -> Self {
match self {
DegradationLevel::Normal => DegradationLevel::Normal,
DegradationLevel::Light => DegradationLevel::Normal,
DegradationLevel::Moderate => DegradationLevel::Light,
DegradationLevel::Severe => DegradationLevel::Moderate,
DegradationLevel::Critical => DegradationLevel::Severe,
}
}
}
#[derive(Debug, Clone)]
pub struct LoadSheddingPolicy {
pub queue_threshold: usize,
pub cpu_threshold: f64,
pub memory_threshold: f64,
pub shed_ratio: f64,
}
impl Default for LoadSheddingPolicy {
fn default() -> Self {
Self {
queue_threshold: 10000,
cpu_threshold: 0.9,
memory_threshold: 0.85,
shed_ratio: 0.1,
}
}
}
pub struct DegradationManager {
level: RwLock<DegradationLevel>,
policy: LoadSheddingPolicy,
#[allow(clippy::type_complexity)]
callbacks: RwLock<Vec<Arc<dyn Fn(DegradationLevel, DegradationLevel) + Send + Sync>>>,
shed_counter: AtomicU64,
total_requests: AtomicU64,
shed_requests: AtomicU64,
}
impl DegradationManager {
pub fn new() -> Arc<Self> {
Arc::new(Self {
level: RwLock::new(DegradationLevel::Normal),
policy: LoadSheddingPolicy::default(),
callbacks: RwLock::new(Vec::new()),
shed_counter: AtomicU64::new(0),
total_requests: AtomicU64::new(0),
shed_requests: AtomicU64::new(0),
})
}
pub fn with_policy(policy: LoadSheddingPolicy) -> Arc<Self> {
Arc::new(Self {
level: RwLock::new(DegradationLevel::Normal),
policy,
callbacks: RwLock::new(Vec::new()),
shed_counter: AtomicU64::new(0),
total_requests: AtomicU64::new(0),
shed_requests: AtomicU64::new(0),
})
}
pub fn level(&self) -> DegradationLevel {
*self.level.read()
}
pub fn set_level(&self, new_level: DegradationLevel) {
let old_level = *self.level.read();
if old_level != new_level {
*self.level.write() = new_level;
let callbacks = self.callbacks.read().clone();
for callback in callbacks {
callback(old_level, new_level);
}
}
}
pub fn on_level_change<F>(&self, callback: F)
where
F: Fn(DegradationLevel, DegradationLevel) + Send + Sync + 'static,
{
self.callbacks.write().push(Arc::new(callback));
}
pub fn should_shed(&self) -> bool {
self.total_requests.fetch_add(1, Ordering::Relaxed);
let level = self.level();
if level == DegradationLevel::Normal {
return false;
}
let base_ratio = self.policy.shed_ratio;
let level_factor = match level {
DegradationLevel::Normal => 0.0,
DegradationLevel::Light => 1.0,
DegradationLevel::Moderate => 2.0,
DegradationLevel::Severe => 3.0,
DegradationLevel::Critical => 4.0,
};
let shed_probability = (base_ratio * level_factor).min(0.9);
let counter = self.shed_counter.fetch_add(1, Ordering::Relaxed);
let should_shed = (counter % 100) < (shed_probability * 100.0) as u64;
if should_shed {
self.shed_requests.fetch_add(1, Ordering::Relaxed);
}
should_shed
}
pub fn is_feature_disabled(&self, required_level: DegradationLevel) -> bool {
self.level() > required_level
}
pub fn stats(&self) -> DegradationStats {
let total = self.total_requests.load(Ordering::Relaxed);
let shed = self.shed_requests.load(Ordering::Relaxed);
DegradationStats {
level: self.level(),
total_requests: total,
shed_requests: shed,
shed_ratio: if total > 0 {
shed as f64 / total as f64
} else {
0.0
},
}
}
}
impl Default for DegradationManager {
fn default() -> Self {
Self {
level: RwLock::new(DegradationLevel::Normal),
policy: LoadSheddingPolicy::default(),
callbacks: RwLock::new(Vec::new()),
shed_counter: AtomicU64::new(0),
total_requests: AtomicU64::new(0),
shed_requests: AtomicU64::new(0),
}
}
}
#[derive(Debug, Clone)]
pub struct DegradationStats {
pub level: DegradationLevel,
pub total_requests: u64,
pub shed_requests: u64,
pub shed_ratio: f64,
}
#[derive(Debug, Clone)]
pub struct KernelHealth {
pub kernel_id: KernelId,
pub last_heartbeat: Instant,
pub status: HealthStatus,
pub failure_count: u32,
pub messages_per_sec: f64,
pub queue_depth: usize,
}
pub struct KernelWatchdog {
kernels: RwLock<HashMap<KernelId, KernelHealth>>,
heartbeat_timeout: Duration,
#[allow(dead_code)]
check_interval: Duration,
failure_threshold: u32,
#[allow(dead_code)]
running: std::sync::atomic::AtomicBool,
#[allow(clippy::type_complexity)]
callbacks: RwLock<Vec<Arc<dyn Fn(&KernelHealth) + Send + Sync>>>,
}
impl KernelWatchdog {
pub fn new() -> Arc<Self> {
Arc::new(Self {
kernels: RwLock::new(HashMap::new()),
heartbeat_timeout: Duration::from_secs(30),
check_interval: Duration::from_secs(5),
failure_threshold: 3,
running: std::sync::atomic::AtomicBool::new(false),
callbacks: RwLock::new(Vec::new()),
})
}
pub fn with_heartbeat_timeout(self: Arc<Self>, timeout: Duration) -> Arc<Self> {
let _ = timeout; self
}
pub fn watch(&self, kernel_id: KernelId) {
let health = KernelHealth {
kernel_id: kernel_id.clone(),
last_heartbeat: Instant::now(),
status: HealthStatus::Healthy,
failure_count: 0,
messages_per_sec: 0.0,
queue_depth: 0,
};
self.kernels.write().insert(kernel_id, health);
}
pub fn unwatch(&self, kernel_id: &KernelId) {
self.kernels.write().remove(kernel_id);
}
pub fn heartbeat(&self, kernel_id: &KernelId) {
if let Some(health) = self.kernels.write().get_mut(kernel_id) {
health.last_heartbeat = Instant::now();
health.failure_count = 0;
if health.status == HealthStatus::Unhealthy {
health.status = HealthStatus::Healthy;
}
}
}
pub fn update_metrics(&self, kernel_id: &KernelId, messages_per_sec: f64, queue_depth: usize) {
if let Some(health) = self.kernels.write().get_mut(kernel_id) {
health.messages_per_sec = messages_per_sec;
health.queue_depth = queue_depth;
}
}
pub fn check_all(&self) -> Vec<KernelHealth> {
let now = Instant::now();
let mut kernels = self.kernels.write();
let mut results = Vec::with_capacity(kernels.len());
for health in kernels.values_mut() {
if now.duration_since(health.last_heartbeat) > self.heartbeat_timeout {
health.failure_count += 1;
if health.failure_count >= self.failure_threshold {
health.status = HealthStatus::Unhealthy;
} else {
health.status = HealthStatus::Degraded;
}
}
results.push(health.clone());
}
drop(kernels);
let callbacks = self.callbacks.read().clone();
for health in results
.iter()
.filter(|h| h.status == HealthStatus::Unhealthy)
{
for callback in &callbacks {
callback(health);
}
}
results
}
pub fn on_unhealthy<F>(&self, callback: F)
where
F: Fn(&KernelHealth) + Send + Sync + 'static,
{
self.callbacks.write().push(Arc::new(callback));
}
pub fn get_health(&self, kernel_id: &KernelId) -> Option<KernelHealth> {
self.kernels.read().get(kernel_id).cloned()
}
pub fn unhealthy_kernels(&self) -> Vec<KernelHealth> {
self.kernels
.read()
.values()
.filter(|h| h.status == HealthStatus::Unhealthy)
.cloned()
.collect()
}
pub fn watched_count(&self) -> usize {
self.kernels.read().len()
}
}
impl Default for KernelWatchdog {
fn default() -> Self {
Self {
kernels: RwLock::new(HashMap::new()),
heartbeat_timeout: Duration::from_secs(30),
check_interval: Duration::from_secs(5),
failure_threshold: 3,
running: std::sync::atomic::AtomicBool::new(false),
callbacks: RwLock::new(Vec::new()),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum RecoveryPolicy {
Restart,
Migrate,
Checkpoint,
#[default]
Notify,
Escalate,
Circuit,
}
impl RecoveryPolicy {
pub fn severity(&self) -> u8 {
match self {
RecoveryPolicy::Notify => 1,
RecoveryPolicy::Checkpoint => 2,
RecoveryPolicy::Restart => 3,
RecoveryPolicy::Circuit => 4,
RecoveryPolicy::Migrate => 5,
RecoveryPolicy::Escalate => 6,
}
}
pub fn requires_intervention(&self) -> bool {
matches!(self, RecoveryPolicy::Notify | RecoveryPolicy::Escalate)
}
}
#[derive(Debug, Clone)]
pub struct RecoveryConfig {
pub max_restart_attempts: u32,
pub restart_delay: Duration,
pub checkpoint_before_restart: bool,
pub migrate_on_device_error: bool,
pub recovery_cooldown: Duration,
pub policies: HashMap<FailureType, RecoveryPolicy>,
}
impl Default for RecoveryConfig {
fn default() -> Self {
let mut policies = HashMap::new();
policies.insert(FailureType::Timeout, RecoveryPolicy::Restart);
policies.insert(FailureType::Crash, RecoveryPolicy::Restart);
policies.insert(FailureType::DeviceError, RecoveryPolicy::Migrate);
policies.insert(FailureType::ResourceExhausted, RecoveryPolicy::Circuit);
policies.insert(FailureType::Unknown, RecoveryPolicy::Notify);
Self {
max_restart_attempts: 3,
restart_delay: Duration::from_secs(5),
checkpoint_before_restart: true,
migrate_on_device_error: true,
recovery_cooldown: Duration::from_secs(60),
policies,
}
}
}
impl RecoveryConfig {
pub fn builder() -> RecoveryConfigBuilder {
RecoveryConfigBuilder::new()
}
#[allow(clippy::field_reassign_with_default)]
pub fn conservative() -> Self {
let mut config = Self::default();
config.max_restart_attempts = 1;
config.checkpoint_before_restart = true;
for policy in config.policies.values_mut() {
if *policy == RecoveryPolicy::Restart {
*policy = RecoveryPolicy::Notify;
}
}
config
}
#[allow(clippy::field_reassign_with_default)]
pub fn aggressive() -> Self {
let mut config = Self::default();
config.max_restart_attempts = 5;
config.checkpoint_before_restart = false;
config.restart_delay = Duration::from_secs(1);
config.recovery_cooldown = Duration::from_secs(10);
config
}
pub fn policy_for(&self, failure_type: FailureType) -> RecoveryPolicy {
self.policies
.get(&failure_type)
.copied()
.unwrap_or(RecoveryPolicy::Notify)
}
}
#[derive(Debug, Default)]
pub struct RecoveryConfigBuilder {
config: RecoveryConfig,
}
impl RecoveryConfigBuilder {
pub fn new() -> Self {
Self {
config: RecoveryConfig::default(),
}
}
pub fn max_restart_attempts(mut self, attempts: u32) -> Self {
self.config.max_restart_attempts = attempts;
self
}
pub fn restart_delay(mut self, delay: Duration) -> Self {
self.config.restart_delay = delay;
self
}
pub fn checkpoint_before_restart(mut self, enabled: bool) -> Self {
self.config.checkpoint_before_restart = enabled;
self
}
pub fn migrate_on_device_error(mut self, enabled: bool) -> Self {
self.config.migrate_on_device_error = enabled;
self
}
pub fn recovery_cooldown(mut self, cooldown: Duration) -> Self {
self.config.recovery_cooldown = cooldown;
self
}
pub fn policy(mut self, failure_type: FailureType, policy: RecoveryPolicy) -> Self {
self.config.policies.insert(failure_type, policy);
self
}
pub fn build(self) -> RecoveryConfig {
self.config
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum FailureType {
Timeout,
Crash,
DeviceError,
ResourceExhausted,
QueueOverflow,
StateCorruption,
Unknown,
}
impl FailureType {
pub fn description(&self) -> &'static str {
match self {
FailureType::Timeout => "Kernel heartbeat timeout",
FailureType::Crash => "Kernel crash",
FailureType::DeviceError => "GPU device error",
FailureType::ResourceExhausted => "Resource exhaustion",
FailureType::QueueOverflow => "Message queue overflow",
FailureType::StateCorruption => "State corruption detected",
FailureType::Unknown => "Unknown failure",
}
}
}
#[derive(Debug, Clone)]
pub struct RecoveryAction {
pub kernel_id: KernelId,
pub failure_type: FailureType,
pub policy: RecoveryPolicy,
pub attempt: u32,
pub created_at: Instant,
pub context: HashMap<String, String>,
}
impl RecoveryAction {
pub fn new(kernel_id: KernelId, failure_type: FailureType, policy: RecoveryPolicy) -> Self {
Self {
kernel_id,
failure_type,
policy,
attempt: 1,
created_at: Instant::now(),
context: HashMap::new(),
}
}
pub fn with_context(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.context.insert(key.into(), value.into());
self
}
pub fn with_attempt(mut self, attempt: u32) -> Self {
self.attempt = attempt;
self
}
}
#[derive(Debug, Clone)]
pub struct RecoveryResult {
pub action: RecoveryAction,
pub success: bool,
pub error: Option<String>,
pub duration: Duration,
pub next_action: Option<RecoveryPolicy>,
}
impl RecoveryResult {
pub fn success(action: RecoveryAction, duration: Duration) -> Self {
Self {
action,
success: true,
error: None,
duration,
next_action: None,
}
}
pub fn failure(action: RecoveryAction, error: String, duration: Duration) -> Self {
Self {
action,
success: false,
error: Some(error),
duration,
next_action: Some(RecoveryPolicy::Escalate),
}
}
pub fn failure_with_next(
action: RecoveryAction,
error: String,
duration: Duration,
next: RecoveryPolicy,
) -> Self {
Self {
action,
success: false,
error: Some(error),
duration,
next_action: Some(next),
}
}
}
pub type RecoveryHandler = Arc<
dyn Fn(&RecoveryAction) -> Pin<Box<dyn Future<Output = RecoveryResult> + Send>> + Send + Sync,
>;
pub struct RecoveryManager {
config: RwLock<RecoveryConfig>,
handlers: RwLock<HashMap<RecoveryPolicy, RecoveryHandler>>,
history: RwLock<HashMap<KernelId, Vec<RecoveryResult>>>,
attempts: RwLock<HashMap<KernelId, u32>>,
last_recovery: RwLock<HashMap<KernelId, Instant>>,
stats: RecoveryStats,
enabled: std::sync::atomic::AtomicBool,
}
impl RecoveryManager {
pub fn new() -> Self {
Self::with_config(RecoveryConfig::default())
}
pub fn with_config(config: RecoveryConfig) -> Self {
Self {
config: RwLock::new(config),
handlers: RwLock::new(HashMap::new()),
history: RwLock::new(HashMap::new()),
attempts: RwLock::new(HashMap::new()),
last_recovery: RwLock::new(HashMap::new()),
stats: RecoveryStats::default(),
enabled: std::sync::atomic::AtomicBool::new(true),
}
}
pub fn set_enabled(&self, enabled: bool) {
self.enabled.store(enabled, Ordering::SeqCst);
}
pub fn is_enabled(&self) -> bool {
self.enabled.load(Ordering::SeqCst)
}
pub fn set_config(&self, config: RecoveryConfig) {
*self.config.write() = config;
}
pub fn config(&self) -> RecoveryConfig {
self.config.read().clone()
}
pub fn register_handler(&self, policy: RecoveryPolicy, handler: RecoveryHandler) {
self.handlers.write().insert(policy, handler);
}
pub fn should_recover(&self, kernel_id: &KernelId) -> bool {
if !self.is_enabled() {
return false;
}
let config = self.config.read();
let last_recovery = self.last_recovery.read();
if let Some(last) = last_recovery.get(kernel_id) {
last.elapsed() >= config.recovery_cooldown
} else {
true
}
}
pub fn determine_action(
&self,
kernel_id: &KernelId,
failure_type: FailureType,
) -> RecoveryAction {
let config = self.config.read();
let attempts = self.attempts.read();
let current_attempt = attempts.get(kernel_id).copied().unwrap_or(0) + 1;
let policy = if current_attempt > config.max_restart_attempts {
RecoveryPolicy::Escalate
} else {
config.policy_for(failure_type)
};
RecoveryAction::new(kernel_id.clone(), failure_type, policy).with_attempt(current_attempt)
}
pub async fn recover(&self, action: RecoveryAction) -> RecoveryResult {
let _start = Instant::now();
let kernel_id = action.kernel_id.clone();
let policy = action.policy;
{
let mut attempts = self.attempts.write();
let count = attempts.entry(kernel_id.clone()).or_insert(0);
*count += 1;
}
self.last_recovery
.write()
.insert(kernel_id.clone(), Instant::now());
let handler = self.handlers.read().get(&policy).cloned();
let result = if let Some(handler) = handler {
self.stats.attempts.fetch_add(1, Ordering::Relaxed);
handler(&action).await
} else {
let result = self.default_recovery(&action).await;
result
};
if result.success {
self.stats.successes.fetch_add(1, Ordering::Relaxed);
self.attempts.write().remove(&kernel_id);
} else {
self.stats.failures.fetch_add(1, Ordering::Relaxed);
}
self.history
.write()
.entry(kernel_id)
.or_default()
.push(result.clone());
result
}
async fn default_recovery(&self, action: &RecoveryAction) -> RecoveryResult {
let start = Instant::now();
match action.policy {
RecoveryPolicy::Notify => {
RecoveryResult::success(action.clone(), start.elapsed())
}
RecoveryPolicy::Checkpoint => {
RecoveryResult::success(action.clone(), start.elapsed())
}
RecoveryPolicy::Restart => {
let config = self.config.read();
if action.attempt > config.max_restart_attempts {
RecoveryResult::failure_with_next(
action.clone(),
"Max restart attempts exceeded".to_string(),
start.elapsed(),
RecoveryPolicy::Escalate,
)
} else {
RecoveryResult::success(action.clone(), start.elapsed())
}
}
RecoveryPolicy::Migrate => {
RecoveryResult::success(action.clone(), start.elapsed())
}
RecoveryPolicy::Circuit => {
RecoveryResult::success(action.clone(), start.elapsed())
}
RecoveryPolicy::Escalate => {
RecoveryResult::failure(
action.clone(),
"Manual intervention required".to_string(),
start.elapsed(),
)
}
}
}
pub fn get_history(&self, kernel_id: &KernelId) -> Vec<RecoveryResult> {
self.history
.read()
.get(kernel_id)
.cloned()
.unwrap_or_default()
}
pub fn clear_history(&self) {
self.history.write().clear();
self.attempts.write().clear();
self.last_recovery.write().clear();
}
pub fn stats(&self) -> RecoveryStatsSnapshot {
RecoveryStatsSnapshot {
attempts: self.stats.attempts.load(Ordering::Relaxed),
successes: self.stats.successes.load(Ordering::Relaxed),
failures: self.stats.failures.load(Ordering::Relaxed),
kernels_tracked: self.history.read().len(),
}
}
}
impl Default for RecoveryManager {
fn default() -> Self {
Self::new()
}
}
#[derive(Default)]
struct RecoveryStats {
attempts: AtomicU64,
successes: AtomicU64,
failures: AtomicU64,
}
#[derive(Debug, Clone, Default)]
pub struct RecoveryStatsSnapshot {
pub attempts: u64,
pub successes: u64,
pub failures: u64,
pub kernels_tracked: usize,
}
impl RecoveryStatsSnapshot {
pub fn success_rate(&self) -> f64 {
if self.attempts == 0 {
1.0
} else {
self.successes as f64 / self.attempts as f64
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_health_status() {
assert!(HealthStatus::Healthy.is_healthy());
assert!(HealthStatus::Degraded.is_healthy());
assert!(!HealthStatus::Unhealthy.is_healthy());
assert!(HealthStatus::Unhealthy.is_unhealthy());
}
#[tokio::test]
async fn test_health_checker() {
let checker = HealthChecker::new();
checker.register_liveness("test_alive", || async { true });
checker.register_readiness("test_ready", || async { true });
assert_eq!(checker.check_count(), 2);
assert!(checker.is_alive().await);
assert!(checker.is_ready().await);
}
#[tokio::test]
async fn test_health_checker_unhealthy() {
let checker = HealthChecker::new();
checker.register_liveness("failing_check", || async { false });
assert!(!checker.is_alive().await);
let status = checker.aggregate_status().await;
assert_eq!(status, HealthStatus::Unhealthy);
}
#[test]
fn test_circuit_breaker_initial_state() {
let breaker = CircuitBreaker::new();
assert_eq!(breaker.state(), CircuitState::Closed);
assert!(breaker.is_allowed());
}
#[test]
fn test_circuit_breaker_opens_on_failures() {
let config = CircuitBreakerConfig {
failure_threshold: 3,
..Default::default()
};
let breaker = CircuitBreaker::with_config(config);
breaker.record_failure();
breaker.record_failure();
assert_eq!(breaker.state(), CircuitState::Closed);
breaker.record_failure();
assert_eq!(breaker.state(), CircuitState::Open);
assert!(!breaker.is_allowed());
}
#[test]
fn test_circuit_breaker_reset() {
let config = CircuitBreakerConfig {
failure_threshold: 1,
..Default::default()
};
let breaker = CircuitBreaker::with_config(config);
breaker.record_failure();
assert_eq!(breaker.state(), CircuitState::Open);
breaker.reset();
assert_eq!(breaker.state(), CircuitState::Closed);
}
#[test]
fn test_backoff_strategy_fixed() {
let backoff = BackoffStrategy::Fixed(Duration::from_secs(1));
assert_eq!(backoff.delay(0), Duration::from_secs(1));
assert_eq!(backoff.delay(5), Duration::from_secs(1));
}
#[test]
fn test_backoff_strategy_exponential() {
let backoff = BackoffStrategy::Exponential {
initial: Duration::from_millis(100),
max: Duration::from_secs(10),
multiplier: 2.0,
};
assert_eq!(backoff.delay(0), Duration::from_millis(100));
assert_eq!(backoff.delay(1), Duration::from_millis(200));
assert_eq!(backoff.delay(2), Duration::from_millis(400));
}
#[test]
fn test_backoff_strategy_linear() {
let backoff = BackoffStrategy::Linear {
initial: Duration::from_millis(100),
max: Duration::from_secs(1),
};
assert_eq!(backoff.delay(0), Duration::from_millis(100));
assert_eq!(backoff.delay(1), Duration::from_millis(200));
assert_eq!(backoff.delay(9), Duration::from_secs(1)); }
#[tokio::test]
async fn test_retry_policy_success() {
let policy = RetryPolicy::new(3);
let result: Result<i32> = policy.execute(|| async { Ok::<_, &str>(42) }).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
}
#[test]
fn test_degradation_manager_levels() {
let manager = DegradationManager::new();
assert_eq!(manager.level(), DegradationLevel::Normal);
manager.set_level(DegradationLevel::Moderate);
assert_eq!(manager.level(), DegradationLevel::Moderate);
}
#[test]
fn test_degradation_feature_disabled() {
let manager = DegradationManager::new();
manager.set_level(DegradationLevel::Severe);
assert!(!manager.is_feature_disabled(DegradationLevel::Critical));
assert!(manager.is_feature_disabled(DegradationLevel::Moderate));
assert!(manager.is_feature_disabled(DegradationLevel::Normal));
}
#[test]
fn test_kernel_watchdog() {
let watchdog = KernelWatchdog::new();
let kernel_id = KernelId::new("test_kernel");
watchdog.watch(kernel_id.clone());
assert_eq!(watchdog.watched_count(), 1);
watchdog.heartbeat(&kernel_id);
let health = watchdog.get_health(&kernel_id).unwrap();
assert_eq!(health.status, HealthStatus::Healthy);
}
#[test]
fn test_kernel_watchdog_metrics() {
let watchdog = KernelWatchdog::new();
let kernel_id = KernelId::new("test_kernel");
watchdog.watch(kernel_id.clone());
watchdog.update_metrics(&kernel_id, 1000.0, 50);
let health = watchdog.get_health(&kernel_id).unwrap();
assert_eq!(health.messages_per_sec, 1000.0);
assert_eq!(health.queue_depth, 50);
}
#[test]
fn test_recovery_policy_severity() {
assert!(RecoveryPolicy::Notify.severity() < RecoveryPolicy::Restart.severity());
assert!(RecoveryPolicy::Restart.severity() < RecoveryPolicy::Migrate.severity());
assert!(RecoveryPolicy::Migrate.severity() < RecoveryPolicy::Escalate.severity());
}
#[test]
fn test_recovery_policy_requires_intervention() {
assert!(RecoveryPolicy::Notify.requires_intervention());
assert!(RecoveryPolicy::Escalate.requires_intervention());
assert!(!RecoveryPolicy::Restart.requires_intervention());
assert!(!RecoveryPolicy::Migrate.requires_intervention());
}
#[test]
fn test_recovery_config_default() {
let config = RecoveryConfig::default();
assert_eq!(config.max_restart_attempts, 3);
assert!(config.checkpoint_before_restart);
assert!(config.migrate_on_device_error);
assert_eq!(
config.policy_for(FailureType::Timeout),
RecoveryPolicy::Restart
);
assert_eq!(
config.policy_for(FailureType::DeviceError),
RecoveryPolicy::Migrate
);
}
#[test]
fn test_recovery_config_conservative() {
let config = RecoveryConfig::conservative();
assert_eq!(config.max_restart_attempts, 1);
assert_eq!(
config.policy_for(FailureType::Timeout),
RecoveryPolicy::Notify
);
}
#[test]
fn test_recovery_config_aggressive() {
let config = RecoveryConfig::aggressive();
assert_eq!(config.max_restart_attempts, 5);
assert!(!config.checkpoint_before_restart);
assert_eq!(config.restart_delay, Duration::from_secs(1));
}
#[test]
fn test_recovery_config_builder() {
let config = RecoveryConfig::builder()
.max_restart_attempts(10)
.restart_delay(Duration::from_secs(2))
.checkpoint_before_restart(false)
.recovery_cooldown(Duration::from_secs(30))
.policy(FailureType::Crash, RecoveryPolicy::Migrate)
.build();
assert_eq!(config.max_restart_attempts, 10);
assert_eq!(config.restart_delay, Duration::from_secs(2));
assert!(!config.checkpoint_before_restart);
assert_eq!(config.recovery_cooldown, Duration::from_secs(30));
assert_eq!(
config.policy_for(FailureType::Crash),
RecoveryPolicy::Migrate
);
}
#[test]
fn test_failure_type_description() {
assert_eq!(
FailureType::Timeout.description(),
"Kernel heartbeat timeout"
);
assert_eq!(FailureType::Crash.description(), "Kernel crash");
assert_eq!(FailureType::DeviceError.description(), "GPU device error");
}
#[test]
fn test_recovery_action() {
let kernel_id = KernelId::new("test_kernel");
let action = RecoveryAction::new(
kernel_id.clone(),
FailureType::Timeout,
RecoveryPolicy::Restart,
)
.with_context("reason", "heartbeat missed")
.with_attempt(2);
assert_eq!(action.kernel_id, kernel_id);
assert_eq!(action.failure_type, FailureType::Timeout);
assert_eq!(action.policy, RecoveryPolicy::Restart);
assert_eq!(action.attempt, 2);
assert_eq!(
action.context.get("reason"),
Some(&"heartbeat missed".to_string())
);
}
#[test]
fn test_recovery_result() {
let action = RecoveryAction::new(
KernelId::new("test"),
FailureType::Crash,
RecoveryPolicy::Restart,
);
let success = RecoveryResult::success(action.clone(), Duration::from_millis(100));
assert!(success.success);
assert!(success.error.is_none());
assert!(success.next_action.is_none());
let failure = RecoveryResult::failure(
action.clone(),
"Failed".to_string(),
Duration::from_millis(50),
);
assert!(!failure.success);
assert_eq!(failure.error, Some("Failed".to_string()));
assert_eq!(failure.next_action, Some(RecoveryPolicy::Escalate));
}
#[test]
fn test_recovery_manager_creation() {
let manager = RecoveryManager::new();
assert!(manager.is_enabled());
let stats = manager.stats();
assert_eq!(stats.attempts, 0);
assert_eq!(stats.successes, 0);
assert_eq!(stats.failures, 0);
}
#[test]
fn test_recovery_manager_enable_disable() {
let manager = RecoveryManager::new();
assert!(manager.is_enabled());
manager.set_enabled(false);
assert!(!manager.is_enabled());
manager.set_enabled(true);
assert!(manager.is_enabled());
}
#[test]
fn test_recovery_manager_determine_action() {
let manager = RecoveryManager::new();
let kernel_id = KernelId::new("test_kernel");
let action = manager.determine_action(&kernel_id, FailureType::Timeout);
assert_eq!(action.kernel_id, kernel_id);
assert_eq!(action.failure_type, FailureType::Timeout);
assert_eq!(action.policy, RecoveryPolicy::Restart);
assert_eq!(action.attempt, 1);
}
#[test]
fn test_recovery_manager_should_recover() {
let config = RecoveryConfig::builder()
.recovery_cooldown(Duration::from_millis(10))
.build();
let manager = RecoveryManager::with_config(config);
let kernel_id = KernelId::new("test_kernel");
assert!(manager.should_recover(&kernel_id));
manager.set_enabled(false);
assert!(!manager.should_recover(&kernel_id));
}
#[tokio::test]
async fn test_recovery_manager_recover() {
let manager = RecoveryManager::new();
let kernel_id = KernelId::new("test_kernel");
let action = RecoveryAction::new(
kernel_id.clone(),
FailureType::Timeout,
RecoveryPolicy::Notify,
);
let result = manager.recover(action).await;
assert!(result.success);
let stats = manager.stats();
assert_eq!(stats.successes, 1);
assert_eq!(stats.kernels_tracked, 1);
let history = manager.get_history(&kernel_id);
assert_eq!(history.len(), 1);
}
#[test]
fn test_recovery_stats_snapshot_success_rate() {
let stats = RecoveryStatsSnapshot {
attempts: 10,
successes: 8,
failures: 2,
kernels_tracked: 3,
};
assert!((stats.success_rate() - 0.8).abs() < 0.001);
let empty = RecoveryStatsSnapshot::default();
assert_eq!(empty.success_rate(), 1.0);
}
#[test]
fn test_recovery_manager_clear_history() {
let manager = RecoveryManager::new();
let kernel_id = KernelId::new("test_kernel");
manager.attempts.write().insert(kernel_id.clone(), 5);
manager.clear_history();
assert!(manager.get_history(&kernel_id).is_empty());
assert!(manager.attempts.read().is_empty());
}
}