use prost::Message;
use quantize_rs::calibration::methods::CalibrationMethod;
use quantize_rs::calibration::stats::ActivationStats;
use quantize_rs::onnx_proto::{
tensor_proto, GraphProto, ModelProto, NodeProto, OperatorSetIdProto, TensorProto,
ValueInfoProto,
};
use quantize_rs::onnx_utils::graph_builder::QdqWeightInput;
use quantize_rs::quantization::{QuantConfig, QuantizedTensor, QuantizedTensorInt4, Quantizer};
use quantize_rs::*;
use std::collections::HashMap;
fn build_minimal_model(weight_data: &[f32], weight_shape: &[i64]) -> ModelProto {
ModelProto {
opset_import: vec![OperatorSetIdProto {
domain: String::new(),
version: 13,
}],
graph: Some(GraphProto {
name: "test_graph".to_string(),
input: vec![ValueInfoProto {
name: "input".to_string(),
..Default::default()
}],
output: vec![ValueInfoProto {
name: "output".to_string(),
..Default::default()
}],
initializer: vec![TensorProto {
name: "weight".to_string(),
data_type: tensor_proto::DataType::Float as i32,
dims: weight_shape.to_vec(),
float_data: weight_data.to_vec(),
..Default::default()
}],
node: vec![NodeProto {
op_type: "Conv".to_string(),
name: "conv0".to_string(),
input: vec!["input".to_string(), "weight".to_string()],
output: vec!["output".to_string()],
..Default::default()
}],
..Default::default()
}),
..Default::default()
}
}
fn build_two_weight_model(
w1_data: &[f32],
w1_shape: &[i64],
w2_data: &[f32],
w2_shape: &[i64],
) -> ModelProto {
ModelProto {
opset_import: vec![OperatorSetIdProto {
domain: String::new(),
version: 13,
}],
graph: Some(GraphProto {
name: "test_two_weight".to_string(),
input: vec![ValueInfoProto {
name: "input".to_string(),
..Default::default()
}],
output: vec![ValueInfoProto {
name: "output".to_string(),
..Default::default()
}],
initializer: vec![
TensorProto {
name: "w1".to_string(),
data_type: tensor_proto::DataType::Float as i32,
dims: w1_shape.to_vec(),
float_data: w1_data.to_vec(),
..Default::default()
},
TensorProto {
name: "w2".to_string(),
data_type: tensor_proto::DataType::Float as i32,
dims: w2_shape.to_vec(),
float_data: w2_data.to_vec(),
..Default::default()
},
],
node: vec![
NodeProto {
op_type: "Conv".to_string(),
name: "conv1".to_string(),
input: vec!["input".to_string(), "w1".to_string()],
output: vec!["mid".to_string()],
..Default::default()
},
NodeProto {
op_type: "Conv".to_string(),
name: "conv2".to_string(),
input: vec!["mid".to_string(), "w2".to_string()],
output: vec!["output".to_string()],
..Default::default()
},
],
..Default::default()
}),
..Default::default()
}
}
fn write_model_to_tempfile(
model: &ModelProto,
dir: &tempfile::TempDir,
name: &str,
) -> std::path::PathBuf {
let path = dir.path().join(name);
let mut buf = Vec::new();
model.encode(&mut buf).unwrap();
std::fs::write(&path, buf).unwrap();
path
}
#[test]
fn test_quantize_simple_model_int8() {
let weight_data: Vec<f32> = (0..16).map(|i| (i as f32 - 8.0) * 0.1).collect();
let model_proto = build_minimal_model(&weight_data, &[4, 4]);
let dir = tempfile::tempdir().unwrap();
let model_path = write_model_to_tempfile(&model_proto, &dir, "model.onnx");
let mut model = OnnxModel::load(&model_path).unwrap();
let weights = model.extract_weights();
assert_eq!(weights.len(), 1);
assert_eq!(weights[0].name, "weight");
let quantizer = Quantizer::new(QuantConfig {
bits: 8,
per_channel: false,
calibration_method: None,
..Default::default()
});
let quantized = quantizer
.quantize_tensor(&weights[0].data, weights[0].shape.clone())
.unwrap();
assert!(quantized.is_int8());
let (scales, zero_points) = quantized.get_all_scales_zero_points();
let is_pc = quantized.is_per_channel();
let qdq_data = vec![QdqWeightInput {
original_name: weights[0].name.clone(),
quantized_values: quantized.data(),
scales,
zero_points,
bits: 8,
axis: if is_pc { Some(0) } else { None },
}];
let output_path = dir.path().join("model_int8.onnx");
model.save_quantized(&qdq_data, &output_path).unwrap();
let reloaded = OnnxModel::load(&output_path).unwrap();
let report = reloaded.validate_connectivity();
assert!(
report.valid,
"Connectivity broken: {:?}",
report.broken_refs
);
let qinfo = reloaded.load_quantized_info();
assert_eq!(qinfo.len(), 1);
assert_eq!(qinfo[0].name, "weight");
assert_eq!(qinfo[0].bits, 8);
assert!(qinfo[0].scale() > 0.0);
assert_eq!(qinfo[0].scales.len(), 1, "per-tensor should have 1 scale");
}
#[test]
fn test_quantize_simple_model_int4() {
let weight_data: Vec<f32> = (0..16).map(|i| (i as f32 - 8.0) * 0.1).collect();
let model_proto = build_minimal_model(&weight_data, &[4, 4]);
let dir = tempfile::tempdir().unwrap();
let model_path = write_model_to_tempfile(&model_proto, &dir, "model.onnx");
let mut model = OnnxModel::load(&model_path).unwrap();
let weights = model.extract_weights();
let quantizer = Quantizer::new(QuantConfig {
bits: 4,
per_channel: false,
calibration_method: None,
..Default::default()
});
let quantized = quantizer
.quantize_tensor(&weights[0].data, weights[0].shape.clone())
.unwrap();
assert!(quantized.is_int4());
assert_eq!(quantized.bits(), 4);
let (scales, zero_points) = quantized.get_all_scales_zero_points();
let qdq_data = vec![QdqWeightInput {
original_name: weights[0].name.clone(),
quantized_values: quantized.data(),
scales,
zero_points,
bits: 4,
axis: None,
}];
let output_path = dir.path().join("model_int4.onnx");
model.save_quantized(&qdq_data, &output_path).unwrap();
let reloaded = OnnxModel::load(&output_path).unwrap();
let report = reloaded.validate_connectivity();
assert!(
report.valid,
"Connectivity broken: {:?}",
report.broken_refs
);
let qinfo = reloaded.load_quantized_info();
assert_eq!(qinfo.len(), 1);
assert_eq!(qinfo[0].bits, 4);
}
#[test]
fn test_quantize_per_channel() {
let mut w1_data = Vec::new();
for i in 0..8 {
w1_data.push(i as f32 * 0.01);
} for i in 0..8 {
w1_data.push(i as f32 * 1.0);
}
let mut w2_data = Vec::new();
for i in 0..8 {
w2_data.push(i as f32 * 0.5);
}
for i in 0..8 {
w2_data.push(i as f32 * 2.0);
}
let model_proto = build_two_weight_model(&w1_data, &[2, 8], &w2_data, &[2, 8]);
let dir = tempfile::tempdir().unwrap();
let model_path = write_model_to_tempfile(&model_proto, &dir, "model.onnx");
let mut model = OnnxModel::load(&model_path).unwrap();
let weights = model.extract_weights();
assert_eq!(weights.len(), 2);
let quantizer = Quantizer::new(QuantConfig {
bits: 8,
per_channel: true,
calibration_method: None,
..Default::default()
});
let mut qdq_data = Vec::new();
for w in &weights {
let quantized = quantizer.quantize_tensor(&w.data, w.shape.clone()).unwrap();
assert!(quantized.is_per_channel());
let (scales, zero_points) = quantized.get_all_scales_zero_points();
assert_eq!(scales.len(), 2);
assert_eq!(zero_points.len(), 2);
qdq_data.push(QdqWeightInput {
original_name: w.name.clone(),
quantized_values: quantized.data(),
scales,
zero_points,
bits: 8,
axis: Some(0),
});
}
let output_path = dir.path().join("model_pc.onnx");
model.save_quantized(&qdq_data, &output_path).unwrap();
let reloaded = OnnxModel::load(&output_path).unwrap();
let report = reloaded.validate_connectivity();
assert!(
report.valid,
"Connectivity broken: {:?}",
report.broken_refs
);
}
#[test]
fn test_round_trip_quantization_accuracy() {
let data: Vec<f32> = (0..1000).map(|i| (i as f32 / 999.0) * 2.0 - 1.0).collect();
let shape = vec![1000];
let q8 = QuantizedTensor::from_f32(&data, shape.clone()).unwrap();
let mse_8 = q8.quantization_error(&data);
assert!(mse_8 < 1e-4, "INT8 MSE too high: {}", mse_8);
let q4 = QuantizedTensorInt4::from_f32(&data, shape).unwrap();
let mse_4 = q4.quantization_error(&data);
assert!(mse_4 < 0.01, "INT4 MSE too high: {}", mse_4);
assert!(mse_4 > mse_8, "INT4 error should exceed INT8 error");
assert_eq!(q8.to_f32().len(), 1000);
assert_eq!(q4.to_f32().len(), 1000);
}
#[test]
fn test_error_variants_are_correct() {
let result = QuantizedTensor::from_f32(&[], vec![0]);
assert!(result.is_err());
assert!(
matches!(result.unwrap_err(), QuantizeError::InvalidTensor { .. }),
"expected InvalidTensor for empty tensor"
);
let result = QuantizedTensor::from_f32(&[1.0, 2.0], vec![3]);
assert!(matches!(
result.unwrap_err(),
QuantizeError::InvalidTensor { .. }
));
let result = QuantizedTensor::from_f32_per_channel(&[1.0], vec![]);
assert!(matches!(
result.unwrap_err(),
QuantizeError::InvalidTensor { .. }
));
let quantizer = Quantizer::new(QuantConfig {
bits: 3,
per_channel: false,
calibration_method: None,
..Default::default()
});
let result = quantizer.quantize_tensor(&[1.0, 2.0], vec![2]);
assert!(matches!(
result.unwrap_err(),
QuantizeError::UnsupportedConfig { .. }
));
let result = OnnxModel::load("/nonexistent/path/model.onnx");
assert!(matches!(
result.unwrap_err(),
QuantizeError::ModelLoad { .. }
));
let result: Result<quantize_rs::calibration::methods::CalibrationMethod, _> = "invalid".parse();
assert!(matches!(result.unwrap_err(), QuantizeError::Config { .. }));
let cfg = Config::from_yaml("bits: 3").unwrap();
assert!(matches!(
cfg.validate().unwrap_err(),
QuantizeError::Config { .. }
));
}
#[test]
fn test_mixed_precision_quantization() {
let w1_data: Vec<f32> = (0..16).map(|i| (i as f32 - 8.0) * 0.1).collect();
let w2_data: Vec<f32> = (0..16).map(|i| (i as f32 - 8.0) * 0.5).collect();
let model_proto = build_two_weight_model(&w1_data, &[4, 4], &w2_data, &[4, 4]);
let dir = tempfile::tempdir().unwrap();
let model_path = write_model_to_tempfile(&model_proto, &dir, "model.onnx");
let mut model = OnnxModel::load(&model_path).unwrap();
let weights = model.extract_weights();
assert_eq!(weights.len(), 2);
let mut layer_bits = HashMap::new();
layer_bits.insert("w1".to_string(), 4u8);
let config = QuantConfig {
bits: 8,
per_channel: false,
calibration_method: None,
layer_bits,
..Default::default()
};
let mut qdq_data = Vec::new();
for w in &weights {
let layer_config = QuantConfig {
bits: config.bits_for_layer(&w.name),
..config.clone()
};
let quantized = Quantizer::new(layer_config)
.quantize_tensor(&w.data, w.shape.clone())
.unwrap();
let (scales, zero_points) = quantized.get_all_scales_zero_points();
let bits_used = quantized.bits();
if w.name == "w1" {
assert!(quantized.is_int4(), "w1 should be INT4");
assert_eq!(bits_used, 4);
} else {
assert!(quantized.is_int8(), "w2 should be INT8");
assert_eq!(bits_used, 8);
}
qdq_data.push(QdqWeightInput {
original_name: w.name.clone(),
quantized_values: quantized.data(),
scales,
zero_points,
bits: bits_used,
axis: None,
});
}
let output_path = dir.path().join("model_mixed.onnx");
model.save_quantized(&qdq_data, &output_path).unwrap();
let reloaded = OnnxModel::load(&output_path).unwrap();
let report = reloaded.validate_connectivity();
assert!(
report.valid,
"Connectivity broken: {:?}",
report.broken_refs
);
let qinfo = reloaded.load_quantized_info();
assert_eq!(qinfo.len(), 2);
let w1_info = qinfo.iter().find(|q| q.name == "w1").expect("w1 not found");
let w2_info = qinfo.iter().find(|q| q.name == "w2").expect("w2 not found");
assert_eq!(w1_info.bits, 4, "w1 bits should be 4 in metadata");
assert_eq!(w2_info.bits, 8, "w2 bits should be 8 in metadata");
assert!(w1_info.scale() > 0.0);
assert!(w2_info.scale() > 0.0);
}
#[test]
fn test_config_layer_bits() {
let yaml = r#"
bits: 8
models:
- input: model.onnx
output: model_mixed.onnx
layer_bits:
conv1.weight: 4
head.weight: 8
"#;
let config = Config::from_yaml(yaml).unwrap();
config.validate().unwrap();
let model_cfg = &config.models[0];
let lb = config.get_layer_bits(model_cfg);
assert_eq!(lb.get("conv1.weight"), Some(&4u8));
assert_eq!(lb.get("head.weight"), Some(&8u8));
let toml_str = r#"
bits = 8
[[models]]
input = "model.onnx"
output = "model_mixed.onnx"
[models.layer_bits]
"conv1.weight" = 4
"head.weight" = 8
"#;
let config = Config::from_toml(toml_str).unwrap();
config.validate().unwrap();
let lb = config.get_layer_bits(&config.models[0]);
assert_eq!(lb.get("conv1.weight"), Some(&4u8));
let yaml_bad = r#"
bits: 8
models:
- input: model.onnx
output: model_mixed.onnx
layer_bits:
conv1.weight: 3
"#;
let cfg = Config::from_yaml(yaml_bad).unwrap();
assert!(matches!(
cfg.validate().unwrap_err(),
QuantizeError::Config { .. }
));
}
fn build_multilayer_model() -> ModelProto {
let make_tensor = |name: &str, dims: &[i64], n: usize| TensorProto {
name: name.to_string(),
data_type: tensor_proto::DataType::Float as i32,
dims: dims.to_vec(),
float_data: (0..n).map(|i| (i as f32) / (n as f32) - 0.5).collect(),
..Default::default()
};
ModelProto {
opset_import: vec![OperatorSetIdProto {
domain: String::new(),
version: 13,
}],
graph: Some(GraphProto {
name: "multilayer".to_string(),
input: vec![ValueInfoProto {
name: "input".to_string(),
..Default::default()
}],
output: vec![ValueInfoProto {
name: "output".to_string(),
..Default::default()
}],
initializer: vec![
make_tensor("conv1.weight", &[8, 3, 3, 3], 216),
make_tensor("conv1.bias", &[8], 8),
make_tensor("conv2.weight", &[16, 8, 3, 3], 1152),
make_tensor("conv2.bias", &[16], 16),
make_tensor("fc.weight", &[10, 144], 1440),
make_tensor("fc.bias", &[10], 10),
],
node: vec![
NodeProto {
op_type: "Conv".to_string(),
name: "conv1".to_string(),
input: vec![
"input".to_string(),
"conv1.weight".to_string(),
"conv1.bias".to_string(),
],
output: vec!["conv1_out".to_string()],
..Default::default()
},
NodeProto {
op_type: "Conv".to_string(),
name: "conv2".to_string(),
input: vec![
"conv1_out".to_string(),
"conv2.weight".to_string(),
"conv2.bias".to_string(),
],
output: vec!["conv2_out".to_string()],
..Default::default()
},
NodeProto {
op_type: "Gemm".to_string(),
name: "fc".to_string(),
input: vec![
"conv2_out".to_string(),
"fc.weight".to_string(),
"fc.bias".to_string(),
],
output: vec!["output".to_string()],
..Default::default()
},
],
..Default::default()
}),
..Default::default()
}
}
fn quantize_weights(
config: &QuantConfig,
weights: &[quantize_rs::WeightTensor],
) -> Vec<QdqWeightInput> {
weights
.iter()
.filter(|w| config.should_quantize(&w.name, w.data.len()))
.map(|w| {
let layer_config = QuantConfig {
bits: config.bits_for_layer(&w.name),
..config.clone()
};
let quantized = Quantizer::new(layer_config)
.quantize_tensor(&w.data, w.shape.clone())
.unwrap();
let (scales, zero_points) = quantized.get_all_scales_zero_points();
let bits = quantized.bits();
let is_pc = quantized.is_per_channel();
QdqWeightInput {
original_name: w.name.clone(),
quantized_values: quantized.data(),
scales,
zero_points,
bits,
axis: if is_pc { Some(0) } else { None },
}
})
.collect()
}
#[test]
fn test_multilayer_min_elements() {
let model_proto = build_multilayer_model();
let dir = tempfile::tempdir().unwrap();
let model_path = write_model_to_tempfile(&model_proto, &dir, "model.onnx");
let mut model = OnnxModel::load(&model_path).unwrap();
let weights = model.extract_weights();
assert_eq!(weights.len(), 6);
let config = QuantConfig {
bits: 8,
min_elements: 100,
..Default::default()
};
let qdq_data = quantize_weights(&config, &weights);
assert_eq!(qdq_data.len(), 3, "expected 3 large weights quantized");
let names: Vec<&str> = qdq_data.iter().map(|q| q.original_name.as_str()).collect();
assert!(names.contains(&"conv1.weight"));
assert!(names.contains(&"conv2.weight"));
assert!(names.contains(&"fc.weight"));
assert!(!names.contains(&"conv1.bias"));
assert!(!names.contains(&"conv2.bias"));
assert!(!names.contains(&"fc.bias"));
let output_path = dir.path().join("model_min_elements.onnx");
model.save_quantized(&qdq_data, &output_path).unwrap();
let reloaded = OnnxModel::load(&output_path).unwrap();
assert!(reloaded.validate_connectivity().valid);
assert_eq!(reloaded.load_quantized_info().len(), 3);
}
#[test]
fn test_multilayer_excluded_layers() {
let model_proto = build_multilayer_model();
let dir = tempfile::tempdir().unwrap();
let model_path = write_model_to_tempfile(&model_proto, &dir, "model.onnx");
let mut model = OnnxModel::load(&model_path).unwrap();
let weights = model.extract_weights();
let config = QuantConfig {
bits: 8,
excluded_layers: vec!["conv1.weight".to_string(), "fc.bias".to_string()],
..Default::default()
};
let qdq_data = quantize_weights(&config, &weights);
assert_eq!(qdq_data.len(), 4, "expected 4 weights after exclusions");
let names: Vec<&str> = qdq_data.iter().map(|q| q.original_name.as_str()).collect();
assert!(
!names.contains(&"conv1.weight"),
"conv1.weight should be excluded"
);
assert!(!names.contains(&"fc.bias"), "fc.bias should be excluded");
let output_path = dir.path().join("model_excluded.onnx");
model.save_quantized(&qdq_data, &output_path).unwrap();
let reloaded = OnnxModel::load(&output_path).unwrap();
assert!(reloaded.validate_connectivity().valid);
assert_eq!(reloaded.load_quantized_info().len(), 4);
}
#[test]
fn test_multilayer_full_round_trip() {
let model_proto = build_multilayer_model();
let dir = tempfile::tempdir().unwrap();
let model_path = write_model_to_tempfile(&model_proto, &dir, "model.onnx");
let mut model = OnnxModel::load(&model_path).unwrap();
let weights = model.extract_weights();
assert_eq!(weights.len(), 6);
let config = QuantConfig {
bits: 8,
per_channel: true,
..Default::default()
};
let qdq_data = quantize_weights(&config, &weights);
assert_eq!(qdq_data.len(), 6, "all 6 weights should be quantized");
let output_path = dir.path().join("model_full.onnx");
model.save_quantized(&qdq_data, &output_path).unwrap();
let reloaded = OnnxModel::load(&output_path).unwrap();
let report = reloaded.validate_connectivity();
assert!(
report.valid,
"Connectivity broken: {:?}",
report.broken_refs
);
let qinfo = reloaded.load_quantized_info();
assert_eq!(qinfo.len(), 6, "all 6 weights should appear in metadata");
for info in &qinfo {
assert!(
info.scale() > 0.0,
"scale must be positive for {}",
info.name
);
assert_eq!(info.bits, 8);
}
}
#[test]
fn test_multilayer_compression_ratio() {
let model_proto = build_multilayer_model();
let dir = tempfile::tempdir().unwrap();
let model_path = write_model_to_tempfile(&model_proto, &dir, "model.onnx");
let weights = {
let m = OnnxModel::load(&model_path).unwrap();
m.extract_weights()
};
let original_bytes: usize = weights.iter().map(|w| w.data.len() * 4).sum();
let cfg8 = QuantConfig {
bits: 8,
..Default::default()
};
let bytes_int8: usize = weights
.iter()
.map(|w| {
Quantizer::new(cfg8.clone())
.quantize_tensor(&w.data, w.shape.clone())
.unwrap()
.size_bytes()
})
.sum();
let cfg4 = QuantConfig {
bits: 4,
..Default::default()
};
let bytes_int4: usize = weights
.iter()
.map(|w| {
Quantizer::new(cfg4.clone())
.quantize_tensor(&w.data, w.shape.clone())
.unwrap()
.size_bytes()
})
.sum();
assert!(
bytes_int8 < original_bytes,
"INT8 ({bytes_int8} B) should be smaller than FP32 ({original_bytes} B)"
);
assert!(
bytes_int4 < bytes_int8,
"INT4 ({bytes_int4} B) should be smaller than INT8 ({bytes_int8} B)"
);
let ratio8 = bytes_int8 as f64 / original_bytes as f64;
let ratio4 = bytes_int4 as f64 / original_bytes as f64;
assert!(ratio8 < 0.5, "INT8 ratio {ratio8:.2} should be < 0.5");
assert!(ratio4 < 0.3, "INT4 ratio {ratio4:.2} should be < 0.3");
}
#[test]
fn test_dual_input_initializer_model() {
let weight_data: Vec<f32> = (0..16).map(|i| (i as f32 - 8.0) * 0.1).collect();
let model_proto = ModelProto {
opset_import: vec![OperatorSetIdProto {
domain: String::new(),
version: 13,
}],
graph: Some(GraphProto {
name: "dual_input".to_string(),
input: vec![
ValueInfoProto {
name: "input".to_string(),
..Default::default()
},
ValueInfoProto {
name: "weight".to_string(),
..Default::default()
},
],
output: vec![ValueInfoProto {
name: "output".to_string(),
..Default::default()
}],
initializer: vec![TensorProto {
name: "weight".to_string(),
data_type: tensor_proto::DataType::Float as i32,
dims: vec![4, 4],
float_data: weight_data.clone(),
..Default::default()
}],
node: vec![NodeProto {
op_type: "Conv".to_string(),
name: "conv0".to_string(),
input: vec!["input".to_string(), "weight".to_string()],
output: vec!["output".to_string()],
..Default::default()
}],
..Default::default()
}),
..Default::default()
};
let dir = tempfile::tempdir().unwrap();
let model_path = write_model_to_tempfile(&model_proto, &dir, "dual.onnx");
let mut model = OnnxModel::load(&model_path).unwrap();
let info = model.info();
assert_eq!(
info.inputs.len(),
2,
"model should have 2 inputs (data + weight)"
);
let weights = model.extract_weights();
assert_eq!(weights.len(), 1);
let quantizer = Quantizer::new(QuantConfig {
bits: 8,
..Default::default()
});
let quantized = quantizer
.quantize_tensor(&weights[0].data, weights[0].shape.clone())
.unwrap();
let (scales, zero_points) = quantized.get_all_scales_zero_points();
let qdq_data = vec![QdqWeightInput {
original_name: weights[0].name.clone(),
quantized_values: quantized.data(),
scales,
zero_points,
bits: 8,
axis: None,
}];
let output_path = dir.path().join("dual_int8.onnx");
model.save_quantized(&qdq_data, &output_path).unwrap();
let reloaded = OnnxModel::load(&output_path).unwrap();
let report = reloaded.validate_connectivity();
assert!(
report.valid,
"Connectivity broken: {:?}",
report.broken_refs
);
let reloaded_info = reloaded.info();
assert_eq!(
reloaded_info.inputs.len(),
1,
"QDQ transform should remove weight from graph.input; got {:?}",
reloaded_info.inputs
);
assert_eq!(reloaded_info.inputs[0], "input");
}
#[test]
#[ignore = "set QUANTIZE_RS_TEST_MODEL=/path/to/model.onnx to enable"]
fn test_real_model_int8() {
let path = std::env::var("QUANTIZE_RS_TEST_MODEL")
.expect("QUANTIZE_RS_TEST_MODEL must point to an ONNX file");
let mut model = OnnxModel::load(&path).expect("failed to load model");
let weights = model.extract_weights();
assert!(!weights.is_empty(), "model has no extractable weights");
let config = QuantConfig {
bits: 8,
per_channel: false,
min_elements: 128,
..Default::default()
};
let qdq_data = quantize_weights(&config, &weights);
assert!(
!qdq_data.is_empty(),
"no weights passed the min_elements=128 filter"
);
let dir = tempfile::tempdir().unwrap();
let output_path = dir.path().join("model_int8.onnx");
model
.save_quantized(&qdq_data, &output_path)
.expect("save failed");
let reloaded = OnnxModel::load(&output_path).unwrap();
let report = reloaded.validate_connectivity();
assert!(
report.valid,
"Connectivity broken after INT8 quantization: {:?}",
report.broken_refs
);
let qinfo = reloaded.load_quantized_info();
assert_eq!(
qinfo.len(),
qdq_data.len(),
"metadata weight count mismatch"
);
for info in &qinfo {
assert_eq!(info.bits, 8);
assert!(info.scale() > 0.0);
}
}
#[test]
#[ignore = "set QUANTIZE_RS_TEST_MODEL=/path/to/model.onnx to enable"]
fn test_real_model_int4() {
let path = std::env::var("QUANTIZE_RS_TEST_MODEL")
.expect("QUANTIZE_RS_TEST_MODEL must point to an ONNX file");
let mut model = OnnxModel::load(&path).expect("failed to load model");
let weights = model.extract_weights();
let config = QuantConfig {
bits: 4,
per_channel: false,
min_elements: 128,
..Default::default()
};
let qdq_data = quantize_weights(&config, &weights);
let dir = tempfile::tempdir().unwrap();
let output_path = dir.path().join("model_int4.onnx");
model
.save_quantized(&qdq_data, &output_path)
.expect("save failed");
let reloaded = OnnxModel::load(&output_path).unwrap();
let report = reloaded.validate_connectivity();
assert!(
report.valid,
"Connectivity broken after INT4 quantization: {:?}",
report.broken_refs
);
let qinfo = reloaded.load_quantized_info();
assert_eq!(qinfo.len(), qdq_data.len());
for info in &qinfo {
assert_eq!(info.bits, 4);
assert!(info.scale() > 0.0);
}
}
#[test]
fn test_config_round_trip() {
let yaml = r#"
bits: 8
per_channel: true
models:
- input: model.onnx
output: model_int8.onnx
"#;
let config = Config::from_yaml(yaml).unwrap();
assert_eq!(config.bits, 8);
assert!(config.per_channel);
config.validate().unwrap();
let toml_str = r#"
bits = 4
per_channel = false
[[models]]
input = "a.onnx"
output = "b.onnx"
"#;
let config = Config::from_toml(toml_str).unwrap();
assert_eq!(config.bits, 4);
config.validate().unwrap();
let result = Config::from_yaml("bits: [invalid");
assert!(matches!(result.unwrap_err(), QuantizeError::Config { .. }));
let result = Config::from_toml("bits = [invalid");
assert!(matches!(result.unwrap_err(), QuantizeError::Config { .. }));
let cfg = Config::from_yaml("bits: 16").unwrap();
assert!(matches!(
cfg.validate().unwrap_err(),
QuantizeError::Config { .. }
));
let yaml_bad = r#"
bits: 8
models:
- input: ""
output: "out.onnx"
"#;
let cfg = Config::from_yaml(yaml_bad).unwrap();
assert!(matches!(
cfg.validate().unwrap_err(),
QuantizeError::Config { .. }
));
}
#[test]
fn test_calibrated_quantization_uses_stats() {
let weight_data: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) / 32.0).collect();
let shape = vec![4, 16];
let config_uncalibrated = QuantConfig {
bits: 8,
..Default::default()
};
let q_uncalibrated = Quantizer::new(config_uncalibrated);
let result_uncalibrated = q_uncalibrated
.quantize_tensor_with_name("layer0.weight", &weight_data, shape.clone())
.unwrap();
let activation_data: Vec<f32> = (0..1000).map(|i| (i as f32 - 500.0) / 1000.0).collect();
let stats = ActivationStats::from_data(&activation_data);
let mut stats_map = HashMap::new();
stats_map.insert("layer0.weight".to_string(), stats);
let config_calibrated = QuantConfig {
bits: 8,
..Default::default()
};
let q_calibrated = Quantizer::with_calibration(config_calibrated, stats_map);
let result_calibrated = q_calibrated
.quantize_tensor_with_name("layer0.weight", &weight_data, shape.clone())
.unwrap();
let (scales_uncal, _) = result_uncalibrated.get_all_scales_zero_points();
let (scales_cal, _) = result_calibrated.get_all_scales_zero_points();
assert_eq!(scales_uncal.len(), 1);
assert_eq!(scales_cal.len(), 1);
assert!(
scales_cal[0] < scales_uncal[0],
"Calibrated scale ({}) should be < uncalibrated scale ({})",
scales_cal[0],
scales_uncal[0],
);
}
#[test]
fn test_calibrated_quantization_with_method() {
let weight_data: Vec<f32> = (0..100).map(|i| (i as f32 - 50.0) * 0.1).collect();
let shape = vec![10, 10];
let mut activation_data: Vec<f32> = (0..500).map(|i| (i as f32 - 250.0) / 250.0).collect();
activation_data.push(10.0); activation_data.push(-10.0); let stats = ActivationStats::from_data(&activation_data);
let mut stats_map = HashMap::new();
stats_map.insert("conv.weight".to_string(), stats);
let config = QuantConfig {
bits: 8,
calibration_method: Some(CalibrationMethod::MinMax),
..Default::default()
};
let quantizer = Quantizer::with_calibration(config, stats_map);
let result = quantizer
.quantize_tensor_with_name("conv.weight", &weight_data, shape)
.unwrap();
assert!(result.is_int8());
let error = result.quantization_error(&weight_data);
assert!(error.is_finite());
assert!(
error < 10.0,
"Quantization error unexpectedly high: {}",
error
);
}
#[test]
fn test_calibrated_quantization_fallback_no_stats() {
let weight_data: Vec<f32> = (0..32).map(|i| (i as f32 - 16.0) * 0.5).collect();
let shape = vec![4, 8];
let stats = ActivationStats::from_data(&[0.0, 1.0]);
let mut stats_map = HashMap::new();
stats_map.insert("other_layer.weight".to_string(), stats);
let config = QuantConfig {
bits: 8,
..Default::default()
};
let q_calibrated = Quantizer::with_calibration(config.clone(), stats_map);
let result_calibrated = q_calibrated
.quantize_tensor_with_name("conv.weight", &weight_data, shape.clone())
.unwrap();
let q_uncalibrated = Quantizer::new(config);
let result_uncalibrated = q_uncalibrated
.quantize_tensor_with_name("conv.weight", &weight_data, shape)
.unwrap();
let (scales_cal, zps_cal) = result_calibrated.get_all_scales_zero_points();
let (scales_uncal, zps_uncal) = result_uncalibrated.get_all_scales_zero_points();
assert_eq!(
scales_cal, scales_uncal,
"Fallback should match uncalibrated"
);
assert_eq!(zps_cal, zps_uncal, "Fallback should match uncalibrated");
}
#[test]
fn test_calibrated_quantization_int4() {
let weight_data: Vec<f32> = (0..48).map(|i| (i as f32 - 24.0) / 24.0).collect();
let shape = vec![6, 8];
let activation_data: Vec<f32> = (0..200).map(|i| (i as f32 - 100.0) / 100.0).collect();
let stats = ActivationStats::from_data(&activation_data);
let mut stats_map = HashMap::new();
stats_map.insert("fc.weight".to_string(), stats);
let config = QuantConfig {
bits: 4,
..Default::default()
};
let quantizer = Quantizer::with_calibration(config, stats_map);
let result = quantizer
.quantize_tensor_with_name("fc.weight", &weight_data, shape)
.unwrap();
assert!(result.is_int4());
assert_eq!(result.bits(), 4);
let error = result.quantization_error(&weight_data);
assert!(error.is_finite());
}
#[test]
fn test_calibrated_full_pipeline() {
let weight_data: Vec<f32> = (0..16).map(|i| (i as f32 - 8.0) * 0.1).collect();
let model = build_minimal_model(&weight_data, &[4, 4]);
let dir = tempfile::tempdir().unwrap();
let input_path = write_model_to_tempfile(&model, &dir, "calib_input.onnx");
let output_path = dir.path().join("calib_output.onnx");
let mut loaded = OnnxModel::load(&input_path).unwrap();
let weights = loaded.extract_weights();
assert_eq!(weights.len(), 1);
let activation_data: Vec<f32> = (0..500).map(|i| (i as f32 - 250.0) / 500.0).collect();
let stats = ActivationStats::from_data(&activation_data);
let mut stats_map = HashMap::new();
stats_map.insert("weight".to_string(), stats);
let config = QuantConfig {
bits: 8,
..Default::default()
};
let quantizer = Quantizer::with_calibration(config, stats_map);
let mut quantized_data = Vec::new();
for weight in &weights {
let quantized = quantizer
.quantize_tensor_with_name(&weight.name, &weight.data, weight.shape.clone())
.unwrap();
let (scales, zero_points) = quantized.get_all_scales_zero_points();
let is_per_channel = quantized.is_per_channel();
quantized_data.push(QdqWeightInput {
original_name: weight.name.clone(),
quantized_values: quantized.data(),
scales,
zero_points,
bits: quantized.bits(),
axis: if is_per_channel { Some(0) } else { None },
});
}
loaded
.save_quantized(&quantized_data, &output_path)
.unwrap();
let reloaded = OnnxModel::load(&output_path).unwrap();
let report = reloaded.validate_connectivity();
assert!(
report.valid,
"Connectivity broken: {:?}",
report.broken_refs
);
let qdq_info = reloaded.load_quantized_info();
assert_eq!(qdq_info.len(), 1);
assert_eq!(qdq_info[0].name, "weight");
assert!(qdq_info[0].scale() > 0.0);
}
use quantize_rs::onnx_utils::SaveOptions;
fn reload_proto(path: &std::path::Path) -> ModelProto {
let buf = std::fs::read(path).unwrap();
ModelProto::decode(&buf[..]).unwrap()
}
fn get_default_opset(proto: &ModelProto) -> i64 {
proto
.opset_import
.iter()
.find(|o| o.domain.is_empty())
.map(|o| o.version)
.unwrap_or(0)
}
#[test]
fn test_native_int4_uses_int4_data_type_and_packs_raw_data() {
let weight_data: Vec<f32> = (0..16).map(|i| (i as f32 - 8.0) * 0.1).collect();
let model_proto = build_minimal_model(&weight_data, &[4, 4]);
let dir = tempfile::tempdir().unwrap();
let model_path = write_model_to_tempfile(&model_proto, &dir, "model.onnx");
let mut model = OnnxModel::load(&model_path).unwrap();
let weights = model.extract_weights();
let quantizer = Quantizer::new(QuantConfig {
bits: 4,
..Default::default()
});
let quantized = quantizer
.quantize_tensor(&weights[0].data, weights[0].shape.clone())
.unwrap();
let (scales, zero_points) = quantized.get_all_scales_zero_points();
let qdq_data = vec![QdqWeightInput {
original_name: weights[0].name.clone(),
quantized_values: quantized.data(),
scales,
zero_points,
bits: 4,
axis: None,
}];
let output_path = dir.path().join("model_native_int4.onnx");
model
.save_quantized_with_options(
&qdq_data,
&output_path,
SaveOptions::default().with_native_int4(true),
)
.unwrap();
let proto = reload_proto(&output_path);
assert!(
get_default_opset(&proto) >= 21,
"native INT4 should bump opset to >= 21"
);
let graph = proto.graph.as_ref().unwrap();
let q_init = graph
.initializer
.iter()
.find(|i| i.name == "weight_quantized")
.expect("weight_quantized initializer missing");
assert_eq!(
q_init.data_type,
tensor_proto::DataType::Int4 as i32,
"weight tensor should be DataType::Int4"
);
assert_eq!(
q_init.dims,
vec![4, 4],
"dims should be logical element count"
);
assert_eq!(q_init.raw_data.len(), 8, "16 values → 8 packed bytes");
let zp_init = graph
.initializer
.iter()
.find(|i| i.name == "weight_zp")
.expect("weight_zp initializer missing");
assert_eq!(
zp_init.data_type,
tensor_proto::DataType::Int4 as i32,
"zero-point should also be Int4"
);
let reloaded = OnnxModel::load(&output_path).unwrap();
let info = reloaded.load_quantized_info();
assert_eq!(info.len(), 1);
assert_eq!(info[0].bits, 4);
assert_eq!(info[0].scales.len(), 1, "per-tensor → 1 scale");
assert_eq!(info[0].zero_points.len(), 1);
assert!(info[0].scale() > 0.0);
let report = reloaded.validate_connectivity();
assert!(
report.valid,
"Connectivity broken: {:?}",
report.broken_refs
);
}
#[test]
fn test_native_int4_per_channel_packs_all_zero_points() {
let mut data = Vec::new();
for i in 0..16 {
data.push((i as f32) * 0.01);
}
for i in 0..16 {
data.push((i as f32) * 0.5);
}
let model_proto = build_minimal_model(&data, &[2, 16]);
let dir = tempfile::tempdir().unwrap();
let model_path = write_model_to_tempfile(&model_proto, &dir, "model.onnx");
let mut model = OnnxModel::load(&model_path).unwrap();
let weights = model.extract_weights();
let quantizer = Quantizer::new(QuantConfig {
bits: 4,
per_channel: true,
..Default::default()
});
let quantized = quantizer
.quantize_tensor(&weights[0].data, weights[0].shape.clone())
.unwrap();
let (scales, zero_points) = quantized.get_all_scales_zero_points();
assert_eq!(scales.len(), 2);
assert_eq!(zero_points.len(), 2);
let qdq_data = vec![QdqWeightInput {
original_name: weights[0].name.clone(),
quantized_values: quantized.data(),
scales: scales.clone(),
zero_points: zero_points.clone(),
bits: 4,
axis: Some(0),
}];
let output_path = dir.path().join("model_native_pc.onnx");
model
.save_quantized_with_options(
&qdq_data,
&output_path,
SaveOptions::default().with_native_int4(true),
)
.unwrap();
let reloaded = OnnxModel::load(&output_path).unwrap();
let info = reloaded.load_quantized_info();
assert_eq!(info.len(), 1);
let info = &info[0];
assert!(info.is_per_channel(), "should be per-channel");
assert_eq!(
info.scales.len(),
2,
"per-channel: expected 2 scales, got {}",
info.scales.len()
);
assert_eq!(info.zero_points.len(), 2);
for (i, (&s_out, &s_in)) in info.scales.iter().zip(scales.iter()).enumerate() {
assert!(
(s_out - s_in).abs() < 1e-6,
"scale[{}] mismatch: reloaded {} vs saved {}",
i,
s_out,
s_in
);
}
for (i, (&zp_out, &zp_in)) in info.zero_points.iter().zip(zero_points.iter()).enumerate() {
assert_eq!(
zp_out, zp_in,
"zero_point[{}] mismatch: reloaded {} vs saved {}",
i, zp_out, zp_in
);
}
}
#[test]
fn test_load_quantized_info_preserves_all_per_channel_values() {
let mut data = Vec::new();
data.extend(std::iter::repeat_n(0.01_f32, 8));
data.extend(std::iter::repeat_n(5.0_f32, 8));
let model_proto = build_minimal_model(&data, &[2, 8]);
let dir = tempfile::tempdir().unwrap();
let model_path = write_model_to_tempfile(&model_proto, &dir, "model.onnx");
let mut model = OnnxModel::load(&model_path).unwrap();
let weights = model.extract_weights();
let quantizer = Quantizer::new(QuantConfig {
bits: 8,
per_channel: true,
..Default::default()
});
let quantized = quantizer
.quantize_tensor(&weights[0].data, weights[0].shape.clone())
.unwrap();
let (scales, zero_points) = quantized.get_all_scales_zero_points();
assert_eq!(scales.len(), 2);
assert!(
scales[0] < scales[1],
"test precondition: scales[0] ({}) should be < scales[1] ({})",
scales[0],
scales[1]
);
let qdq_data = vec![QdqWeightInput {
original_name: weights[0].name.clone(),
quantized_values: quantized.data(),
scales: scales.clone(),
zero_points: zero_points.clone(),
bits: 8,
axis: Some(0),
}];
let output_path = dir.path().join("model_pc_check.onnx");
model.save_quantized(&qdq_data, &output_path).unwrap();
let reloaded = OnnxModel::load(&output_path).unwrap();
let info = reloaded.load_quantized_info();
assert_eq!(info.len(), 1);
let info = &info[0];
assert_eq!(info.scales.len(), 2, "both per-channel scales must load");
assert_eq!(info.zero_points.len(), 2);
assert!(info.is_per_channel());
for (a, b) in info.scales.iter().zip(scales.iter()) {
assert!((a - b).abs() < 1e-6);
}
assert_eq!(info.zero_points, zero_points);
}
#[test]
fn test_int4_without_native_flag_still_widens_to_int8() {
let weight_data: Vec<f32> = (0..16).map(|i| (i as f32 - 8.0) * 0.1).collect();
let model_proto = build_minimal_model(&weight_data, &[4, 4]);
let dir = tempfile::tempdir().unwrap();
let model_path = write_model_to_tempfile(&model_proto, &dir, "model.onnx");
let mut model = OnnxModel::load(&model_path).unwrap();
let weights = model.extract_weights();
let quantized = Quantizer::new(QuantConfig {
bits: 4,
..Default::default()
})
.quantize_tensor(&weights[0].data, weights[0].shape.clone())
.unwrap();
let (scales, zero_points) = quantized.get_all_scales_zero_points();
let qdq_data = vec![QdqWeightInput {
original_name: weights[0].name.clone(),
quantized_values: quantized.data(),
scales,
zero_points,
bits: 4,
axis: None,
}];
let output_path = dir.path().join("model_int4_widened.onnx");
model.save_quantized(&qdq_data, &output_path).unwrap();
let proto = reload_proto(&output_path);
let graph = proto.graph.as_ref().unwrap();
let q_init = graph
.initializer
.iter()
.find(|i| i.name == "weight_quantized")
.unwrap();
assert_eq!(
q_init.data_type,
tensor_proto::DataType::Int8 as i32,
"default path keeps INT8 widening for backward compat"
);
assert_eq!(q_init.raw_data.len(), 16, "widened: 1 byte per element");
assert!(
get_default_opset(&proto) < 21,
"default INT4 path must not require opset 21"
);
}
#[test]
fn test_symmetric_per_channel_reloaded_zero_points_all_zero() {
let mut data = Vec::new();
for i in 0..32 {
data.push((i as f32 - 8.0) * 0.05); }
for i in 0..32 {
data.push((i as f32 - 5.0) * 0.2); }
for i in 0..32 {
data.push((i as f32 - 30.0) * 0.01); }
let model_proto = build_minimal_model(&data, &[3, 32]);
let dir = tempfile::tempdir().unwrap();
let model_path = write_model_to_tempfile(&model_proto, &dir, "model.onnx");
let mut model = OnnxModel::load(&model_path).unwrap();
let weights = model.extract_weights();
let quantizer = Quantizer::new(QuantConfig {
bits: 8,
per_channel: true,
symmetric: true,
..Default::default()
});
let quantized = quantizer
.quantize_tensor(&weights[0].data, weights[0].shape.clone())
.unwrap();
let (scales, zero_points) = quantized.get_all_scales_zero_points();
assert_eq!(scales.len(), 3, "expected one scale per channel");
assert_eq!(zero_points, vec![0, 0, 0], "symmetric: all zp must be 0");
let qdq_data = vec![QdqWeightInput {
original_name: weights[0].name.clone(),
quantized_values: quantized.data(),
scales,
zero_points,
bits: 8,
axis: Some(0),
}];
let output_path = dir.path().join("model_symmetric_pc.onnx");
model.save_quantized(&qdq_data, &output_path).unwrap();
let reloaded = OnnxModel::load(&output_path).unwrap();
let info = reloaded.load_quantized_info();
assert_eq!(info.len(), 1);
let info = &info[0];
assert!(info.is_per_channel());
assert_eq!(info.zero_points.len(), 3);
for (i, &zp) in info.zero_points.iter().enumerate() {
assert_eq!(
zp, 0,
"channel {} zero-point must be 0 for ORT INT8 kernels; got {}",
i, zp
);
}
for (i, &s) in info.scales.iter().enumerate() {
assert!(
s > 0.0 && s.is_finite(),
"channel {} scale invalid: {}",
i,
s
);
}
}
#[test]
fn test_asymmetric_per_channel_produces_nonzero_zp_on_skewed_data() {
let mut data = Vec::new();
data.extend(std::iter::repeat_n(-0.5_f32, 16));
data.extend(std::iter::repeat_n(2.0_f32, 16));
let model_proto = build_minimal_model(&data, &[2, 16]);
let dir = tempfile::tempdir().unwrap();
let model_path = write_model_to_tempfile(&model_proto, &dir, "model.onnx");
let mut model = OnnxModel::load(&model_path).unwrap();
let weights = model.extract_weights();
let quantizer = Quantizer::new(QuantConfig {
bits: 8,
per_channel: true,
symmetric: false,
..Default::default()
});
let quantized = quantizer
.quantize_tensor(&weights[0].data, weights[0].shape.clone())
.unwrap();
let (_, zps) = quantized.get_all_scales_zero_points();
assert!(
zps.iter().any(|&z| z != 0),
"asymmetric per-channel on skewed data should yield a non-zero zp; got {:?}",
zps
);
}
#[cfg(feature = "mmap")]
#[test]
fn test_load_mmap_produces_same_model_info_as_load() {
let weight_data: Vec<f32> = (0..16).map(|i| (i as f32 - 8.0) * 0.1).collect();
let model_proto = build_minimal_model(&weight_data, &[4, 4]);
let dir = tempfile::tempdir().unwrap();
let path = write_model_to_tempfile(&model_proto, &dir, "mmap_eq.onnx");
let via_load = OnnxModel::load(&path).unwrap();
let via_mmap = OnnxModel::load_mmap(&path).unwrap();
let a = via_load.info();
let b = via_mmap.info();
assert_eq!(a.name, b.name);
assert_eq!(a.num_nodes, b.num_nodes);
assert_eq!(a.inputs, b.inputs);
assert_eq!(a.outputs, b.outputs);
let wa = via_load.extract_weights();
let wb = via_mmap.extract_weights();
assert_eq!(wa.len(), wb.len());
assert_eq!(wa[0].name, wb[0].name);
assert_eq!(wa[0].data, wb[0].data);
}