from __future__ import annotations
import argparse
import json
import math
import os
import shlex
import shutil
import subprocess
import sys
from collections import Counter
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
REPO_ROOT = Path(__file__).resolve().parent.parent
DEFAULT_PROMPTS = REPO_ROOT / "scripts" / "prompts" / "real_model_eval_prompts.jsonl"
DEFAULT_BUNDLE_ROOT = REPO_ROOT / "artifacts" / "real-model-bundles"
DEFAULT_REPORT_ROOT = REPO_ROOT / "artifacts" / "real-model-evals"
DEFAULT_PRESET = "smollm2-135m-instruct"
DEFAULT_EXPORT_VENV = REPO_ROOT / ".venv-real-model-export"
REQUIREMENTS_REAL_MODEL = REPO_ROOT / "scripts" / "requirements-real-model.txt"
MODEL_CANDIDATES = ("model.onnx", "decoder_model_merged.onnx", "decoder_model.onnx")
EXPORT_REQUIRED_MODULES = (
"optimum.exporters.onnx",
"transformers",
"torch",
"numpy",
"safetensors",
)
PREFERRED_EXPORT_PYTHONS = (
"/opt/homebrew/bin/python3.12",
"/opt/homebrew/bin/python3.11",
"/opt/homebrew/bin/python3.10",
"python3.12",
"python3.11",
"python3.10",
)
@dataclass(frozen=True)
class PromptRecord:
prompt: str
prompt_id: str | None = None
category: str | None = None
@dataclass(frozen=True)
class QuantizationRun:
strategy: str
key_bits: int
value_bits: int
@property
def slug(self) -> str:
return f"{self.strategy}-k{self.key_bits}-v{self.value_bits}"
@property
def label(self) -> str:
return f"{self.strategy} {self.key_bits}/{self.value_bits}-bit"
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--preset",
default=DEFAULT_PRESET,
help=(
"Named preset to export with scripts/export_hf_decoder_onnx.py. "
f"Defaults to {DEFAULT_PRESET!r}."
),
)
parser.add_argument(
"--model",
help="Explicit Hugging Face model id. Overrides --preset when exporting.",
)
parser.add_argument(
"--model-dir",
type=Path,
help="Use an existing exported ONNX bundle instead of exporting one.",
)
parser.add_argument(
"--bundle-dir",
type=Path,
help=(
"Directory for the exported ONNX bundle. Defaults to "
"artifacts/real-model-bundles/<preset-or-model>."
),
)
parser.add_argument(
"--output-dir",
type=Path,
help=(
"Directory to write raw benchmark JSON and summary files into. "
"Defaults to artifacts/real-model-evals/<timestamp>-<model>."
),
)
parser.add_argument(
"--force-output",
action="store_true",
help="Allow replacing an existing output directory.",
)
parser.add_argument(
"--force-export",
action="store_true",
help="Re-export the model bundle even if the bundle directory already exists.",
)
parser.add_argument(
"--skip-export",
action="store_true",
help="Require the ONNX bundle to already exist instead of exporting it.",
)
parser.add_argument(
"--skip-build",
action="store_true",
help="Use an existing benchmark binary instead of building it first.",
)
parser.add_argument(
"--export-python",
type=Path,
help=(
"Python interpreter to use for Hugging Face / Optimum export setup. "
"If omitted, the script prefers a stable 3.10-3.12 interpreter."
),
)
parser.add_argument(
"--export-venv",
type=Path,
default=DEFAULT_EXPORT_VENV,
help=(
"Virtual environment path for the export toolchain. "
"Defaults to .venv-real-model-export in the repo root."
),
)
parser.add_argument(
"--skip-export-bootstrap",
action="store_true",
help="Do not create/install the dedicated export virtualenv automatically.",
)
parser.add_argument(
"--cargo-profile",
choices=["release", "debug"],
default="release",
help="Rust profile to use for the benchmark example.",
)
parser.add_argument(
"--prompts",
type=Path,
default=DEFAULT_PROMPTS,
help="Prompt file consumed by the benchmark CLI. JSONL and plain text are both supported.",
)
parser.add_argument(
"--max-prompts",
type=int,
default=6,
help="Maximum number of prompts to benchmark from the prompt file.",
)
parser.add_argument(
"--max-new-tokens",
type=int,
default=24,
help="Maximum number of decode steps per prompt.",
)
parser.add_argument(
"--top-k",
type=int,
default=5,
help="Top-k used when comparing exact vs quantized logits.",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="Benchmark RNG seed.",
)
parser.add_argument(
"--strategies",
nargs="+",
choices=["prod", "mse"],
default=["prod", "mse"],
help="Key quantization strategies to compare.",
)
parser.add_argument(
"--bits",
nargs="+",
type=int,
default=[2, 4, 8],
help="Key bit widths to compare.",
)
parser.add_argument(
"--value-bits",
type=int,
help="Value bit width for all runs. Defaults to matching the key bit width.",
)
parser.add_argument(
"--min-token-match",
type=float,
default=0.90,
help="Quality gate used when picking speed/compression recommendations.",
)
parser.add_argument(
"--min-top-k-agreement",
type=float,
default=0.80,
help="Quality gate used when picking speed/compression recommendations.",
)
parser.add_argument(
"--notes",
help="Optional free-form note to include in the summary report.",
)
return parser.parse_args()
def main() -> int:
args = parse_args()
validate_args(args)
prompts = load_prompts(args.prompts, args.max_prompts)
output_dir = prepare_output_dir(resolve_output_dir(args), args.force_output)
raw_dir = output_dir / "raw"
raw_dir.mkdir(parents=True, exist_ok=True)
copy_prompt_suite(args.prompts, output_dir)
bundle_dir = resolve_bundle_dir(args)
reused_existing_bundle = False
if args.model_dir is None:
export_python = prepare_export_python(args)
reused_existing_bundle = ensure_bundle_exported(args, bundle_dir, export_python)
else:
reused_existing_bundle = True
validate_bundle_dir(bundle_dir)
benchmark_binary = build_benchmark_binary(args)
benchmark_commands: list[list[str]] = []
exact_command = benchmark_command(
benchmark_binary,
bundle_dir,
args,
mode="exact",
)
benchmark_commands.append(exact_command)
exact_report = run_json_command(exact_command, raw_dir / "exact.json")
exact_row = expect_single_row(exact_report, "exact")
compare_runs = [
QuantizationRun(
strategy=strategy,
key_bits=bits,
value_bits=args.value_bits if args.value_bits is not None else bits,
)
for strategy in args.strategies
for bits in args.bits
]
compare_results: list[dict[str, Any]] = []
for quantization in compare_runs:
compare_command = benchmark_command(
benchmark_binary,
bundle_dir,
args,
mode="compare",
quantization=quantization,
)
benchmark_commands.append(compare_command)
compare_report = run_json_command(
compare_command,
raw_dir / f"compare-{quantization.slug}.json",
)
compare_row = expect_single_row(compare_report, quantization.label)
compare_results.append(
augment_compare_row(
compare_row,
exact_row=exact_row,
quantization=quantization,
raw_report_path=raw_dir / f"compare-{quantization.slug}.json",
min_token_match=args.min_token_match,
min_top_k_agreement=args.min_top_k_agreement,
)
)
compare_results.sort(key=lambda row: row["balanced_score"], reverse=True)
summary = build_summary(
args=args,
bundle_dir=bundle_dir,
output_dir=output_dir,
prompts=prompts,
exact_row=exact_row,
compare_results=compare_results,
benchmark_binary=benchmark_binary,
benchmark_commands=benchmark_commands,
reused_existing_bundle=reused_existing_bundle,
)
write_json(output_dir / "summary.json", summary)
write_text(output_dir / "summary.md", build_markdown_summary(summary))
print()
print(f"Summary written to {output_dir / 'summary.md'}")
print(f"Raw benchmark JSON saved under {raw_dir}")
if compare_results:
best = compare_results[0]
print(
"Best balanced configuration: "
f"{best['label']} "
f"(token_match={format_ratio(best.get('token_match_rate'))}, "
f"top_k={format_ratio(best.get('top_k_agreement'))}, "
f"compression={format_float(best.get('compression_ratio'), 2)}x)"
)
return 0
def validate_args(args: argparse.Namespace) -> None:
if args.model_dir is not None and args.force_export:
raise SystemExit("--force-export cannot be used together with --model-dir")
if args.model_dir is not None and args.skip_export:
raise SystemExit("--skip-export is redundant with --model-dir; pass only --model-dir")
if args.max_prompts <= 0:
raise SystemExit("--max-prompts must be positive")
if args.max_new_tokens <= 0:
raise SystemExit("--max-new-tokens must be positive")
if args.top_k <= 0:
raise SystemExit("--top-k must be positive")
if not args.bits:
raise SystemExit("pass at least one value to --bits")
for bits in args.bits:
if bits <= 0:
raise SystemExit("--bits values must be positive")
if args.value_bits is not None and args.value_bits <= 0:
raise SystemExit("--value-bits must be positive")
if not (0.0 <= args.min_token_match <= 1.0):
raise SystemExit("--min-token-match must be between 0 and 1")
if not (0.0 <= args.min_top_k_agreement <= 1.0):
raise SystemExit("--min-top-k-agreement must be between 0 and 1")
if args.export_python is not None and args.skip_export_bootstrap:
pass
def load_prompts(path: Path, max_prompts: int) -> list[PromptRecord]:
if not path.exists():
raise SystemExit(f"prompt file does not exist: {path}")
text = path.read_text(encoding="utf-8")
prompts: list[PromptRecord] = []
if path.suffix == ".jsonl":
for line_number, line in enumerate(text.splitlines(), start=1):
if not line.strip():
continue
try:
payload = json.loads(line)
except json.JSONDecodeError as exc:
raise SystemExit(f"{path}:{line_number}: invalid JSONL: {exc}") from exc
prompt = first_non_empty_string(
payload,
("prompt", "text", "input", "question"),
)
if prompt is None:
raise SystemExit(
f"{path}:{line_number}: expected a non-empty prompt/text/input/question field"
)
prompts.append(
PromptRecord(
prompt=prompt,
prompt_id=string_or_none(payload.get("id")),
category=string_or_none(payload.get("category")),
)
)
if len(prompts) >= max_prompts:
break
else:
for line in text.splitlines():
prompt = line.strip()
if not prompt:
continue
prompts.append(PromptRecord(prompt=prompt))
if len(prompts) >= max_prompts:
break
if not prompts:
raise SystemExit(f"no prompts found in {path}")
return prompts
def first_non_empty_string(payload: dict[str, Any], keys: tuple[str, ...]) -> str | None:
for key in keys:
value = payload.get(key)
if isinstance(value, str):
trimmed = value.strip()
if trimmed:
return trimmed
return None
def string_or_none(value: Any) -> str | None:
if isinstance(value, str):
trimmed = value.strip()
if trimmed:
return trimmed
return None
def copy_prompt_suite(source: Path, output_dir: Path) -> None:
destination = output_dir / f"prompts{source.suffix}"
shutil.copy2(source, destination)
def resolve_bundle_dir(args: argparse.Namespace) -> Path:
if args.model_dir is not None:
return args.model_dir.resolve()
if args.bundle_dir is not None:
return args.bundle_dir.resolve()
model_name = args.model if args.model else args.preset
return (DEFAULT_BUNDLE_ROOT / slugify(model_name)).resolve()
def resolve_output_dir(args: argparse.Namespace) -> Path:
if args.output_dir is not None:
return args.output_dir.resolve()
model_name = (
args.model_dir.name
if args.model_dir is not None
else slugify(args.model if args.model else args.preset)
)
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%SZ")
return (DEFAULT_REPORT_ROOT / f"{timestamp}-{model_name}").resolve()
def prepare_output_dir(path: Path, force_output: bool) -> Path:
if path.exists():
if not path.is_dir():
raise SystemExit(f"{path} exists and is not a directory")
if any(path.iterdir()):
if not force_output:
raise SystemExit(
f"{path} already exists and is not empty; pass --force-output to replace it"
)
shutil.rmtree(path)
path.mkdir(parents=True, exist_ok=True)
return path
def prepare_export_python(args: argparse.Namespace) -> Path:
requested = args.export_python.resolve() if args.export_python is not None else None
export_venv = args.export_venv.resolve()
if requested is not None:
if not requested.exists():
raise SystemExit(f"--export-python does not exist: {requested}")
if python_has_required_modules(requested):
return requested
if args.skip_export_bootstrap:
raise SystemExit(
"the requested --export-python is missing export dependencies; "
f"install {REQUIREMENTS_REAL_MODEL} into that interpreter or remove "
"--skip-export-bootstrap"
)
install_requirements(requested, REQUIREMENTS_REAL_MODEL)
return requested
venv_python = venv_python_path(export_venv)
if venv_python.exists() and python_has_required_modules(venv_python):
print(f"Reusing export virtualenv at {export_venv}")
return venv_python
current_python = Path(sys.executable).resolve()
if python_version_tuple(current_python) in {(3, 10), (3, 11), (3, 12)} and python_has_required_modules(current_python):
return current_python
if args.skip_export_bootstrap:
raise SystemExit(
"no ready export environment found. Pass --export-python with optimum/transformers "
"installed, or rerun without --skip-export-bootstrap so the script can create "
f"{export_venv} automatically."
)
base_python = find_bootstrap_python()
if base_python is None:
raise SystemExit(
"could not find a suitable Python 3.10-3.12 interpreter for ONNX export bootstrap. "
"Install one and rerun with --export-python /path/to/python3.10."
)
ensure_virtualenv(export_venv, base_python)
venv_python = venv_python_path(export_venv)
install_requirements(venv_python, REQUIREMENTS_REAL_MODEL)
return venv_python
def ensure_bundle_exported(args: argparse.Namespace, bundle_dir: Path, export_python: Path) -> bool:
if args.skip_export:
if not bundle_dir.exists():
raise SystemExit(
f"--skip-export was set, but the bundle directory does not exist: {bundle_dir}"
)
return True
if bundle_dir.exists() and has_valid_bundle_files(bundle_dir) and not args.force_export:
print(f"Reusing existing ONNX bundle at {bundle_dir}")
return True
if bundle_dir.exists() and args.force_export:
shutil.rmtree(bundle_dir)
bundle_dir.parent.mkdir(parents=True, exist_ok=True)
export_command = [
str(export_python),
str(REPO_ROOT / "scripts" / "export_hf_decoder_onnx.py"),
]
if args.model:
export_command.extend(["--model", args.model])
else:
export_command.extend(["--preset", args.preset])
export_command.extend(["--output-dir", str(bundle_dir), "--force"])
print(f"Exporting ONNX bundle to {bundle_dir}")
run_command(export_command)
return False
def venv_python_path(venv_dir: Path) -> Path:
if os.name == "nt":
return venv_dir / "Scripts" / "python.exe"
return venv_dir / "bin" / "python"
def find_bootstrap_python() -> Path | None:
for candidate in PREFERRED_EXPORT_PYTHONS:
resolved = shutil.which(candidate)
if resolved is None:
continue
path = Path(resolved).resolve()
version = python_version_tuple(path)
if version in {(3, 10), (3, 11), (3, 12)}:
return path
return None
def python_version_tuple(python: Path) -> tuple[int, int] | None:
completed = subprocess.run(
[str(python), "-c", "import sys; print(f'{sys.version_info[0]}.{sys.version_info[1]}')"],
cwd=REPO_ROOT,
text=True,
check=True,
capture_output=True,
)
text = completed.stdout.strip()
major, minor = text.split(".", 1)
return int(major), int(minor)
def python_has_required_modules(python: Path) -> bool:
script = (
"import importlib.util, sys\n"
f"modules = {EXPORT_REQUIRED_MODULES!r}\n"
"missing = [name for name in modules if importlib.util.find_spec(name) is None]\n"
"print('\\n'.join(missing))\n"
"sys.exit(0 if not missing else 1)\n"
)
completed = subprocess.run(
[str(python), "-c", script],
cwd=REPO_ROOT,
text=True,
capture_output=True,
)
if completed.returncode == 0:
return True
missing = completed.stdout.strip().splitlines()
if missing:
print(
f"Export dependencies missing for {python}: {', '.join(missing)}",
file=sys.stderr,
)
return False
def ensure_virtualenv(venv_dir: Path, base_python: Path) -> None:
if venv_python_path(venv_dir).exists():
return
print(f"Creating export virtualenv at {venv_dir} using {base_python}")
run_command([str(base_python), "-m", "venv", str(venv_dir)])
def install_requirements(python: Path, requirements_path: Path) -> None:
print(f"Installing export dependencies from {requirements_path} using {python}")
run_command([str(python), "-m", "pip", "install", "--upgrade", "pip"])
run_command([str(python), "-m", "pip", "install", "-r", str(requirements_path)])
def has_valid_bundle_files(bundle_dir: Path) -> bool:
return (
bundle_dir.is_dir()
and (bundle_dir / "config.json").exists()
and (bundle_dir / "tokenizer.json").exists()
and any((bundle_dir / candidate).exists() for candidate in MODEL_CANDIDATES)
)
def validate_bundle_dir(bundle_dir: Path) -> None:
if not bundle_dir.exists():
raise SystemExit(f"ONNX bundle directory does not exist: {bundle_dir}")
if not bundle_dir.is_dir():
raise SystemExit(f"ONNX bundle path is not a directory: {bundle_dir}")
if not (bundle_dir / "config.json").exists():
raise SystemExit(f"missing config.json in bundle directory: {bundle_dir}")
if not (bundle_dir / "tokenizer.json").exists():
raise SystemExit(f"missing tokenizer.json in bundle directory: {bundle_dir}")
if not any((bundle_dir / candidate).exists() for candidate in MODEL_CANDIDATES):
joined = ", ".join(MODEL_CANDIDATES)
raise SystemExit(f"expected one of {joined} in bundle directory: {bundle_dir}")
def build_benchmark_binary(args: argparse.Namespace) -> Path:
binary = REPO_ROOT / "target" / args.cargo_profile / "examples" / benchmark_binary_name()
if not args.skip_build:
build_command = ["cargo", "build", f"--{args.cargo_profile}", "--example", "benchmark"]
print(f"Building benchmark binary with cargo profile {args.cargo_profile}")
run_command(build_command)
if not binary.exists():
raise SystemExit(
f"benchmark binary not found at {binary}; remove --skip-build or build it manually"
)
return binary
def benchmark_binary_name() -> str:
return "benchmark.exe" if os.name == "nt" else "benchmark"
def benchmark_command(
benchmark_binary: Path,
bundle_dir: Path,
args: argparse.Namespace,
*,
mode: str,
quantization: QuantizationRun | None = None,
) -> list[str]:
command = [
str(benchmark_binary),
"--workload",
"real-model",
"--real-model-dir",
str(bundle_dir),
"--prompts",
str(args.prompts.resolve()),
"--max-prompts",
str(args.max_prompts),
"--max-new-tokens",
str(args.max_new_tokens),
"--top-k",
str(args.top_k),
"--seed",
str(args.seed),
"--json",
"--real-eval-mode",
mode,
]
if quantization is not None:
command.extend(
[
"--bits",
str(quantization.key_bits),
"--value-bits",
str(quantization.value_bits),
"--real-key-strategy",
quantization.strategy,
]
)
return command
def run_json_command(command: list[str], raw_output_path: Path) -> dict[str, Any]:
completed = run_command(command, capture_output=True)
raw_output_path.write_text(completed.stdout, encoding="utf-8")
try:
return json.loads(completed.stdout)
except json.JSONDecodeError as exc:
raise SystemExit(
f"benchmark output at {raw_output_path} was not valid JSON: {exc}"
) from exc
def run_command(
command: list[str],
*,
capture_output: bool = False,
) -> subprocess.CompletedProcess[str]:
print(f"$ {shlex.join(command)}")
return subprocess.run(
command,
cwd=REPO_ROOT,
text=True,
check=True,
capture_output=capture_output,
)
def expect_single_row(report: dict[str, Any], label: str) -> dict[str, Any]:
rows = report.get("rows")
if not isinstance(rows, list) or len(rows) != 1 or not isinstance(rows[0], dict):
raise SystemExit(f"expected exactly one benchmark row in the {label} report")
return rows[0]
def augment_compare_row(
row: dict[str, Any],
*,
exact_row: dict[str, Any],
quantization: QuantizationRun,
raw_report_path: Path,
min_token_match: float,
min_top_k_agreement: float,
) -> dict[str, Any]:
exact_tps = number_or_none(exact_row.get("exact_tokens_per_second"))
quantized_tps = number_or_none(row.get("quantized_tokens_per_second"))
speedup = safe_divide(quantized_tps, exact_tps)
exact_latency = number_or_none(exact_row.get("exact_latency_seconds"))
quantized_latency = number_or_none(row.get("quantized_latency_seconds"))
latency_ratio = safe_divide(quantized_latency, exact_latency)
token_match = number_or_none(row.get("token_match_rate"))
top_k = number_or_none(row.get("top_k_agreement"))
logit_rmse = number_or_none(row.get("logit_rmse"))
compression_ratio = number_or_none(row.get("compression_ratio"))
quality_score = compute_quality_score(
token_match_rate=token_match,
top_k_agreement=top_k,
logit_rmse=logit_rmse,
)
balanced_score = compute_balanced_score(
quality_score=quality_score,
compression_ratio=compression_ratio,
speedup=speedup,
)
passing_quality_gate = (
token_match is not None
and top_k is not None
and token_match >= min_token_match
and top_k >= min_top_k_agreement
)
augmented = dict(row)
augmented.update(
{
"strategy": quantization.strategy,
"key_bits": quantization.key_bits,
"value_bits": quantization.value_bits,
"label": quantization.label,
"slug": quantization.slug,
"quality_score": quality_score,
"balanced_score": balanced_score,
"speedup_vs_exact": speedup,
"latency_ratio_vs_exact": latency_ratio,
"quality_gate_passed": passing_quality_gate,
"raw_report_path": str(raw_report_path),
}
)
return augmented
def compute_quality_score(
*,
token_match_rate: float | None,
top_k_agreement: float | None,
logit_rmse: float | None,
) -> float:
components: list[float] = []
if token_match_rate is not None:
components.append(max(0.0, min(1.0, token_match_rate)))
if top_k_agreement is not None:
components.append(max(0.0, min(1.0, top_k_agreement)))
if logit_rmse is not None:
components.append(1.0 / (1.0 + max(logit_rmse, 0.0)))
if not components:
return 0.0
return sum(components) / len(components)
def compute_balanced_score(
*,
quality_score: float,
compression_ratio: float | None,
speedup: float | None,
) -> float:
compression_component = 0.0
speed_component = 0.0
if compression_ratio is not None and compression_ratio > 0.0:
compression_component = min(compression_ratio / 8.0, 1.0)
if speedup is not None and speedup > 0.0:
speed_component = min(speedup / 1.5, 1.0)
return (
0.55 * quality_score
+ 0.25 * compression_component
+ 0.20 * speed_component
)
def build_summary(
*,
args: argparse.Namespace,
bundle_dir: Path,
output_dir: Path,
prompts: list[PromptRecord],
exact_row: dict[str, Any],
compare_results: list[dict[str, Any]],
benchmark_binary: Path,
benchmark_commands: list[list[str]],
reused_existing_bundle: bool,
) -> dict[str, Any]:
categories = Counter(prompt.category for prompt in prompts if prompt.category)
model_id = string_or_none(exact_row.get("model")) or bundle_dir.name
quality_passing_rows = [row for row in compare_results if row["quality_gate_passed"]]
summary = {
"generated_at_utc": datetime.now(timezone.utc).isoformat(),
"git_commit": git_commit(),
"python_version": sys.version.split()[0],
"cargo_profile": args.cargo_profile,
"benchmark_binary": str(benchmark_binary),
"bundle_dir": str(bundle_dir),
"model": {
"resolved_model_id": model_id,
"requested_preset": None if args.model_dir is not None else args.preset,
"requested_model_id": args.model,
"used_existing_bundle": reused_existing_bundle,
},
"prompt_suite": {
"path": str(args.prompts.resolve()),
"copied_to": str(output_dir / f"prompts{args.prompts.suffix}"),
"count": len(prompts),
"categories": dict(categories),
"prompt_ids": [prompt.prompt_id for prompt in prompts if prompt.prompt_id],
},
"generation": {
"max_new_tokens": args.max_new_tokens,
"top_k": args.top_k,
"seed": args.seed,
},
"quality_gates": {
"min_token_match": args.min_token_match,
"min_top_k_agreement": args.min_top_k_agreement,
},
"exact_baseline": exact_row,
"compare_runs": compare_results,
"rankings": {
"best_balanced": compare_results[0]["slug"] if compare_results else None,
"best_quality": best_slug(compare_results, "quality_score"),
"best_compression_passing_gate": best_slug(
quality_passing_rows, "compression_ratio"
),
"fastest_passing_gate": best_slug(quality_passing_rows, "speedup_vs_exact"),
},
"commands": [shlex.join(command) for command in benchmark_commands],
"notes": args.notes,
}
return summary
def best_slug(rows: list[dict[str, Any]], metric: str) -> str | None:
if not rows:
return None
ranked = sorted(
rows,
key=lambda row: number_or_none(row.get(metric)) or float("-inf"),
reverse=True,
)
return ranked[0]["slug"]
def build_markdown_summary(summary: dict[str, Any]) -> str:
exact = summary["exact_baseline"]
compare_rows = summary["compare_runs"]
prompt_suite = summary["prompt_suite"]
quality_gates = summary["quality_gates"]
rankings = summary["rankings"]
lines = [
"# TurboQuant Real-Model Evaluation",
"",
"## Run Metadata",
f"- Generated at (UTC): `{summary['generated_at_utc']}`",
f"- Git commit: `{summary['git_commit']}`",
f"- Resolved model id: `{summary['model']['resolved_model_id']}`",
f"- ONNX bundle: `{summary['bundle_dir']}`",
f"- Prompt suite: `{prompt_suite['path']}` ({prompt_suite['count']} prompts)",
f"- Cargo profile: `{summary['cargo_profile']}`",
f"- Benchmark binary: `{summary['benchmark_binary']}`",
f"- Max new tokens: `{summary['generation']['max_new_tokens']}`",
f"- Top-k agreement metric: `top-{summary['generation']['top_k']}`",
"",
"## Exact Baseline",
"",
"| Metric | Value |",
"| --- | ---: |",
f"| Prompt count | {exact['samples']} |",
f"| Prompt tokens | {exact['tokens']} |",
f"| Generated tokens | {none_to_dash(exact.get('generated_tokens'))} |",
f"| Exact latency (s) | {format_float(exact.get('exact_latency_seconds'), 3)} |",
f"| Exact tokens/sec | {format_float(exact.get('exact_tokens_per_second'), 3)} |",
f"| Exact KV memory (bytes) | {format_int(exact.get('kv_memory_exact_bytes'))} |",
f"| Self cross-entropy | {format_float(exact.get('cross_entropy_exact'), 4)} |",
f"| Self perplexity | {format_float(exact.get('perplexity_exact'), 4)} |",
"",
"## Compare Results",
"",
"| Config | Logit RMSE | Top-k | Token match | Divergence | Compression | Quant tok/s | Speedup | Balanced | Gate |",
"| --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | --- |",
]
for row in compare_rows:
lines.append(
"| "
f"{row['label']} | "
f"{format_float(row.get('logit_rmse'), 4)} | "
f"{format_ratio(row.get('top_k_agreement'))} | "
f"{format_ratio(row.get('token_match_rate'))} | "
f"{format_ratio(row.get('divergence_rate'))} | "
f"{format_float(row.get('compression_ratio'), 2)}x | "
f"{format_float(row.get('quantized_tokens_per_second'), 3)} | "
f"{format_float(row.get('speedup_vs_exact'), 3)}x | "
f"{format_float(row.get('balanced_score'), 3)} | "
f"{'pass' if row['quality_gate_passed'] else 'fail'} |"
)
lines.extend(
[
"",
"## Recommendations",
"",
f"- Best balanced: {describe_rank(summary, rankings.get('best_balanced'))}",
f"- Best quality: {describe_rank(summary, rankings.get('best_quality'))}",
(
"- Best compression that passed the quality gate "
f"(token_match>={quality_gates['min_token_match']:.2f}, "
f"top_k>={quality_gates['min_top_k_agreement']:.2f}): "
f"{describe_rank(summary, rankings.get('best_compression_passing_gate'))}"
),
(
"- Fastest configuration that passed the quality gate "
f"(token_match>={quality_gates['min_token_match']:.2f}, "
f"top_k>={quality_gates['min_top_k_agreement']:.2f}): "
f"{describe_rank(summary, rankings.get('fastest_passing_gate'))}"
),
"",
"## Raw Artifacts",
"",
"- `summary.json`: machine-readable run summary",
"- `summary.md`: this report",
"- `raw/exact.json`: exact baseline benchmark output",
"- `raw/compare-*.json`: raw compare-mode benchmark outputs",
]
)
if summary.get("notes"):
lines.extend(["", "## Notes", "", summary["notes"]])
lines.extend(["", "## Commands", ""])
for command in summary["commands"]:
lines.append(f"- `{command}`")
return "\n".join(lines) + "\n"
def describe_rank(summary: dict[str, Any], slug: str | None) -> str:
if slug is None:
return "none met the filter"
for row in summary["compare_runs"]:
if row["slug"] == slug:
return (
f"`{row['label']}` "
f"(token_match={format_ratio(row.get('token_match_rate'))}, "
f"top_k={format_ratio(row.get('top_k_agreement'))}, "
f"compression={format_float(row.get('compression_ratio'), 2)}x, "
f"speedup={format_float(row.get('speedup_vs_exact'), 3)}x)"
)
return slug
def git_commit() -> str:
try:
completed = subprocess.run(
["git", "rev-parse", "HEAD"],
cwd=REPO_ROOT,
text=True,
check=True,
capture_output=True,
)
except (FileNotFoundError, subprocess.CalledProcessError):
return "unknown"
return completed.stdout.strip()
def number_or_none(value: Any) -> float | None:
if isinstance(value, (int, float)):
return float(value)
return None
def safe_divide(numerator: float | None, denominator: float | None) -> float | None:
if numerator is None or denominator is None or denominator == 0.0:
return None
return numerator / denominator
def slugify(text: str) -> str:
characters = []
for char in text.lower():
if char.isalnum():
characters.append(char)
else:
characters.append("-")
slug = "".join(characters)
while "--" in slug:
slug = slug.replace("--", "-")
return slug.strip("-") or "model"
def write_json(path: Path, payload: dict[str, Any]) -> None:
path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")
def write_text(path: Path, text: str) -> None:
path.write_text(text, encoding="utf-8")
def format_float(value: Any, digits: int) -> str:
numeric = number_or_none(value)
if numeric is None or math.isnan(numeric):
return "-"
return f"{numeric:.{digits}f}"
def format_ratio(value: Any) -> str:
numeric = number_or_none(value)
if numeric is None or math.isnan(numeric):
return "-"
return f"{numeric:.1%}"
def format_int(value: Any) -> str:
if isinstance(value, int):
return f"{value:,}"
numeric = number_or_none(value)
if numeric is None or math.isnan(numeric):
return "-"
return f"{int(numeric):,}"
def none_to_dash(value: Any) -> str:
if value is None:
return "-"
return str(value)
if __name__ == "__main__":
try:
raise SystemExit(main())
except subprocess.CalledProcessError as exc:
print(f"Command failed with exit code {exc.returncode}: {shlex.join(exc.cmd)}", file=sys.stderr)
if exc.stdout:
print(exc.stdout, file=sys.stderr, end="" if exc.stdout.endswith("\n") else "\n")
if exc.stderr:
print(exc.stderr, file=sys.stderr, end="" if exc.stderr.endswith("\n") else "\n")
raise SystemExit(exc.returncode)