use crate::error::{OptimError, Result};
use scirs2_core::ndarray::{Array1, ScalarOperand};
use scirs2_core::numeric::Float;
use std::collections::HashMap;
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct GradientFlowConfig {
pub vanishing_threshold: f64,
pub exploding_threshold: f64,
pub histogram_bins: usize,
pub max_history: usize,
}
impl Default for GradientFlowConfig {
fn default() -> Self {
Self {
vanishing_threshold: 1e-7,
exploding_threshold: 1e3,
histogram_bins: 50,
max_history: 100,
}
}
}
#[derive(Debug, Clone)]
pub struct LayerGradientStats<A> {
pub layer_name: String,
pub mean_norm: A,
pub max_norm: A,
pub min_norm: A,
pub variance: A,
pub sparsity: A,
pub histogram: Vec<usize>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum GradientHealth {
Healthy,
Warning,
Critical,
}
#[derive(Debug, Clone)]
pub struct GradientHealthReport {
pub vanishing_layers: Vec<String>,
pub exploding_layers: Vec<String>,
pub healthy_layers: Vec<String>,
pub overall_health: GradientHealth,
pub recommendations: Vec<String>,
}
pub struct GradientFlowAnalyzer<A> {
config: GradientFlowConfig,
layer_stats: HashMap<String, Vec<LayerGradientStats<A>>>,
layer_order: Vec<String>,
}
impl<A> GradientFlowAnalyzer<A>
where
A: Float + ScalarOperand + Debug + std::iter::Sum,
{
pub fn new(config: GradientFlowConfig) -> Self {
Self {
config,
layer_stats: HashMap::new(),
layer_order: Vec::new(),
}
}
pub fn record_gradients(
&mut self,
layer_name: &str,
gradients: &Array1<A>,
) -> Result<LayerGradientStats<A>> {
let len = gradients.len();
if len == 0 {
return Err(OptimError::InvalidParameter(
"Gradients array must not be empty".to_string(),
));
}
let len_a = A::from(len).ok_or_else(|| {
OptimError::ComputationError("Failed to convert length to float".to_string())
})?;
let abs_grads: Vec<A> = gradients.iter().map(|&g| g.abs()).collect();
let sum: A = abs_grads.iter().copied().sum();
let mean_norm = sum / len_a;
let max_norm = abs_grads
.iter()
.copied()
.fold(A::neg_infinity(), |a, b| if b > a { b } else { a });
let min_norm = abs_grads
.iter()
.copied()
.fold(A::infinity(), |a, b| if b < a { b } else { a });
let sum_sq: A = abs_grads.iter().map(|&g| g * g).sum();
let mean_sq = sum_sq / len_a;
let variance = mean_sq - mean_norm * mean_norm;
let variance = if variance < A::zero() {
A::zero()
} else {
variance
};
let vanishing_thresh = A::from(self.config.vanishing_threshold).ok_or_else(|| {
OptimError::ComputationError(
"Failed to convert vanishing threshold to float".to_string(),
)
})?;
let near_zero_count = abs_grads.iter().filter(|&&g| g < vanishing_thresh).count();
let sparsity = A::from(near_zero_count).ok_or_else(|| {
OptimError::ComputationError("Failed to convert count to float".to_string())
})? / len_a;
let histogram = self.compute_histogram(&abs_grads, max_norm)?;
let stats = LayerGradientStats {
layer_name: layer_name.to_string(),
mean_norm,
max_norm,
min_norm,
variance,
sparsity,
histogram,
};
if !self.layer_order.contains(&layer_name.to_string()) {
self.layer_order.push(layer_name.to_string());
}
let history = self.layer_stats.entry(layer_name.to_string()).or_default();
history.push(stats.clone());
if history.len() > self.config.max_history {
history.remove(0);
}
Ok(stats)
}
fn compute_histogram(&self, abs_grads: &[A], max_val: A) -> Result<Vec<usize>> {
let bins = self.config.histogram_bins;
let mut histogram = vec![0usize; bins];
if max_val <= A::zero() {
histogram[0] = abs_grads.len();
return Ok(histogram);
}
for &val in abs_grads {
let normalized = val / max_val;
let bin_idx = (normalized
* A::from(bins).ok_or_else(|| {
OptimError::ComputationError("Failed to convert bins to float".to_string())
})?)
.to_f64()
.ok_or_else(|| OptimError::ComputationError("Failed to convert to f64".to_string()))?;
let bin_idx = (bin_idx as usize).min(bins - 1);
histogram[bin_idx] += 1;
}
Ok(histogram)
}
pub fn detect_vanishing_gradients(&self) -> Vec<String> {
let threshold = self.config.vanishing_threshold;
let mut vanishing = Vec::new();
for (name, stats_history) in &self.layer_stats {
if let Some(latest) = stats_history.last() {
let mean_f64 = latest.mean_norm.to_f64().unwrap_or(0.0);
if mean_f64 < threshold {
vanishing.push(name.clone());
}
}
}
vanishing.sort();
vanishing
}
pub fn detect_exploding_gradients(&self) -> Vec<String> {
let threshold = self.config.exploding_threshold;
let mut exploding = Vec::new();
for (name, stats_history) in &self.layer_stats {
if let Some(latest) = stats_history.last() {
let max_f64 = latest.max_norm.to_f64().unwrap_or(0.0);
if max_f64 > threshold {
exploding.push(name.clone());
}
}
}
exploding.sort();
exploding
}
pub fn get_health_report(&self) -> GradientHealthReport {
let vanishing = self.detect_vanishing_gradients();
let exploding = self.detect_exploding_gradients();
let mut healthy = Vec::new();
for name in &self.layer_order {
if !vanishing.contains(name) && !exploding.contains(name) {
healthy.push(name.clone());
}
}
let overall_health = if !exploding.is_empty() {
GradientHealth::Critical
} else if !vanishing.is_empty() {
if vanishing.len() > self.layer_order.len() / 2 {
GradientHealth::Critical
} else {
GradientHealth::Warning
}
} else {
GradientHealth::Healthy
};
let mut recommendations = Vec::new();
if !vanishing.is_empty() {
recommendations.push(format!(
"Vanishing gradients detected in {} layer(s): consider using residual connections, \
batch normalization, or switching to ReLU-family activations.",
vanishing.len()
));
recommendations
.push("Consider using gradient scaling or a smaller model depth.".to_string());
}
if !exploding.is_empty() {
recommendations.push(format!(
"Exploding gradients detected in {} layer(s): apply gradient clipping \
(e.g., max norm clipping) or reduce learning rate.",
exploding.len()
));
recommendations.push(
"Consider weight initialization with smaller variance (e.g., He or Xavier init)."
.to_string(),
);
}
if vanishing.is_empty() && exploding.is_empty() {
recommendations.push("Gradient flow appears healthy across all layers.".to_string());
}
GradientHealthReport {
vanishing_layers: vanishing,
exploding_layers: exploding,
healthy_layers: healthy,
overall_health,
recommendations,
}
}
pub fn render_flow_chart(&self) -> Result<String> {
if self.layer_order.is_empty() {
return Err(OptimError::InvalidState(
"No gradient data recorded yet".to_string(),
));
}
let vanishing = self.detect_vanishing_gradients();
let exploding = self.detect_exploding_gradients();
let bar_width = 40;
let bar_spacing = 10;
let margin_left = 150;
let margin_top = 40;
let chart_width = 400;
let num_layers = self.layer_order.len();
let total_height = margin_top + num_layers * (bar_width + bar_spacing) + 40;
let total_width = margin_left + chart_width + 60;
let mut svg = format!(
r#"<svg xmlns="http://www.w3.org/2000/svg" width="{}" height="{}" viewBox="0 0 {} {}">"#,
total_width, total_height, total_width, total_height
);
svg.push('\n');
svg.push_str(&format!(
r#" <text x="{}" y="25" text-anchor="middle" font-size="16" font-weight="bold">Gradient Flow Analysis</text>"#,
total_width / 2
));
svg.push('\n');
let mut max_mean = 0.0f64;
for name in &self.layer_order {
if let Some(history) = self.layer_stats.get(name) {
if let Some(latest) = history.last() {
let val = latest.mean_norm.to_f64().unwrap_or(0.0);
if val > max_mean {
max_mean = val;
}
}
}
}
if max_mean <= 0.0 {
max_mean = 1.0;
}
for (i, name) in self.layer_order.iter().enumerate() {
let y = margin_top + i * (bar_width + bar_spacing);
let mean_val = self
.layer_stats
.get(name)
.and_then(|h| h.last())
.map(|s| s.mean_norm.to_f64().unwrap_or(0.0))
.unwrap_or(0.0);
let bar_len = ((mean_val / max_mean) * chart_width as f64).max(1.0) as usize;
let color = if exploding.contains(name) {
"#ff4444" } else if vanishing.contains(name) {
"#ffaa00" } else {
"#44bb44" };
svg.push_str(&format!(
r#" <text x="{}" y="{}" text-anchor="end" font-size="12" dominant-baseline="middle">{}</text>"#,
margin_left - 10,
y + bar_width / 2,
name
));
svg.push('\n');
svg.push_str(&format!(
r#" <rect x="{}" y="{}" width="{}" height="{}" fill="{}" rx="3" ry="3"/>"#,
margin_left, y, bar_len, bar_width, color
));
svg.push('\n');
svg.push_str(&format!(
r#" <text x="{}" y="{}" font-size="10" dominant-baseline="middle">{:.2e}</text>"#,
margin_left + bar_len + 5,
y + bar_width / 2,
mean_val
));
svg.push('\n');
}
svg.push_str("</svg>");
Ok(svg)
}
pub fn get_layer_history(&self, layer_name: &str) -> Option<&Vec<LayerGradientStats<A>>> {
self.layer_stats.get(layer_name)
}
pub fn clear_history(&mut self) {
self.layer_stats.clear();
self.layer_order.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array1;
#[test]
fn test_record_gradients_basic() {
let config = GradientFlowConfig::default();
let mut analyzer = GradientFlowAnalyzer::<f64>::new(config);
let gradients = Array1::from_vec(vec![0.1, -0.2, 0.3, -0.4, 0.5]);
let stats = analyzer
.record_gradients("layer1", &gradients)
.expect("Should record gradients");
assert_eq!(stats.layer_name, "layer1");
assert!((stats.mean_norm - 0.3).abs() < 1e-10);
assert!((stats.max_norm - 0.5).abs() < 1e-10);
assert!((stats.min_norm - 0.1).abs() < 1e-10);
assert!((stats.sparsity - 0.0).abs() < 1e-10);
let hist_sum: usize = stats.histogram.iter().sum();
assert_eq!(hist_sum, 5);
let history = analyzer.get_layer_history("layer1");
assert!(history.is_some());
assert_eq!(history.map(|h| h.len()).unwrap_or(0), 1);
}
#[test]
fn test_detect_vanishing_gradients() {
let config = GradientFlowConfig {
vanishing_threshold: 1e-7,
..Default::default()
};
let mut analyzer = GradientFlowAnalyzer::<f64>::new(config);
let normal_grads = Array1::from_vec(vec![0.01, 0.02, 0.015, 0.008]);
analyzer
.record_gradients("healthy_layer", &normal_grads)
.expect("Should record");
let tiny_grads = Array1::from_vec(vec![1e-9, 1e-10, 1e-8, 1e-11]);
analyzer
.record_gradients("vanishing_layer", &tiny_grads)
.expect("Should record");
let vanishing = analyzer.detect_vanishing_gradients();
assert!(vanishing.contains(&"vanishing_layer".to_string()));
assert!(!vanishing.contains(&"healthy_layer".to_string()));
}
#[test]
fn test_detect_exploding_gradients() {
let config = GradientFlowConfig {
exploding_threshold: 1e3,
..Default::default()
};
let mut analyzer = GradientFlowAnalyzer::<f64>::new(config);
let normal_grads = Array1::from_vec(vec![0.5, 1.0, 0.3, 0.8]);
analyzer
.record_gradients("normal_layer", &normal_grads)
.expect("Should record");
let huge_grads = Array1::from_vec(vec![5000.0, 10000.0, 3000.0, 8000.0]);
analyzer
.record_gradients("exploding_layer", &huge_grads)
.expect("Should record");
let exploding = analyzer.detect_exploding_gradients();
assert!(exploding.contains(&"exploding_layer".to_string()));
assert!(!exploding.contains(&"normal_layer".to_string()));
}
#[test]
fn test_health_report_generation() {
let config = GradientFlowConfig::default();
let mut analyzer = GradientFlowAnalyzer::<f64>::new(config);
let healthy = Array1::from_vec(vec![0.01, 0.02, 0.015]);
analyzer
.record_gradients("fc1", &healthy)
.expect("Should record");
let vanishing = Array1::from_vec(vec![1e-10, 1e-11, 1e-9]);
analyzer
.record_gradients("fc2", &vanishing)
.expect("Should record");
let exploding = Array1::from_vec(vec![5000.0, 10000.0, 8000.0]);
analyzer
.record_gradients("fc3", &exploding)
.expect("Should record");
let report = analyzer.get_health_report();
assert!(report.vanishing_layers.contains(&"fc2".to_string()));
assert!(report.exploding_layers.contains(&"fc3".to_string()));
assert!(report.healthy_layers.contains(&"fc1".to_string()));
assert_eq!(report.overall_health, GradientHealth::Critical);
assert!(!report.recommendations.is_empty());
}
#[test]
fn test_render_flow_chart_svg() {
let config = GradientFlowConfig::default();
let mut analyzer = GradientFlowAnalyzer::<f64>::new(config);
let grads1 = Array1::from_vec(vec![0.01, 0.02, 0.015]);
let grads2 = Array1::from_vec(vec![0.005, 0.003, 0.004]);
let grads3 = Array1::from_vec(vec![0.1, 0.08, 0.12]);
analyzer
.record_gradients("conv1", &grads1)
.expect("Should record");
analyzer
.record_gradients("conv2", &grads2)
.expect("Should record");
analyzer
.record_gradients("fc1", &grads3)
.expect("Should record");
let svg = analyzer
.render_flow_chart()
.expect("Should render flow chart");
assert!(svg.starts_with("<svg"));
assert!(svg.ends_with("</svg>"));
assert!(svg.contains("conv1"));
assert!(svg.contains("conv2"));
assert!(svg.contains("fc1"));
assert!(svg.contains("Gradient Flow Analysis"));
}
}