Skip to main content

entrenar/train/tui/charts/
gradient_flow.rs

1//! Gradient flow heatmap for visualizing per-layer gradients.
2
3/// Gradient flow heatmap for visualizing per-layer gradients.
4#[derive(Debug, Clone)]
5pub struct GradientFlowHeatmap {
6    /// Layer names
7    pub(crate) layer_names: Vec<String>,
8    /// Gradient magnitudes per layer (log scale)
9    pub(crate) gradients: Vec<Vec<f32>>,
10    /// Column labels (Q, K, V, O, FFN, etc.)
11    pub(crate) column_labels: Vec<String>,
12}
13
14impl GradientFlowHeatmap {
15    /// Create a new gradient flow heatmap.
16    pub fn new(layer_names: Vec<String>, column_labels: Vec<String>) -> Self {
17        let num_layers = layer_names.len();
18        Self {
19            layer_names,
20            gradients: vec![vec![0.0; column_labels.len()]; num_layers],
21            column_labels,
22        }
23    }
24
25    /// Update gradient for a specific layer and column.
26    pub fn update(&mut self, layer: usize, col: usize, grad_norm: f32) {
27        if layer < self.gradients.len() && col < self.column_labels.len() {
28            // Store log scale for visualization
29            self.gradients[layer][col] = (grad_norm + 1e-8).max(f32::MIN_POSITIVE).ln();
30        }
31    }
32
33    /// Render to string.
34    pub fn render(&self) -> String {
35        let heatmap_chars = ['░', '▒', '▓', '█'];
36
37        // Find min/max for normalization
38        let mut min = f32::INFINITY;
39        let mut max = f32::NEG_INFINITY;
40        for row in &self.gradients {
41            for &v in row {
42                min = min.min(v);
43                max = max.max(v);
44            }
45        }
46        let range = max - min;
47
48        let mut output = String::new();
49        output.push_str("Gradient Flow (log scale):\n");
50
51        // Header
52        output.push_str("         ");
53        for label in &self.column_labels {
54            output.push_str(&format!("{label:^5}"));
55        }
56        output.push('\n');
57
58        // Rows
59        for (i, row) in self.gradients.iter().enumerate() {
60            let name = self.layer_names.get(i).map_or("?", String::as_str);
61            output.push_str(&format!("{name:>8} "));
62
63            for &v in row {
64                let normalized =
65                    if range > f32::EPSILON { ((v - min) / range).clamp(0.0, 1.0) } else { 0.5 };
66                let idx = (normalized * 3.0).round() as usize;
67                let c = heatmap_chars[idx.min(3)];
68                output.push_str(&format!("{c}{c}{c}{c} "));
69            }
70            output.push('\n');
71        }
72
73        output
74    }
75}