use crate::autograd::Variable;
use crate::tensor::Tensor;
use std::collections::{HashMap, HashSet};
use std::fmt::Write;
use std::fs::File;
use std::io::Write as IoWrite;
use std::path::Path;
pub struct GradientFlowVisualizer {
nodes: Vec<NodeInfo>,
edges: Vec<EdgeInfo>,
node_counter: usize,
visited: HashSet<usize>,
}
#[derive(Debug, Clone)]
pub struct NodeInfo {
id: usize,
label: String,
node_type: NodeType,
shape: Vec<usize>,
gradient_norm: Option<f32>,
_requires_grad: bool,
}
#[derive(Debug, Clone)]
struct EdgeInfo {
from: usize,
to: usize,
label: String,
gradient_magnitude: Option<f32>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum NodeType {
Input,
Parameter,
Operation(String),
Output,
}
impl Default for GradientFlowVisualizer {
fn default() -> Self {
Self::new()
}
}
impl GradientFlowVisualizer {
pub fn new() -> Self {
Self {
nodes: Vec::new(),
edges: Vec::new(),
node_counter: 0,
visited: HashSet::new(),
}
}
pub fn trace_from_variable<T>(&mut self, var: &Variable<T>, label: &str) -> usize
where
T: num_traits::Float
+ std::fmt::Debug
+ Send
+ Sync
+ 'static
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive,
{
let node_id = self.node_counter;
self.node_counter += 1;
if self.visited.contains(&node_id) {
return node_id;
}
self.visited.insert(node_id);
let shape = var.data.read().unwrap().shape().to_vec();
let gradient_norm = if let Ok(grad_lock) = var.grad.read() {
grad_lock.as_ref().map(|g| {
let sum: f32 = g
.data
.iter()
.map(|&x| x.to_f32().unwrap_or(0.0).powi(2))
.sum();
sum.sqrt()
})
} else {
None
};
let node_type = if label.contains("loss") || label.contains("output") {
NodeType::Output
} else if var.requires_grad {
NodeType::Parameter
} else {
NodeType::Input
};
self.nodes.push(NodeInfo {
id: node_id,
label: label.to_string(),
node_type,
shape,
gradient_norm,
_requires_grad: var.requires_grad,
});
node_id
}
pub fn add_operation(&mut self, op_name: &str, inputs: Vec<usize>, output: usize) {
let op_id = self.node_counter;
self.node_counter += 1;
self.nodes.push(NodeInfo {
id: op_id,
label: op_name.to_string(),
node_type: NodeType::Operation(op_name.to_string()),
shape: Vec::new(),
gradient_norm: None,
_requires_grad: true,
});
for input_id in inputs {
self.edges.push(EdgeInfo {
from: input_id,
to: op_id,
label: "forward".to_string(),
gradient_magnitude: None,
});
}
self.edges.push(EdgeInfo {
from: op_id,
to: output,
label: "result".to_string(),
gradient_magnitude: None,
});
}
pub fn to_dot(&self) -> String {
let mut dot = String::new();
writeln!(&mut dot, "digraph GradientFlow {{").unwrap();
writeln!(&mut dot, " rankdir=TB;").unwrap();
writeln!(&mut dot, " node [shape=box, style=\"rounded,filled\"];").unwrap();
writeln!(&mut dot, " edge [fontsize=10];").unwrap();
writeln!(&mut dot).unwrap();
for node in &self.nodes {
let color = match node.node_type {
NodeType::Input => "#e8f4f8",
NodeType::Parameter => "#fff4e6",
NodeType::Operation(_) => "#f0f8ff",
NodeType::Output => "#ffe6e6",
};
let label = if let Some(grad_norm) = node.gradient_norm {
format!(
"{}\\nshape: {:?}\\ngrad_norm: {:.4}",
node.label, node.shape, grad_norm
)
} else {
format!("{}\\nshape: {:?}", node.label, node.shape)
};
writeln!(
&mut dot,
" n{} [label=\"{}\", fillcolor=\"{}\"];",
node.id, label, color
)
.unwrap();
}
writeln!(&mut dot).unwrap();
for edge in &self.edges {
let style = if edge.gradient_magnitude.is_some() {
let magnitude = edge.gradient_magnitude.unwrap();
let width = (magnitude.log10() + 2.0).clamp(0.5, 3.0);
format!("penwidth={:.1}", width)
} else {
"".to_string()
};
writeln!(
&mut dot,
" n{} -> n{} [label=\"{}\", {}];",
edge.from, edge.to, edge.label, style
)
.unwrap();
}
writeln!(&mut dot, "}}").unwrap();
dot
}
pub fn save_to_file(&self, path: &Path) -> std::io::Result<()> {
let dot_content = self.to_dot();
let mut file = File::create(path)?;
file.write_all(dot_content.as_bytes())?;
Ok(())
}
pub fn gradient_flow_summary(&self) -> GradientFlowSummary {
let total_nodes = self.nodes.len();
let parameter_nodes = self
.nodes
.iter()
.filter(|n| matches!(n.node_type, NodeType::Parameter))
.count();
let nodes_with_gradients = self
.nodes
.iter()
.filter(|n| n.gradient_norm.is_some())
.count();
let gradient_norms: Vec<f32> = self.nodes.iter().filter_map(|n| n.gradient_norm).collect();
let avg_gradient_norm = if !gradient_norms.is_empty() {
gradient_norms.iter().sum::<f32>() / gradient_norms.len() as f32
} else {
0.0
};
let max_gradient_norm = gradient_norms.iter().cloned().fold(0.0f32, f32::max);
let min_gradient_norm = gradient_norms.iter().cloned().fold(f32::INFINITY, f32::min);
GradientFlowSummary {
total_nodes,
parameter_nodes,
nodes_with_gradients,
avg_gradient_norm,
max_gradient_norm,
min_gradient_norm: if min_gradient_norm.is_finite() {
min_gradient_norm
} else {
0.0
},
total_edges: self.edges.len(),
}
}
pub fn detect_issues(&self) -> Vec<GradientFlowIssue> {
let mut issues = Vec::new();
for node in &self.nodes {
if let Some(grad_norm) = node.gradient_norm {
if grad_norm < 1e-6 && matches!(node.node_type, NodeType::Parameter) {
issues.push(GradientFlowIssue::VanishingGradient {
node_label: node.label.clone(),
gradient_norm: grad_norm,
});
}
if grad_norm > 1e3 {
issues.push(GradientFlowIssue::ExplodingGradient {
node_label: node.label.clone(),
gradient_norm: grad_norm,
});
}
}
}
for node in &self.nodes {
if matches!(node.node_type, NodeType::Parameter) && node.gradient_norm.is_none() {
issues.push(GradientFlowIssue::DisconnectedParameter {
node_label: node.label.clone(),
});
}
}
issues
}
pub fn clear(&mut self) {
self.nodes.clear();
self.edges.clear();
self.visited.clear();
self.node_counter = 0;
}
}
#[derive(Debug, Clone)]
pub struct GradientFlowSummary {
pub total_nodes: usize,
pub parameter_nodes: usize,
pub nodes_with_gradients: usize,
pub avg_gradient_norm: f32,
pub max_gradient_norm: f32,
pub min_gradient_norm: f32,
pub total_edges: usize,
}
impl std::fmt::Display for GradientFlowSummary {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Gradient Flow Summary:")?;
writeln!(f, " Total nodes: {}", self.total_nodes)?;
writeln!(f, " Parameter nodes: {}", self.parameter_nodes)?;
writeln!(f, " Nodes with gradients: {}", self.nodes_with_gradients)?;
writeln!(f, " Average gradient norm: {:.6}", self.avg_gradient_norm)?;
writeln!(f, " Max gradient norm: {:.6}", self.max_gradient_norm)?;
writeln!(f, " Min gradient norm: {:.6}", self.min_gradient_norm)?;
writeln!(f, " Total edges: {}", self.total_edges)?;
Ok(())
}
}
#[derive(Debug, Clone)]
pub enum GradientFlowIssue {
VanishingGradient {
node_label: String,
gradient_norm: f32,
},
ExplodingGradient {
node_label: String,
gradient_norm: f32,
},
DisconnectedParameter {
node_label: String,
},
}
impl std::fmt::Display for GradientFlowIssue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
GradientFlowIssue::VanishingGradient {
node_label,
gradient_norm,
} => {
write!(
f,
"Vanishing gradient in '{}': norm = {:.2e}",
node_label, gradient_norm
)
}
GradientFlowIssue::ExplodingGradient {
node_label,
gradient_norm,
} => {
write!(
f,
"Exploding gradient in '{}': norm = {:.2e}",
node_label, gradient_norm
)
}
GradientFlowIssue::DisconnectedParameter { node_label } => {
write!(
f,
"Disconnected parameter '{}': no gradient computed",
node_label
)
}
}
}
}
pub struct GradientFlowAnalyzer {
gradient_history: HashMap<String, Vec<f32>>,
max_history_length: usize,
}
impl GradientFlowAnalyzer {
pub fn new(max_history_length: usize) -> Self {
Self {
gradient_history: HashMap::new(),
max_history_length,
}
}
pub fn record_gradient<T>(&mut self, name: &str, tensor: &Tensor<T>)
where
T: num_traits::Float,
{
let norm = tensor
.data
.iter()
.map(|&x| x.to_f32().unwrap_or(0.0).powi(2))
.sum::<f32>()
.sqrt();
let history = self.gradient_history.entry(name.to_string()).or_default();
history.push(norm);
if history.len() > self.max_history_length {
history.remove(0);
}
}
pub fn get_history(&self, name: &str) -> Option<&Vec<f32>> {
self.gradient_history.get(name)
}
pub fn analyze_trends(&self) -> HashMap<String, GradientTrend> {
let mut trends = HashMap::new();
for (name, history) in &self.gradient_history {
if history.len() < 2 {
continue;
}
let recent_avg = history[history.len().saturating_sub(10)..]
.iter()
.sum::<f32>()
/ history[history.len().saturating_sub(10)..].len() as f32;
let overall_avg = history.iter().sum::<f32>() / history.len() as f32;
let trend = if recent_avg < overall_avg * 0.1 {
GradientTrend::Vanishing
} else if recent_avg > overall_avg * 10.0 {
GradientTrend::Exploding
} else if (recent_avg - overall_avg).abs() < overall_avg * 0.1 {
GradientTrend::Stable
} else if recent_avg > overall_avg {
GradientTrend::Increasing
} else {
GradientTrend::Decreasing
};
trends.insert(name.clone(), trend);
}
trends
}
pub fn clear(&mut self) {
self.gradient_history.clear();
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum GradientTrend {
Stable,
Increasing,
Decreasing,
Vanishing,
Exploding,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gradient_flow_visualizer() {
let mut visualizer = GradientFlowVisualizer::new();
let input_id = visualizer.node_counter;
visualizer.node_counter += 1;
visualizer.nodes.push(NodeInfo {
id: input_id,
label: "input".to_string(),
node_type: NodeType::Input,
shape: vec![32, 10],
gradient_norm: None,
_requires_grad: false,
});
let param_id = visualizer.node_counter;
visualizer.node_counter += 1;
visualizer.nodes.push(NodeInfo {
id: param_id,
label: "weight".to_string(),
node_type: NodeType::Parameter,
shape: vec![10, 5],
gradient_norm: Some(0.5),
_requires_grad: true,
});
visualizer.add_operation("matmul", vec![input_id, param_id], 2);
let dot = visualizer.to_dot();
assert!(dot.contains("digraph GradientFlow"));
assert!(dot.contains("weight"));
assert!(dot.contains("matmul"));
}
#[test]
fn test_gradient_flow_summary() {
let mut visualizer = GradientFlowVisualizer::new();
for i in 0..5 {
visualizer.nodes.push(NodeInfo {
id: i,
label: format!("param_{}", i),
node_type: NodeType::Parameter,
shape: vec![10, 10],
gradient_norm: Some((i + 1) as f32 * 0.1),
_requires_grad: true,
});
}
let summary = visualizer.gradient_flow_summary();
assert_eq!(summary.total_nodes, 5);
assert_eq!(summary.parameter_nodes, 5);
assert_eq!(summary.nodes_with_gradients, 5);
assert!(summary.avg_gradient_norm > 0.0);
}
#[test]
fn test_issue_detection() {
let mut visualizer = GradientFlowVisualizer::new();
visualizer.nodes.push(NodeInfo {
id: 0,
label: "vanishing_param".to_string(),
node_type: NodeType::Parameter,
shape: vec![10],
gradient_norm: Some(1e-7),
_requires_grad: true,
});
visualizer.nodes.push(NodeInfo {
id: 1,
label: "exploding_param".to_string(),
node_type: NodeType::Parameter,
shape: vec![10],
gradient_norm: Some(1e4),
_requires_grad: true,
});
let issues = visualizer.detect_issues();
assert_eq!(issues.len(), 2);
}
#[test]
fn test_gradient_analyzer() {
let mut analyzer = GradientFlowAnalyzer::new(100);
let _tensor = Tensor::from_vec(vec![0.1, 0.2, 0.3], vec![3]);
for i in 0..20 {
let scaled = Tensor::from_vec(
vec![
0.1 * (i as f32 + 1.0),
0.2 * (i as f32 + 1.0),
0.3 * (i as f32 + 1.0),
],
vec![3],
);
analyzer.record_gradient("weight", &scaled);
}
let history = analyzer.get_history("weight").unwrap();
assert_eq!(history.len(), 20);
let trends = analyzer.analyze_trends();
assert!(trends.contains_key("weight"));
}
}