entrenar/train/tui/charts/
gradient_flow.rs1#[derive(Debug, Clone)]
5pub struct GradientFlowHeatmap {
6 pub(crate) layer_names: Vec<String>,
8 pub(crate) gradients: Vec<Vec<f32>>,
10 pub(crate) column_labels: Vec<String>,
12}
13
14impl GradientFlowHeatmap {
15 pub fn new(layer_names: Vec<String>, column_labels: Vec<String>) -> Self {
17 let num_layers = layer_names.len();
18 Self {
19 layer_names,
20 gradients: vec![vec![0.0; column_labels.len()]; num_layers],
21 column_labels,
22 }
23 }
24
25 pub fn update(&mut self, layer: usize, col: usize, grad_norm: f32) {
27 if layer < self.gradients.len() && col < self.column_labels.len() {
28 self.gradients[layer][col] = (grad_norm + 1e-8).max(f32::MIN_POSITIVE).ln();
30 }
31 }
32
33 pub fn render(&self) -> String {
35 let heatmap_chars = ['░', '▒', '▓', '█'];
36
37 let mut min = f32::INFINITY;
39 let mut max = f32::NEG_INFINITY;
40 for row in &self.gradients {
41 for &v in row {
42 min = min.min(v);
43 max = max.max(v);
44 }
45 }
46 let range = max - min;
47
48 let mut output = String::new();
49 output.push_str("Gradient Flow (log scale):\n");
50
51 output.push_str(" ");
53 for label in &self.column_labels {
54 output.push_str(&format!("{label:^5}"));
55 }
56 output.push('\n');
57
58 for (i, row) in self.gradients.iter().enumerate() {
60 let name = self.layer_names.get(i).map_or("?", String::as_str);
61 output.push_str(&format!("{name:>8} "));
62
63 for &v in row {
64 let normalized =
65 if range > f32::EPSILON { ((v - min) / range).clamp(0.0, 1.0) } else { 0.5 };
66 let idx = (normalized * 3.0).round() as usize;
67 let c = heatmap_chars[idx.min(3)];
68 output.push_str(&format!("{c}{c}{c}{c} "));
69 }
70 output.push('\n');
71 }
72
73 output
74 }
75}