import json
import subprocess
import sys
from datetime import datetime
from pathlib import Path
try:
import matplotlib
matplotlib.use("Agg") import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
HAS_MATPLOTLIB = True
except ImportError:
HAS_MATPLOTLIB = False
SCRIPT_DIR = Path(__file__).parent.resolve()
ROOT = SCRIPT_DIR.parent
CHARTS_DIR = ROOT / "bench-results" / "charts"
_PALETTE: dict[str, str] = {
"fft": "#2563eb", "ifft": "#059669", "roundtrip": "#dc2626",
"fft_batch_batch_size": "#7c3aed", "fft_batch_signal_len": "#a855f7",
"ifft_batch_batch_size": "#0891b2", "ifft_batch_signal_len": "#14b8a6",
"roundtrip_batch": "#d97706", "roundtrip_batch_signal_len": "#ea580c",
"fft_radix4_outer": "#1d4ed8", "ifft_radix4_outer": "#047857", "roundtrip_radix4_outer": "#b91c1c",
"fft_batch_radix4_outer": "#4338ca", "ifft_batch_radix4_outer": "#0e7490", "roundtrip_batch_radix4_outer": "#b45309", }
_VS_PALETTE: dict[str, dict[str, str]] = {
"fft_batch_vs_sequential": {
"batch": "#7c3aed", "sequential": "#93c5fd", },
"ifft_batch_vs_sequential": {
"batch": "#0891b2", "sequential": "#6ee7b7", },
}
_MARKERS: dict[str, str] = {
"fft": "o", "ifft": "s", "roundtrip": "^", "fft_batch_batch_size": "D", "fft_batch_signal_len": "D",
"ifft_batch_batch_size": "P", "ifft_batch_signal_len": "P",
"roundtrip_batch": "v", "roundtrip_batch_signal_len": "v",
"fft_radix4_outer": "o",
"ifft_radix4_outer": "s",
"roundtrip_radix4_outer": "^",
"fft_batch_radix4_outer": "D",
"ifft_batch_radix4_outer": "P",
"roundtrip_batch_radix4_outer": "v",
}
_VS_MARKERS: dict[str, str] = {"batch": "D", "sequential": "o"}
_VS_LS: dict[str, str] = {"batch": "-", "sequential": "--"}
_FALLBACK_COLOR = "#6b7280" _FALLBACK_MARKER = "o"
_LABELS: dict[str, str] = {
"fft": "fft",
"ifft": "ifft",
"roundtrip": "roundtrip",
"fft_batch_batch_size": "fft_batch (sweep batch)",
"fft_batch_signal_len": "fft_batch (batch=16)",
"ifft_batch_batch_size": "ifft_batch (sweep batch)",
"ifft_batch_signal_len": "ifft_batch (batch=16)",
"roundtrip_batch": "roundtrip_batch (sweep batch)",
"roundtrip_batch_signal_len": "roundtrip_batch (batch=16)",
"fft_radix4_outer": "fft (radix-4 outer, scalar)",
"ifft_radix4_outer": "ifft (radix-4 outer, scalar)",
"roundtrip_radix4_outer": "roundtrip (radix-4 outer, scalar)",
"fft_batch_radix4_outer": "fft_batch (radix-4 outer, batch=16)",
"ifft_batch_radix4_outer": "ifft_batch (radix-4 outer, batch=16)",
"roundtrip_batch_radix4_outer": "roundtrip_batch (radix-4 outer, batch=16)",
}
_SIGNAL_LEN_GROUPS = frozenset({
"fft", "ifft", "roundtrip",
"fft_batch_signal_len",
"ifft_batch_signal_len",
"roundtrip_batch_signal_len",
"fft_radix4_outer",
"ifft_radix4_outer",
"roundtrip_radix4_outer",
"fft_batch_radix4_outer",
"ifft_batch_radix4_outer",
"roundtrip_batch_radix4_outer",
})
_BATCH_SIZE_GROUPS = frozenset({
"fft_batch_batch_size",
"ifft_batch_batch_size",
"roundtrip_batch",
})
_VS_GROUPS = frozenset({
"fft_batch_vs_sequential",
"ifft_batch_vs_sequential",
})
def _fmt_time(ns: float) -> str:
if ns < 1_000:
return f"{ns:.2f} ns"
if ns < 1_000_000:
return f"{ns / 1_000:.2f} µs"
if ns < 1_000_000_000:
return f"{ns / 1_000_000:.2f} ms"
return f"{ns / 1_000_000_000:.2f} s"
def _fmt_throughput(elem_per_s: float) -> str:
if elem_per_s < 1_000:
return f"{elem_per_s:.2f} elem/s"
if elem_per_s < 1_000_000:
return f"{elem_per_s / 1_000:.2f} Kelem/s"
if elem_per_s < 1_000_000_000:
return f"{elem_per_s / 1_000_000:.2f} Melem/s"
return f"{elem_per_s / 1_000_000_000:.2f} Gelem/s"
def collect_results(criterion_dir: Path) -> list[dict]:
rows: list[dict] = []
_walk(criterion_dir, criterion_dir, rows)
return rows
def _walk(root: Path, node: Path, rows: list[dict]) -> None:
if (node / "new" / "estimates.json").exists():
_parse_leaf(root, node, rows)
return
for child in sorted(node.iterdir()):
if child.is_dir() and child.name not in ("report", "new"):
_walk(root, child, rows)
def _parse_leaf(root: Path, node: Path, rows: list[dict]) -> None:
est_file = node / "new" / "estimates.json"
bm_file = node / "new" / "benchmark.json"
if not (est_file.exists() and bm_file.exists()):
return
rel = node.relative_to(root)
parts = rel.parts depth = len(parts)
if depth < 2:
return
if depth == 2:
group, raw_param = parts[0], parts[1]
sub_series: str | None = None
elif depth == 3:
group, sub_series, raw_param = parts[0], parts[1], parts[2]
else:
group, raw_param, sub_series = "/".join(parts[:-1]), parts[-1], None
try:
n = int(raw_param)
except ValueError:
n = 0
est = json.loads(est_file.read_text())
bm = json.loads(bm_file.read_text())
mean_ns = est["mean"]["point_estimate"]
lo_ns = est["mean"]["confidence_interval"]["lower_bound"]
hi_ns = est["mean"]["confidence_interval"]["upper_bound"]
std_ns = est["std_dev"]["point_estimate"]
throughput_str = ""
throughput_mels = 0.0
tp = bm.get("throughput")
if tp and "Elements" in tp:
elem_per_s = tp["Elements"] / (mean_ns / 1e9)
throughput_str = _fmt_throughput(elem_per_s)
throughput_mels = elem_per_s / 1e6
rows.append(dict(
group=group, raw_param=raw_param, sub_series=sub_series, n=n,
mean=mean_ns, lo=lo_ns, hi=hi_ns, std=std_ns,
throughput=throughput_str,
throughput_mels=throughput_mels,
))
def _apply_style() -> None:
plt.rcParams.update({
"figure.facecolor": "white",
"axes.facecolor": "white",
"axes.edgecolor": "#d1d5db",
"axes.linewidth": 0.8,
"grid.color": "#e5e7eb",
"grid.linewidth": 0.7,
"grid.linestyle": "--",
"font.family": "sans-serif",
"font.size": 11,
"axes.titlesize": 13,
"axes.titleweight": "bold",
"axes.labelsize": 11,
"axes.labelcolor": "#374151",
"xtick.color": "#6b7280",
"ytick.color": "#6b7280",
"legend.fontsize": 9,
"legend.framealpha": 0.95,
"legend.edgecolor": "#d1d5db",
"lines.linewidth": 2.0,
"lines.markersize": 6,
})
def _x_formatter(x, _):
return f"{int(x):,}"
def _build_series(rows: list[dict], groups: list[str]) -> dict[str, dict]:
series: dict[str, dict] = {}
for r in rows:
g = r["group"]
if g not in groups or r["sub_series"] is not None:
continue
if g not in series:
series[g] = {"n": [], "mean_us": [], "lo_us": [], "hi_us": [],
"throughput_mels": []}
series[g]["n"].append(r["n"])
series[g]["mean_us"].append(r["mean"] / 1_000)
series[g]["lo_us"].append(r["lo"] / 1_000)
series[g]["hi_us"].append(r["hi"] / 1_000)
series[g]["throughput_mels"].append(r["throughput_mels"])
for d in series.values():
order = sorted(range(len(d["n"])), key=lambda i: d["n"][i])
for k in d:
d[k] = [d[k][i] for i in order]
return series
def _line_chart(
series: dict[str, dict],
filename: str,
title: str,
xlabel: str,
ylabel: str,
metric: str, xscale: str = "log2",
scalar_groups: frozenset[str] = frozenset(),
) -> Path:
fig, ax = plt.subplots(figsize=(8, 4.5))
for gname in sorted(series):
d = series[gname]
color = _PALETTE.get(gname, _FALLBACK_COLOR)
marker = _MARKERS.get(gname, _FALLBACK_MARKER)
label = _LABELS.get(gname, gname)
ls = "-" if (not scalar_groups or gname in scalar_groups) else "--"
ax.plot(d["n"], d[metric], marker=marker, color=color,
label=label, linestyle=ls, zorder=3)
if metric == "mean_us":
ax.fill_between(d["n"], d["lo_us"], d["hi_us"],
alpha=0.12, color=color, zorder=2)
if xscale == "log2":
ax.set_xscale("log", base=2)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_title(title)
ax.xaxis.set_major_formatter(ticker.FuncFormatter(_x_formatter))
if metric == "mean_us":
ax.yaxis.set_major_formatter(ticker.FuncFormatter(
lambda x, _: f"{x:.0f}"
))
ax.set_ylim(bottom=0)
ax.grid(True, which="both")
ax.legend()
fig.tight_layout()
path = CHARTS_DIR / filename
fig.savefig(path, format="svg", bbox_inches="tight")
plt.close(fig)
return path
def _vs_chart(rows: list[dict], group_names: list[str]) -> Path:
n_plots = len(group_names)
fig, axes = plt.subplots(1, n_plots, figsize=(6.5 * n_plots, 4.5),
sharey=False)
if n_plots == 1:
axes = [axes]
for ax, gname in zip(axes, sorted(group_names)):
sub: dict[str, list[dict]] = {}
for r in rows:
if r["group"] != gname or r["sub_series"] is None:
continue
sub.setdefault(r["sub_series"], []).append(r)
vs_colors = _VS_PALETTE.get(gname, {})
for series_name in sorted(sub):
sr = sorted(sub[series_name], key=lambda r: r["n"])
xs = [r["n"] for r in sr]
ys = [r["throughput_mels"] for r in sr]
color = vs_colors.get(series_name, _FALLBACK_COLOR)
marker = _VS_MARKERS.get(series_name, _FALLBACK_MARKER)
ls = _VS_LS.get(series_name, "-")
ax.plot(xs, ys, marker=marker, color=color,
label=series_name, linestyle=ls, linewidth=2, zorder=3)
ax.set_xscale("log", base=2)
ax.set_xlabel("Batch size")
ax.set_ylabel("Throughput (Melem/s)")
transform = gname.replace("_batch_vs_sequential", "").upper()
ax.set_title(f"{transform}: batch vs sequential")
ax.xaxis.set_major_formatter(ticker.FuncFormatter(_x_formatter))
ax.set_ylim(bottom=0)
ax.grid(True, which="both")
ax.legend()
fig.suptitle("Batch vs Sequential Throughput", fontsize=13,
fontweight="bold", y=1.02)
fig.tight_layout()
path = CHARTS_DIR / "vs_sequential.svg"
fig.savefig(path, format="svg", bbox_inches="tight")
plt.close(fig)
return path
def generate_charts(rows: list[dict]) -> dict[str, Path]:
CHARTS_DIR.mkdir(parents=True, exist_ok=True)
_apply_style()
present_groups = {r["group"] for r in rows}
paths: dict[str, Path] = {}
scalar_names = sorted(_SIGNAL_LEN_GROUPS & {"fft", "ifft", "roundtrip"}
& present_groups)
scalar_series = _build_series(rows, scalar_names)
if scalar_series:
paths["latency"] = _line_chart(
scalar_series,
"latency.svg",
"GPU FFT/IFFT — Latency vs Signal Length",
"Signal length N", "Latency (µs)",
metric="mean_us",
)
paths["throughput"] = _line_chart(
scalar_series,
"throughput.svg",
"GPU FFT/IFFT — Throughput vs Signal Length",
"Signal length N", "Throughput (Melem/s)",
metric="throughput_mels",
)
sig_names = sorted(_SIGNAL_LEN_GROUPS & present_groups)
sig_series = _build_series(rows, sig_names)
batch_sig_present = any(g not in {"fft", "ifft", "roundtrip"}
for g in sig_series)
if batch_sig_present:
paths["batch_signal"] = _line_chart(
sig_series,
"batch_signal.svg",
"Throughput vs Signal Length — scalar vs batch×16",
"Signal length N", "Throughput (Melem/s)",
metric="throughput_mels",
scalar_groups=frozenset({"fft", "ifft", "roundtrip"}),
)
bs_names = sorted(_BATCH_SIZE_GROUPS & present_groups)
bs_series = _build_series(rows, bs_names)
if bs_series:
paths["batch_size"] = _line_chart(
bs_series,
"batch_size.svg",
"Batch FFT/IFFT — Throughput vs Batch Size",
"Batch size", "Throughput (Melem/s)",
metric="throughput_mels",
)
vs_names = sorted(_VS_GROUPS & present_groups)
if vs_names:
paths["vs_sequential"] = _vs_chart(rows, vs_names)
r4_scalar_names = sorted(
{"fft_radix4_outer", "ifft_radix4_outer", "roundtrip_radix4_outer"}
& present_groups
)
r4_scalar_series = _build_series(rows, r4_scalar_names)
if r4_scalar_series:
paths["radix4_outer"] = _line_chart(
r4_scalar_series,
"radix4_outer.svg",
"Radix-4 Outer Stages — Scalar Throughput vs N",
"Signal length N", "Throughput (Melem/s)",
metric="throughput_mels",
)
r4_batch_names = sorted(
{"fft_batch_radix4_outer", "ifft_batch_radix4_outer",
"roundtrip_batch_radix4_outer"}
& present_groups
)
r4_batch_series = _build_series(rows, r4_batch_names)
if r4_batch_series:
paths["radix4_batch_outer"] = _line_chart(
r4_batch_series,
"radix4_batch_outer.svg",
"Radix-4 Outer Stages — Batch Throughput vs N (batch=16)",
"Signal length N", "Throughput (Melem/s)",
metric="throughput_mels",
)
return paths
def _git(args: list[str]) -> str:
try:
return subprocess.check_output(
["git"] + args, text=True, stderr=subprocess.DEVNULL
).strip()
except Exception:
return "unknown"
_CHART_META: dict[str, tuple[str, str]] = {
"latency": ("Scalar baselines", "Latency vs N"),
"throughput": ("Scalar baselines", "Throughput vs N"),
"batch_signal": ("Batch vs scalar (batch=16)", "Throughput vs N"),
"batch_size": ("Batch size sweep (N=4 096 fixed)","Throughput vs batch size"),
"vs_sequential": ("Batch vs sequential", "Batch vs sequential throughput"),
"radix4_outer": ("Radix-4 outer stages — scalar", "Throughput vs N (outer-stage sizes)"),
"radix4_batch_outer": ("Radix-4 outer stages — batch", "Throughput vs N (outer-stage sizes, batch=16)"),
}
_TABLE_SECTIONS: list[tuple[str, frozenset[str]]] = [
("Scalar", frozenset({"fft", "ifft", "roundtrip"})),
("Batch FFT", frozenset({"fft_batch_batch_size", "fft_batch_signal_len",
"fft_batch_vs_sequential"})),
("Batch IFFT", frozenset({"ifft_batch_batch_size", "ifft_batch_signal_len",
"ifft_batch_vs_sequential"})),
("Batch round-trip", frozenset({"roundtrip_batch", "roundtrip_batch_signal_len"})),
("Radix-4 outer — scalar",
frozenset({"fft_radix4_outer", "ifft_radix4_outer",
"roundtrip_radix4_outer"})),
("Radix-4 outer — batch",
frozenset({"fft_batch_radix4_outer", "ifft_batch_radix4_outer",
"roundtrip_batch_radix4_outer"})),
]
def render(rows: list[dict], raw: str, chart_paths: dict[str, Path] | None) -> str:
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
commit = _git(["rev-parse", "--short", "HEAD"])
branch = _git(["rev-parse", "--abbrev-ref", "HEAD"])
lines = [
"# Benchmark Results",
"",
"| | |",
"|---|---|",
f"| **Date** | {now} |",
f"| **Commit** | `{commit}` ({branch}) |",
"",
]
if chart_paths:
bench_results = ROOT / "bench-results"
emitted: set[str] = set()
def _img(key: str) -> str:
p = chart_paths[key]
_, alt = _CHART_META.get(key, ("", key))
return f""
lines.append("## Charts")
lines.append("")
if "latency" in chart_paths and "throughput" in chart_paths:
lines += [
"### Scalar baselines",
"",
"| Latency | Throughput |",
"|---------|------------|",
f"| {_img('latency')} | {_img('throughput')} |",
"",
]
emitted |= {"latency", "throughput"}
for key in ("batch_signal", "batch_size", "vs_sequential",
"radix4_outer", "radix4_batch_outer"):
if key not in chart_paths:
continue
heading, _ = _CHART_META.get(key, (key, key))
lines += [f"### {heading}", "", _img(key), ""]
emitted.add(key)
for key in chart_paths:
if key in emitted:
continue
heading, _ = _CHART_META.get(key, (key, key))
lines += [f"### {heading}", "", _img(key), ""]
lines += [
"## Summary",
"",
"| Benchmark | Param | Mean | 95% CI | Std dev | Throughput |",
"|-----------|------:|-----:|--------|--------:|------------|",
]
by_group: dict[str, list[dict]] = {}
for r in rows:
by_group.setdefault(r["group"], []).append(r)
for g in by_group:
by_group[g].sort(key=lambda r: (r["sub_series"] or "", r["n"]))
emitted_groups: set[str] = set()
first_section = True
def _emit_group(g: str) -> None:
if g not in by_group:
return
for r in by_group[g]:
mean = _fmt_time(r["mean"])
ci = f"[{_fmt_time(r['lo'])} … {_fmt_time(r['hi'])}]"
std = _fmt_time(r["std"])
param = r["raw_param"]
if r["sub_series"]:
param = f"{r['sub_series']} {param}"
lines.append(
f"| {r['group']} | {param:>12} | {mean:>10} | {ci}"
f" | {std:>10} | {r['throughput']} |"
)
emitted_groups.add(g)
for section_name, section_groups in _TABLE_SECTIONS:
present = [g for g in sorted(section_groups) if g in by_group]
if not present:
continue
if not first_section:
lines.append("| | | | | | |")
first_section = False
for g in present:
_emit_group(g)
for g in sorted(by_group):
if g not in emitted_groups:
lines.append("| | | | | | |")
_emit_group(g)
result_lines: list[str] = []
capturing = False
for line in raw.splitlines():
if line.startswith(" Running") and "bench" in line:
capturing = True
if capturing:
result_lines.append(line)
if result_lines:
lines += [
"",
"## Raw Output",
"",
"<details>",
"<summary>expand</summary>",
"",
"```",
*result_lines,
"```",
"",
"</details>",
]
return "\n".join(lines) + "\n"
if __name__ == "__main__":
crit = ROOT / "target" / "criterion"
if not crit.exists():
sys.exit(
"No Criterion results found.\n"
"Run `cargo bench --features wgpu` first, then re-run this script."
)
rows = collect_results(crit)
raw = "" if sys.stdin.isatty() else sys.stdin.read()
chart_paths: dict[str, Path] | None = None
if HAS_MATPLOTLIB:
chart_paths = generate_charts(rows)
for key, path in chart_paths.items():
print(f"✓ Chart ({key:<14}) → {path}", file=sys.stderr)
else:
print(
"Warning: matplotlib not found — skipping chart generation.",
file=sys.stderr,
)
print(render(rows, raw, chart_paths), end="")