use clap::Parser;
use quantize_rs::onnx_utils::graph_builder::QdqWeightInput;
use quantize_rs::{OnnxModel, QuantConfig, Quantizer};
use std::collections::HashMap;
#[derive(Parser)]
#[command(
name = "validate_real_model",
about = "Quantize and validate a real ONNX model",
long_about = None,
)]
struct Args {
input: String,
#[arg(long, default_value_t = 8)]
bits: u8,
#[arg(long, default_value_t = false)]
per_channel: bool,
#[arg(long, default_value_t = 128)]
min_elements: usize,
#[arg(long)]
output: Option<String>,
}
fn main() -> anyhow::Result<()> {
let args = Args::parse();
if args.bits != 4 && args.bits != 8 {
anyhow::bail!("--bits must be 4 or 8, got {}", args.bits);
}
println!("Loading model: {}", args.input);
let mut model = OnnxModel::load(&args.input)?;
let info = model.info();
let file_bytes = std::fs::metadata(&args.input)?.len() as usize;
println!(
" graph: \"{}\" nodes: {} inputs: {:?} outputs: {:?}",
info.name, info.num_nodes, info.inputs, info.outputs
);
println!(" file size: {}", fmt_bytes(file_bytes));
let weights = model.extract_weights();
println!(
"\nFound {} weight tensors ({} will be quantized, {} skipped by min_elements={})\n",
weights.len(),
weights
.iter()
.filter(|w| w.data.len() >= args.min_elements)
.count(),
weights
.iter()
.filter(|w| w.data.len() < args.min_elements)
.count(),
args.min_elements,
);
let int4_mode = args.bits == 4;
let onnx_col = if int4_mode { "ONNX bytes" } else { "Quantized" };
println!(
"{:<40} {:>10} {:>10} {:>10} {:>10} {:>8}",
"Tensor name", "Elements", "FP32", onnx_col, "MAE", "Bits"
);
println!("{}", "-".repeat(100));
let config = QuantConfig {
bits: args.bits,
per_channel: args.per_channel,
min_elements: args.min_elements,
..Default::default()
};
let quantizer = Quantizer::new(config.clone());
let mut qdq_data: Vec<QdqWeightInput> = Vec::new();
let mut total_fp32_bytes: usize = 0;
let mut total_onnx_bytes: usize = 0;
let mut total_packed_bytes: usize = 0;
let mut total_elements: usize = 0;
let mut skipped: usize = 0;
for w in &weights {
let fp32_bytes = w.data.len() * 4;
total_fp32_bytes += fp32_bytes;
total_elements += w.data.len();
if !config.should_quantize(&w.name, w.data.len()) {
println!(
"{:<40} {:>10} {:>10} {:>10} {:>10} {:>8}",
truncate(&w.name, 40),
fmt_count(w.data.len()),
fmt_bytes(fp32_bytes),
"skipped",
"-",
"-",
);
skipped += 1;
continue;
}
let quantized = quantizer.quantize_tensor(&w.data, w.shape.clone())?;
let onnx_bytes = w.data.len();
let packed_bytes = quantized.size_bytes();
let mae = quantized.quantization_error(&w.data);
let bits_used = quantized.bits();
println!(
"{:<40} {:>10} {:>10} {:>10} {:>10.2e} {:>8}",
truncate(&w.name, 40),
fmt_count(w.data.len()),
fmt_bytes(fp32_bytes),
fmt_bytes(onnx_bytes),
mae,
bits_used,
);
let (scales, zero_points) = quantized.get_all_scales_zero_points();
let is_pc = quantized.is_per_channel();
qdq_data.push(QdqWeightInput {
original_name: w.name.clone(),
quantized_values: quantized.data(),
scales,
zero_points,
bits: bits_used,
axis: if is_pc { Some(0) } else { None },
});
total_onnx_bytes += onnx_bytes;
total_packed_bytes += packed_bytes;
}
println!("{}", "-".repeat(100));
println!("\nSummary");
println!(
" Total tensors : {} ({} quantized, {} skipped)",
weights.len(),
qdq_data.len(),
skipped
);
println!(" Total elements: {}", fmt_count(total_elements));
println!(" FP32 weight bytes : {}", fmt_bytes(total_fp32_bytes));
println!(" ONNX storage (actual): {}", fmt_bytes(total_onnx_bytes));
if total_fp32_bytes > 0 {
let ratio = total_onnx_bytes as f64 / total_fp32_bytes as f64;
println!(
" Compression ratio : {:.1}x ({:.1}% of original)",
1.0 / ratio,
ratio * 100.0
);
}
if int4_mode && total_packed_bytes < total_onnx_bytes {
let ratio = total_packed_bytes as f64 / total_fp32_bytes as f64;
println!(
" Theoretical INT4 packed: {} ({:.1}x / {:.1}% of original)",
fmt_bytes(total_packed_bytes),
1.0 / ratio,
ratio * 100.0,
);
println!(" (INT4 values stored as INT8 in ONNX — opset 21 required for true 8x)");
}
if let Some(ref out_path) = args.output {
if qdq_data.is_empty() {
println!("\nNo tensors were quantized; skipping save.");
} else {
model.save_quantized(&qdq_data, out_path)?;
let reloaded = OnnxModel::load(out_path)?;
let report = reloaded.validate_connectivity();
let out_bytes = std::fs::metadata(out_path)?.len() as usize;
println!("\nSaved to: {out_path}");
println!(" Output file size: {}", fmt_bytes(out_bytes));
if report.valid {
println!(" Connectivity: OK");
} else {
println!(
" Connectivity: BROKEN ({} broken refs)",
report.broken_refs.len()
);
for r in &report.broken_refs {
println!(" - {r}");
}
}
}
}
Ok(())
}
fn fmt_bytes(n: usize) -> String {
if n >= 1_073_741_824 {
format!("{:.2} GB", n as f64 / 1_073_741_824.0)
} else if n >= 1_048_576 {
format!("{:.2} MB", n as f64 / 1_048_576.0)
} else if n >= 1_024 {
format!("{:.1} KB", n as f64 / 1_024.0)
} else {
format!("{n} B")
}
}
fn fmt_count(n: usize) -> String {
if n >= 1_000_000 {
format!("{:.2}M", n as f64 / 1_000_000.0)
} else if n >= 1_000 {
format!("{:.1}K", n as f64 / 1_000.0)
} else {
n.to_string()
}
}
fn truncate(s: &str, max: usize) -> String {
if s.len() <= max {
s.to_string()
} else {
format!("{}…", &s[..max - 1])
}
}
#[allow(dead_code)]
fn parse_layer_bits_map(pairs: &[(String, u8)]) -> HashMap<String, u8> {
pairs.iter().cloned().collect()
}