#![allow(dead_code)]
use crate::distributed_monitoring::{AlertSeverity, DistributedMonitor, NodeHealthStatus};
use crate::{TorshDistributedError, TorshResult};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use tracing::{debug, info, warn};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum FailureType {
NodeUnresponsive { node_id: String, last_seen: u64 },
CommunicationFailure {
source: String,
target: String,
error: String,
},
OutOfMemory {
node_id: String,
available_mb: u64,
requested_mb: u64,
},
GpuFailure {
node_id: String,
gpu_id: u32,
error: String,
},
TrainingDivergence { loss_value: f32, gradient_norm: f32 },
CheckpointFailure {
checkpoint_path: String,
error: String,
},
NetworkPartition { affected_nodes: Vec<String> },
StorageFailure { path: String, error: String },
ResourceExhaustion {
resource_type: String,
node_id: String,
},
Custom {
failure_name: String,
details: String,
},
}
impl std::fmt::Display for FailureType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FailureType::NodeUnresponsive { node_id, .. } => {
write!(f, "Node {} is unresponsive", node_id)
}
FailureType::CommunicationFailure { source, target, .. } => {
write!(f, "Communication failure from {} to {}", source, target)
}
FailureType::OutOfMemory { node_id, .. } => {
write!(f, "Out of memory on node {}", node_id)
}
FailureType::GpuFailure {
node_id, gpu_id, ..
} => write!(f, "GPU {} failure on node {}", gpu_id, node_id),
FailureType::TrainingDivergence { loss_value, .. } => {
write!(f, "Training divergence detected (loss: {})", loss_value)
}
FailureType::CheckpointFailure {
checkpoint_path, ..
} => write!(f, "Checkpoint failure: {}", checkpoint_path),
FailureType::NetworkPartition { affected_nodes } => write!(
f,
"Network partition affecting {} nodes",
affected_nodes.len()
),
FailureType::StorageFailure { path, .. } => write!(f, "Storage failure: {}", path),
FailureType::ResourceExhaustion {
resource_type,
node_id,
} => write!(f, "{} exhaustion on node {}", resource_type, node_id),
FailureType::Custom { failure_name, .. } => {
write!(f, "Custom failure: {}", failure_name)
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum RecoveryStrategy {
RestartNode { node_id: String, max_attempts: u32 },
ExcludeNode { node_id: String },
LoadCheckpoint {
checkpoint_path: String,
rollback_steps: u32,
},
ReduceBatchSize {
new_batch_size: u32,
reduction_factor: f32,
},
RedistributeWork {
failed_nodes: Vec<String>,
target_nodes: Vec<String>,
},
ScaleDown {
new_world_size: u32,
keep_nodes: Vec<String>,
},
ResetTrainingState {
reset_optimizer: bool,
reset_lr_schedule: bool,
},
DegradedMode { disabled_features: Vec<String> },
EmergencyStop { save_checkpoint: bool },
Custom {
action_name: String,
parameters: HashMap<String, String>,
},
}
impl std::fmt::Display for RecoveryStrategy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RecoveryStrategy::RestartNode { node_id, .. } => write!(f, "Restart node {}", node_id),
RecoveryStrategy::ExcludeNode { node_id } => write!(f, "Exclude node {}", node_id),
RecoveryStrategy::LoadCheckpoint {
checkpoint_path, ..
} => write!(f, "Load checkpoint: {}", checkpoint_path),
RecoveryStrategy::ReduceBatchSize { new_batch_size, .. } => {
write!(f, "Reduce batch size to {}", new_batch_size)
}
RecoveryStrategy::RedistributeWork { failed_nodes, .. } => write!(
f,
"Redistribute work from {} failed nodes",
failed_nodes.len()
),
RecoveryStrategy::ScaleDown { new_world_size, .. } => {
write!(f, "Scale down to {} nodes", new_world_size)
}
RecoveryStrategy::ResetTrainingState { .. } => write!(f, "Reset training state"),
RecoveryStrategy::DegradedMode { disabled_features } => write!(
f,
"Enter degraded mode (disable {} features)",
disabled_features.len()
),
RecoveryStrategy::EmergencyStop { .. } => write!(f, "Emergency stop"),
RecoveryStrategy::Custom { action_name, .. } => {
write!(f, "Custom action: {}", action_name)
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum RecoveryStatus {
Pending,
InProgress { progress: f32, stage: String },
Completed { duration_ms: u64 },
Failed { error: String, retry_count: u32 },
Cancelled { reason: String },
}
impl std::fmt::Display for RecoveryStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RecoveryStatus::Pending => write!(f, "Pending"),
RecoveryStatus::InProgress { progress, stage } => {
write!(f, "In Progress ({:.1}%): {}", progress * 100.0, stage)
}
RecoveryStatus::Completed { duration_ms } => {
write!(f, "Completed in {}ms", duration_ms)
}
RecoveryStatus::Failed { error, retry_count } => {
write!(f, "Failed (attempt {}): {}", retry_count, error)
}
RecoveryStatus::Cancelled { reason } => write!(f, "Cancelled: {}", reason),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FailureIncident {
pub id: String,
pub failure_type: FailureType,
pub detected_at: u64,
pub severity: AlertSeverity,
pub recovery_strategy: RecoveryStrategy,
pub recovery_status: RecoveryStatus,
pub recovery_started_at: Option<u64>,
pub recovery_completed_at: Option<u64>,
pub affected_nodes: Vec<String>,
pub context: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FaultToleranceConfig {
pub node_timeout: Duration,
pub communication_timeout: Duration,
pub max_recovery_attempts: u32,
pub enable_automatic_recovery: bool,
pub enable_predictive_detection: bool,
pub checkpoint_interval: Duration,
pub max_incident_history: usize,
pub detection_sensitivity: f32,
pub recovery_timeout: Duration,
pub enable_health_checks: bool,
pub health_check_interval: Duration,
}
impl Default for FaultToleranceConfig {
fn default() -> Self {
Self {
node_timeout: Duration::from_secs(30),
communication_timeout: Duration::from_secs(10),
max_recovery_attempts: 3,
enable_automatic_recovery: true,
enable_predictive_detection: true,
checkpoint_interval: Duration::from_secs(300), max_incident_history: 1000,
detection_sensitivity: 0.8,
recovery_timeout: Duration::from_secs(300), enable_health_checks: true,
health_check_interval: Duration::from_secs(10),
}
}
}
pub struct EnhancedFaultTolerance {
config: FaultToleranceConfig,
monitor: Arc<DistributedMonitor>,
active_incidents: Arc<RwLock<HashMap<String, FailureIncident>>>,
incident_history: Arc<Mutex<VecDeque<FailureIncident>>>,
node_status: Arc<RwLock<HashMap<String, NodeStatus>>>,
failure_predictor: Arc<Mutex<FailurePredictor>>,
recovery_executor: Arc<Mutex<RecoveryExecutor>>,
last_health_check: Arc<Mutex<Instant>>,
}
#[derive(Debug, Clone)]
struct NodeStatus {
last_heartbeat: Instant,
health_status: NodeHealthStatus,
consecutive_failures: u32,
is_excluded: bool,
last_metrics: Option<crate::distributed_monitoring::NodeMetrics>,
}
#[derive(Debug)]
struct FailurePredictor {
failure_patterns: HashMap<String, Vec<f32>>,
node_risk_scores: HashMap<String, f32>,
prediction_models: HashMap<String, PredictionModel>,
}
#[derive(Debug)]
struct PredictionModel {
historical_values: VecDeque<f32>,
trend_slope: f32,
variance: f32,
last_update: Instant,
}
impl PredictionModel {
fn new() -> Self {
Self {
historical_values: VecDeque::with_capacity(100),
trend_slope: 0.0,
variance: 0.0,
last_update: Instant::now(),
}
}
fn update(&mut self, value: f32) {
self.historical_values.push_back(value);
if self.historical_values.len() > 100 {
self.historical_values.pop_front();
}
if self.historical_values.len() >= 10 {
self.calculate_trend();
self.calculate_variance();
}
self.last_update = Instant::now();
}
fn calculate_trend(&mut self) {
let values: Vec<f32> = self.historical_values.iter().cloned().collect();
if values.len() < 2 {
return;
}
let n = values.len() as f32;
let sum_x: f32 = (0..values.len()).map(|i| i as f32).sum();
let sum_y: f32 = values.iter().sum();
let sum_xy: f32 = values.iter().enumerate().map(|(i, &y)| i as f32 * y).sum();
let sum_x2: f32 = (0..values.len()).map(|i| (i as f32).powi(2)).sum();
let denominator = n * sum_x2 - sum_x.powi(2);
if denominator.abs() > 0.001 {
self.trend_slope = (n * sum_xy - sum_x * sum_y) / denominator;
}
}
fn calculate_variance(&mut self) {
if self.historical_values.len() < 2 {
return;
}
let mean: f32 =
self.historical_values.iter().sum::<f32>() / self.historical_values.len() as f32;
self.variance = self
.historical_values
.iter()
.map(|&x| (x - mean).powi(2))
.sum::<f32>()
/ self.historical_values.len() as f32;
}
fn predict_failure_risk(&self) -> f32 {
if self.historical_values.len() < 10 {
return 0.0; }
let trend_risk = if self.trend_slope > 0.1 {
self.trend_slope * 2.0
} else {
0.0
};
let variance_risk = if self.variance > 1.0 {
(self.variance - 1.0) * 0.5
} else {
0.0
};
let staleness_risk = if self.last_update.elapsed().as_secs() > 60 {
0.3
} else {
0.0
};
(trend_risk + variance_risk + staleness_risk).min(1.0)
}
}
impl FailurePredictor {
fn new() -> Self {
Self {
failure_patterns: HashMap::new(),
node_risk_scores: HashMap::new(),
prediction_models: HashMap::new(),
}
}
fn update_node_metrics(
&mut self,
node_id: &str,
metrics: &crate::distributed_monitoring::NodeMetrics,
) {
{
let cpu_model = self
.prediction_models
.entry(format!("{}_cpu", node_id))
.or_insert_with(PredictionModel::new);
cpu_model.update(metrics.system_metrics.cpu_utilization);
}
{
let memory_model = self
.prediction_models
.entry(format!("{}_memory", node_id))
.or_insert_with(PredictionModel::new);
memory_model.update(metrics.system_metrics.memory_usage_mb as f32);
}
{
let latency_model = self
.prediction_models
.entry(format!("{}_latency", node_id))
.or_insert_with(PredictionModel::new);
latency_model.update(metrics.communication_metrics.avg_latency_us as f32);
}
let cpu_risk = self
.prediction_models
.get(&format!("{}_cpu", node_id))
.map_or(0.0, |model| model.predict_failure_risk());
let memory_risk = self
.prediction_models
.get(&format!("{}_memory", node_id))
.map_or(0.0, |model| model.predict_failure_risk());
let latency_risk = self
.prediction_models
.get(&format!("{}_latency", node_id))
.map_or(0.0, |model| model.predict_failure_risk());
let overall_risk = (cpu_risk + memory_risk + latency_risk) / 3.0;
self.node_risk_scores
.insert(node_id.to_string(), overall_risk);
}
fn get_node_risk_score(&self, node_id: &str) -> f32 {
self.node_risk_scores.get(node_id).copied().unwrap_or(0.0)
}
fn get_high_risk_nodes(&self, threshold: f32) -> Vec<String> {
self.node_risk_scores
.iter()
.filter(|(_, &risk)| risk > threshold)
.map(|(node_id, _)| node_id.clone())
.collect()
}
}
#[derive(Debug)]
struct RecoveryExecutor {
active_recoveries: HashMap<String, RecoveryExecution>,
}
#[derive(Debug)]
struct RecoveryExecution {
incident_id: String,
strategy: RecoveryStrategy,
start_time: Instant,
progress: f32,
current_stage: String,
retry_count: u32,
}
impl RecoveryExecutor {
fn new() -> Self {
Self {
active_recoveries: HashMap::new(),
}
}
fn start_recovery(
&mut self,
incident_id: String,
strategy: RecoveryStrategy,
) -> TorshResult<()> {
let execution = RecoveryExecution {
incident_id: incident_id.clone(),
strategy: strategy.clone(),
start_time: Instant::now(),
progress: 0.0,
current_stage: "Initializing recovery".to_string(),
retry_count: 0,
};
self.active_recoveries
.insert(incident_id.clone(), execution);
info!(
"Started recovery for incident {}: {}",
incident_id, strategy
);
Ok(())
}
fn update_recovery_progress(
&mut self,
incident_id: &str,
progress: f32,
stage: String,
) -> TorshResult<()> {
if let Some(execution) = self.active_recoveries.get_mut(incident_id) {
execution.progress = progress.clamp(0.0, 1.0);
execution.current_stage = stage;
debug!(
"Recovery {} progress: {:.1}% - {}",
incident_id,
progress * 100.0,
execution.current_stage
);
}
Ok(())
}
fn complete_recovery(
&mut self,
incident_id: &str,
success: bool,
error: Option<String>,
) -> TorshResult<RecoveryStatus> {
if let Some(execution) = self.active_recoveries.remove(incident_id) {
let duration_ms = execution.start_time.elapsed().as_millis() as u64;
if success {
info!(
"Recovery {} completed successfully in {}ms",
incident_id, duration_ms
);
Ok(RecoveryStatus::Completed { duration_ms })
} else {
let error_msg = error.unwrap_or_else(|| "Unknown error".to_string());
warn!("Recovery {} failed: {}", incident_id, error_msg);
Ok(RecoveryStatus::Failed {
error: error_msg,
retry_count: execution.retry_count,
})
}
} else {
Err(TorshDistributedError::communication_error(
"recovery_complete",
format!("Recovery {} not found", incident_id),
))
}
}
fn get_recovery_status(&self, incident_id: &str) -> Option<RecoveryStatus> {
self.active_recoveries
.get(incident_id)
.map(|execution| RecoveryStatus::InProgress {
progress: execution.progress,
stage: execution.current_stage.clone(),
})
}
}
impl EnhancedFaultTolerance {
pub fn new(config: FaultToleranceConfig, monitor: Arc<DistributedMonitor>) -> Self {
Self {
config: config.clone(),
monitor,
active_incidents: Arc::new(RwLock::new(HashMap::new())),
incident_history: Arc::new(Mutex::new(VecDeque::with_capacity(
config.max_incident_history,
))),
node_status: Arc::new(RwLock::new(HashMap::new())),
failure_predictor: Arc::new(Mutex::new(FailurePredictor::new())),
recovery_executor: Arc::new(Mutex::new(RecoveryExecutor::new())),
last_health_check: Arc::new(Mutex::new(Instant::now())),
}
}
pub fn detect_failures(&self) -> TorshResult<Vec<FailureType>> {
let mut detected_failures = Vec::new();
let alerts = self.monitor.get_active_alerts()?;
if let Some(current_metrics) = self.monitor.get_current_metrics()? {
self.check_node_responsiveness(¤t_metrics, &mut detected_failures)?;
self.check_training_divergence(¤t_metrics, &mut detected_failures)?;
self.check_resource_exhaustion(¤t_metrics, &mut detected_failures)?;
}
for alert in alerts {
match alert.metric_name.as_str() {
"cpu_utilization" | "gpu_utilization"
if alert.severity >= AlertSeverity::Critical =>
{
detected_failures.push(FailureType::ResourceExhaustion {
resource_type: alert.metric_name.clone(),
node_id: alert.node_id.clone(),
});
}
"avg_latency_us" if alert.severity >= AlertSeverity::Critical => {
detected_failures.push(FailureType::CommunicationFailure {
source: alert.node_id.clone(),
target: "cluster".to_string(),
error: "High latency detected".to_string(),
});
}
_ => {}
}
}
if self.config.enable_predictive_detection {
self.detect_predictive_failures(&mut detected_failures)?;
}
Ok(detected_failures)
}
fn check_node_responsiveness(
&self,
metrics: &crate::distributed_monitoring::NodeMetrics,
failures: &mut Vec<FailureType>,
) -> TorshResult<()> {
let node_status = self.node_status.read().map_err(|e| {
TorshDistributedError::communication_error("node_status", format!("Lock error: {}", e))
})?;
if let Some(status) = node_status.get(&metrics.node_id) {
if status.last_heartbeat.elapsed() > self.config.node_timeout {
failures.push(FailureType::NodeUnresponsive {
node_id: metrics.node_id.clone(),
last_seen: SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time should be after UNIX_EPOCH")
.as_millis() as u64,
});
}
}
Ok(())
}
fn check_training_divergence(
&self,
metrics: &crate::distributed_monitoring::NodeMetrics,
failures: &mut Vec<FailureType>,
) -> TorshResult<()> {
let loss = metrics.training_metrics.loss;
let gradient_norm = metrics.training_metrics.gradient_norm;
if !loss.is_finite() || !gradient_norm.is_finite() {
failures.push(FailureType::TrainingDivergence {
loss_value: loss,
gradient_norm,
});
}
if loss > 1000.0 || gradient_norm > 100.0 {
failures.push(FailureType::TrainingDivergence {
loss_value: loss,
gradient_norm,
});
}
Ok(())
}
fn check_resource_exhaustion(
&self,
metrics: &crate::distributed_monitoring::NodeMetrics,
failures: &mut Vec<FailureType>,
) -> TorshResult<()> {
if metrics.system_metrics.memory_usage_mb > 32000 {
failures.push(FailureType::OutOfMemory {
node_id: metrics.node_id.clone(),
available_mb: 32000 - metrics.system_metrics.memory_usage_mb,
requested_mb: metrics.system_metrics.memory_usage_mb,
});
}
if metrics.system_metrics.gpu_memory_mb > 20000 {
failures.push(FailureType::GpuFailure {
node_id: metrics.node_id.clone(),
gpu_id: 0, error: "GPU memory exhaustion".to_string(),
});
}
Ok(())
}
fn detect_predictive_failures(&self, failures: &mut Vec<FailureType>) -> TorshResult<()> {
let predictor = self.failure_predictor.lock().map_err(|e| {
TorshDistributedError::communication_error("predictor", format!("Lock error: {}", e))
})?;
let high_risk_nodes = predictor.get_high_risk_nodes(0.7);
for node_id in high_risk_nodes {
let risk_score = predictor.get_node_risk_score(&node_id);
warn!(
"Predictive failure detection: node {} has high risk score {:.2}",
node_id, risk_score
);
failures.push(FailureType::Custom {
failure_name: "Predictive Failure Risk".to_string(),
details: format!("Node {} has risk score {:.2}", node_id, risk_score),
});
}
Ok(())
}
pub fn handle_failure(&self, failure: FailureType) -> TorshResult<String> {
let incident_id = format!(
"incident_{}_{}",
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time should be after UNIX_EPOCH")
.as_millis(),
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time should be after UNIX_EPOCH")
.as_nanos()
% 100000
);
let recovery_strategy = self.determine_recovery_strategy(&failure)?;
let incident = FailureIncident {
id: incident_id.clone(),
failure_type: failure.clone(),
detected_at: SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time should be after UNIX_EPOCH")
.as_millis() as u64,
severity: self.determine_failure_severity(&failure),
recovery_strategy: recovery_strategy.clone(),
recovery_status: RecoveryStatus::Pending,
recovery_started_at: None,
recovery_completed_at: None,
affected_nodes: self.get_affected_nodes(&failure),
context: HashMap::new(),
};
{
let mut active_incidents = self.active_incidents.write().map_err(|e| {
TorshDistributedError::communication_error(
"incidents",
format!("Lock error: {}", e),
)
})?;
active_incidents.insert(incident_id.clone(), incident);
}
info!(
"Handling failure: {} with strategy: {}",
failure, recovery_strategy
);
if self.config.enable_automatic_recovery {
self.execute_recovery(&incident_id, recovery_strategy)?;
}
Ok(incident_id)
}
fn determine_recovery_strategy(&self, failure: &FailureType) -> TorshResult<RecoveryStrategy> {
let strategy = match failure {
FailureType::NodeUnresponsive { node_id, .. } => RecoveryStrategy::RestartNode {
node_id: node_id.clone(),
max_attempts: 3,
},
FailureType::CommunicationFailure { source, .. } => RecoveryStrategy::ExcludeNode {
node_id: source.clone(),
},
FailureType::OutOfMemory { .. } => {
RecoveryStrategy::ReduceBatchSize {
new_batch_size: 16, reduction_factor: 0.5,
}
}
FailureType::GpuFailure { node_id, .. } => RecoveryStrategy::ExcludeNode {
node_id: node_id.clone(),
},
FailureType::TrainingDivergence { .. } => RecoveryStrategy::LoadCheckpoint {
checkpoint_path: "latest_checkpoint".to_string(),
rollback_steps: 100,
},
FailureType::CheckpointFailure { .. } => RecoveryStrategy::EmergencyStop {
save_checkpoint: true,
},
FailureType::NetworkPartition { affected_nodes } => {
let healthy_nodes = self.get_healthy_nodes()?;
RecoveryStrategy::RedistributeWork {
failed_nodes: affected_nodes.clone(),
target_nodes: healthy_nodes,
}
}
FailureType::StorageFailure { .. } => RecoveryStrategy::DegradedMode {
disabled_features: vec!["checkpointing".to_string(), "logging".to_string()],
},
FailureType::ResourceExhaustion {
resource_type,
node_id,
} => match resource_type.as_str() {
"cpu_utilization" | "gpu_utilization" => RecoveryStrategy::ReduceBatchSize {
new_batch_size: 8,
reduction_factor: 0.75,
},
_ => RecoveryStrategy::ExcludeNode {
node_id: node_id.clone(),
},
},
FailureType::Custom { failure_name, .. } => {
if failure_name.contains("Predictive") {
RecoveryStrategy::Custom {
action_name: "PreventiveAction".to_string(),
parameters: [("monitoring".to_string(), "increased".to_string())]
.iter()
.cloned()
.collect(),
}
} else {
RecoveryStrategy::EmergencyStop {
save_checkpoint: true,
}
}
}
};
Ok(strategy)
}
fn determine_failure_severity(&self, failure: &FailureType) -> AlertSeverity {
match failure {
FailureType::NodeUnresponsive { .. } => AlertSeverity::Critical,
FailureType::CommunicationFailure { .. } => AlertSeverity::Warning,
FailureType::OutOfMemory { .. } => AlertSeverity::Critical,
FailureType::GpuFailure { .. } => AlertSeverity::Critical,
FailureType::TrainingDivergence { .. } => AlertSeverity::Emergency,
FailureType::CheckpointFailure { .. } => AlertSeverity::Warning,
FailureType::NetworkPartition { .. } => AlertSeverity::Critical,
FailureType::StorageFailure { .. } => AlertSeverity::Warning,
FailureType::ResourceExhaustion { .. } => AlertSeverity::Warning,
FailureType::Custom { .. } => AlertSeverity::Info,
}
}
fn get_affected_nodes(&self, failure: &FailureType) -> Vec<String> {
match failure {
FailureType::NodeUnresponsive { node_id, .. } => vec![node_id.clone()],
FailureType::CommunicationFailure { source, target, .. } => {
if target == "cluster" {
vec![source.clone()]
} else {
vec![source.clone(), target.clone()]
}
}
FailureType::OutOfMemory { node_id, .. } => vec![node_id.clone()],
FailureType::GpuFailure { node_id, .. } => vec![node_id.clone()],
FailureType::TrainingDivergence { .. } => vec![], FailureType::CheckpointFailure { .. } => vec![], FailureType::NetworkPartition { affected_nodes } => affected_nodes.clone(),
FailureType::StorageFailure { .. } => vec![], FailureType::ResourceExhaustion { node_id, .. } => vec![node_id.clone()],
FailureType::Custom { .. } => vec![],
}
}
fn get_healthy_nodes(&self) -> TorshResult<Vec<String>> {
let node_status = self.node_status.read().map_err(|e| {
TorshDistributedError::communication_error(
"healthy_nodes",
format!("Lock error: {}", e),
)
})?;
let healthy_nodes = node_status
.iter()
.filter(|(_, status)| {
matches!(status.health_status, NodeHealthStatus::Healthy) && !status.is_excluded
})
.map(|(node_id, _)| node_id.clone())
.collect();
Ok(healthy_nodes)
}
pub fn execute_recovery(
&self,
incident_id: &str,
strategy: RecoveryStrategy,
) -> TorshResult<()> {
let mut executor = self.recovery_executor.lock().map_err(|e| {
TorshDistributedError::communication_error("executor", format!("Lock error: {}", e))
})?;
executor.start_recovery(incident_id.to_string(), strategy.clone())?;
{
let mut active_incidents = self.active_incidents.write().map_err(|e| {
TorshDistributedError::communication_error(
"incidents",
format!("Lock error: {}", e),
)
})?;
if let Some(incident) = active_incidents.get_mut(incident_id) {
incident.recovery_status = RecoveryStatus::InProgress {
progress: 0.0,
stage: "Starting recovery".to_string(),
};
incident.recovery_started_at = Some(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time should be after UNIX_EPOCH")
.as_millis() as u64,
);
}
}
self.simulate_recovery_execution(&mut executor, incident_id, strategy)?;
Ok(())
}
fn simulate_recovery_execution(
&self,
executor: &mut RecoveryExecutor,
incident_id: &str,
strategy: RecoveryStrategy,
) -> TorshResult<()> {
info!("Executing recovery strategy: {}", strategy);
let stages = match strategy {
RecoveryStrategy::RestartNode { .. } => {
vec![
(0.2, "Stopping node services".to_string()),
(0.5, "Restarting node".to_string()),
(0.8, "Reinitializing training".to_string()),
(1.0, "Recovery complete".to_string()),
]
}
RecoveryStrategy::LoadCheckpoint { .. } => {
vec![
(0.3, "Loading checkpoint".to_string()),
(0.6, "Restoring model state".to_string()),
(0.9, "Synchronizing nodes".to_string()),
(1.0, "Recovery complete".to_string()),
]
}
RecoveryStrategy::ReduceBatchSize { .. } => {
vec![
(0.5, "Reducing batch size".to_string()),
(0.8, "Redistributing work".to_string()),
(1.0, "Recovery complete".to_string()),
]
}
_ => {
vec![
(0.5, "Executing recovery".to_string()),
(1.0, "Recovery complete".to_string()),
]
}
};
for (progress, stage) in stages {
executor.update_recovery_progress(incident_id, progress, stage)?;
std::thread::sleep(Duration::from_millis(100));
}
let success = (SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time should be after UNIX_EPOCH")
.as_nanos()
% 10)
!= 0;
let recovery_status = executor.complete_recovery(
incident_id,
success,
if success {
None
} else {
Some("Simulated failure".to_string())
},
)?;
{
let mut active_incidents = self.active_incidents.write().map_err(|e| {
TorshDistributedError::communication_error(
"incidents",
format!("Lock error: {}", e),
)
})?;
if let Some(incident) = active_incidents.get_mut(incident_id) {
incident.recovery_status = recovery_status;
incident.recovery_completed_at = Some(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time should be after UNIX_EPOCH")
.as_millis() as u64,
);
if matches!(
incident.recovery_status,
RecoveryStatus::Completed { .. } | RecoveryStatus::Failed { .. }
) {
let completed_incident = incident.clone();
drop(active_incidents);
let mut history = self.incident_history.lock().map_err(|e| {
TorshDistributedError::communication_error(
"history",
format!("Lock error: {}", e),
)
})?;
history.push_back(completed_incident);
if history.len() > self.config.max_incident_history {
history.pop_front();
}
}
}
}
Ok(())
}
pub fn update_node_heartbeat(
&self,
node_id: String,
metrics: crate::distributed_monitoring::NodeMetrics,
) -> TorshResult<()> {
{
let mut node_status = self.node_status.write().map_err(|e| {
TorshDistributedError::communication_error(
"node_heartbeat",
format!("Lock error: {}", e),
)
})?;
let status = node_status
.entry(node_id.clone())
.or_insert_with(|| NodeStatus {
last_heartbeat: Instant::now(),
health_status: NodeHealthStatus::Healthy,
consecutive_failures: 0,
is_excluded: false,
last_metrics: None,
});
status.last_heartbeat = Instant::now();
status.health_status = metrics.health_status.clone();
status.last_metrics = Some(metrics.clone());
if matches!(metrics.health_status, NodeHealthStatus::Healthy) {
status.consecutive_failures = 0;
}
}
if self.config.enable_predictive_detection {
let mut predictor = self.failure_predictor.lock().map_err(|e| {
TorshDistributedError::communication_error(
"predictor_update",
format!("Lock error: {}", e),
)
})?;
predictor.update_node_metrics(&node_id, &metrics);
}
Ok(())
}
pub fn get_status(&self) -> TorshResult<FaultToleranceStatus> {
let active_incidents = self.active_incidents.read().map_err(|e| {
TorshDistributedError::communication_error("status", format!("Lock error: {}", e))
})?;
let node_status = self.node_status.read().map_err(|e| {
TorshDistributedError::communication_error("status", format!("Lock error: {}", e))
})?;
let total_nodes = node_status.len();
let healthy_nodes = node_status
.values()
.filter(|s| matches!(s.health_status, NodeHealthStatus::Healthy))
.count();
let excluded_nodes = node_status.values().filter(|s| s.is_excluded).count();
let active_incident_count = active_incidents.len();
let recovering_incidents = active_incidents
.values()
.filter(|i| matches!(i.recovery_status, RecoveryStatus::InProgress { .. }))
.count();
Ok(FaultToleranceStatus {
total_nodes,
healthy_nodes,
excluded_nodes,
active_incidents: active_incident_count,
recovering_incidents,
system_health_score: self.calculate_system_health_score(&node_status)?,
last_incident_time: active_incidents.values().map(|i| i.detected_at).max(),
timestamp_ms: SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time should be after UNIX_EPOCH")
.as_millis() as u64,
})
}
fn calculate_system_health_score(
&self,
node_status: &HashMap<String, NodeStatus>,
) -> TorshResult<f32> {
if node_status.is_empty() {
return Ok(1.0);
}
let total_nodes = node_status.len() as f32;
let healthy_weight = node_status
.values()
.map(|status| match status.health_status {
NodeHealthStatus::Healthy => 1.0,
NodeHealthStatus::Degraded { .. } => 0.7,
NodeHealthStatus::Critical { .. } => 0.3,
NodeHealthStatus::Failed { .. } => 0.0,
NodeHealthStatus::Recovering { progress } => 0.5 + progress * 0.3,
})
.sum::<f32>();
Ok((healthy_weight / total_nodes).clamp(0.0, 1.0))
}
pub fn get_incident_history(&self) -> TorshResult<Vec<FailureIncident>> {
let history = self.incident_history.lock().map_err(|e| {
TorshDistributedError::communication_error(
"incident_history",
format!("Lock error: {}", e),
)
})?;
Ok(history.iter().cloned().collect())
}
pub fn export_fault_tolerance_data(&self) -> TorshResult<FaultToleranceExport> {
let status = self.get_status()?;
let incident_history = self.get_incident_history()?;
let active_incidents = self.active_incidents.read().map_err(|e| {
TorshDistributedError::communication_error("export", format!("Lock error: {}", e))
})?;
let current_incidents: Vec<FailureIncident> = active_incidents.values().cloned().collect();
Ok(FaultToleranceExport {
status,
current_incidents,
incident_history,
config: self.config.clone(),
export_timestamp_ms: SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time should be after UNIX_EPOCH")
.as_millis() as u64,
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FaultToleranceStatus {
pub total_nodes: usize,
pub healthy_nodes: usize,
pub excluded_nodes: usize,
pub active_incidents: usize,
pub recovering_incidents: usize,
pub system_health_score: f32,
pub last_incident_time: Option<u64>,
pub timestamp_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FaultToleranceExport {
pub status: FaultToleranceStatus,
pub current_incidents: Vec<FailureIncident>,
pub incident_history: Vec<FailureIncident>,
pub config: FaultToleranceConfig,
pub export_timestamp_ms: u64,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::distributed_monitoring::{DistributedMonitor, MonitoringConfig};
#[tokio::test]
async fn test_enhanced_fault_tolerance_creation() -> TorshResult<()> {
let monitor_config = MonitoringConfig::default();
let monitor = Arc::new(DistributedMonitor::new(monitor_config, false));
let ft_config = FaultToleranceConfig::default();
let fault_tolerance = EnhancedFaultTolerance::new(ft_config, monitor);
let status = fault_tolerance.get_status()?;
assert_eq!(status.total_nodes, 0);
assert_eq!(status.active_incidents, 0);
Ok(())
}
#[tokio::test]
async fn test_failure_detection() -> TorshResult<()> {
let monitor_config = MonitoringConfig::default();
let monitor = Arc::new(DistributedMonitor::new(monitor_config, false));
let ft_config = FaultToleranceConfig::default();
let fault_tolerance = EnhancedFaultTolerance::new(ft_config, monitor);
let failures = fault_tolerance.detect_failures()?;
assert!(failures.is_empty());
Ok(())
}
#[tokio::test]
async fn test_failure_handling() -> TorshResult<()> {
let monitor_config = MonitoringConfig::default();
let monitor = Arc::new(DistributedMonitor::new(monitor_config, false));
let ft_config = FaultToleranceConfig::default();
let fault_tolerance = EnhancedFaultTolerance::new(ft_config, monitor);
let failure = FailureType::NodeUnresponsive {
node_id: "test_node".to_string(),
last_seen: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as u64,
};
let incident_id = fault_tolerance.handle_failure(failure)?;
assert!(!incident_id.is_empty());
let status = fault_tolerance.get_status()?;
assert_eq!(status.active_incidents, 1);
Ok(())
}
#[tokio::test]
async fn test_prediction_model() -> TorshResult<()> {
let mut model = PredictionModel::new();
for i in 0..20 {
model.update(50.0 + (i as f32 % 5.0));
}
let normal_risk = model.predict_failure_risk();
assert!(
(0.0..=1.0).contains(&normal_risk),
"Risk should be normalized"
);
for i in 20..40 {
model.update(60.0 + i as f32);
}
let high_risk = model.predict_failure_risk();
assert!(
(0.0..=1.0).contains(&high_risk),
"Risk should be normalized"
);
Ok(())
}
#[tokio::test]
async fn test_recovery_strategy_determination() -> TorshResult<()> {
let monitor_config = MonitoringConfig::default();
let monitor = Arc::new(DistributedMonitor::new(monitor_config, false));
let ft_config = FaultToleranceConfig::default();
let fault_tolerance = EnhancedFaultTolerance::new(ft_config, monitor);
let failure = FailureType::OutOfMemory {
node_id: "test_node".to_string(),
available_mb: 1000,
requested_mb: 8000,
};
let strategy = fault_tolerance.determine_recovery_strategy(&failure)?;
match strategy {
RecoveryStrategy::ReduceBatchSize { .. } => {} _ => panic!("Unexpected recovery strategy for OOM failure"),
}
Ok(())
}
}