import argparse
import csv
import os
import sys
parser = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument("--figures", default="figures",
help="root directory containing per-stage sub-dirs (default: figures/)")
parser.add_argument("--out", default=None,
help="output file path (default: <figures>/overview.png)")
args = parser.parse_args()
figures_root = args.figures
out_path = args.out or os.path.join(figures_root, "overview.png")
STAGES = [
("stage1_mcq", "S1 · MCQ"),
("stage2_captioning", "S2 · Caption"),
("stage3_cot", "S3 · HAR CoT"),
("stage4_sleep_cot", "S4 · Sleep CoT"),
("stage5_ecg_cot", "S5 · ECG CoT"),
]
COLORS = ["#1f77b4", "#2ca02c", "#ff7f0e", "#9467bd", "#d62728"]
def load_csv(stage_dir):
path = os.path.join(stage_dir, "metrics.csv")
if not os.path.exists(path):
return None
rows = []
with open(path, newline="") as f:
for row in csv.DictReader(f):
rows.append({k: float(v) for k, v in row.items()})
return rows or None
stage_data = [] for (stage_id, label), color in zip(STAGES, COLORS):
rows = load_csv(os.path.join(figures_root, stage_id))
stage_data.append((label, color, rows))
available = [(lbl, col, rows) for lbl, col, rows in stage_data if rows]
if not available:
sys.exit(f"No metrics.csv files found under '{figures_root}'. Run training first.")
print(f"Loaded {len(available)} stage(s): {', '.join(l for l,_,_ in available)}")
global_x = [] boundaries = [] offset = 0
for _lbl, _col, rows in stage_data:
n = len(rows) if rows else 0
if n > 0 and offset > 0:
boundaries.append(offset - 0.5)
global_x.append(list(range(offset, offset + n)))
offset += n
total_epochs = offset
try:
import matplotlib
matplotlib.use("Agg") import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
except ImportError:
sys.exit("matplotlib is required: pip install matplotlib")
fig, axes = plt.subplots(2, 2, figsize=(14, 8))
fig.suptitle("OpenTSLM · curriculum training overview", fontsize=13, y=0.995)
fig.subplots_adjust(hspace=0.38, wspace=0.28,
left=0.07, right=0.98, top=0.96, bottom=0.08)
ax_loss, ax_ppl, ax_acc, ax_rec = axes.flat
def draw_boundaries(ax):
for bx in boundaries:
ax.axvline(bx, color="#aaaaaa", linewidth=0.8, linestyle="--", zorder=1)
def style_axis(ax, title, ylabel):
ax.set_title(title, fontsize=10, pad=4)
ax.set_xlabel("global epoch", fontsize=8)
ax.set_ylabel(ylabel, fontsize=8)
ax.tick_params(labelsize=7)
ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
ax.grid(axis="y", linewidth=0.4, alpha=0.5)
ax.set_facecolor("#fafafa")
ax.spines[["top", "right"]].set_visible(False)
draw_boundaries(ax_loss)
for (lbl, col, rows), xs in zip(stage_data, global_x):
if not rows:
continue
tl = [r["train_loss"] for r in rows]
vl = [r["val_loss"] for r in rows]
ax_loss.plot(xs, tl, color=col, linewidth=1.0, alpha=0.35, linestyle="--")
ax_loss.plot(xs, vl, color=col, linewidth=2.0, label=lbl, zorder=3)
ax_loss.scatter(xs, vl, color=col, s=28, zorder=4)
ax_loss.legend(fontsize=7, framealpha=0.85, edgecolor="#cccccc",
loc="upper right", ncol=1)
style_axis(ax_loss,
"Loss (val = solid · train = dashed)",
"NLL loss")
draw_boundaries(ax_ppl)
for (lbl, col, rows), xs in zip(stage_data, global_x):
if not rows:
continue
ppl = [r["val_perplexity"] for r in rows]
ax_ppl.plot(xs, ppl, color=col, linewidth=2.0, zorder=3)
ax_ppl.scatter(xs, ppl, color=col, s=28, zorder=4)
style_axis(ax_ppl, "Val Perplexity exp(val_loss)", "perplexity")
for ax, key, title in [
(ax_acc, "val_accuracy", "Val Token Accuracy (stages 1–2 n/a)"),
(ax_rec, "val_macro_recall", "Val Macro Recall (stages 1–2 n/a)"),
]:
draw_boundaries(ax)
for (lbl, col, rows), xs in zip(stage_data, global_x):
if not rows:
continue
vals = [r[key] for r in rows]
if all(v == 0.0 for v in vals):
continue
ax.plot(xs, vals, color=col, linewidth=2.0, zorder=3)
ax.scatter(xs, vals, color=col, s=28, zorder=4)
ax.set_ylim(bottom=0)
style_axis(ax, title, key.replace("val_", "").replace("_", " "))
os.makedirs(os.path.dirname(os.path.abspath(out_path)), exist_ok=True)
fig.savefig(out_path, dpi=150, bbox_inches="tight")
print(f"Saved → {out_path}")