Skip to main content

proof_engine/ml/
training_viz.rs

1//! Training visualization: loss landscapes, gradient flow, weight distributions, dashboards.
2
3use glam::{Vec2, Vec4};
4
5/// Statistics for a single training epoch.
6#[derive(Debug, Clone)]
7pub struct EpochStats {
8    pub epoch: usize,
9    pub loss: f32,
10    pub accuracy: f32,
11    pub lr: f32,
12    pub grad_norm: f32,
13    pub weight_norms: Vec<f32>,
14}
15
16impl EpochStats {
17    pub fn new(epoch: usize, loss: f32, accuracy: f32, lr: f32) -> Self {
18        Self { epoch, loss, accuracy, lr, grad_norm: 0.0, weight_norms: Vec::new() }
19    }
20}
21
22/// Log of training progress over multiple epochs.
23#[derive(Debug, Clone)]
24pub struct TrainingLog {
25    pub epochs: Vec<EpochStats>,
26}
27
28impl TrainingLog {
29    pub fn new() -> Self {
30        Self { epochs: Vec::new() }
31    }
32
33    pub fn push(&mut self, stats: EpochStats) {
34        self.epochs.push(stats);
35    }
36
37    pub fn len(&self) -> usize {
38        self.epochs.len()
39    }
40
41    pub fn is_empty(&self) -> bool {
42        self.epochs.is_empty()
43    }
44
45    pub fn best_loss(&self) -> Option<f32> {
46        self.epochs.iter().map(|e| e.loss).fold(None, |acc, v| {
47            Some(match acc { None => v, Some(a) => a.min(v) })
48        })
49    }
50
51    pub fn best_accuracy(&self) -> Option<f32> {
52        self.epochs.iter().map(|e| e.accuracy).fold(None, |acc, v| {
53            Some(match acc { None => v, Some(a) => a.max(v) })
54        })
55    }
56
57    /// Return the loss values as a vector for plotting.
58    pub fn loss_curve(&self) -> Vec<f32> {
59        self.epochs.iter().map(|e| e.loss).collect()
60    }
61
62    /// Return the accuracy values as a vector for plotting.
63    pub fn accuracy_curve(&self) -> Vec<f32> {
64        self.epochs.iter().map(|e| e.accuracy).collect()
65    }
66
67    /// Return learning rate schedule.
68    pub fn lr_schedule(&self) -> Vec<f32> {
69        self.epochs.iter().map(|e| e.lr).collect()
70    }
71
72    /// Compute smoothed loss using exponential moving average.
73    pub fn smoothed_loss(&self, alpha: f32) -> Vec<f32> {
74        let mut smoothed = Vec::with_capacity(self.epochs.len());
75        let mut ema = 0.0f32;
76        for (i, e) in self.epochs.iter().enumerate() {
77            if i == 0 {
78                ema = e.loss;
79            } else {
80                ema = alpha * ema + (1.0 - alpha) * e.loss;
81            }
82            smoothed.push(ema);
83        }
84        smoothed
85    }
86}
87
88// ── Loss Landscape ──────────────────────────────────────────────────────
89
90/// A 2D grid of loss values sampled along two parameter directions.
91#[derive(Debug, Clone)]
92pub struct LossLandscape {
93    /// Loss values in row-major order (height x width).
94    pub values: Vec<f32>,
95    pub width: usize,
96    pub height: usize,
97    /// Range of the x-axis parameter.
98    pub x_range: (f32, f32),
99    /// Range of the y-axis parameter.
100    pub y_range: (f32, f32),
101}
102
103impl LossLandscape {
104    pub fn new(width: usize, height: usize, x_range: (f32, f32), y_range: (f32, f32)) -> Self {
105        Self {
106            values: vec![0.0; width * height],
107            width,
108            height,
109            x_range,
110            y_range,
111        }
112    }
113
114    pub fn set(&mut self, x: usize, y: usize, val: f32) {
115        if x < self.width && y < self.height {
116            self.values[y * self.width + x] = val;
117        }
118    }
119
120    pub fn get(&self, x: usize, y: usize) -> f32 {
121        if x < self.width && y < self.height {
122            self.values[y * self.width + x]
123        } else {
124            0.0
125        }
126    }
127
128    pub fn min_loss(&self) -> f32 {
129        self.values.iter().cloned().fold(f32::INFINITY, f32::min)
130    }
131
132    pub fn max_loss(&self) -> f32 {
133        self.values.iter().cloned().fold(f32::NEG_INFINITY, f32::max)
134    }
135
136    /// Generate a synthetic loss landscape (for testing/demo).
137    /// Uses a sum-of-Gaussians to create a bowl-like landscape.
138    pub fn generate_synthetic(width: usize, height: usize) -> Self {
139        let mut landscape = Self::new(width, height, (-2.0, 2.0), (-2.0, 2.0));
140        for y in 0..height {
141            for x in 0..width {
142                let px = -2.0 + 4.0 * x as f32 / (width - 1).max(1) as f32;
143                let py = -2.0 + 4.0 * y as f32 / (height - 1).max(1) as f32;
144                // Rosenbrock-like function
145                let val = (1.0 - px).powi(2) + 100.0 * (py - px * px).powi(2);
146                landscape.set(x, y, val.ln().max(-5.0));
147            }
148        }
149        landscape
150    }
151}
152
153/// Render a loss landscape as colored points: (position, loss_value, color).
154pub fn render_loss_landscape(landscape: &LossLandscape) -> Vec<(Vec2, f32, Vec4)> {
155    let min_loss = landscape.min_loss();
156    let max_loss = landscape.max_loss();
157    let range = (max_loss - min_loss).max(1e-6);
158
159    let mut points = Vec::with_capacity(landscape.width * landscape.height);
160    for y in 0..landscape.height {
161        for x in 0..landscape.width {
162            let loss = landscape.get(x, y);
163            let t = (loss - min_loss) / range; // 0..1
164
165            let px = landscape.x_range.0
166                + (landscape.x_range.1 - landscape.x_range.0) * x as f32 / (landscape.width - 1).max(1) as f32;
167            let py = landscape.y_range.0
168                + (landscape.y_range.1 - landscape.y_range.0) * y as f32 / (landscape.height - 1).max(1) as f32;
169
170            // Color: blue (low loss) -> red (high loss)
171            let r = t;
172            let g = (1.0 - (2.0 * t - 1.0).abs()).max(0.0);
173            let b = 1.0 - t;
174            let color = Vec4::new(r, g, b, 1.0);
175
176            points.push((Vec2::new(px, py), loss, color));
177        }
178    }
179    points
180}
181
182// ── Gradient Flow Visualization ─────────────────────────────────────────
183
184/// Gradient magnitude per layer, rendered as a horizontal bar chart.
185pub struct GradientFlowViz;
186
187impl GradientFlowViz {
188    /// Render gradient norms as bar chart data: (layer_index, bar_height, color).
189    pub fn render(layer_names: &[String], grad_norms: &[f32]) -> Vec<(usize, f32, Vec4)> {
190        let max_norm = grad_norms.iter().cloned().fold(0.0f32, f32::max).max(1e-6);
191        layer_names.iter().enumerate().zip(grad_norms).map(|((i, _name), &norm)| {
192            let t = norm / max_norm;
193            // Green (healthy) -> Yellow (warning) -> Red (vanishing/exploding)
194            let color = if t < 0.01 {
195                // Vanishing gradient: red
196                Vec4::new(1.0, 0.0, 0.0, 1.0)
197            } else if t > 0.8 {
198                // Potentially exploding: orange
199                Vec4::new(1.0, 0.5, 0.0, 1.0)
200            } else {
201                // Healthy: green
202                Vec4::new(0.2, 0.8, 0.2, 1.0)
203            };
204            (i, t, color)
205        }).collect()
206    }
207
208    /// Check for vanishing gradients (any layer with norm < threshold).
209    pub fn detect_vanishing(grad_norms: &[f32], threshold: f32) -> Vec<usize> {
210        grad_norms.iter().enumerate()
211            .filter(|(_, &n)| n < threshold)
212            .map(|(i, _)| i)
213            .collect()
214    }
215
216    /// Check for exploding gradients (any layer with norm > threshold).
217    pub fn detect_exploding(grad_norms: &[f32], threshold: f32) -> Vec<usize> {
218        grad_norms.iter().enumerate()
219            .filter(|(_, &n)| n > threshold)
220            .map(|(i, _)| i)
221            .collect()
222    }
223}
224
225// ── Weight Distribution Visualization ───────────────────────────────────
226
227/// Render weight histograms per layer.
228pub struct WeightDistViz;
229
230impl WeightDistViz {
231    /// Compute a histogram of weight values.
232    /// Returns (bin_centers, counts) for the given number of bins.
233    pub fn histogram(weights: &[f32], num_bins: usize) -> (Vec<f32>, Vec<u32>) {
234        if weights.is_empty() || num_bins == 0 {
235            return (vec![], vec![]);
236        }
237        let min_w = weights.iter().cloned().fold(f32::INFINITY, f32::min);
238        let max_w = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
239        let range = (max_w - min_w).max(1e-8);
240        let bin_width = range / num_bins as f32;
241
242        let mut counts = vec![0u32; num_bins];
243        for &w in weights {
244            let bin = ((w - min_w) / bin_width) as usize;
245            let bin = bin.min(num_bins - 1);
246            counts[bin] += 1;
247        }
248
249        let centers: Vec<f32> = (0..num_bins)
250            .map(|i| min_w + (i as f32 + 0.5) * bin_width)
251            .collect();
252
253        (centers, counts)
254    }
255
256    /// Compute statistics of a weight array.
257    pub fn stats(weights: &[f32]) -> WeightStats {
258        if weights.is_empty() {
259            return WeightStats { mean: 0.0, std: 0.0, min: 0.0, max: 0.0, sparsity: 1.0 };
260        }
261        let n = weights.len() as f32;
262        let mean: f32 = weights.iter().sum::<f32>() / n;
263        let var: f32 = weights.iter().map(|w| (w - mean) * (w - mean)).sum::<f32>() / n;
264        let min = weights.iter().cloned().fold(f32::INFINITY, f32::min);
265        let max = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
266        let zeros = weights.iter().filter(|&&w| w.abs() < 1e-8).count();
267        WeightStats { mean, std: var.sqrt(), min, max, sparsity: zeros as f32 / n }
268    }
269}
270
271#[derive(Debug, Clone)]
272pub struct WeightStats {
273    pub mean: f32,
274    pub std: f32,
275    pub min: f32,
276    pub max: f32,
277    pub sparsity: f32,
278}
279
280// ── Activation Map Visualization ────────────────────────────────────────
281
282/// Render intermediate activations as colored grids.
283pub struct ActivationMapViz;
284
285impl ActivationMapViz {
286    /// Render a 2-D activation map (H, W) as colored cells.
287    /// Returns Vec of (position, value, color).
288    pub fn render_2d(activations: &[f32], height: usize, width: usize) -> Vec<(Vec2, f32, Vec4)> {
289        assert_eq!(activations.len(), height * width);
290        let min_a = activations.iter().cloned().fold(f32::INFINITY, f32::min);
291        let max_a = activations.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
292        let range = (max_a - min_a).max(1e-8);
293
294        let mut points = Vec::with_capacity(height * width);
295        for y in 0..height {
296            for x in 0..width {
297                let val = activations[y * width + x];
298                let t = (val - min_a) / range;
299                // Viridis-like colormap
300                let r = (0.267 + 0.004 * t + 1.3 * t * t - 0.6 * t * t * t).clamp(0.0, 1.0);
301                let g = (0.004 + 1.0 * t - 0.15 * t * t).clamp(0.0, 1.0);
302                let b = (0.329 + 1.4 * t - 1.75 * t * t + 0.5 * t * t * t).clamp(0.0, 1.0);
303                points.push((
304                    Vec2::new(x as f32, y as f32),
305                    val,
306                    Vec4::new(r, g, b, 1.0),
307                ));
308            }
309        }
310        points
311    }
312
313    /// Render multiple channels stacked vertically.
314    pub fn render_multichannel(
315        activations: &[f32],
316        channels: usize,
317        height: usize,
318        width: usize,
319    ) -> Vec<(Vec2, f32, Vec4)> {
320        assert_eq!(activations.len(), channels * height * width);
321        let mut all_points = Vec::new();
322        for c in 0..channels {
323            let offset = c * height * width;
324            let channel_data = &activations[offset..offset + height * width];
325            let mut pts = Self::render_2d(channel_data, height, width);
326            // Offset y position by channel
327            let y_off = c as f32 * (height as f32 + 1.0);
328            for p in &mut pts {
329                p.0.y += y_off;
330            }
331            all_points.extend(pts);
332        }
333        all_points
334    }
335}
336
337// ── Training Dashboard ──────────────────────────────────────────────────
338
339/// Composite training dashboard combining multiple visualizations.
340pub struct TrainingDashboard {
341    pub log: TrainingLog,
342    pub layer_names: Vec<String>,
343}
344
345impl TrainingDashboard {
346    pub fn new(log: TrainingLog, layer_names: Vec<String>) -> Self {
347        Self { log, layer_names }
348    }
349
350    /// Render the loss curve as a series of 2D points.
351    pub fn render_loss_curve(&self) -> Vec<Vec2> {
352        let n = self.log.len();
353        if n == 0 { return vec![]; }
354        let max_loss = self.log.epochs.iter().map(|e| e.loss).fold(0.0f32, f32::max).max(1e-6);
355        self.log.epochs.iter().enumerate().map(|(i, e)| {
356            Vec2::new(i as f32 / n as f32, e.loss / max_loss)
357        }).collect()
358    }
359
360    /// Render the accuracy curve as a series of 2D points.
361    pub fn render_accuracy_curve(&self) -> Vec<Vec2> {
362        let n = self.log.len();
363        if n == 0 { return vec![]; }
364        self.log.epochs.iter().enumerate().map(|(i, e)| {
365            Vec2::new(i as f32 / n as f32, e.accuracy)
366        }).collect()
367    }
368
369    /// Render the learning rate schedule.
370    pub fn render_lr_curve(&self) -> Vec<Vec2> {
371        let n = self.log.len();
372        if n == 0 { return vec![]; }
373        let max_lr = self.log.epochs.iter().map(|e| e.lr).fold(0.0f32, f32::max).max(1e-8);
374        self.log.epochs.iter().enumerate().map(|(i, e)| {
375            Vec2::new(i as f32 / n as f32, e.lr / max_lr)
376        }).collect()
377    }
378
379    /// Render gradient flow for the latest epoch.
380    pub fn render_gradient_flow(&self) -> Vec<(usize, f32, Vec4)> {
381        if let Some(last) = self.log.epochs.last() {
382            if last.weight_norms.len() == self.layer_names.len() {
383                // Use weight norms as proxy for gradient norms if grad_norm not per-layer
384                return GradientFlowViz::render(&self.layer_names, &last.weight_norms);
385            }
386        }
387        vec![]
388    }
389
390    /// Summary string for the current training state.
391    pub fn summary(&self) -> String {
392        let n = self.log.len();
393        if n == 0 { return "No training data".to_string(); }
394        let last = &self.log.epochs[n - 1];
395        let best_loss = self.log.best_loss().unwrap_or(0.0);
396        let best_acc = self.log.best_accuracy().unwrap_or(0.0);
397        format!(
398            "Epoch {}/{}: loss={:.4} acc={:.2}% lr={:.6} | best_loss={:.4} best_acc={:.2}%",
399            last.epoch, n, last.loss, last.accuracy * 100.0, last.lr,
400            best_loss, best_acc * 100.0
401        )
402    }
403}
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408
409    #[test]
410    fn test_training_log() {
411        let mut log = TrainingLog::new();
412        assert!(log.is_empty());
413        log.push(EpochStats::new(0, 2.5, 0.1, 0.01));
414        log.push(EpochStats::new(1, 1.5, 0.5, 0.01));
415        log.push(EpochStats::new(2, 0.8, 0.8, 0.005));
416        assert_eq!(log.len(), 3);
417        assert!((log.best_loss().unwrap() - 0.8).abs() < 1e-5);
418        assert!((log.best_accuracy().unwrap() - 0.8).abs() < 1e-5);
419    }
420
421    #[test]
422    fn test_loss_curve() {
423        let mut log = TrainingLog::new();
424        log.push(EpochStats::new(0, 2.0, 0.1, 0.01));
425        log.push(EpochStats::new(1, 1.0, 0.5, 0.01));
426        let curve = log.loss_curve();
427        assert_eq!(curve, vec![2.0, 1.0]);
428    }
429
430    #[test]
431    fn test_smoothed_loss() {
432        let mut log = TrainingLog::new();
433        for i in 0..10 {
434            log.push(EpochStats::new(i, 10.0 - i as f32, 0.0, 0.01));
435        }
436        let smoothed = log.smoothed_loss(0.9);
437        assert_eq!(smoothed.len(), 10);
438        // Smoothed should lag behind actual
439        assert!(smoothed[9] > log.epochs[9].loss);
440    }
441
442    #[test]
443    fn test_loss_landscape() {
444        let landscape = LossLandscape::generate_synthetic(10, 10);
445        assert_eq!(landscape.values.len(), 100);
446        assert!(landscape.min_loss() < landscape.max_loss());
447    }
448
449    #[test]
450    fn test_render_loss_landscape() {
451        let landscape = LossLandscape::generate_synthetic(5, 5);
452        let points = render_loss_landscape(&landscape);
453        assert_eq!(points.len(), 25);
454        for (pos, loss, color) in &points {
455            assert!(pos.x >= -2.0 && pos.x <= 2.0);
456            assert!(pos.y >= -2.0 && pos.y <= 2.0);
457            assert!(color.w == 1.0); // full alpha
458            let _ = loss;
459        }
460    }
461
462    #[test]
463    fn test_gradient_flow_viz() {
464        let names = vec!["dense_0".into(), "dense_1".into(), "dense_2".into()];
465        let norms = vec![0.5, 0.001, 0.3];
466        let bars = GradientFlowViz::render(&names, &norms);
467        assert_eq!(bars.len(), 3);
468    }
469
470    #[test]
471    fn test_detect_vanishing() {
472        let norms = vec![0.5, 0.001, 0.0001, 0.3];
473        let vanishing = GradientFlowViz::detect_vanishing(&norms, 0.01);
474        assert_eq!(vanishing, vec![1, 2]);
475    }
476
477    #[test]
478    fn test_detect_exploding() {
479        let norms = vec![0.5, 100.0, 0.3, 200.0];
480        let exploding = GradientFlowViz::detect_exploding(&norms, 50.0);
481        assert_eq!(exploding, vec![1, 3]);
482    }
483
484    #[test]
485    fn test_weight_histogram() {
486        let weights = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0];
487        let (centers, counts) = WeightDistViz::histogram(&weights, 5);
488        assert_eq!(centers.len(), 5);
489        assert_eq!(counts.len(), 5);
490        let total: u32 = counts.iter().sum();
491        assert_eq!(total, 10);
492    }
493
494    #[test]
495    fn test_weight_stats() {
496        let weights = vec![1.0, 2.0, 3.0, 4.0, 5.0];
497        let stats = WeightDistViz::stats(&weights);
498        assert!((stats.mean - 3.0).abs() < 1e-5);
499        assert!(stats.std > 0.0);
500        assert_eq!(stats.min, 1.0);
501        assert_eq!(stats.max, 5.0);
502        assert_eq!(stats.sparsity, 0.0);
503    }
504
505    #[test]
506    fn test_activation_map_2d() {
507        let acts = vec![0.0, 0.5, 1.0, 0.25, 0.75, 0.5, 0.1, 0.9, 0.4];
508        let points = ActivationMapViz::render_2d(&acts, 3, 3);
509        assert_eq!(points.len(), 9);
510    }
511
512    #[test]
513    fn test_activation_map_multichannel() {
514        let acts = vec![0.0; 2 * 3 * 3]; // 2 channels, 3x3
515        let points = ActivationMapViz::render_multichannel(&acts, 2, 3, 3);
516        assert_eq!(points.len(), 18);
517    }
518
519    #[test]
520    fn test_training_dashboard() {
521        let mut log = TrainingLog::new();
522        log.push(EpochStats { epoch: 0, loss: 2.0, accuracy: 0.1, lr: 0.01, grad_norm: 0.5, weight_norms: vec![0.5, 0.3] });
523        log.push(EpochStats { epoch: 1, loss: 1.0, accuracy: 0.5, lr: 0.01, grad_norm: 0.4, weight_norms: vec![0.4, 0.3] });
524
525        let dashboard = TrainingDashboard::new(log, vec!["dense_0".into(), "dense_1".into()]);
526        let loss_pts = dashboard.render_loss_curve();
527        assert_eq!(loss_pts.len(), 2);
528        let acc_pts = dashboard.render_accuracy_curve();
529        assert_eq!(acc_pts.len(), 2);
530        let summary = dashboard.summary();
531        assert!(summary.contains("Epoch"));
532        assert!(summary.contains("loss="));
533    }
534
535    #[test]
536    fn test_dashboard_gradient_flow() {
537        let mut log = TrainingLog::new();
538        log.push(EpochStats {
539            epoch: 0, loss: 1.0, accuracy: 0.5, lr: 0.01, grad_norm: 0.5,
540            weight_norms: vec![0.5, 0.3, 0.1],
541        });
542        let names = vec!["a".into(), "b".into(), "c".into()];
543        let dashboard = TrainingDashboard::new(log, names);
544        let flow = dashboard.render_gradient_flow();
545        assert_eq!(flow.len(), 3);
546    }
547}