use anyhow::{Context, Result};
use onnx_pb::{ModelProto, NodeProto, TensorProto};
use prost::Message;
use std::collections::{HashMap, HashSet};
use std::path::Path;
const FLOAT: i32 = 1;
const INT8: i32 = 3;
const QUANTIZABLE_OPS: &[&str] = &["MatMul", "Conv", "Gemm"];
const MIN_ELEMENTS: usize = 1024;
pub fn quantize_model(input: &Path, output: &Path) -> Result<()> {
let model_bytes = std::fs::read(input).context("Failed to read ONNX model")?;
let mut model =
ModelProto::decode(&model_bytes[..]).context("Failed to decode ONNX protobuf")?;
let graph = model.graph.as_mut().context("Model has no graph")?;
let init_map: HashMap<String, usize> = graph
.initializer
.iter()
.enumerate()
.map(|(i, t)| (t.name.clone(), i))
.collect();
let mut targets = Vec::new();
for (ni, node) in graph.node.iter().enumerate() {
if !QUANTIZABLE_OPS.contains(&node.op_type.as_str()) {
continue;
}
for (ii, input_name) in node.input.iter().enumerate() {
if ii == 0 {
continue; }
if let Some(&init_idx) = init_map.get(input_name) {
let init = &graph.initializer[init_idx];
if init.data_type != FLOAT {
continue;
}
let num_elements: i64 = init.dims.iter().product();
if num_elements > 0 && num_elements as usize >= MIN_ELEMENTS {
targets.push((ni, ii, input_name.clone(), init_idx));
}
}
}
}
tracing::info!(
"Found {} quantizable weight tensors in {} nodes",
targets.len(),
graph.node.len()
);
let mut new_nodes = Vec::new();
let mut new_initializers = Vec::new();
let mut quantized_names: HashSet<String> = HashSet::new();
for (_node_idx, _input_idx, weight_name, init_idx) in &targets {
if !quantized_names.insert(weight_name.clone()) {
continue;
}
let init = &graph.initializer[*init_idx];
let float_data = extract_float_data(init)?;
let dims = &init.dims;
if dims.is_empty() {
continue;
}
let channels = dims[0] as usize;
if channels == 0 {
continue;
}
let expected_elements: usize = dims.iter().map(|&d| d.max(0) as usize).product();
if expected_elements != float_data.len() {
tracing::warn!(
"Skipping tensor '{}': shape mismatch (dims={:?}, data={})",
init.name,
dims,
float_data.len()
);
continue;
}
let channel_size = float_data.len() / channels;
let mut quantized_data = Vec::with_capacity(float_data.len());
let mut scales = Vec::with_capacity(channels);
let zero_points = vec![0i8; channels];
for ch in 0..channels {
let start = ch * channel_size;
let end = start + channel_size;
let channel_data = &float_data[start..end];
let abs_max = channel_data.iter().fold(0.0f32, |m, &v| m.max(v.abs()));
let scale = if abs_max == 0.0 { 1.0 } else { abs_max / 127.0 };
scales.push(scale);
for &val in channel_data {
let q = (val / scale).round().clamp(-128.0, 127.0) as i8;
quantized_data.push(q);
}
}
let q_name = format!("{weight_name}_quantized");
let s_name = format!("{weight_name}_scale");
let zp_name = format!("{weight_name}_zero_point");
let dq_output = format!("{weight_name}_dequantized");
new_initializers.push(TensorProto {
name: q_name.clone(),
dims: dims.clone(),
data_type: INT8,
raw_data: quantized_data.iter().map(|&v| v as u8).collect(),
..Default::default()
});
new_initializers.push(TensorProto {
name: s_name.clone(),
dims: vec![channels as i64],
data_type: FLOAT,
float_data: scales,
..Default::default()
});
new_initializers.push(TensorProto {
name: zp_name.clone(),
dims: vec![channels as i64],
data_type: INT8,
raw_data: zero_points.iter().map(|&v| v as u8).collect(),
..Default::default()
});
new_nodes.push(NodeProto {
op_type: "DequantizeLinear".into(),
input: vec![q_name, s_name, zp_name],
output: vec![dq_output.clone()],
name: format!("dequant_{}", weight_name),
attribute: vec![onnx_pb::AttributeProto {
name: "axis".into(),
i: 0, r#type: 2, ..Default::default()
}],
..Default::default()
});
quantized_names.insert(weight_name.clone());
}
for (node_idx, input_idx, weight_name, _) in &targets {
let dq_output = format!("{weight_name}_dequantized");
graph.node[*node_idx].input[*input_idx] = dq_output;
}
graph
.initializer
.retain(|t| !quantized_names.contains(&t.name));
graph.initializer.extend(new_initializers);
let mut all_nodes = new_nodes;
all_nodes.append(&mut graph.node);
graph.node = all_nodes;
let mut output_bytes = Vec::new();
model
.encode(&mut output_bytes)
.context("Failed to encode quantized model")?;
let tmp = output.with_extension("onnx.tmp");
std::fs::write(&tmp, &output_bytes).context("Failed to write quantized model")?;
std::fs::rename(&tmp, output).context("Failed to finalize quantized model")?;
let in_mb = model_bytes.len() as f64 / (1024.0 * 1024.0);
let out_mb = output_bytes.len() as f64 / (1024.0 * 1024.0);
tracing::info!(
"Quantized: {in_mb:.0}MB → {out_mb:.0}MB ({:.1}x smaller)",
in_mb / out_mb
);
Ok(())
}
fn extract_float_data(tensor: &TensorProto) -> Result<Vec<f32>> {
if !tensor.float_data.is_empty() {
return Ok(tensor.float_data.clone());
}
if !tensor.raw_data.is_empty() {
anyhow::ensure!(
tensor.raw_data.len().is_multiple_of(4),
"Tensor '{}' raw_data length {} is not aligned to 4 bytes",
tensor.name,
tensor.raw_data.len()
);
let num_floats = tensor.raw_data.len() / 4;
let mut data = Vec::with_capacity(num_floats);
for chunk in tensor.raw_data.chunks_exact(4) {
data.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
}
return Ok(data);
}
anyhow::bail!("Tensor '{}' has no float data", tensor.name);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_float_data_from_float_data_field() {
let tensor = TensorProto {
name: "test".into(),
float_data: vec![1.0, 2.0, 3.0],
..Default::default()
};
let data = extract_float_data(&tensor).unwrap();
assert_eq!(data, vec![1.0, 2.0, 3.0]);
}
#[test]
fn test_extract_float_data_from_raw_data() {
let mut raw = Vec::new();
raw.extend_from_slice(&1.0f32.to_le_bytes());
raw.extend_from_slice(&(-2.5f32).to_le_bytes());
let tensor = TensorProto {
name: "test".into(),
raw_data: raw,
..Default::default()
};
let data = extract_float_data(&tensor).unwrap();
assert_eq!(data, vec![1.0, -2.5]);
}
#[test]
fn test_extract_float_data_empty() {
let tensor = TensorProto {
name: "empty".into(),
..Default::default()
};
assert!(extract_float_data(&tensor).is_err());
}
#[test]
fn test_symmetric_quantization_values() {
let val = 1.27f32;
let scale = val.abs() / 127.0; let q = (val / scale).round().clamp(-128.0, 127.0) as i8;
assert_eq!(q, 127);
let val2 = -1.27f32;
let q2 = (val2 / scale).round().clamp(-128.0, 127.0) as i8;
assert_eq!(q2, -127);
}
#[test]
fn test_zero_scale_handling() {
let data = vec![0.0f32; 100];
let abs_max = data.iter().fold(0.0f32, |m, &v| m.max(v.abs()));
let scale = if abs_max == 0.0 { 1.0 } else { abs_max / 127.0 };
assert_eq!(scale, 1.0);
}
}