opentslm 0.1.0

Rust implementation of OpenTSLM using Burn, WGPU, and llama.cpp
#!/usr/bin/env python3
"""
plot_overview.py — merge all per-stage metrics.csv files into one figure.

Usage:
    python3 plot_overview.py                     # reads figures/  writes figures/overview.png
    python3 plot_overview.py --figures figures/  # explicit figures root
    python3 plot_overview.py --out overview.png  # custom output path

Each stage gets its own colour.  Thin vertical grey lines mark stage
boundaries.  A legend in the top-left panel identifies each colour.
Accuracy and recall panels only show stages 3-5 (stages 1-2 set those
to 0.0 because the targets are free-form, not class labels).
"""

import argparse
import csv
import os
import sys

# ── argument parsing ─────────────────────────────────────────────────────────

parser = argparse.ArgumentParser(description=__doc__,
                                 formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument("--figures", default="figures",
                    help="root directory containing per-stage sub-dirs (default: figures/)")
parser.add_argument("--out", default=None,
                    help="output file path (default: <figures>/overview.png)")
args = parser.parse_args()

figures_root = args.figures
out_path     = args.out or os.path.join(figures_root, "overview.png")

# ── stage config ─────────────────────────────────────────────────────────────

STAGES = [
    ("stage1_mcq",        "S1 · MCQ"),
    ("stage2_captioning", "S2 · Caption"),
    ("stage3_cot",        "S3 · HAR CoT"),
    ("stage4_sleep_cot",  "S4 · Sleep CoT"),
    ("stage5_ecg_cot",    "S5 · ECG CoT"),
]

# tab10 first five — same palette used by the Rust plotters charts
COLORS = ["#1f77b4", "#2ca02c", "#ff7f0e", "#9467bd", "#d62728"]

# ── load CSVs ────────────────────────────────────────────────────────────────

def load_csv(stage_dir):
    path = os.path.join(stage_dir, "metrics.csv")
    if not os.path.exists(path):
        return None
    rows = []
    with open(path, newline="") as f:
        for row in csv.DictReader(f):
            rows.append({k: float(v) for k, v in row.items()})
    return rows or None

stage_data = []   # list of (label, color, rows | None)
for (stage_id, label), color in zip(STAGES, COLORS):
    rows = load_csv(os.path.join(figures_root, stage_id))
    stage_data.append((label, color, rows))

available = [(lbl, col, rows) for lbl, col, rows in stage_data if rows]
if not available:
    sys.exit(f"No metrics.csv files found under '{figures_root}'. Run training first.")

print(f"Loaded {len(available)} stage(s): {', '.join(l for l,_,_ in available)}")

# ── build global epoch axis ───────────────────────────────────────────────────
# Each stage's epochs are concatenated on a single x-axis.

global_x      = []   # list-of-lists of x positions, one per stage
boundaries    = []   # x positions of the inter-stage separator lines
offset        = 0
for _lbl, _col, rows in stage_data:
    n = len(rows) if rows else 0
    if n > 0 and offset > 0:
        boundaries.append(offset - 0.5)
    global_x.append(list(range(offset, offset + n)))
    offset += n

total_epochs = offset

# ── plot ──────────────────────────────────────────────────────────────────────

try:
    import matplotlib
    matplotlib.use("Agg")          # no display needed
    import matplotlib.pyplot as plt
    import matplotlib.ticker as ticker
except ImportError:
    sys.exit("matplotlib is required:  pip install matplotlib")

fig, axes = plt.subplots(2, 2, figsize=(14, 8))
fig.suptitle("OpenTSLM · curriculum training overview", fontsize=13, y=0.995)
fig.subplots_adjust(hspace=0.38, wspace=0.28,
                    left=0.07, right=0.98, top=0.96, bottom=0.08)

ax_loss, ax_ppl, ax_acc, ax_rec = axes.flat

def draw_boundaries(ax):
    for bx in boundaries:
        ax.axvline(bx, color="#aaaaaa", linewidth=0.8, linestyle="--", zorder=1)

def style_axis(ax, title, ylabel):
    ax.set_title(title, fontsize=10, pad=4)
    ax.set_xlabel("global epoch", fontsize=8)
    ax.set_ylabel(ylabel, fontsize=8)
    ax.tick_params(labelsize=7)
    ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
    ax.grid(axis="y", linewidth=0.4, alpha=0.5)
    ax.set_facecolor("#fafafa")
    ax.spines[["top", "right"]].set_visible(False)

# ── Panel 0: Loss ─────────────────────────────────────────────────────────────
draw_boundaries(ax_loss)
for (lbl, col, rows), xs in zip(stage_data, global_x):
    if not rows:
        continue
    tl = [r["train_loss"]  for r in rows]
    vl = [r["val_loss"]    for r in rows]
    ax_loss.plot(xs, tl, color=col, linewidth=1.0, alpha=0.35, linestyle="--")
    ax_loss.plot(xs, vl, color=col, linewidth=2.0, label=lbl, zorder=3)
    ax_loss.scatter(xs, vl, color=col, s=28, zorder=4)
ax_loss.legend(fontsize=7, framealpha=0.85, edgecolor="#cccccc",
               loc="upper right", ncol=1)
style_axis(ax_loss,
           "Loss  (val = solid · train = dashed)",
           "NLL loss")

# ── Panel 1: Perplexity ───────────────────────────────────────────────────────
draw_boundaries(ax_ppl)
for (lbl, col, rows), xs in zip(stage_data, global_x):
    if not rows:
        continue
    ppl = [r["val_perplexity"] for r in rows]
    ax_ppl.plot(xs, ppl, color=col, linewidth=2.0, zorder=3)
    ax_ppl.scatter(xs, ppl, color=col, s=28, zorder=4)
style_axis(ax_ppl, "Val Perplexity  exp(val_loss)", "perplexity")

# ── Panels 2 & 3: Accuracy / Recall (stages 1–2 are all-zero → skip) ─────────
for ax, key, title in [
    (ax_acc, "val_accuracy",     "Val Token Accuracy  (stages 1–2 n/a)"),
    (ax_rec, "val_macro_recall", "Val Macro Recall    (stages 1–2 n/a)"),
]:
    draw_boundaries(ax)
    for (lbl, col, rows), xs in zip(stage_data, global_x):
        if not rows:
            continue
        vals = [r[key] for r in rows]
        if all(v == 0.0 for v in vals):
            continue
        ax.plot(xs, vals, color=col, linewidth=2.0, zorder=3)
        ax.scatter(xs, vals, color=col, s=28, zorder=4)
    ax.set_ylim(bottom=0)
    style_axis(ax, title, key.replace("val_", "").replace("_", " "))

# ── save ──────────────────────────────────────────────────────────────────────
os.makedirs(os.path.dirname(os.path.abspath(out_path)), exist_ok=True)
fig.savefig(out_path, dpi=150, bbox_inches="tight")
print(f"Saved → {out_path}")