use std::{fs, path::{Path, PathBuf}};
use anyhow::Result;
use plotters::prelude::*;
#[derive(Debug, Clone)]
pub struct EpochMetrics {
pub epoch: usize,
pub train_loss: f64,
pub val_loss: f64,
pub val_perplexity: f64,
pub val_accuracy: f64,
pub val_macro_recall: f64,
}
impl EpochMetrics {
pub fn new(
epoch: usize,
train_loss: f64,
val_loss: f64,
val_accuracy: f64,
val_macro_recall: f64,
) -> Self {
Self {
epoch,
train_loss,
val_loss,
val_perplexity: val_loss.exp(),
val_accuracy,
val_macro_recall,
}
}
}
fn epoch_x_range(epochs: &[f64]) -> (std::ops::Range<f64>, usize) {
let n = epochs.len();
let last = *epochs.last().unwrap_or(&1.0);
let pad = if last > 0.0 { 0.5 } else { 0.5 };
let range = (-pad)..(last + pad);
let n_labels = n.min(11).max(2);
(range, n_labels)
}
macro_rules! mesh {
($chart:expr, $nx:expr, $ny:expr, $xd:expr, $yd:expr) => {
$chart.configure_mesh()
.x_labels($nx)
.y_labels($ny)
.x_label_formatter(&|x: &f64| format!("{}", x.round() as i64))
.light_line_style(&WHITE)
.bold_line_style(BLACK.mix(0.10))
.x_desc($xd)
.y_desc($yd)
.draw()
.map_err(|e| anyhow::anyhow!("{e}"))?
};
}
pub struct StageMetrics {
pub stage: String,
pub history: Vec<EpochMetrics>,
}
impl StageMetrics {
pub fn new(stage: &str) -> Self {
Self { stage: stage.to_string(), history: Vec::new() }
}
pub fn from_csv(stage: &str, figures_root: &Path) -> anyhow::Result<Self> {
let path = figures_root.join(stage).join("metrics.csv");
let text = fs::read_to_string(&path)
.map_err(|e| anyhow::anyhow!("Cannot read {}: {e}", path.display()))?;
let mut history = Vec::new();
for line in text.lines().skip(1) { let cols: Vec<&str> = line.splitn(6, ',').collect();
if cols.len() < 6 { continue; }
let parse = |s: &str| s.trim().parse::<f64>().unwrap_or(0.0);
history.push(EpochMetrics {
epoch: cols[0].trim().parse::<usize>().unwrap_or(0),
train_loss: parse(cols[1]),
val_loss: parse(cols[2]),
val_perplexity: parse(cols[3]),
val_accuracy: parse(cols[4]),
val_macro_recall: parse(cols[5]),
});
}
if history.is_empty() {
anyhow::bail!(
"No data rows found in {}\n\
Run training first: cargo run --release -- train --stages {stage}",
path.display()
);
}
Ok(Self { stage: stage.to_string(), history })
}
pub fn push(&mut self, m: EpochMetrics) {
self.history.push(m);
}
pub fn save(&self, figures_root: &Path) -> Result<()> {
if self.history.is_empty() {
return Ok(());
}
let dir = figures_root.join(&self.stage);
fs::create_dir_all(&dir)?;
self.write_csv(&dir)?;
self.plot_loss(&dir)?;
self.plot_perplexity(&dir)?;
self.plot_accuracy(&dir)?;
self.plot_recall(&dir)?;
tracing::info!(
"{}: metrics saved → {}/",
self.stage,
dir.display()
);
Ok(())
}
fn write_csv(&self, dir: &Path) -> Result<()> {
use std::io::Write;
let path = dir.join("metrics.csv");
let mut f = fs::File::create(&path)?;
writeln!(f, "epoch,train_loss,val_loss,val_perplexity,val_accuracy,val_macro_recall")?;
for m in &self.history {
writeln!(
f,
"{},{:.6},{:.6},{:.4},{:.6},{:.6}",
m.epoch, m.train_loss, m.val_loss,
m.val_perplexity, m.val_accuracy, m.val_macro_recall
)?;
}
Ok(())
}
fn plot_loss(&self, dir: &Path) -> Result<()> {
let path = dir.join("loss_curve.svg");
let root = SVGBackend::new(&path, (900, 500)).into_drawing_area();
root.fill(&WHITE)?;
let epochs: Vec<f64> = self.history.iter().map(|m| m.epoch as f64).collect();
let train_vals: Vec<f64> = self.history.iter().map(|m| m.train_loss).collect();
let val_vals: Vec<f64> = self.history.iter().map(|m| m.val_loss).collect();
let (x_range, n_xlabels) = epoch_x_range(&epochs);
let y_max = train_vals.iter().chain(val_vals.iter())
.cloned().fold(f64::NEG_INFINITY, f64::max).max(0.1) * 1.08;
let y_min = val_vals.iter() .cloned().fold(f64::INFINITY, f64::min)
.min(0.0).max(y_max - 10.0);
let mut chart = ChartBuilder::on(&root)
.caption(format!("{} — Loss", self.stage), ("sans-serif", 22))
.margin(24).x_label_area_size(44).y_label_area_size(64)
.build_cartesian_2d(x_range, y_min..y_max)?;
mesh!(chart, n_xlabels, 6, "Epoch", "NLL Loss");
chart.draw_series(LineSeries::new(
epochs.iter().zip(train_vals.iter()).map(|(&x, &y)| (x, y)),
ShapeStyle { color: BLUE.to_rgba(), filled: false, stroke_width: 2 },
))?.label("Train loss")
.legend(|(x, y)| PathElement::new(vec![(x, y), (x+22, y)],
ShapeStyle { color: BLUE.to_rgba(), filled: false, stroke_width: 2 }));
chart.draw_series(LineSeries::new(
epochs.iter().zip(val_vals.iter()).map(|(&x, &y)| (x, y)),
ShapeStyle { color: RED.to_rgba(), filled: false, stroke_width: 2 },
))?.label("Val loss")
.legend(|(x, y)| PathElement::new(vec![(x, y), (x+22, y)],
ShapeStyle { color: RED.to_rgba(), filled: false, stroke_width: 2 }));
chart.draw_series(epochs.iter().zip(train_vals.iter())
.map(|(&x, &y)| Circle::new((x, y), 5, BLUE.filled())))?;
chart.draw_series(epochs.iter().zip(val_vals.iter())
.map(|(&x, &y)| Circle::new((x, y), 5, RED.filled())))?;
chart.configure_series_labels()
.background_style(&WHITE.mix(0.9)).border_style(&BLACK.mix(0.4))
.draw()?;
root.present()?;
Ok(())
}
fn plot_perplexity(&self, dir: &Path) -> Result<()> {
let path = dir.join("perplexity.svg");
let root = SVGBackend::new(&path, (900, 500)).into_drawing_area();
root.fill(&WHITE)?;
let epochs: Vec<f64> = self.history.iter().map(|m| m.epoch as f64).collect();
let vals: Vec<f64> = self.history.iter().map(|m| m.val_perplexity).collect();
let (x_range, n_xlabels) = epoch_x_range(&epochs);
let y_max = vals.iter().cloned().fold(f64::NEG_INFINITY, f64::max).max(1.1) * 1.08;
let y_min = vals.iter().cloned().fold(f64::INFINITY, f64::min)
.max(0.0).max(y_max - 20.0);
let c = full_palette::PURPLE;
let mut chart = ChartBuilder::on(&root)
.caption(format!("{} — Perplexity", self.stage), ("sans-serif", 22))
.margin(24).x_label_area_size(44).y_label_area_size(72)
.build_cartesian_2d(x_range, y_min..y_max)?;
mesh!(chart, n_xlabels, 6, "Epoch", "Perplexity exp(val_loss)");
chart.draw_series(LineSeries::new(
epochs.iter().zip(vals.iter()).map(|(&x, &y)| (x, y)),
ShapeStyle { color: c.to_rgba(), filled: false, stroke_width: 2 },
))?.label("Val perplexity")
.legend(move |(x, y)| PathElement::new(vec![(x, y), (x+22, y)],
ShapeStyle { color: c.to_rgba(), filled: false, stroke_width: 2 }));
chart.draw_series(epochs.iter().zip(vals.iter())
.map(|(&x, &y)| Circle::new((x, y), 5, c.filled())))?;
chart.configure_series_labels()
.background_style(&WHITE.mix(0.9)).border_style(&BLACK.mix(0.4))
.draw()?;
root.present()?;
Ok(())
}
fn plot_accuracy(&self, dir: &Path) -> Result<()> {
let path = dir.join("accuracy.svg");
let root = SVGBackend::new(&path, (900, 500)).into_drawing_area();
root.fill(&WHITE)?;
let epochs: Vec<f64> = self.history.iter().map(|m| m.epoch as f64).collect();
let vals: Vec<f64> = self.history.iter().map(|m| m.val_accuracy).collect();
let (x_range, n_xlabels) = epoch_x_range(&epochs);
let y_max = vals.iter().cloned().fold(0.0f64, f64::max).max(0.05) * 1.2;
let y_max = y_max.min(1.0);
let c = full_palette::TEAL_700;
let mut chart = ChartBuilder::on(&root)
.caption(format!("{} — Token Accuracy", self.stage), ("sans-serif", 22))
.margin(24).x_label_area_size(44).y_label_area_size(64)
.build_cartesian_2d(x_range, 0f64..y_max)?;
mesh!(chart, n_xlabels, 5, "Epoch", "Token Accuracy");
chart.draw_series(AreaSeries::new(
epochs.iter().zip(vals.iter()).map(|(&x, &y)| (x, y)),
0.0, full_palette::TEAL_400.mix(0.25),
))?;
chart.draw_series(LineSeries::new(
epochs.iter().zip(vals.iter()).map(|(&x, &y)| (x, y)),
ShapeStyle { color: c.to_rgba(), filled: false, stroke_width: 2 },
))?.label("Val token accuracy")
.legend(move |(x, y)| PathElement::new(vec![(x, y), (x+22, y)],
ShapeStyle { color: c.to_rgba(), filled: false, stroke_width: 2 }));
chart.draw_series(epochs.iter().zip(vals.iter())
.map(|(&x, &y)| Circle::new((x, y), 5, c.filled())))?;
chart.configure_series_labels()
.background_style(&WHITE.mix(0.9)).border_style(&BLACK.mix(0.4))
.draw()?;
root.present()?;
Ok(())
}
fn plot_recall(&self, dir: &Path) -> Result<()> {
let path = dir.join("macro_recall.svg");
let root = SVGBackend::new(&path, (900, 500)).into_drawing_area();
root.fill(&WHITE)?;
let epochs: Vec<f64> = self.history.iter().map(|m| m.epoch as f64).collect();
let vals: Vec<f64> = self.history.iter().map(|m| m.val_macro_recall).collect();
let (x_range, n_xlabels) = epoch_x_range(&epochs);
let y_max = vals.iter().cloned().fold(0.0f64, f64::max).max(0.05) * 1.2;
let y_max = y_max.min(1.0);
let c = full_palette::ORANGE_700;
let mut chart = ChartBuilder::on(&root)
.caption(format!("{} — Macro Recall", self.stage), ("sans-serif", 22))
.margin(24).x_label_area_size(44).y_label_area_size(64)
.build_cartesian_2d(x_range, 0f64..y_max)?;
mesh!(chart, n_xlabels, 5, "Epoch", "Macro-averaged Recall");
chart.draw_series(AreaSeries::new(
epochs.iter().zip(vals.iter()).map(|(&x, &y)| (x, y)),
0.0, full_palette::ORANGE_400.mix(0.25),
))?;
chart.draw_series(LineSeries::new(
epochs.iter().zip(vals.iter()).map(|(&x, &y)| (x, y)),
ShapeStyle { color: c.to_rgba(), filled: false, stroke_width: 2 },
))?.label("Val macro recall")
.legend(move |(x, y)| PathElement::new(vec![(x, y), (x+22, y)],
ShapeStyle { color: c.to_rgba(), filled: false, stroke_width: 2 }));
chart.draw_series(epochs.iter().zip(vals.iter())
.map(|(&x, &y)| Circle::new((x, y), 5, c.filled())))?;
chart.configure_series_labels()
.background_style(&WHITE.mix(0.9)).border_style(&BLACK.mix(0.4))
.draw()?;
root.present()?;
Ok(())
}
}
pub fn write_html_index(metrics: &StageMetrics, figures_root: &Path) -> Result<()> {
if metrics.history.is_empty() {
return Ok(());
}
let dir = figures_root.join(&metrics.stage);
fs::create_dir_all(&dir)?;
let load = |name: &str| -> String {
fs::read_to_string(dir.join(name)).unwrap_or_default()
};
let loss_svg = load("loss_curve.svg");
let ppl_svg = load("perplexity.svg");
let acc_svg = load("accuracy.svg");
let recall_svg = load("macro_recall.svg");
let summary = metrics.history.last().map(|m| format!(
"epoch {} | \
train loss <b>{:.4}</b> | \
val loss <b>{:.4}</b> | \
perplexity <b>{:.2}</b> | \
accuracy <b>{:.2}%</b> | \
macro recall <b>{:.2}%</b>",
m.epoch,
m.train_loss, m.val_loss, m.val_perplexity,
m.val_accuracy * 100.0, m.val_macro_recall * 100.0,
)).unwrap_or_default();
let html = format!(
r#"<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<title>{stage} — training metrics</title>
<style>
body {{ font-family: system-ui, sans-serif; background: #f8f8f8; margin: 0; padding: 16px; }}
h1 {{ font-size: 1.2rem; margin: 0 0 4px; }}
.sub {{ font-size: .85rem; color: #555; margin: 0 0 16px; }}
.grid {{ display: grid; grid-template-columns: 1fr 1fr; gap: 12px; }}
.cell {{ background: #fff; border: 1px solid #ddd; border-radius: 6px;
padding: 8px; box-sizing: border-box; }}
.cell svg {{ width: 100%; height: auto; display: block; }}
table {{ border-collapse: collapse; font-size: .8rem; margin-top: 16px;
background: #fff; border: 1px solid #ddd; border-radius: 6px; overflow: hidden; }}
th, td {{ padding: 5px 12px; border-bottom: 1px solid #eee; text-align: right; }}
th {{ background: #f0f0f0; text-align: center; }}
tr:last-child td {{ border-bottom: none; }}
</style>
</head>
<body>
<h1>{stage}</h1>
<p class="sub">{summary}</p>
<div class="grid">
<div class="cell">{loss_svg}</div>
<div class="cell">{ppl_svg}</div>
<div class="cell">{acc_svg}</div>
<div class="cell">{recall_svg}</div>
</div>
<table>
<thead>
<tr>
<th>Epoch</th><th>Train loss</th><th>Val loss</th>
<th>Perplexity</th><th>Accuracy</th><th>Macro recall</th>
</tr>
</thead>
<tbody>
{rows}
</tbody>
</table>
</body>
</html>
"#,
stage = metrics.stage,
summary = summary,
loss_svg = loss_svg,
ppl_svg = ppl_svg,
acc_svg = acc_svg,
recall_svg = recall_svg,
rows = metrics.history.iter().map(|m| format!(
" <tr><td>{}</td><td>{:.6}</td><td>{:.6}</td>\
<td>{:.3}</td><td>{:.2}%</td><td>{:.2}%</td></tr>",
m.epoch, m.train_loss, m.val_loss,
m.val_perplexity,
m.val_accuracy * 100.0, m.val_macro_recall * 100.0,
)).collect::<Vec<_>>().join("\n"),
);
fs::write(dir.join("index.html"), html)?;
Ok(())
}
pub fn plot_curriculum_overview(
all_stages: &[&StageMetrics],
figures_root: &Path,
) -> Result<()> {
if all_stages.iter().all(|s| s.history.is_empty()) {
return Ok(());
}
let stage_colors: [RGBColor; 5] = [
RGBColor( 31, 119, 180), RGBColor( 44, 160, 44), RGBColor(255, 127, 14), RGBColor(148, 103, 189), RGBColor(214, 39, 40), ];
let short_labels = ["S1·MCQ", "S2·Cap.", "S3·HAR", "S4·Sleep", "S5·ECG"];
let mut g_xs: Vec<Vec<f64>> = Vec::new();
let mut boundaries: Vec<f64> = Vec::new();
let mut offset = 0usize;
for (i, stage) in all_stages.iter().enumerate() {
let n = stage.history.len();
let xs: Vec<f64> = (0..n).map(|j| (offset + j) as f64).collect();
if n > 0 && i > 0 { boundaries.push(offset as f64 - 0.5); }
g_xs.push(xs);
offset += n;
}
let x_lo = -0.5f64;
let x_hi = offset as f64 - 0.5;
let train_loss: Vec<Vec<f64>> = all_stages.iter()
.map(|s| s.history.iter().map(|m| m.train_loss).collect()).collect();
let val_loss: Vec<Vec<f64>> = all_stages.iter()
.map(|s| s.history.iter().map(|m| m.val_loss).collect()).collect();
let perplexity: Vec<Vec<f64>> = all_stages.iter()
.map(|s| s.history.iter().map(|m| m.val_perplexity).collect()).collect();
let accuracy: Vec<Vec<f64>> = all_stages.iter()
.map(|s| s.history.iter().map(|m| m.val_accuracy).collect()).collect();
let recall: Vec<Vec<f64>> = all_stages.iter()
.map(|s| s.history.iter().map(|m| m.val_macro_recall).collect()).collect();
let path = figures_root.join("curriculum_overview.svg");
let root = SVGBackend::new(&path, (1100, 750)).into_drawing_area();
root.fill(&WHITE)?;
let panels = root.margin(8, 8, 8, 8).split_evenly((2, 2));
macro_rules! draw_boundaries {
($chart:expr, $y_lo:expr, $y_hi:expr) => {
for &bx in &boundaries {
$chart.draw_series(std::iter::once(PathElement::new(
vec![(bx, $y_lo), (bx, $y_hi)],
ShapeStyle { color: BLACK.mix(0.20), filled: false, stroke_width: 1 },
)))?;
}
};
}
{
let flat: Vec<f64> = val_loss.iter().chain(train_loss.iter())
.flat_map(|v| v.iter().copied()).collect();
let y_hi = flat.iter().cloned().fold(0.0f64, f64::max).max(0.1) * 1.08;
let y_lo = flat.iter().cloned().fold(f64::INFINITY, f64::min)
.min(0.0).max(y_hi - 10.0);
let mut chart = ChartBuilder::on(&panels[0])
.caption("Loss (val = solid · train = faint)", ("sans-serif", 12))
.margin(10).x_label_area_size(28).y_label_area_size(54)
.build_cartesian_2d(x_lo..x_hi, y_lo..y_hi)?;
chart.configure_mesh().x_labels(5).y_labels(5)
.light_line_style(&WHITE).bold_line_style(BLACK.mix(0.08))
.x_desc("global epoch").y_desc("NLL loss").draw()?;
draw_boundaries!(chart, y_lo, y_hi);
for (i, ((xs, vl), tl)) in g_xs.iter()
.zip(val_loss.iter()).zip(train_loss.iter()).enumerate()
{
if xs.is_empty() { continue; }
let c = stage_colors.get(i).copied().unwrap_or(BLACK);
let lbl = short_labels.get(i).copied().unwrap_or("");
chart.draw_series(LineSeries::new(
xs.iter().zip(tl.iter()).map(|(&x, &y)| (x, y)),
ShapeStyle { color: c.mix(0.38), filled: false, stroke_width: 1 },
))?;
chart.draw_series(LineSeries::new(
xs.iter().zip(vl.iter()).map(|(&x, &y)| (x, y)),
ShapeStyle { color: c.to_rgba(), filled: false, stroke_width: 2 },
))?.label(lbl)
.legend(move |(lx, ly)| PathElement::new(
vec![(lx, ly), (lx + 18, ly)],
ShapeStyle { color: c.to_rgba(), filled: false, stroke_width: 2 },
));
chart.draw_series(
xs.iter().zip(vl.iter()).map(|(&x, &y)| Circle::new((x, y), 4, c.filled()))
)?;
}
chart.configure_series_labels()
.background_style(&WHITE.mix(0.85)).border_style(&BLACK.mix(0.3))
.draw()?;
}
{
let flat: Vec<f64> = perplexity.iter().flat_map(|v| v.iter().copied()).collect();
let y_hi = flat.iter().cloned().fold(0.0f64, f64::max).max(1.1) * 1.08;
let y_lo = flat.iter().cloned().fold(f64::INFINITY, f64::min)
.max(0.0).max(y_hi - 20.0);
let mut chart = ChartBuilder::on(&panels[1])
.caption("Val Perplexity exp(val_loss)", ("sans-serif", 12))
.margin(10).x_label_area_size(28).y_label_area_size(62)
.build_cartesian_2d(x_lo..x_hi, y_lo..y_hi)?;
chart.configure_mesh().x_labels(5).y_labels(5)
.light_line_style(&WHITE).bold_line_style(BLACK.mix(0.08))
.x_desc("global epoch").y_desc("perplexity").draw()?;
draw_boundaries!(chart, y_lo, y_hi);
for (i, (xs, vals)) in g_xs.iter().zip(perplexity.iter()).enumerate() {
if xs.is_empty() { continue; }
let c = stage_colors.get(i).copied().unwrap_or(BLACK);
chart.draw_series(LineSeries::new(
xs.iter().zip(vals.iter()).map(|(&x, &y)| (x, y)),
ShapeStyle { color: c.to_rgba(), filled: false, stroke_width: 2 },
))?;
chart.draw_series(
xs.iter().zip(vals.iter()).map(|(&x, &y)| Circle::new((x, y), 4, c.filled()))
)?;
}
}
{
let flat: Vec<f64> = accuracy.iter().flat_map(|v| v.iter().copied()).collect();
let y_hi = (flat.iter().cloned().fold(0.0f64, f64::max) * 1.20)
.min(1.0).max(0.05);
let mut chart = ChartBuilder::on(&panels[2])
.caption("Val Token Accuracy (stages 1–2 not tracked)", ("sans-serif", 12))
.margin(10).x_label_area_size(28).y_label_area_size(54)
.build_cartesian_2d(x_lo..x_hi, 0.0f64..y_hi)?;
chart.configure_mesh().x_labels(5).y_labels(5)
.light_line_style(&WHITE).bold_line_style(BLACK.mix(0.08))
.x_desc("global epoch").y_desc("accuracy [0–1]").draw()?;
draw_boundaries!(chart, 0.0, y_hi);
for (i, (xs, vals)) in g_xs.iter().zip(accuracy.iter()).enumerate() {
if xs.is_empty() || vals.iter().all(|&v| v == 0.0) { continue; }
let c = stage_colors.get(i).copied().unwrap_or(BLACK);
chart.draw_series(LineSeries::new(
xs.iter().zip(vals.iter()).map(|(&x, &y)| (x, y)),
ShapeStyle { color: c.to_rgba(), filled: false, stroke_width: 2 },
))?;
chart.draw_series(
xs.iter().zip(vals.iter()).map(|(&x, &y)| Circle::new((x, y), 4, c.filled()))
)?;
}
}
{
let flat: Vec<f64> = recall.iter().flat_map(|v| v.iter().copied()).collect();
let y_hi = (flat.iter().cloned().fold(0.0f64, f64::max) * 1.20)
.min(1.0).max(0.05);
let mut chart = ChartBuilder::on(&panels[3])
.caption("Val Macro Recall (stages 1–2 not tracked)", ("sans-serif", 12))
.margin(10).x_label_area_size(28).y_label_area_size(54)
.build_cartesian_2d(x_lo..x_hi, 0.0f64..y_hi)?;
chart.configure_mesh().x_labels(5).y_labels(5)
.light_line_style(&WHITE).bold_line_style(BLACK.mix(0.08))
.x_desc("global epoch").y_desc("macro recall [0–1]").draw()?;
draw_boundaries!(chart, 0.0, y_hi);
for (i, (xs, vals)) in g_xs.iter().zip(recall.iter()).enumerate() {
if xs.is_empty() || vals.iter().all(|&v| v == 0.0) { continue; }
let c = stage_colors.get(i).copied().unwrap_or(BLACK);
chart.draw_series(LineSeries::new(
xs.iter().zip(vals.iter()).map(|(&x, &y)| (x, y)),
ShapeStyle { color: c.to_rgba(), filled: false, stroke_width: 2 },
))?;
chart.draw_series(
xs.iter().zip(vals.iter()).map(|(&x, &y)| Circle::new((x, y), 4, c.filled()))
)?;
}
}
root.present()?;
tracing::info!("curriculum overview → {}", path.display());
Ok(())
}