from __future__ import annotations
import argparse
import os
import platform
import subprocess
import sys
import tempfile
import time
import urllib.request
from dataclasses import dataclass
from pathlib import Path
import numpy as np
try:
import onnxruntime as ort
except ImportError:
print("ERROR: onnxruntime is required. pip install onnxruntime")
sys.exit(1)
@dataclass
class ModelSpec:
name: str
url: str
input_shape: dict[str, tuple[int, ...]]
min_cosine_int8: float = 0.95
min_cosine_int4: float = 0.30
MODELS: dict[str, ModelSpec] = {
"resnet18": ModelSpec(
name="ResNet-18",
url="https://github.com/onnx/models/raw/main/validated/vision/classification/resnet/model/resnet18-v1-7.onnx",
input_shape={"data": (1, 3, 224, 224)},
),
"mobilenetv2": ModelSpec(
name="MobileNetV2",
url="https://github.com/onnx/models/raw/main/validated/vision/classification/mobilenet/model/mobilenetv2-7.onnx",
input_shape={"input": (1, 3, 224, 224)},
),
"squeezenet": ModelSpec(
name="SqueezeNet-1.0",
url="https://github.com/onnx/models/raw/main/validated/vision/classification/squeezenet/model/squeezenet1.0-7.onnx",
input_shape={"data_0": (1, 3, 224, 224)},
),
}
PROJECT_ROOT = Path(__file__).resolve().parent.parent
MODEL_CACHE = PROJECT_ROOT / "eval" / "models"
def default_binary() -> str:
ext = ".exe" if platform.system() == "Windows" else ""
for profile in ("release", "debug"):
p = PROJECT_ROOT / "target" / profile / f"quantize-rs{ext}"
if p.exists():
return str(p)
return f"quantize-rs{ext}"
def download_model(spec: ModelSpec) -> Path:
MODEL_CACHE.mkdir(parents=True, exist_ok=True)
filename = spec.url.split("/")[-1]
dest = MODEL_CACHE / filename
if dest.exists():
print(f" [cached] {dest.name} ({dest.stat().st_size / 1e6:.1f} MB)")
return dest
print(f" Downloading {spec.name}...")
urllib.request.urlretrieve(spec.url, dest)
print(f" Saved {dest.name} ({dest.stat().st_size / 1e6:.1f} MB)")
return dest
def quantize_model(
binary: str, input_path: Path, output_path: Path,
bits: int = 8, per_channel: bool = False, min_elements: int = 0,
) -> bool:
cmd = [
binary, "quantize",
str(input_path),
"-o", str(output_path),
"--bits", str(bits),
]
if per_channel:
cmd.append("--per-channel")
if min_elements > 0:
cmd.extend(["--min-elements", str(min_elements)])
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
print(f" QUANTIZE FAILED (exit {result.returncode})")
if result.stderr:
for line in result.stderr.strip().splitlines()[:5]:
print(f" stderr: {line}")
if result.stdout:
for line in result.stdout.strip().splitlines()[-5:]:
print(f" stdout: {line}")
return False
return True
def run_inference(model_path: Path, inputs: dict[str, np.ndarray]) -> list[np.ndarray]:
sess = ort.InferenceSession(str(model_path), providers=["CPUExecutionProvider"])
outputs = sess.run(None, inputs)
return outputs
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
a_flat = a.flatten().astype(np.float64)
b_flat = b.flatten().astype(np.float64)
dot = np.dot(a_flat, b_flat)
norm_a = np.linalg.norm(a_flat)
norm_b = np.linalg.norm(b_flat)
if norm_a < 1e-12 or norm_b < 1e-12:
return 1.0 if np.allclose(a_flat, b_flat) else 0.0
return float(dot / (norm_a * norm_b))
def max_abs_error(a: np.ndarray, b: np.ndarray) -> float:
return float(np.max(np.abs(a.astype(np.float64) - b.astype(np.float64))))
def top_k_match(a: np.ndarray, b: np.ndarray, k: int = 5) -> bool:
if a.ndim < 1 or a.size < k:
return True
a_flat = a.flatten()
b_flat = b.flatten()
top_a = set(np.argsort(a_flat)[-k:])
top_b = set(np.argsort(b_flat)[-k:])
return len(top_a & top_b) >= 1
@dataclass
class ValidationResult:
model: str
config: str
success: bool
ort_loads: bool
cosine: float
max_error: float
top5_match: bool
compression: float
error_msg: str = ""
def validate_config(
spec: ModelSpec,
model_path: Path,
binary: str,
bits: int,
per_channel: bool,
min_elements: int,
) -> ValidationResult:
config_str = f"INT{bits}" + (" per-ch" if per_channel else "")
min_cosine = spec.min_cosine_int8 if bits == 8 else spec.min_cosine_int4
with tempfile.TemporaryDirectory() as tmpdir:
output_path = Path(tmpdir) / f"quantized_int{bits}.onnx"
ok = quantize_model(binary, model_path, output_path, bits, per_channel, min_elements)
if not ok:
return ValidationResult(
spec.name, config_str, False, False,
0.0, 0.0, False, 0.0, "Quantization CLI failed",
)
try:
sess_q = ort.InferenceSession(str(output_path), providers=["CPUExecutionProvider"])
except Exception as e:
return ValidationResult(
spec.name, config_str, False, False,
0.0, 0.0, False, 0.0, f"ORT failed to load: {e}",
)
np.random.seed(42)
inputs = {}
for name, shape in spec.input_shape.items():
inputs[name] = np.random.randn(*shape).astype(np.float32)
try:
fp32_out = run_inference(model_path, inputs)
except Exception as e:
return ValidationResult(
spec.name, config_str, False, True,
0.0, 0.0, False, 0.0, f"FP32 inference failed: {e}",
)
try:
q_inputs = {}
for inp in sess_q.get_inputs():
if inp.name in inputs:
q_inputs[inp.name] = inputs[inp.name]
quant_out = sess_q.run(None, q_inputs)
except Exception as e:
return ValidationResult(
spec.name, config_str, False, True,
0.0, 0.0, False, 0.0, f"Quantized inference failed: {e}",
)
cos = cosine_similarity(fp32_out[0], quant_out[0])
mae = max_abs_error(fp32_out[0], quant_out[0])
top5 = top_k_match(fp32_out[0], quant_out[0])
has_nan = any(np.isnan(o).any() for o in quant_out)
has_inf = any(np.isinf(o).any() for o in quant_out)
orig_size = model_path.stat().st_size
quant_size = output_path.stat().st_size
compression = orig_size / max(quant_size, 1)
errors = []
if has_nan:
errors.append("Output contains NaN")
if has_inf:
errors.append("Output contains Inf")
if cos < min_cosine:
errors.append(f"Cosine {cos:.4f} < threshold {min_cosine}")
return ValidationResult(
model=spec.name,
config=config_str,
success=len(errors) == 0,
ort_loads=True,
cosine=cos,
max_error=mae,
top5_match=top5,
compression=compression,
error_msg="; ".join(errors),
)
def main():
parser = argparse.ArgumentParser(description="Validate quantize-rs on real ONNX models")
parser.add_argument("--model", choices=list(MODELS.keys()), help="Run a single model only")
parser.add_argument("--no-download", action="store_true", help="Skip downloads, use cached models")
parser.add_argument("--binary", default=default_binary(), help="Path to quantize-rs binary")
parser.add_argument("--bits", type=int, choices=[4, 8], help="Test a single bit width only")
args = parser.parse_args()
models = {args.model: MODELS[args.model]} if args.model else MODELS
print("=" * 70)
print("quantize-rs Real-World Model Validation")
print("=" * 70)
print(f"Binary: {args.binary}")
print(f"Models: {', '.join(m.name for m in models.values())}")
print(f"Cache: {MODEL_CACHE}")
print()
if not args.no_download:
print("Downloading models...")
model_paths: dict[str, Path] = {}
for key, spec in models.items():
try:
model_paths[key] = download_model(spec)
except Exception as e:
print(f" FAILED to download {spec.name}: {e}")
print()
else:
model_paths = {}
for key, spec in models.items():
filename = spec.url.split("/")[-1]
p = MODEL_CACHE / filename
if p.exists():
model_paths[key] = p
else:
print(f" SKIP {spec.name} (not cached)")
configs = []
if args.bits is None or args.bits == 8:
configs.append({"bits": 8, "per_channel": False, "min_elements": 0})
configs.append({"bits": 8, "per_channel": True, "min_elements": 0})
if args.bits is None or args.bits == 4:
configs.append({"bits": 4, "per_channel": False, "min_elements": 128})
results: list[ValidationResult] = []
for key, spec in models.items():
if key not in model_paths:
continue
model_path = model_paths[key]
print(f"--- {spec.name} ({model_path.stat().st_size / 1e6:.1f} MB) ---")
for cfg in configs:
label = f"INT{cfg['bits']}" + (" per-ch" if cfg["per_channel"] else "")
print(f" [{label}] ", end="", flush=True)
t0 = time.time()
result = validate_config(spec, model_path, args.binary, **cfg)
elapsed = time.time() - t0
status = "PASS" if result.success else "FAIL"
print(
f"{status} cosine={result.cosine:.4f} "
f"max_err={result.max_error:.4f} "
f"top5={result.top5_match} "
f"compress={result.compression:.2f}x "
f"({elapsed:.1f}s)"
)
if result.error_msg:
print(f" {result.error_msg}")
results.append(result)
print()
print("=" * 70)
print("SUMMARY")
print("=" * 70)
passed = sum(1 for r in results if r.success)
failed = sum(1 for r in results if not r.success)
for r in results:
icon = "PASS" if r.success else "FAIL"
print(f" [{icon}] {r.model:20s} {r.config:12s} cosine={r.cosine:.4f} compress={r.compression:.2f}x")
if r.error_msg:
print(f" {r.error_msg}")
print()
print(f"Total: {passed} passed, {failed} failed out of {len(results)} configurations")
print()
if failed > 0:
print("VALIDATION FAILED")
sys.exit(1)
else:
print("ALL VALIDATIONS PASSED")
if __name__ == "__main__":
main()