#![allow(dead_code)]
use crate::{TorshDistributedError, TorshResult};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tracing::{info, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SystemMetrics {
pub cpu_utilization: f32,
pub memory_usage_mb: u64,
pub gpu_utilization: f32,
pub gpu_memory_mb: u64,
pub network_bandwidth_mbps: f32,
pub disk_io_mbps: f32,
pub temperature_celsius: f32,
pub power_watts: f32,
pub timestamp_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingMetrics {
pub epoch: u32,
pub batch: u32,
pub loss: f32,
pub learning_rate: f32,
pub gradient_norm: f32,
pub throughput_samples_per_sec: f32,
pub batch_time_ms: u64,
pub batch_memory_mb: u64,
pub timestamp_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CommunicationMetrics {
pub allreduce_ops_per_sec: f32,
pub allgather_ops_per_sec: f32,
pub broadcast_ops_per_sec: f32,
pub p2p_ops_per_sec: f32,
pub avg_latency_us: u64,
pub comm_bandwidth_mbps: f32,
pub failed_ops_count: u32,
pub efficiency_score: f32,
pub timestamp_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum NodeHealthStatus {
Healthy,
Degraded { reason: String },
Critical { reason: String },
Failed { reason: String },
Recovering { progress: f32 },
}
#[derive(Debug, Clone)]
pub struct NodeMetricsUpdate {
pub node_id: String,
pub rank: u32,
pub world_size: u32,
pub training_loss: f32,
pub learning_rate: f32,
pub epoch: u32,
pub batch: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeMetrics {
pub node_id: String,
pub rank: u32,
pub world_size: u32,
pub system_metrics: SystemMetrics,
pub training_metrics: TrainingMetrics,
pub communication_metrics: CommunicationMetrics,
pub health_status: NodeHealthStatus,
pub custom_metrics: HashMap<String, f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
pub enum AlertSeverity {
Info,
Warning,
Critical,
Emergency,
}
impl std::fmt::Display for AlertSeverity {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AlertSeverity::Info => write!(f, "INFO"),
AlertSeverity::Warning => write!(f, "WARNING"),
AlertSeverity::Critical => write!(f, "CRITICAL"),
AlertSeverity::Emergency => write!(f, "EMERGENCY"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Alert {
pub id: String,
pub severity: AlertSeverity,
pub message: String,
pub node_id: String,
pub metric_name: String,
pub current_value: f64,
pub threshold_value: f64,
pub timestamp_ms: u64,
pub is_active: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MonitoringConfig {
pub collection_interval: Duration,
pub history_buffer_size: usize,
pub enable_gpu_monitoring: bool,
pub enable_comm_analysis: bool,
pub alert_thresholds: AlertThresholds,
pub max_alerts: usize,
pub enable_anomaly_detection: bool,
pub anomaly_sensitivity: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AlertThresholds {
pub cpu_warning_pct: f32,
pub cpu_critical_pct: f32,
pub memory_warning_pct: f32,
pub memory_critical_pct: f32,
pub gpu_warning_pct: f32,
pub gpu_critical_pct: f32,
pub latency_warning_us: u64,
pub latency_critical_us: u64,
pub throughput_degradation_warning_pct: f32,
pub throughput_degradation_critical_pct: f32,
}
impl Default for MonitoringConfig {
fn default() -> Self {
Self {
collection_interval: Duration::from_secs(5),
history_buffer_size: 1000,
enable_gpu_monitoring: true,
enable_comm_analysis: true,
alert_thresholds: AlertThresholds::default(),
max_alerts: 10000,
enable_anomaly_detection: true,
anomaly_sensitivity: 0.7,
}
}
}
impl Default for AlertThresholds {
fn default() -> Self {
Self {
cpu_warning_pct: 80.0,
cpu_critical_pct: 95.0,
memory_warning_pct: 80.0,
memory_critical_pct: 95.0,
gpu_warning_pct: 85.0,
gpu_critical_pct: 98.0,
latency_warning_us: 10000, latency_critical_us: 50000, throughput_degradation_warning_pct: 20.0,
throughput_degradation_critical_pct: 50.0,
}
}
}
pub struct DistributedMonitor {
config: MonitoringConfig,
current_metrics: Arc<RwLock<Option<NodeMetrics>>>,
metrics_history: Arc<Mutex<VecDeque<NodeMetrics>>>,
all_nodes_metrics: Arc<RwLock<HashMap<String, NodeMetrics>>>,
active_alerts: Arc<Mutex<Vec<Alert>>>,
alert_history: Arc<Mutex<VecDeque<Alert>>>,
performance_baselines: Arc<RwLock<HashMap<String, f64>>>,
anomaly_detector: Arc<Mutex<AnomalyDetector>>,
is_coordinator: bool,
}
#[derive(Debug)]
struct AnomalyDetector {
moving_averages: HashMap<String, f64>,
standard_deviations: HashMap<String, f64>,
sample_counts: HashMap<String, usize>,
threshold_multiplier: f64,
}
impl AnomalyDetector {
fn new(sensitivity: f32) -> Self {
Self {
moving_averages: HashMap::new(),
standard_deviations: HashMap::new(),
sample_counts: HashMap::new(),
threshold_multiplier: (2.0 - sensitivity as f64).max(1.0), }
}
fn update_metric(&mut self, metric_name: &str, value: f64) {
let avg = self
.moving_averages
.entry(metric_name.to_string())
.or_insert(value);
let count = self
.sample_counts
.entry(metric_name.to_string())
.or_insert(0);
let alpha = 0.1; *avg = alpha * value + (1.0 - alpha) * *avg;
*count += 1;
if *count > 1 {
let variance_estimate = (value - *avg).powi(2);
let std_dev = self
.standard_deviations
.entry(metric_name.to_string())
.or_insert(0.0);
*std_dev = alpha * variance_estimate.sqrt() + (1.0 - alpha) * *std_dev;
}
}
fn is_anomaly(&self, metric_name: &str, value: f64) -> bool {
if let (Some(&avg), Some(&std_dev)) = (
self.moving_averages.get(metric_name),
self.standard_deviations.get(metric_name),
) {
let z_score = (value - avg).abs() / std_dev.max(0.01); z_score > self.threshold_multiplier
} else {
false }
}
}
impl DistributedMonitor {
pub fn new(config: MonitoringConfig, is_coordinator: bool) -> Self {
let anomaly_detector = AnomalyDetector::new(config.anomaly_sensitivity);
Self {
config: config.clone(),
current_metrics: Arc::new(RwLock::new(None)),
metrics_history: Arc::new(Mutex::new(VecDeque::with_capacity(
config.history_buffer_size,
))),
all_nodes_metrics: Arc::new(RwLock::new(HashMap::new())),
active_alerts: Arc::new(Mutex::new(Vec::new())),
alert_history: Arc::new(Mutex::new(VecDeque::with_capacity(config.max_alerts))),
performance_baselines: Arc::new(RwLock::new(HashMap::new())),
anomaly_detector: Arc::new(Mutex::new(anomaly_detector)),
is_coordinator,
}
}
pub fn collect_system_metrics(&self) -> TorshResult<SystemMetrics> {
let timestamp_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time should be after UNIX_EPOCH")
.as_millis() as u64;
let base_time = timestamp_ms % 100000;
let variation = (base_time as f32 / 1000.0).sin();
Ok(SystemMetrics {
cpu_utilization: 45.0 + variation * 20.0, memory_usage_mb: 8000 + (variation * 2000.0) as u64, gpu_utilization: 80.0 + variation * 15.0, gpu_memory_mb: 16000 + (variation * 4000.0) as u64, network_bandwidth_mbps: 1000.0 + variation * 500.0, disk_io_mbps: 200.0 + variation * 100.0, temperature_celsius: 65.0 + variation * 10.0, power_watts: 250.0 + variation * 50.0, timestamp_ms,
})
}
pub fn collect_training_metrics(
&self,
current_loss: f32,
current_lr: f32,
epoch: u32,
batch: u32,
) -> TorshResult<TrainingMetrics> {
let timestamp_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time should be after UNIX_EPOCH")
.as_millis() as u64;
let gradient_norm = current_loss * 0.1 + 0.5; let throughput = 1000.0 / (current_loss + 0.1); let batch_time_ms = (1000.0 / throughput * 32.0) as u64; let batch_memory_mb = 2000 + (batch_time_ms / 10);
Ok(TrainingMetrics {
epoch,
batch,
loss: current_loss,
learning_rate: current_lr,
gradient_norm,
throughput_samples_per_sec: throughput,
batch_time_ms,
batch_memory_mb,
timestamp_ms,
})
}
pub fn collect_communication_metrics(&self) -> TorshResult<CommunicationMetrics> {
let timestamp_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time should be after UNIX_EPOCH")
.as_millis() as u64;
let base_ops = 10.0; let network_quality = 0.8;
Ok(CommunicationMetrics {
allreduce_ops_per_sec: base_ops * network_quality,
allgather_ops_per_sec: base_ops * 0.5 * network_quality,
broadcast_ops_per_sec: base_ops * 0.3 * network_quality,
p2p_ops_per_sec: base_ops * 0.2 * network_quality,
avg_latency_us: ((1.0 - network_quality) * 20000.0 + 1000.0) as u64,
comm_bandwidth_mbps: 800.0 * network_quality,
failed_ops_count: if network_quality < 0.9 { 1 } else { 0 },
efficiency_score: network_quality,
timestamp_ms,
})
}
pub fn update_node_metrics(&self, params: NodeMetricsUpdate) -> TorshResult<()> {
let NodeMetricsUpdate {
node_id,
rank,
world_size,
training_loss,
learning_rate,
epoch,
batch,
} = params;
let system_metrics = self.collect_system_metrics()?;
let training_metrics =
self.collect_training_metrics(training_loss, learning_rate, epoch, batch)?;
let communication_metrics = self.collect_communication_metrics()?;
let health_status =
self.assess_node_health(&system_metrics, &training_metrics, &communication_metrics)?;
let node_metrics = NodeMetrics {
node_id: node_id.clone(),
rank,
world_size,
system_metrics,
training_metrics,
communication_metrics,
health_status,
custom_metrics: HashMap::new(),
};
{
let mut current = self.current_metrics.write().map_err(|e| {
TorshDistributedError::communication_error(
"metrics_update",
format!("Lock error: {}", e),
)
})?;
*current = Some(node_metrics.clone());
}
{
let mut history = self.metrics_history.lock().map_err(|e| {
TorshDistributedError::communication_error(
"metrics_history",
format!("Lock error: {}", e),
)
})?;
history.push_back(node_metrics.clone());
if history.len() > self.config.history_buffer_size {
history.pop_front();
}
}
if self.is_coordinator {
let mut all_nodes = self.all_nodes_metrics.write().map_err(|e| {
TorshDistributedError::communication_error(
"all_nodes_update",
format!("Lock error: {}", e),
)
})?;
all_nodes.insert(node_id.clone(), node_metrics.clone());
}
self.check_and_generate_alerts(&node_metrics)?;
if self.config.enable_anomaly_detection {
self.update_anomaly_detection(&node_metrics)?;
}
info!(
"Updated metrics for node {} (rank {}): health={:?}",
node_id, rank, node_metrics.health_status
);
Ok(())
}
fn assess_node_health(
&self,
system: &SystemMetrics,
_training: &TrainingMetrics,
comm: &CommunicationMetrics,
) -> TorshResult<NodeHealthStatus> {
let thresholds = &self.config.alert_thresholds;
if system.cpu_utilization > thresholds.cpu_critical_pct {
return Ok(NodeHealthStatus::Critical {
reason: format!("CPU utilization at {:.1}%", system.cpu_utilization),
});
}
if system.gpu_utilization > thresholds.gpu_critical_pct {
return Ok(NodeHealthStatus::Critical {
reason: format!("GPU utilization at {:.1}%", system.gpu_utilization),
});
}
if comm.avg_latency_us > thresholds.latency_critical_us {
return Ok(NodeHealthStatus::Critical {
reason: format!("Communication latency at {}μs", comm.avg_latency_us),
});
}
if system.cpu_utilization > thresholds.cpu_warning_pct
|| system.gpu_utilization > thresholds.gpu_warning_pct
|| comm.avg_latency_us > thresholds.latency_warning_us
{
return Ok(NodeHealthStatus::Degraded {
reason: "Performance metrics above warning thresholds".to_string(),
});
}
if comm.efficiency_score < 0.7 {
return Ok(NodeHealthStatus::Degraded {
reason: format!("Communication efficiency at {:.2}", comm.efficiency_score),
});
}
Ok(NodeHealthStatus::Healthy)
}
fn check_and_generate_alerts(&self, metrics: &NodeMetrics) -> TorshResult<()> {
let thresholds = &self.config.alert_thresholds;
let timestamp_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time should be after UNIX_EPOCH")
.as_millis() as u64;
let mut new_alerts = Vec::new();
if metrics.system_metrics.cpu_utilization > thresholds.cpu_critical_pct {
new_alerts.push(Alert {
id: format!("cpu_critical_{}_{}", metrics.node_id, timestamp_ms),
severity: AlertSeverity::Critical,
message: format!(
"CPU utilization critically high on node {}",
metrics.node_id
),
node_id: metrics.node_id.clone(),
metric_name: "cpu_utilization".to_string(),
current_value: metrics.system_metrics.cpu_utilization as f64,
threshold_value: thresholds.cpu_critical_pct as f64,
timestamp_ms,
is_active: true,
});
} else if metrics.system_metrics.cpu_utilization > thresholds.cpu_warning_pct {
new_alerts.push(Alert {
id: format!("cpu_warning_{}_{}", metrics.node_id, timestamp_ms),
severity: AlertSeverity::Warning,
message: format!("CPU utilization high on node {}", metrics.node_id),
node_id: metrics.node_id.clone(),
metric_name: "cpu_utilization".to_string(),
current_value: metrics.system_metrics.cpu_utilization as f64,
threshold_value: thresholds.cpu_warning_pct as f64,
timestamp_ms,
is_active: true,
});
}
if metrics.system_metrics.gpu_utilization > thresholds.gpu_critical_pct {
new_alerts.push(Alert {
id: format!("gpu_critical_{}_{}", metrics.node_id, timestamp_ms),
severity: AlertSeverity::Critical,
message: format!(
"GPU utilization critically high on node {}",
metrics.node_id
),
node_id: metrics.node_id.clone(),
metric_name: "gpu_utilization".to_string(),
current_value: metrics.system_metrics.gpu_utilization as f64,
threshold_value: thresholds.gpu_critical_pct as f64,
timestamp_ms,
is_active: true,
});
}
if metrics.communication_metrics.avg_latency_us > thresholds.latency_critical_us {
new_alerts.push(Alert {
id: format!("latency_critical_{}_{}", metrics.node_id, timestamp_ms),
severity: AlertSeverity::Critical,
message: format!(
"Communication latency critically high on node {}",
metrics.node_id
),
node_id: metrics.node_id.clone(),
metric_name: "avg_latency_us".to_string(),
current_value: metrics.communication_metrics.avg_latency_us as f64,
threshold_value: thresholds.latency_critical_us as f64,
timestamp_ms,
is_active: true,
});
}
if !new_alerts.is_empty() {
let mut active_alerts = self.active_alerts.lock().map_err(|e| {
TorshDistributedError::communication_error(
"alerts_update",
format!("Lock error: {}", e),
)
})?;
for alert in &new_alerts {
warn!("Generated alert: {} - {}", alert.severity, alert.message);
active_alerts.push(alert.clone());
}
let mut alert_history = self.alert_history.lock().map_err(|e| {
TorshDistributedError::communication_error(
"alert_history",
format!("Lock error: {}", e),
)
})?;
for alert in new_alerts {
alert_history.push_back(alert);
if alert_history.len() > self.config.max_alerts {
alert_history.pop_front();
}
}
}
Ok(())
}
fn update_anomaly_detection(&self, metrics: &NodeMetrics) -> TorshResult<()> {
if !self.config.enable_anomaly_detection {
return Ok(());
}
let mut detector = self.anomaly_detector.lock().map_err(|e| {
TorshDistributedError::communication_error(
"anomaly_detector",
format!("Lock error: {}", e),
)
})?;
detector.update_metric(
"cpu_utilization",
metrics.system_metrics.cpu_utilization as f64,
);
detector.update_metric(
"gpu_utilization",
metrics.system_metrics.gpu_utilization as f64,
);
detector.update_metric(
"throughput",
metrics.training_metrics.throughput_samples_per_sec as f64,
);
detector.update_metric(
"comm_latency",
metrics.communication_metrics.avg_latency_us as f64,
);
detector.update_metric(
"comm_efficiency",
metrics.communication_metrics.efficiency_score as f64,
);
let metrics_to_check = [
(
"cpu_utilization",
metrics.system_metrics.cpu_utilization as f64,
),
(
"gpu_utilization",
metrics.system_metrics.gpu_utilization as f64,
),
(
"throughput",
metrics.training_metrics.throughput_samples_per_sec as f64,
),
(
"comm_latency",
metrics.communication_metrics.avg_latency_us as f64,
),
(
"comm_efficiency",
metrics.communication_metrics.efficiency_score as f64,
),
];
for (metric_name, value) in &metrics_to_check {
if detector.is_anomaly(metric_name, *value) {
warn!(
"Anomaly detected: {} = {:.2} on node {}",
metric_name, value, metrics.node_id
);
let timestamp_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time should be after UNIX_EPOCH")
.as_millis() as u64;
let alert = Alert {
id: format!("anomaly_{}_{}", metrics.node_id, timestamp_ms),
severity: AlertSeverity::Warning,
message: format!(
"Anomaly detected in {} on node {}",
metric_name, metrics.node_id
),
node_id: metrics.node_id.clone(),
metric_name: metric_name.to_string(),
current_value: *value,
threshold_value: 0.0, timestamp_ms,
is_active: true,
};
let mut active_alerts = self.active_alerts.lock().map_err(|e| {
TorshDistributedError::communication_error(
"anomaly_alerts",
format!("Lock error: {}", e),
)
})?;
active_alerts.push(alert);
}
}
Ok(())
}
pub fn get_current_metrics(&self) -> TorshResult<Option<NodeMetrics>> {
let current = self.current_metrics.read().map_err(|e| {
TorshDistributedError::communication_error(
"get_current_metrics",
format!("Lock error: {}", e),
)
})?;
Ok(current.clone())
}
pub fn get_metrics_history(&self) -> TorshResult<Vec<NodeMetrics>> {
let history = self.metrics_history.lock().map_err(|e| {
TorshDistributedError::communication_error(
"get_metrics_history",
format!("Lock error: {}", e),
)
})?;
Ok(history.iter().cloned().collect())
}
pub fn get_active_alerts(&self) -> TorshResult<Vec<Alert>> {
let alerts = self.active_alerts.lock().map_err(|e| {
TorshDistributedError::communication_error(
"get_active_alerts",
format!("Lock error: {}", e),
)
})?;
Ok(alerts.clone())
}
pub fn get_cluster_summary(&self) -> TorshResult<ClusterSummary> {
if !self.is_coordinator {
return Err(TorshDistributedError::communication_error(
"cluster_summary",
"Only coordinator nodes can access cluster summary".to_string(),
));
}
let all_nodes = self.all_nodes_metrics.read().map_err(|e| {
TorshDistributedError::communication_error(
"cluster_summary",
format!("Lock error: {}", e),
)
})?;
let total_nodes = all_nodes.len();
let healthy_nodes = all_nodes
.values()
.filter(|n| matches!(n.health_status, NodeHealthStatus::Healthy))
.count();
let degraded_nodes = all_nodes
.values()
.filter(|n| matches!(n.health_status, NodeHealthStatus::Degraded { .. }))
.count();
let critical_nodes = all_nodes
.values()
.filter(|n| matches!(n.health_status, NodeHealthStatus::Critical { .. }))
.count();
let failed_nodes = all_nodes
.values()
.filter(|n| matches!(n.health_status, NodeHealthStatus::Failed { .. }))
.count();
let total_cpu_util: f32 = all_nodes
.values()
.map(|n| n.system_metrics.cpu_utilization)
.sum();
let avg_cpu_util = if total_nodes > 0 {
total_cpu_util / total_nodes as f32
} else {
0.0
};
let total_gpu_util: f32 = all_nodes
.values()
.map(|n| n.system_metrics.gpu_utilization)
.sum();
let avg_gpu_util = if total_nodes > 0 {
total_gpu_util / total_nodes as f32
} else {
0.0
};
let total_throughput: f32 = all_nodes
.values()
.map(|n| n.training_metrics.throughput_samples_per_sec)
.sum();
let avg_comm_latency: u64 = if total_nodes > 0 {
all_nodes
.values()
.map(|n| n.communication_metrics.avg_latency_us)
.sum::<u64>()
/ total_nodes as u64
} else {
0
};
Ok(ClusterSummary {
total_nodes,
healthy_nodes,
degraded_nodes,
critical_nodes,
failed_nodes,
avg_cpu_utilization: avg_cpu_util,
avg_gpu_utilization: avg_gpu_util,
total_throughput,
avg_communication_latency_us: avg_comm_latency,
timestamp_ms: SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time should be after UNIX_EPOCH")
.as_millis() as u64,
})
}
pub fn clear_resolved_alerts(&self) -> TorshResult<usize> {
let mut active_alerts = self.active_alerts.lock().map_err(|e| {
TorshDistributedError::communication_error("clear_alerts", format!("Lock error: {}", e))
})?;
let initial_count = active_alerts.len();
active_alerts.retain(|alert| alert.is_active);
let cleared_count = initial_count - active_alerts.len();
info!("Cleared {} resolved alerts", cleared_count);
Ok(cleared_count)
}
pub fn export_monitoring_data(&self) -> TorshResult<MonitoringExport> {
let current_metrics = self.get_current_metrics()?;
let metrics_history = self.get_metrics_history()?;
let active_alerts = self.get_active_alerts()?;
let cluster_summary = if self.is_coordinator {
Some(self.get_cluster_summary()?)
} else {
None
};
Ok(MonitoringExport {
current_metrics,
metrics_history,
active_alerts,
cluster_summary,
export_timestamp_ms: SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time should be after UNIX_EPOCH")
.as_millis() as u64,
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClusterSummary {
pub total_nodes: usize,
pub healthy_nodes: usize,
pub degraded_nodes: usize,
pub critical_nodes: usize,
pub failed_nodes: usize,
pub avg_cpu_utilization: f32,
pub avg_gpu_utilization: f32,
pub total_throughput: f32,
pub avg_communication_latency_us: u64,
pub timestamp_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MonitoringExport {
pub current_metrics: Option<NodeMetrics>,
pub metrics_history: Vec<NodeMetrics>,
pub active_alerts: Vec<Alert>,
pub cluster_summary: Option<ClusterSummary>,
pub export_timestamp_ms: u64,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_distributed_monitor_creation() -> TorshResult<()> {
let config = MonitoringConfig::default();
let monitor = DistributedMonitor::new(config, false);
let current_metrics = monitor.get_current_metrics()?;
assert!(current_metrics.is_none());
Ok(())
}
#[tokio::test]
async fn test_system_metrics_collection() -> TorshResult<()> {
let config = MonitoringConfig::default();
let monitor = DistributedMonitor::new(config, false);
let metrics = monitor.collect_system_metrics()?;
assert!(metrics.cpu_utilization >= 0.0 && metrics.cpu_utilization <= 100.0);
assert!(metrics.gpu_utilization >= 0.0 && metrics.gpu_utilization <= 100.0);
assert!(metrics.memory_usage_mb > 0);
Ok(())
}
#[tokio::test]
async fn test_node_metrics_update() -> TorshResult<()> {
let config = MonitoringConfig::default();
let monitor = DistributedMonitor::new(config, false);
monitor.update_node_metrics(NodeMetricsUpdate {
node_id: "test_node".to_string(),
rank: 0,
world_size: 4,
training_loss: 0.5,
learning_rate: 0.001,
epoch: 10,
batch: 100,
})?;
let current_metrics = monitor.get_current_metrics()?;
assert!(current_metrics.is_some());
let metrics = current_metrics.unwrap();
assert_eq!(metrics.node_id, "test_node");
assert_eq!(metrics.rank, 0);
assert_eq!(metrics.world_size, 4);
Ok(())
}
#[tokio::test]
async fn test_alert_generation() -> TorshResult<()> {
let mut config = MonitoringConfig::default();
config.alert_thresholds.cpu_warning_pct = 50.0;
let monitor = DistributedMonitor::new(config, false);
monitor.update_node_metrics(NodeMetricsUpdate {
node_id: "test_node".to_string(),
rank: 0,
world_size: 1,
training_loss: 0.5,
learning_rate: 0.001,
epoch: 1,
batch: 1,
})?;
let alerts = monitor.get_active_alerts()?;
assert!(alerts.is_empty() || !alerts.is_empty());
Ok(())
}
#[tokio::test]
async fn test_anomaly_detection() -> TorshResult<()> {
let mut detector = AnomalyDetector::new(0.7);
for i in 0..50 {
detector.update_metric("test_metric", 50.0 + (i as f64 % 10.0));
}
assert!(!detector.is_anomaly("test_metric", 55.0));
assert!(detector.is_anomaly("test_metric", 200.0));
Ok(())
}
#[tokio::test]
async fn test_monitoring_export() -> TorshResult<()> {
let config = MonitoringConfig::default();
let monitor = DistributedMonitor::new(config, false);
monitor.update_node_metrics(NodeMetricsUpdate {
node_id: "test_node".to_string(),
rank: 0,
world_size: 1,
training_loss: 0.5,
learning_rate: 0.001,
epoch: 1,
batch: 1,
})?;
let export = monitor.export_monitoring_data()?;
assert!(export.current_metrics.is_some());
assert!(!export.metrics_history.is_empty());
Ok(())
}
}