use super::types::*;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdaptiveThresholds {
pub layer_name: String,
pub vanishing_threshold: f64,
pub exploding_threshold: f64,
pub adaptation_rate: f64,
pub recent_gradients: VecDeque<f64>,
pub last_updated: DateTime<Utc>,
}
impl AdaptiveThresholds {
pub fn new(layer_name: String, initial_vanishing: f64, initial_exploding: f64) -> Self {
Self {
layer_name,
vanishing_threshold: initial_vanishing,
exploding_threshold: initial_exploding,
adaptation_rate: 0.1,
recent_gradients: VecDeque::with_capacity(100),
last_updated: Utc::now(),
}
}
pub fn update_thresholds(&mut self, gradient_norm: f64) {
if self.recent_gradients.len() >= 100 {
self.recent_gradients.pop_front();
}
self.recent_gradients.push_back(gradient_norm);
if self.recent_gradients.len() >= 10 {
let mean =
self.recent_gradients.iter().sum::<f64>() / self.recent_gradients.len() as f64;
let variance = self.recent_gradients.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
/ self.recent_gradients.len() as f64;
let std_dev = variance.sqrt();
let new_vanishing = (mean - 2.0 * std_dev).max(1e-8);
let new_exploding = mean + 3.0 * std_dev;
self.vanishing_threshold = self.vanishing_threshold * (1.0 - self.adaptation_rate)
+ new_vanishing * self.adaptation_rate;
self.exploding_threshold = self.exploding_threshold * (1.0 - self.adaptation_rate)
+ new_exploding * self.adaptation_rate;
self.last_updated = Utc::now();
}
}
pub fn check_thresholds(&self, gradient_norm: f64) -> Vec<GradientAlert> {
let mut alerts = Vec::new();
if gradient_norm < self.vanishing_threshold {
alerts.push(GradientAlert::VanishingGradients {
layer_name: self.layer_name.clone(),
norm: gradient_norm,
threshold: self.vanishing_threshold,
});
}
if gradient_norm > self.exploding_threshold {
alerts.push(GradientAlert::ExplodingGradients {
layer_name: self.layer_name.clone(),
norm: gradient_norm,
threshold: self.exploding_threshold,
});
}
alerts
}
pub fn from_history(history: &GradientHistory) -> Self {
let layer_name = history.layer_name.clone();
if history.gradient_norms.is_empty() {
return Self::new(layer_name, 1e-6, 10.0);
}
let norms: Vec<f64> = history.gradient_norms.iter().cloned().collect();
let mean = norms.iter().sum::<f64>() / norms.len() as f64;
let variance = norms.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / norms.len() as f64;
let std_dev = variance.sqrt();
let initial_vanishing = (mean - 2.0 * std_dev).max(1e-8);
let initial_exploding = mean + 3.0 * std_dev;
let mut thresholds = Self::new(layer_name, initial_vanishing, initial_exploding);
for &norm in norms.iter().rev().take(50) {
if thresholds.recent_gradients.len() >= 100 {
thresholds.recent_gradients.pop_front();
}
thresholds.recent_gradients.push_back(norm);
}
thresholds
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RealTimeGradientMonitor {
pub layer_name: String,
pub current_gradient_norm: f64,
pub gradient_velocity: f64,
pub gradient_acceleration: f64,
pub stability_window: VecDeque<f64>,
pub anomaly_score: f64,
}
impl RealTimeGradientMonitor {
pub fn new(layer_name: String) -> Self {
Self {
layer_name,
current_gradient_norm: 0.0,
gradient_velocity: 0.0,
gradient_acceleration: 0.0,
stability_window: VecDeque::with_capacity(10),
anomaly_score: 0.0,
}
}
pub fn update(&mut self, new_gradient_norm: f64) {
let previous_norm = self.current_gradient_norm;
let previous_velocity = self.gradient_velocity;
self.current_gradient_norm = new_gradient_norm;
self.gradient_velocity = new_gradient_norm - previous_norm;
self.gradient_acceleration = self.gradient_velocity - previous_velocity;
if self.stability_window.len() >= 10 {
self.stability_window.pop_front();
}
self.stability_window.push_back(new_gradient_norm);
self.anomaly_score = self.compute_anomaly_score();
}
fn compute_anomaly_score(&self) -> f64 {
if self.stability_window.len() < 5 {
return 0.0;
}
let mean = self.stability_window.iter().sum::<f64>() / self.stability_window.len() as f64;
let variance = self.stability_window.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
/ self.stability_window.len() as f64;
let std_dev = variance.sqrt();
if std_dev == 0.0 {
return 0.0;
}
let z_score = (self.current_gradient_norm - mean) / std_dev;
z_score.abs().min(5.0) / 5.0 }
pub fn get_stability_score(&self) -> f64 {
if self.stability_window.len() < 3 {
return 1.0;
}
let variance = self
.stability_window
.iter()
.map(|&x| (x - self.current_gradient_norm).powi(2))
.sum::<f64>()
/ self.stability_window.len() as f64;
1.0 / (1.0 + variance)
}
pub fn is_stable(&self, threshold: f64) -> bool {
self.get_stability_score() > threshold
}
pub fn is_oscillating(&self) -> bool {
if self.stability_window.len() < 6 {
return false;
}
let mut sign_changes = 0;
let values: Vec<f64> = self.stability_window.iter().cloned().collect();
for i in 1..values.len() - 1 {
let prev_diff = values[i] - values[i - 1];
let curr_diff = values[i + 1] - values[i];
if prev_diff * curr_diff < 0.0 {
sign_changes += 1;
}
}
sign_changes > values.len() / 2
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MonitoringConfig {
pub enable_adaptive_thresholds: bool,
pub enable_real_time_monitoring: bool,
pub stability_threshold: f64,
pub anomaly_threshold: f64,
pub update_frequency: usize,
pub history_window_size: usize,
}
impl Default for MonitoringConfig {
fn default() -> Self {
Self {
enable_adaptive_thresholds: true,
enable_real_time_monitoring: true,
stability_threshold: 0.8,
anomaly_threshold: 0.7,
update_frequency: 1,
history_window_size: 100,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MonitoringResults {
pub layer_name: String,
pub timestamp: DateTime<Utc>,
pub current_status: LayerHealth,
pub stability_score: f64,
pub anomaly_score: f64,
pub alerts: Vec<GradientAlert>,
pub recommendations: Vec<String>,
}
impl MonitoringResults {
pub fn new(layer_name: String) -> Self {
Self {
layer_name,
timestamp: Utc::now(),
current_status: LayerHealth::Healthy,
stability_score: 1.0,
anomaly_score: 0.0,
alerts: Vec::new(),
recommendations: Vec::new(),
}
}
pub fn add_alert(&mut self, alert: GradientAlert) {
self.alerts.push(alert);
self.update_status();
}
pub fn add_recommendation(&mut self, recommendation: String) {
self.recommendations.push(recommendation);
}
fn update_status(&mut self) {
if self.alerts.iter().any(|alert| {
matches!(
alert,
GradientAlert::ExplodingGradients { .. } | GradientAlert::NoGradientFlow { .. }
)
}) {
self.current_status = LayerHealth::Critical;
} else if !self.alerts.is_empty() || self.anomaly_score > 0.7 {
self.current_status = LayerHealth::Warning;
} else {
self.current_status = LayerHealth::Healthy;
}
}
}