import quantize_rs
import numpy as np
import os
def inspect_model(model_path: str):
info = quantize_rs.model_info(model_path)
print(f"Model: {info.name}")
print(f"Version: {info.version}")
print(f"Nodes: {info.num_nodes}")
print(f"Inputs: {info.inputs}")
print(f"Outputs: {info.outputs}")
return info
def basic_int8(input_path: str, output_path: str):
quantize_rs.quantize(input_path, output_path, bits=8)
print(f"INT8 quantized: {output_path}")
def basic_int4(input_path: str, output_path: str):
quantize_rs.quantize(input_path, output_path, bits=4)
print(f"INT4 quantized: {output_path}")
def per_channel_int8(input_path: str, output_path: str):
quantize_rs.quantize(input_path, output_path, bits=8, per_channel=True)
print(f"Per-channel INT8 quantized: {output_path}")
def exclude_sensitive_layers(input_path: str, output_path: str):
quantize_rs.quantize(
input_path,
output_path,
bits=8,
excluded_layers=[
"conv1.weight", "fc.weight", ],
)
print(f"Quantized with excluded layers: {output_path}")
def skip_small_tensors(input_path: str, output_path: str):
quantize_rs.quantize(
input_path,
output_path,
bits=8,
min_elements=512, )
print(f"Quantized (small tensors skipped): {output_path}")
def mixed_precision(input_path: str, output_path: str):
quantize_rs.quantize(
input_path,
output_path,
bits=4, layer_bits={
"conv1.weight": 8, "fc.weight": 8, "layer4.1.conv2.weight": 8, },
)
print(f"Mixed precision quantized: {output_path}")
def calibrate_with_real_data(input_path: str, output_path: str, data_path: str):
quantize_rs.quantize_with_calibration(
input_path,
output_path,
calibration_data=data_path,
bits=8,
per_channel=True,
method="minmax",
)
print(f"Calibrated (real data, minmax): {output_path}")
def calibrate_with_random_samples(input_path: str, output_path: str):
quantize_rs.quantize_with_calibration(
input_path,
output_path,
num_samples=50,
method="percentile",
)
print(f"Calibrated (random samples, percentile): {output_path}")
def compare_calibration_methods(input_path: str, output_dir: str):
os.makedirs(output_dir, exist_ok=True)
for method in ["minmax", "percentile", "entropy", "mse"]:
output_path = os.path.join(output_dir, f"model_{method}.onnx")
quantize_rs.quantize_with_calibration(
input_path,
output_path,
num_samples=50,
method=method,
)
size_mb = os.path.getsize(output_path) / (1024 * 1024)
print(f" {method:12s} -> {size_mb:.2f} MB")
def scenario_vision_classification(model_path: str, output_path: str):
quantize_rs.quantize(
model_path,
output_path,
bits=8,
per_channel=True,
min_elements=256, )
def scenario_nlp_transformer(model_path: str, output_path: str):
quantize_rs.quantize(
model_path,
output_path,
bits=4,
min_elements=8192, layer_bits={
"transformer.wte.weight": 8,
"transformer.wpe.weight": 8,
"lm_head.weight": 8,
},
)
def scenario_edge_deployment(model_path: str, output_path: str):
quantize_rs.quantize(
model_path,
output_path,
bits=4,
per_channel=True,
)
def scenario_high_accuracy(model_path: str, output_path: str, data_path: str):
quantize_rs.quantize_with_calibration(
model_path,
output_path,
calibration_data=data_path,
bits=8,
per_channel=True,
method="entropy",
)
def scenario_batch_processing(input_dir: str, output_dir: str):
import glob
os.makedirs(output_dir, exist_ok=True)
models = glob.glob(os.path.join(input_dir, "*.onnx"))
for model_path in models:
name = os.path.splitext(os.path.basename(model_path))[0]
output_path = os.path.join(output_dir, f"{name}_int8.onnx")
try:
info = quantize_rs.model_info(model_path)
print(f"Quantizing {info.name} ({info.num_nodes} nodes)...")
quantize_rs.quantize(model_path, output_path, bits=8, per_channel=True)
orig_size = os.path.getsize(model_path) / (1024 * 1024)
quant_size = os.path.getsize(output_path) / (1024 * 1024)
ratio = orig_size / quant_size if quant_size > 0 else 0
print(f" {orig_size:.1f} MB -> {quant_size:.1f} MB ({ratio:.1f}x)")
except Exception as e:
print(f" FAILED: {e}")
def verify_with_onnxruntime(original_path: str, quantized_path: str):
import onnxruntime as ort
orig_session = ort.InferenceSession(original_path)
quant_session = ort.InferenceSession(quantized_path)
inp = orig_session.get_inputs()[0]
shape = [d if isinstance(d, int) else 1 for d in inp.shape]
x = np.random.randn(*shape).astype(np.float32)
orig_out = orig_session.run(None, {inp.name: x})[0]
quant_out = quant_session.run(None, {inp.name: x})[0]
cos_sim = np.dot(orig_out.flat, quant_out.flat) / (
np.linalg.norm(orig_out) * np.linalg.norm(quant_out) + 1e-8
)
max_err = np.max(np.abs(orig_out - quant_out))
print(f"Cosine similarity: {cos_sim:.6f}")
print(f"Max absolute error: {max_err:.6f}")
print(f"Result: {'PASS' if cos_sim > 0.95 else 'CHECK ACCURACY'}")
def prepare_imagenet_calibration(image_dir: str, output_path: str, num_samples: int = 100):
from PIL import Image
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
samples = []
image_files = [
f for f in os.listdir(image_dir)
if f.lower().endswith((".jpg", ".jpeg", ".png"))
][:num_samples]
for fname in image_files:
img = Image.open(os.path.join(image_dir, fname)).convert("RGB")
ratio = 256.0 / min(img.size)
img = img.resize((int(img.width * ratio), int(img.height * ratio)))
left = (img.width - 224) // 2
top = (img.height - 224) // 2
img = img.crop((left, top, left + 224, top + 224))
arr = np.array(img, dtype=np.float32) / 255.0
arr = (arr - mean) / std
arr = arr.transpose(2, 0, 1) samples.append(arr)
calibration_data = np.stack(samples)
np.save(output_path, calibration_data)
print(f"Saved {len(samples)} samples to {output_path}")
print(f"Shape: {calibration_data.shape}")
def full_pipeline(model_path: str, output_dir: str):
os.makedirs(output_dir, exist_ok=True)
print("=" * 60)
print("Step 1: Model Inspection")
print("=" * 60)
info = inspect_model(model_path)
orig_size = os.path.getsize(model_path) / (1024 * 1024)
print(f"Size: {orig_size:.2f} MB\n")
variants = {
"int8": dict(bits=8),
"int8_perchannel": dict(bits=8, per_channel=True),
"int4": dict(bits=4),
"int4_perchannel": dict(bits=4, per_channel=True),
}
print("=" * 60)
print("Step 2: Quantization")
print("=" * 60)
results = {}
for name, kwargs in variants.items():
out_path = os.path.join(output_dir, f"{info.name}_{name}.onnx")
try:
quantize_rs.quantize(model_path, out_path, **kwargs)
size = os.path.getsize(out_path) / (1024 * 1024)
ratio = orig_size / size if size > 0 else 0
results[name] = (out_path, size, ratio)
print(f" {name:20s} {size:7.2f} MB ({ratio:.1f}x)")
except Exception as e:
print(f" {name:20s} FAILED: {e}")
print()
print("=" * 60)
print("Step 3: Verification")
print("=" * 60)
try:
import onnxruntime for name, (out_path, _, _) in results.items():
print(f"\n {name}:")
verify_with_onnxruntime(model_path, out_path)
except ImportError:
print(" onnxruntime not installed, skipping verification.")
print(" Install with: pip install onnxruntime")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description="quantize-rs Python examples",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python examples.py --model resnet18.onnx
python examples.py --model resnet18.onnx --example calibrate --data samples.npy
python examples.py --model resnet18.onnx --example pipeline --output-dir results/
""",
)
parser.add_argument("--model", required=True, help="Path to ONNX model")
parser.add_argument(
"--example",
default="pipeline",
choices=[
"inspect", "int8", "int4", "perchannel", "exclude",
"mixed", "calibrate", "compare", "verify", "pipeline",
],
help="Which example to run (default: pipeline)",
)
parser.add_argument("--output", default="model_quantized.onnx", help="Output path")
parser.add_argument("--output-dir", default="quantized", help="Output directory")
parser.add_argument("--data", default=None, help="Calibration data .npy path")
args = parser.parse_args()
if args.example == "inspect":
inspect_model(args.model)
elif args.example == "int8":
basic_int8(args.model, args.output)
elif args.example == "int4":
basic_int4(args.model, args.output)
elif args.example == "perchannel":
per_channel_int8(args.model, args.output)
elif args.example == "exclude":
exclude_sensitive_layers(args.model, args.output)
elif args.example == "mixed":
mixed_precision(args.model, args.output)
elif args.example == "calibrate":
if args.data:
calibrate_with_real_data(args.model, args.output, args.data)
else:
calibrate_with_random_samples(args.model, args.output)
elif args.example == "compare":
compare_calibration_methods(args.model, args.output_dir)
elif args.example == "verify":
basic_int8(args.model, args.output)
verify_with_onnxruntime(args.model, args.output)
elif args.example == "pipeline":
full_pipeline(args.model, args.output_dir)