use anyhow::Result;
use chrono::{DateTime, Duration, Utc};
use std::collections::VecDeque;
use super::types::{
ConvergenceStatus, LayerActivationStats, ModelDiagnosticAlert, ModelPerformanceMetrics,
TrainingDynamics, TrainingStability,
};
#[derive(Debug)]
pub struct AlertManager {
active_alerts: Vec<ActiveAlert>,
alert_history: VecDeque<HistoricalAlert>,
config: AlertConfig,
thresholds: AlertThresholds,
performance_baseline: Option<PerformanceBaseline>,
}
#[derive(Debug, Clone)]
pub struct AlertConfig {
pub max_history_size: usize,
pub duplicate_alert_cooldown: Duration,
pub monitored_severities: Vec<AlertSeverity>,
pub auto_resolve_alerts: bool,
pub notification_settings: NotificationSettings,
}
#[derive(Debug, Clone)]
pub struct AlertThresholds {
pub performance_degradation_percent: f64,
pub memory_usage_threshold_mb: f64,
pub memory_leak_threshold_mb_per_step: f64,
pub training_instability_variance: f64,
pub dead_neuron_ratio_threshold: f64,
pub saturated_neuron_ratio_threshold: f64,
pub plateau_duration_threshold: usize,
pub learning_rate_adjustment_threshold: f64,
}
#[derive(Debug, Clone)]
pub struct PerformanceBaseline {
pub baseline_loss: f64,
pub baseline_throughput: f64,
pub baseline_memory_mb: f64,
pub baseline_accuracy: Option<f64>,
pub established_at: DateTime<Utc>,
}
#[derive(Debug, Clone)]
pub struct ActiveAlert {
pub alert: ModelDiagnosticAlert,
pub severity: AlertSeverity,
pub triggered_at: DateTime<Utc>,
pub trigger_count: usize,
pub recommended_actions: Vec<String>,
pub status: AlertStatus,
}
#[derive(Debug, Clone)]
pub struct HistoricalAlert {
pub alert: ModelDiagnosticAlert,
pub severity: AlertSeverity,
pub triggered_at: DateTime<Utc>,
pub resolved_at: Option<DateTime<Utc>>,
pub resolution_method: Option<String>,
pub duration: Option<Duration>,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum AlertSeverity {
Info,
Warning,
Critical,
Emergency,
}
#[derive(Debug, Clone, PartialEq)]
pub enum AlertStatus {
Active,
Acknowledged,
InvestigationInProgress,
Resolved,
FalsePositive,
}
#[derive(Debug, Clone)]
pub struct NotificationSettings {
pub console_notifications: bool,
pub file_logging: bool,
pub log_file_path: Option<String>,
pub webhook_notifications: bool,
pub webhook_url: Option<String>,
}
impl Default for AlertConfig {
fn default() -> Self {
Self {
max_history_size: 1000,
duplicate_alert_cooldown: Duration::minutes(5),
monitored_severities: vec![
AlertSeverity::Warning,
AlertSeverity::Critical,
AlertSeverity::Emergency,
],
auto_resolve_alerts: true,
notification_settings: NotificationSettings::default(),
}
}
}
impl Default for NotificationSettings {
fn default() -> Self {
Self {
console_notifications: true,
file_logging: false,
log_file_path: None,
webhook_notifications: false,
webhook_url: None,
}
}
}
impl Default for AlertThresholds {
fn default() -> Self {
Self {
performance_degradation_percent: 10.0,
memory_usage_threshold_mb: 8192.0, memory_leak_threshold_mb_per_step: 1.0,
training_instability_variance: 0.1,
dead_neuron_ratio_threshold: 0.1,
saturated_neuron_ratio_threshold: 0.05,
plateau_duration_threshold: 100,
learning_rate_adjustment_threshold: 0.01,
}
}
}
impl AlertManager {
pub fn new() -> Self {
Self {
active_alerts: Vec::new(),
alert_history: VecDeque::new(),
config: AlertConfig::default(),
thresholds: AlertThresholds::default(),
performance_baseline: None,
}
}
pub fn with_config(config: AlertConfig, thresholds: AlertThresholds) -> Self {
Self {
active_alerts: Vec::new(),
alert_history: VecDeque::new(),
config,
thresholds,
performance_baseline: None,
}
}
pub fn set_performance_baseline(&mut self, baseline: PerformanceBaseline) {
self.performance_baseline = Some(baseline);
}
pub fn establish_baseline_from_metrics(&mut self, metrics: &ModelPerformanceMetrics) {
self.performance_baseline = Some(PerformanceBaseline {
baseline_loss: metrics.loss,
baseline_throughput: metrics.throughput_samples_per_sec,
baseline_memory_mb: metrics.memory_usage_mb,
baseline_accuracy: metrics.accuracy,
established_at: Utc::now(),
});
}
pub fn process_performance_metrics(
&mut self,
metrics: &ModelPerformanceMetrics,
) -> Result<Vec<ModelDiagnosticAlert>> {
let mut new_alerts = Vec::new();
if let Some(baseline) = &self.performance_baseline {
let loss_degradation =
((metrics.loss - baseline.baseline_loss) / baseline.baseline_loss) * 100.0;
if loss_degradation > self.thresholds.performance_degradation_percent {
let alert = ModelDiagnosticAlert::PerformanceDegradation {
metric: "loss".to_string(),
current: metrics.loss,
previous_avg: baseline.baseline_loss,
degradation_percent: loss_degradation,
};
new_alerts.push(alert);
}
let throughput_degradation = ((baseline.baseline_throughput
- metrics.throughput_samples_per_sec)
/ baseline.baseline_throughput)
* 100.0;
if throughput_degradation > self.thresholds.performance_degradation_percent {
let alert = ModelDiagnosticAlert::PerformanceDegradation {
metric: "throughput".to_string(),
current: metrics.throughput_samples_per_sec,
previous_avg: baseline.baseline_throughput,
degradation_percent: throughput_degradation,
};
new_alerts.push(alert);
}
}
if metrics.memory_usage_mb > self.thresholds.memory_usage_threshold_mb {
let alert = ModelDiagnosticAlert::MemoryLeak {
current_usage_mb: metrics.memory_usage_mb,
growth_rate_mb_per_step: 0.0, };
new_alerts.push(alert);
}
for alert in &new_alerts {
self.add_alert(alert.clone(), self.determine_alert_severity(alert))?;
}
Ok(new_alerts)
}
pub fn process_training_dynamics(
&mut self,
dynamics: &TrainingDynamics,
) -> Result<Vec<ModelDiagnosticAlert>> {
let mut new_alerts = Vec::new();
if matches!(
dynamics.training_stability,
TrainingStability::Unstable | TrainingStability::HighVariance
) {
let alert = ModelDiagnosticAlert::TrainingInstability {
variance: 0.0, threshold: self.thresholds.training_instability_variance,
};
new_alerts.push(alert);
}
match dynamics.convergence_status {
ConvergenceStatus::Diverging => {
let alert = ModelDiagnosticAlert::ConvergenceIssue {
issue_type: ConvergenceStatus::Diverging,
duration_steps: 0, };
new_alerts.push(alert);
},
ConvergenceStatus::Plateau => {
if let Some(plateau_info) = &dynamics.plateau_detection {
if plateau_info.duration_steps > self.thresholds.plateau_duration_threshold {
let alert = ModelDiagnosticAlert::ConvergenceIssue {
issue_type: ConvergenceStatus::Plateau,
duration_steps: plateau_info.duration_steps,
};
new_alerts.push(alert);
}
}
},
_ => {},
}
for alert in &new_alerts {
self.add_alert(alert.clone(), self.determine_alert_severity(alert))?;
}
Ok(new_alerts)
}
pub fn process_layer_stats(
&mut self,
stats: &LayerActivationStats,
) -> Result<Vec<ModelDiagnosticAlert>> {
let mut new_alerts = Vec::new();
if stats.dead_neurons_ratio > self.thresholds.dead_neuron_ratio_threshold {
let alert = ModelDiagnosticAlert::ArchitecturalConcern {
concern: format!(
"High dead neuron ratio in layer {}: {:.2}%",
stats.layer_name,
stats.dead_neurons_ratio * 100.0
),
recommendation: "Consider adjusting learning rate or initialization".to_string(),
};
new_alerts.push(alert);
}
if stats.saturated_neurons_ratio > self.thresholds.saturated_neuron_ratio_threshold {
let alert = ModelDiagnosticAlert::ArchitecturalConcern {
concern: format!(
"High saturated neuron ratio in layer {}: {:.2}%",
stats.layer_name,
stats.saturated_neurons_ratio * 100.0
),
recommendation: "Consider adjusting activation function or scaling".to_string(),
};
new_alerts.push(alert);
}
for alert in &new_alerts {
self.add_alert(alert.clone(), self.determine_alert_severity(alert))?;
}
Ok(new_alerts)
}
pub fn add_alert(
&mut self,
alert: ModelDiagnosticAlert,
severity: AlertSeverity,
) -> Result<()> {
if self.is_duplicate_alert(&alert) {
return Ok(());
}
let active_alert = ActiveAlert {
alert: alert.clone(),
severity: severity.clone(),
triggered_at: Utc::now(),
trigger_count: 1,
recommended_actions: self.generate_recommended_actions(&alert),
status: AlertStatus::Active,
};
self.active_alerts.push(active_alert);
self.send_notification(&alert, &severity)?;
Ok(())
}
pub fn resolve_alert(&mut self, alert_index: usize, resolution_method: String) -> Result<()> {
if alert_index >= self.active_alerts.len() {
return Err(anyhow::anyhow!("Invalid alert index"));
}
let mut active_alert = self.active_alerts.remove(alert_index);
active_alert.status = AlertStatus::Resolved;
let historical_alert = HistoricalAlert {
alert: active_alert.alert,
severity: active_alert.severity,
triggered_at: active_alert.triggered_at,
resolved_at: Some(Utc::now()),
resolution_method: Some(resolution_method),
duration: Some(Utc::now() - active_alert.triggered_at),
};
self.add_to_history(historical_alert);
Ok(())
}
pub fn get_active_alerts(&self) -> &[ActiveAlert] {
&self.active_alerts
}
pub fn get_alerts_by_severity(&self, severity: AlertSeverity) -> Vec<&ActiveAlert> {
self.active_alerts.iter().filter(|alert| alert.severity == severity).collect()
}
pub fn get_alert_statistics(&self) -> AlertStatistics {
let mut stats = AlertStatistics::default();
for alert in &self.active_alerts {
match alert.severity {
AlertSeverity::Info => stats.info_count += 1,
AlertSeverity::Warning => stats.warning_count += 1,
AlertSeverity::Critical => stats.critical_count += 1,
AlertSeverity::Emergency => stats.emergency_count += 1,
}
}
stats.total_active = self.active_alerts.len();
stats.total_historical = self.alert_history.len();
stats
}
pub fn clear_resolved_alerts(&mut self) {
let now = Utc::now();
let mut resolved_alerts = Vec::new();
self.active_alerts.retain(|alert| {
if matches!(alert.status, AlertStatus::Resolved) {
resolved_alerts.push(HistoricalAlert {
alert: alert.alert.clone(),
severity: alert.severity.clone(),
triggered_at: alert.triggered_at,
resolved_at: Some(now),
resolution_method: Some("Auto-resolved".to_string()),
duration: Some(now - alert.triggered_at),
});
false
} else {
true
}
});
for historical in resolved_alerts {
self.add_to_history(historical);
}
}
fn determine_alert_severity(&self, alert: &ModelDiagnosticAlert) -> AlertSeverity {
match alert {
ModelDiagnosticAlert::PerformanceDegradation {
degradation_percent,
..
} => {
if *degradation_percent > 50.0 {
AlertSeverity::Critical
} else if *degradation_percent > 25.0 {
AlertSeverity::Warning
} else {
AlertSeverity::Info
}
},
ModelDiagnosticAlert::MemoryLeak {
current_usage_mb, ..
} => {
if *current_usage_mb > 16384.0 {
AlertSeverity::Emergency
} else if *current_usage_mb > 8192.0 {
AlertSeverity::Critical
} else {
AlertSeverity::Warning
}
},
ModelDiagnosticAlert::TrainingInstability { .. } => AlertSeverity::Warning,
ModelDiagnosticAlert::ConvergenceIssue { issue_type, .. } => match issue_type {
ConvergenceStatus::Diverging => AlertSeverity::Critical,
ConvergenceStatus::Plateau => AlertSeverity::Warning,
_ => AlertSeverity::Info,
},
ModelDiagnosticAlert::ArchitecturalConcern { .. } => AlertSeverity::Info,
}
}
fn is_duplicate_alert(&self, alert: &ModelDiagnosticAlert) -> bool {
let now = Utc::now();
let cooldown_threshold = now - self.config.duplicate_alert_cooldown;
self.active_alerts.iter().any(|active| {
active.triggered_at > cooldown_threshold
&& std::mem::discriminant(&active.alert) == std::mem::discriminant(alert)
})
}
fn generate_recommended_actions(&self, alert: &ModelDiagnosticAlert) -> Vec<String> {
match alert {
ModelDiagnosticAlert::PerformanceDegradation { metric, .. } => {
vec![
format!("Investigate {} degradation causes", metric),
"Check for data quality issues".to_string(),
"Review recent configuration changes".to_string(),
"Consider adjusting learning rate".to_string(),
]
},
ModelDiagnosticAlert::MemoryLeak { .. } => {
vec![
"Monitor memory usage patterns".to_string(),
"Check for gradient accumulation issues".to_string(),
"Review batch size configuration".to_string(),
"Consider implementing memory cleanup".to_string(),
]
},
ModelDiagnosticAlert::TrainingInstability { .. } => {
vec![
"Reduce learning rate".to_string(),
"Enable gradient clipping".to_string(),
"Check data preprocessing".to_string(),
"Consider using learning rate scheduling".to_string(),
]
},
ModelDiagnosticAlert::ConvergenceIssue { issue_type, .. } => match issue_type {
ConvergenceStatus::Diverging => vec![
"Immediately reduce learning rate".to_string(),
"Check gradient magnitudes".to_string(),
"Review loss function implementation".to_string(),
],
ConvergenceStatus::Plateau => vec![
"Consider learning rate annealing".to_string(),
"Try different optimization algorithm".to_string(),
"Evaluate model capacity".to_string(),
],
_ => vec!["Monitor training progress".to_string()],
},
ModelDiagnosticAlert::ArchitecturalConcern { recommendation, .. } => {
vec![recommendation.clone()]
},
}
}
fn send_notification(
&self,
alert: &ModelDiagnosticAlert,
severity: &AlertSeverity,
) -> Result<()> {
if self.config.notification_settings.console_notifications {
println!("[{:?}] Alert: {:?}", severity, alert);
}
if self.config.notification_settings.file_logging {
if let Some(log_path) = &self.config.notification_settings.log_file_path {
let _ = log_path; }
}
if self.config.notification_settings.webhook_notifications {
if let Some(webhook_url) = &self.config.notification_settings.webhook_url {
let _ = webhook_url; }
}
Ok(())
}
fn add_to_history(&mut self, historical_alert: HistoricalAlert) {
self.alert_history.push_back(historical_alert);
while self.alert_history.len() > self.config.max_history_size {
self.alert_history.pop_front();
}
}
}
#[derive(Debug, Default)]
pub struct AlertStatistics {
pub info_count: usize,
pub warning_count: usize,
pub critical_count: usize,
pub emergency_count: usize,
pub total_active: usize,
pub total_historical: usize,
}
impl Default for AlertManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_alert_manager_creation() {
let manager = AlertManager::new();
assert_eq!(manager.active_alerts.len(), 0);
assert_eq!(manager.alert_history.len(), 0);
}
#[test]
fn test_add_alert() {
let mut manager = AlertManager::new();
let alert = ModelDiagnosticAlert::PerformanceDegradation {
metric: "loss".to_string(),
current: 1.5,
previous_avg: 1.0,
degradation_percent: 50.0,
};
manager.add_alert(alert, AlertSeverity::Warning).expect("add operation failed");
assert_eq!(manager.active_alerts.len(), 1);
}
#[test]
fn test_alert_severity_determination() {
let manager = AlertManager::new();
let high_degradation = ModelDiagnosticAlert::PerformanceDegradation {
metric: "loss".to_string(),
current: 2.0,
previous_avg: 1.0,
degradation_percent: 60.0,
};
let severity = manager.determine_alert_severity(&high_degradation);
assert_eq!(severity, AlertSeverity::Critical);
}
#[test]
fn test_duplicate_alert_detection() {
let mut manager = AlertManager::new();
let alert = ModelDiagnosticAlert::TrainingInstability {
variance: 0.2,
threshold: 0.1,
};
manager
.add_alert(alert.clone(), AlertSeverity::Warning)
.expect("add operation failed");
assert_eq!(manager.active_alerts.len(), 1);
manager.add_alert(alert, AlertSeverity::Warning).expect("add operation failed");
assert_eq!(manager.active_alerts.len(), 1);
}
}