Skip to main content

training_curves/
training_curves.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Simulate training/validation loss curves and plot them.
3//!
4//! This is a common visualization in ML: watching loss and accuracy
5//! converge over training epochs.
6
7use esoc_chart::prelude::{Figure, Result};
8use esoc_color::Color;
9
10fn main() -> Result<()> {
11    // ── Simulate training curves ─────────────────────────────────────
12    let epochs: Vec<f64> = (1..=50).map(f64::from).collect();
13
14    // Training loss: exponential decay + noise
15    let mut rng = SimpleRng::new(42);
16    let train_loss: Vec<f64> = epochs
17        .iter()
18        .map(|&e| 2.0 * (-e / 15.0).exp() + 0.05 + rng.normal() * 0.02)
19        .collect();
20
21    // Validation loss: decays slower, starts overfitting around epoch 30
22    let val_loss: Vec<f64> = epochs
23        .iter()
24        .map(|&e| {
25            let base = 2.0 * (-e / 20.0).exp() + 0.1;
26            let overfit = if e > 30.0 { (e - 30.0) * 0.005 } else { 0.0 };
27            base + overfit + rng.normal() * 0.03
28        })
29        .collect();
30
31    // Training accuracy: rises from ~50% to ~98%
32    let train_acc: Vec<f64> = epochs
33        .iter()
34        .map(|&e| (0.98 - 0.48 * (-e / 12.0).exp()).min(1.0) + rng.normal() * 0.01)
35        .collect();
36
37    // Validation accuracy: rises but plateaus earlier
38    let val_acc: Vec<f64> = epochs
39        .iter()
40        .map(|&e| {
41            let base = 0.93 - 0.43 * (-e / 18.0).exp();
42            let overfit = if e > 30.0 { -(e - 30.0) * 0.002 } else { 0.0 };
43            (base + overfit + rng.normal() * 0.015).min(1.0)
44        })
45        .collect();
46
47    // ── Plot 1: Loss curves ──────────────────────────────────────────
48    let mut fig = Figure::new()
49        .size(750.0, 500.0)
50        .title("Training & Validation Loss");
51
52    let ax = fig.add_axes();
53    ax.x_label("Epoch").y_label("Loss");
54    ax.line(&epochs, &train_loss)
55        .label("Train Loss")
56        .color(Color::from_hex("#1f77b4").unwrap().into())
57        .width(2.0)
58        .done();
59    ax.line(&epochs, &val_loss)
60        .label("Val Loss")
61        .color(Color::from_hex("#ff7f0e").unwrap().into())
62        .width(2.0)
63        .dash(&[8.0, 4.0])
64        .done();
65
66    fig.save_svg("training_loss.svg")?;
67    println!("Saved training_loss.svg");
68
69    // ── Plot 2: Accuracy curves ──────────────────────────────────────
70    let mut fig2 = Figure::new()
71        .size(750.0, 500.0)
72        .title("Training & Validation Accuracy");
73
74    let ax2 = fig2.add_axes();
75    ax2.x_label("Epoch").y_label("Accuracy").y_range(0.4, 1.05);
76    ax2.line(&epochs, &train_acc)
77        .label("Train Accuracy")
78        .color(Color::from_hex("#2ca02c").unwrap().into())
79        .width(2.0)
80        .done();
81    ax2.line(&epochs, &val_acc)
82        .label("Val Accuracy")
83        .color(Color::from_hex("#d62728").unwrap().into())
84        .width(2.0)
85        .dash(&[8.0, 4.0])
86        .done();
87
88    fig2.save_svg("training_accuracy.svg")?;
89    println!("Saved training_accuracy.svg");
90
91    Ok(())
92}
93
94struct SimpleRng(u64);
95impl SimpleRng {
96    fn new(seed: u64) -> Self {
97        Self(seed)
98    }
99    fn next_u64(&mut self) -> u64 {
100        self.0 = self
101            .0
102            .wrapping_mul(6_364_136_223_846_793_005)
103            .wrapping_add(1);
104        self.0
105    }
106    fn uniform(&mut self) -> f64 {
107        (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
108    }
109    fn normal(&mut self) -> f64 {
110        let u1 = self.uniform().max(1e-15);
111        let u2 = self.uniform();
112        (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
113    }
114}