1use glam::{Vec2, Vec4};
4
5#[derive(Debug, Clone)]
7pub struct EpochStats {
8 pub epoch: usize,
9 pub loss: f32,
10 pub accuracy: f32,
11 pub lr: f32,
12 pub grad_norm: f32,
13 pub weight_norms: Vec<f32>,
14}
15
16impl EpochStats {
17 pub fn new(epoch: usize, loss: f32, accuracy: f32, lr: f32) -> Self {
18 Self { epoch, loss, accuracy, lr, grad_norm: 0.0, weight_norms: Vec::new() }
19 }
20}
21
22#[derive(Debug, Clone)]
24pub struct TrainingLog {
25 pub epochs: Vec<EpochStats>,
26}
27
28impl TrainingLog {
29 pub fn new() -> Self {
30 Self { epochs: Vec::new() }
31 }
32
33 pub fn push(&mut self, stats: EpochStats) {
34 self.epochs.push(stats);
35 }
36
37 pub fn len(&self) -> usize {
38 self.epochs.len()
39 }
40
41 pub fn is_empty(&self) -> bool {
42 self.epochs.is_empty()
43 }
44
45 pub fn best_loss(&self) -> Option<f32> {
46 self.epochs.iter().map(|e| e.loss).fold(None, |acc, v| {
47 Some(match acc { None => v, Some(a) => a.min(v) })
48 })
49 }
50
51 pub fn best_accuracy(&self) -> Option<f32> {
52 self.epochs.iter().map(|e| e.accuracy).fold(None, |acc, v| {
53 Some(match acc { None => v, Some(a) => a.max(v) })
54 })
55 }
56
57 pub fn loss_curve(&self) -> Vec<f32> {
59 self.epochs.iter().map(|e| e.loss).collect()
60 }
61
62 pub fn accuracy_curve(&self) -> Vec<f32> {
64 self.epochs.iter().map(|e| e.accuracy).collect()
65 }
66
67 pub fn lr_schedule(&self) -> Vec<f32> {
69 self.epochs.iter().map(|e| e.lr).collect()
70 }
71
72 pub fn smoothed_loss(&self, alpha: f32) -> Vec<f32> {
74 let mut smoothed = Vec::with_capacity(self.epochs.len());
75 let mut ema = 0.0f32;
76 for (i, e) in self.epochs.iter().enumerate() {
77 if i == 0 {
78 ema = e.loss;
79 } else {
80 ema = alpha * ema + (1.0 - alpha) * e.loss;
81 }
82 smoothed.push(ema);
83 }
84 smoothed
85 }
86}
87
88#[derive(Debug, Clone)]
92pub struct LossLandscape {
93 pub values: Vec<f32>,
95 pub width: usize,
96 pub height: usize,
97 pub x_range: (f32, f32),
99 pub y_range: (f32, f32),
101}
102
103impl LossLandscape {
104 pub fn new(width: usize, height: usize, x_range: (f32, f32), y_range: (f32, f32)) -> Self {
105 Self {
106 values: vec![0.0; width * height],
107 width,
108 height,
109 x_range,
110 y_range,
111 }
112 }
113
114 pub fn set(&mut self, x: usize, y: usize, val: f32) {
115 if x < self.width && y < self.height {
116 self.values[y * self.width + x] = val;
117 }
118 }
119
120 pub fn get(&self, x: usize, y: usize) -> f32 {
121 if x < self.width && y < self.height {
122 self.values[y * self.width + x]
123 } else {
124 0.0
125 }
126 }
127
128 pub fn min_loss(&self) -> f32 {
129 self.values.iter().cloned().fold(f32::INFINITY, f32::min)
130 }
131
132 pub fn max_loss(&self) -> f32 {
133 self.values.iter().cloned().fold(f32::NEG_INFINITY, f32::max)
134 }
135
136 pub fn generate_synthetic(width: usize, height: usize) -> Self {
139 let mut landscape = Self::new(width, height, (-2.0, 2.0), (-2.0, 2.0));
140 for y in 0..height {
141 for x in 0..width {
142 let px = -2.0 + 4.0 * x as f32 / (width - 1).max(1) as f32;
143 let py = -2.0 + 4.0 * y as f32 / (height - 1).max(1) as f32;
144 let val = (1.0 - px).powi(2) + 100.0 * (py - px * px).powi(2);
146 landscape.set(x, y, val.ln().max(-5.0));
147 }
148 }
149 landscape
150 }
151}
152
153pub fn render_loss_landscape(landscape: &LossLandscape) -> Vec<(Vec2, f32, Vec4)> {
155 let min_loss = landscape.min_loss();
156 let max_loss = landscape.max_loss();
157 let range = (max_loss - min_loss).max(1e-6);
158
159 let mut points = Vec::with_capacity(landscape.width * landscape.height);
160 for y in 0..landscape.height {
161 for x in 0..landscape.width {
162 let loss = landscape.get(x, y);
163 let t = (loss - min_loss) / range; let px = landscape.x_range.0
166 + (landscape.x_range.1 - landscape.x_range.0) * x as f32 / (landscape.width - 1).max(1) as f32;
167 let py = landscape.y_range.0
168 + (landscape.y_range.1 - landscape.y_range.0) * y as f32 / (landscape.height - 1).max(1) as f32;
169
170 let r = t;
172 let g = (1.0 - (2.0 * t - 1.0).abs()).max(0.0);
173 let b = 1.0 - t;
174 let color = Vec4::new(r, g, b, 1.0);
175
176 points.push((Vec2::new(px, py), loss, color));
177 }
178 }
179 points
180}
181
182pub struct GradientFlowViz;
186
187impl GradientFlowViz {
188 pub fn render(layer_names: &[String], grad_norms: &[f32]) -> Vec<(usize, f32, Vec4)> {
190 let max_norm = grad_norms.iter().cloned().fold(0.0f32, f32::max).max(1e-6);
191 layer_names.iter().enumerate().zip(grad_norms).map(|((i, _name), &norm)| {
192 let t = norm / max_norm;
193 let color = if t < 0.01 {
195 Vec4::new(1.0, 0.0, 0.0, 1.0)
197 } else if t > 0.8 {
198 Vec4::new(1.0, 0.5, 0.0, 1.0)
200 } else {
201 Vec4::new(0.2, 0.8, 0.2, 1.0)
203 };
204 (i, t, color)
205 }).collect()
206 }
207
208 pub fn detect_vanishing(grad_norms: &[f32], threshold: f32) -> Vec<usize> {
210 grad_norms.iter().enumerate()
211 .filter(|(_, &n)| n < threshold)
212 .map(|(i, _)| i)
213 .collect()
214 }
215
216 pub fn detect_exploding(grad_norms: &[f32], threshold: f32) -> Vec<usize> {
218 grad_norms.iter().enumerate()
219 .filter(|(_, &n)| n > threshold)
220 .map(|(i, _)| i)
221 .collect()
222 }
223}
224
225pub struct WeightDistViz;
229
230impl WeightDistViz {
231 pub fn histogram(weights: &[f32], num_bins: usize) -> (Vec<f32>, Vec<u32>) {
234 if weights.is_empty() || num_bins == 0 {
235 return (vec![], vec![]);
236 }
237 let min_w = weights.iter().cloned().fold(f32::INFINITY, f32::min);
238 let max_w = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
239 let range = (max_w - min_w).max(1e-8);
240 let bin_width = range / num_bins as f32;
241
242 let mut counts = vec![0u32; num_bins];
243 for &w in weights {
244 let bin = ((w - min_w) / bin_width) as usize;
245 let bin = bin.min(num_bins - 1);
246 counts[bin] += 1;
247 }
248
249 let centers: Vec<f32> = (0..num_bins)
250 .map(|i| min_w + (i as f32 + 0.5) * bin_width)
251 .collect();
252
253 (centers, counts)
254 }
255
256 pub fn stats(weights: &[f32]) -> WeightStats {
258 if weights.is_empty() {
259 return WeightStats { mean: 0.0, std: 0.0, min: 0.0, max: 0.0, sparsity: 1.0 };
260 }
261 let n = weights.len() as f32;
262 let mean: f32 = weights.iter().sum::<f32>() / n;
263 let var: f32 = weights.iter().map(|w| (w - mean) * (w - mean)).sum::<f32>() / n;
264 let min = weights.iter().cloned().fold(f32::INFINITY, f32::min);
265 let max = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
266 let zeros = weights.iter().filter(|&&w| w.abs() < 1e-8).count();
267 WeightStats { mean, std: var.sqrt(), min, max, sparsity: zeros as f32 / n }
268 }
269}
270
271#[derive(Debug, Clone)]
272pub struct WeightStats {
273 pub mean: f32,
274 pub std: f32,
275 pub min: f32,
276 pub max: f32,
277 pub sparsity: f32,
278}
279
280pub struct ActivationMapViz;
284
285impl ActivationMapViz {
286 pub fn render_2d(activations: &[f32], height: usize, width: usize) -> Vec<(Vec2, f32, Vec4)> {
289 assert_eq!(activations.len(), height * width);
290 let min_a = activations.iter().cloned().fold(f32::INFINITY, f32::min);
291 let max_a = activations.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
292 let range = (max_a - min_a).max(1e-8);
293
294 let mut points = Vec::with_capacity(height * width);
295 for y in 0..height {
296 for x in 0..width {
297 let val = activations[y * width + x];
298 let t = (val - min_a) / range;
299 let r = (0.267 + 0.004 * t + 1.3 * t * t - 0.6 * t * t * t).clamp(0.0, 1.0);
301 let g = (0.004 + 1.0 * t - 0.15 * t * t).clamp(0.0, 1.0);
302 let b = (0.329 + 1.4 * t - 1.75 * t * t + 0.5 * t * t * t).clamp(0.0, 1.0);
303 points.push((
304 Vec2::new(x as f32, y as f32),
305 val,
306 Vec4::new(r, g, b, 1.0),
307 ));
308 }
309 }
310 points
311 }
312
313 pub fn render_multichannel(
315 activations: &[f32],
316 channels: usize,
317 height: usize,
318 width: usize,
319 ) -> Vec<(Vec2, f32, Vec4)> {
320 assert_eq!(activations.len(), channels * height * width);
321 let mut all_points = Vec::new();
322 for c in 0..channels {
323 let offset = c * height * width;
324 let channel_data = &activations[offset..offset + height * width];
325 let mut pts = Self::render_2d(channel_data, height, width);
326 let y_off = c as f32 * (height as f32 + 1.0);
328 for p in &mut pts {
329 p.0.y += y_off;
330 }
331 all_points.extend(pts);
332 }
333 all_points
334 }
335}
336
337pub struct TrainingDashboard {
341 pub log: TrainingLog,
342 pub layer_names: Vec<String>,
343}
344
345impl TrainingDashboard {
346 pub fn new(log: TrainingLog, layer_names: Vec<String>) -> Self {
347 Self { log, layer_names }
348 }
349
350 pub fn render_loss_curve(&self) -> Vec<Vec2> {
352 let n = self.log.len();
353 if n == 0 { return vec![]; }
354 let max_loss = self.log.epochs.iter().map(|e| e.loss).fold(0.0f32, f32::max).max(1e-6);
355 self.log.epochs.iter().enumerate().map(|(i, e)| {
356 Vec2::new(i as f32 / n as f32, e.loss / max_loss)
357 }).collect()
358 }
359
360 pub fn render_accuracy_curve(&self) -> Vec<Vec2> {
362 let n = self.log.len();
363 if n == 0 { return vec![]; }
364 self.log.epochs.iter().enumerate().map(|(i, e)| {
365 Vec2::new(i as f32 / n as f32, e.accuracy)
366 }).collect()
367 }
368
369 pub fn render_lr_curve(&self) -> Vec<Vec2> {
371 let n = self.log.len();
372 if n == 0 { return vec![]; }
373 let max_lr = self.log.epochs.iter().map(|e| e.lr).fold(0.0f32, f32::max).max(1e-8);
374 self.log.epochs.iter().enumerate().map(|(i, e)| {
375 Vec2::new(i as f32 / n as f32, e.lr / max_lr)
376 }).collect()
377 }
378
379 pub fn render_gradient_flow(&self) -> Vec<(usize, f32, Vec4)> {
381 if let Some(last) = self.log.epochs.last() {
382 if last.weight_norms.len() == self.layer_names.len() {
383 return GradientFlowViz::render(&self.layer_names, &last.weight_norms);
385 }
386 }
387 vec![]
388 }
389
390 pub fn summary(&self) -> String {
392 let n = self.log.len();
393 if n == 0 { return "No training data".to_string(); }
394 let last = &self.log.epochs[n - 1];
395 let best_loss = self.log.best_loss().unwrap_or(0.0);
396 let best_acc = self.log.best_accuracy().unwrap_or(0.0);
397 format!(
398 "Epoch {}/{}: loss={:.4} acc={:.2}% lr={:.6} | best_loss={:.4} best_acc={:.2}%",
399 last.epoch, n, last.loss, last.accuracy * 100.0, last.lr,
400 best_loss, best_acc * 100.0
401 )
402 }
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408
409 #[test]
410 fn test_training_log() {
411 let mut log = TrainingLog::new();
412 assert!(log.is_empty());
413 log.push(EpochStats::new(0, 2.5, 0.1, 0.01));
414 log.push(EpochStats::new(1, 1.5, 0.5, 0.01));
415 log.push(EpochStats::new(2, 0.8, 0.8, 0.005));
416 assert_eq!(log.len(), 3);
417 assert!((log.best_loss().unwrap() - 0.8).abs() < 1e-5);
418 assert!((log.best_accuracy().unwrap() - 0.8).abs() < 1e-5);
419 }
420
421 #[test]
422 fn test_loss_curve() {
423 let mut log = TrainingLog::new();
424 log.push(EpochStats::new(0, 2.0, 0.1, 0.01));
425 log.push(EpochStats::new(1, 1.0, 0.5, 0.01));
426 let curve = log.loss_curve();
427 assert_eq!(curve, vec![2.0, 1.0]);
428 }
429
430 #[test]
431 fn test_smoothed_loss() {
432 let mut log = TrainingLog::new();
433 for i in 0..10 {
434 log.push(EpochStats::new(i, 10.0 - i as f32, 0.0, 0.01));
435 }
436 let smoothed = log.smoothed_loss(0.9);
437 assert_eq!(smoothed.len(), 10);
438 assert!(smoothed[9] > log.epochs[9].loss);
440 }
441
442 #[test]
443 fn test_loss_landscape() {
444 let landscape = LossLandscape::generate_synthetic(10, 10);
445 assert_eq!(landscape.values.len(), 100);
446 assert!(landscape.min_loss() < landscape.max_loss());
447 }
448
449 #[test]
450 fn test_render_loss_landscape() {
451 let landscape = LossLandscape::generate_synthetic(5, 5);
452 let points = render_loss_landscape(&landscape);
453 assert_eq!(points.len(), 25);
454 for (pos, loss, color) in &points {
455 assert!(pos.x >= -2.0 && pos.x <= 2.0);
456 assert!(pos.y >= -2.0 && pos.y <= 2.0);
457 assert!(color.w == 1.0); let _ = loss;
459 }
460 }
461
462 #[test]
463 fn test_gradient_flow_viz() {
464 let names = vec!["dense_0".into(), "dense_1".into(), "dense_2".into()];
465 let norms = vec![0.5, 0.001, 0.3];
466 let bars = GradientFlowViz::render(&names, &norms);
467 assert_eq!(bars.len(), 3);
468 }
469
470 #[test]
471 fn test_detect_vanishing() {
472 let norms = vec![0.5, 0.001, 0.0001, 0.3];
473 let vanishing = GradientFlowViz::detect_vanishing(&norms, 0.01);
474 assert_eq!(vanishing, vec![1, 2]);
475 }
476
477 #[test]
478 fn test_detect_exploding() {
479 let norms = vec![0.5, 100.0, 0.3, 200.0];
480 let exploding = GradientFlowViz::detect_exploding(&norms, 50.0);
481 assert_eq!(exploding, vec![1, 3]);
482 }
483
484 #[test]
485 fn test_weight_histogram() {
486 let weights = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0];
487 let (centers, counts) = WeightDistViz::histogram(&weights, 5);
488 assert_eq!(centers.len(), 5);
489 assert_eq!(counts.len(), 5);
490 let total: u32 = counts.iter().sum();
491 assert_eq!(total, 10);
492 }
493
494 #[test]
495 fn test_weight_stats() {
496 let weights = vec![1.0, 2.0, 3.0, 4.0, 5.0];
497 let stats = WeightDistViz::stats(&weights);
498 assert!((stats.mean - 3.0).abs() < 1e-5);
499 assert!(stats.std > 0.0);
500 assert_eq!(stats.min, 1.0);
501 assert_eq!(stats.max, 5.0);
502 assert_eq!(stats.sparsity, 0.0);
503 }
504
505 #[test]
506 fn test_activation_map_2d() {
507 let acts = vec![0.0, 0.5, 1.0, 0.25, 0.75, 0.5, 0.1, 0.9, 0.4];
508 let points = ActivationMapViz::render_2d(&acts, 3, 3);
509 assert_eq!(points.len(), 9);
510 }
511
512 #[test]
513 fn test_activation_map_multichannel() {
514 let acts = vec![0.0; 2 * 3 * 3]; let points = ActivationMapViz::render_multichannel(&acts, 2, 3, 3);
516 assert_eq!(points.len(), 18);
517 }
518
519 #[test]
520 fn test_training_dashboard() {
521 let mut log = TrainingLog::new();
522 log.push(EpochStats { epoch: 0, loss: 2.0, accuracy: 0.1, lr: 0.01, grad_norm: 0.5, weight_norms: vec![0.5, 0.3] });
523 log.push(EpochStats { epoch: 1, loss: 1.0, accuracy: 0.5, lr: 0.01, grad_norm: 0.4, weight_norms: vec![0.4, 0.3] });
524
525 let dashboard = TrainingDashboard::new(log, vec!["dense_0".into(), "dense_1".into()]);
526 let loss_pts = dashboard.render_loss_curve();
527 assert_eq!(loss_pts.len(), 2);
528 let acc_pts = dashboard.render_accuracy_curve();
529 assert_eq!(acc_pts.len(), 2);
530 let summary = dashboard.summary();
531 assert!(summary.contains("Epoch"));
532 assert!(summary.contains("loss="));
533 }
534
535 #[test]
536 fn test_dashboard_gradient_flow() {
537 let mut log = TrainingLog::new();
538 log.push(EpochStats {
539 epoch: 0, loss: 1.0, accuracy: 0.5, lr: 0.01, grad_norm: 0.5,
540 weight_norms: vec![0.5, 0.3, 0.1],
541 });
542 let names = vec!["a".into(), "b".into(), "c".into()];
543 let dashboard = TrainingDashboard::new(log, names);
544 let flow = dashboard.render_gradient_flow();
545 assert_eq!(flow.len(), 3);
546 }
547}