use prost::Message;
use quantize_rs::calibration::methods::CalibrationMethod;
use quantize_rs::calibration::stats::ActivationStats;
use quantize_rs::onnx_proto::{
tensor_proto, GraphProto, ModelProto, NodeProto, OperatorSetIdProto, TensorProto,
ValueInfoProto,
};
use quantize_rs::onnx_utils::graph_builder::QdqWeightInput;
use quantize_rs::quantization::{QuantConfig, QuantizedTensor, QuantizedTensorInt4, Quantizer};
use quantize_rs::*;
use std::collections::HashMap;
fn build_minimal_model(weight_data: &[f32], weight_shape: &[i64]) -> ModelProto {
ModelProto {
opset_import: vec![OperatorSetIdProto {
domain: String::new(),
version: 13,
}],
graph: Some(GraphProto {
name: "test_graph".to_string(),
input: vec![ValueInfoProto {
name: "input".to_string(),
..Default::default()
}],
output: vec![ValueInfoProto {
name: "output".to_string(),
..Default::default()
}],
initializer: vec![TensorProto {
name: "weight".to_string(),
data_type: tensor_proto::DataType::Float as i32,
dims: weight_shape.to_vec(),
float_data: weight_data.to_vec(),
..Default::default()
}],
node: vec![NodeProto {
op_type: "Conv".to_string(),
name: "conv0".to_string(),
input: vec!["input".to_string(), "weight".to_string()],
output: vec!["output".to_string()],
..Default::default()
}],
..Default::default()
}),
..Default::default()
}
}
fn build_two_weight_model(
w1_data: &[f32],
w1_shape: &[i64],
w2_data: &[f32],
w2_shape: &[i64],
) -> ModelProto {
ModelProto {
opset_import: vec![OperatorSetIdProto {
domain: String::new(),
version: 13,
}],
graph: Some(GraphProto {
name: "test_two_weight".to_string(),
input: vec![ValueInfoProto {
name: "input".to_string(),
..Default::default()
}],
output: vec![ValueInfoProto {
name: "output".to_string(),
..Default::default()
}],
initializer: vec![
TensorProto {
name: "w1".to_string(),
data_type: tensor_proto::DataType::Float as i32,
dims: w1_shape.to_vec(),
float_data: w1_data.to_vec(),
..Default::default()
},
TensorProto {
name: "w2".to_string(),
data_type: tensor_proto::DataType::Float as i32,
dims: w2_shape.to_vec(),
float_data: w2_data.to_vec(),
..Default::default()
},
],
node: vec![
NodeProto {
op_type: "Conv".to_string(),
name: "conv1".to_string(),
input: vec!["input".to_string(), "w1".to_string()],
output: vec!["mid".to_string()],
..Default::default()
},
NodeProto {
op_type: "Conv".to_string(),
name: "conv2".to_string(),
input: vec!["mid".to_string(), "w2".to_string()],
output: vec!["output".to_string()],
..Default::default()
},
],
..Default::default()
}),
..Default::default()
}
}
fn write_model_to_tempfile(
model: &ModelProto,
dir: &tempfile::TempDir,
name: &str,
) -> std::path::PathBuf {
let path = dir.path().join(name);
let mut buf = Vec::new();
model.encode(&mut buf).unwrap();
std::fs::write(&path, buf).unwrap();
path
}
#[test]
fn test_quantize_simple_model_int8() {
let weight_data: Vec<f32> = (0..16).map(|i| (i as f32 - 8.0) * 0.1).collect();
let model_proto = build_minimal_model(&weight_data, &[4, 4]);
let dir = tempfile::tempdir().unwrap();
let model_path = write_model_to_tempfile(&model_proto, &dir, "model.onnx");
let mut model = OnnxModel::load(&model_path).unwrap();
let weights = model.extract_weights();
assert_eq!(weights.len(), 1);
assert_eq!(weights[0].name, "weight");
let quantizer = Quantizer::new(QuantConfig {
bits: 8,
per_channel: false,
calibration_method: None,
..Default::default()
});
let quantized = quantizer
.quantize_tensor(&weights[0].data, weights[0].shape.clone())
.unwrap();
assert!(quantized.is_int8());
let (scales, zero_points) = quantized.get_all_scales_zero_points();
let is_pc = quantized.is_per_channel();
let qdq_data = vec![QdqWeightInput {
original_name: weights[0].name.clone(),
quantized_values: quantized.data(),
scales,
zero_points,
bits: 8,
axis: if is_pc { Some(0) } else { None },
}];
let output_path = dir.path().join("model_int8.onnx");
model.save_quantized(&qdq_data, &output_path).unwrap();
let reloaded = OnnxModel::load(&output_path).unwrap();
let report = reloaded.validate_connectivity();
assert!(
report.valid,
"Connectivity broken: {:?}",
report.broken_refs
);
let qinfo = reloaded.load_quantized_info();
assert_eq!(qinfo.len(), 1);
assert_eq!(qinfo[0].name, "weight");
assert_eq!(qinfo[0].bits, 8);
assert!(qinfo[0].scale > 0.0);
}
#[test]
fn test_quantize_simple_model_int4() {
let weight_data: Vec<f32> = (0..16).map(|i| (i as f32 - 8.0) * 0.1).collect();
let model_proto = build_minimal_model(&weight_data, &[4, 4]);
let dir = tempfile::tempdir().unwrap();
let model_path = write_model_to_tempfile(&model_proto, &dir, "model.onnx");
let mut model = OnnxModel::load(&model_path).unwrap();
let weights = model.extract_weights();
let quantizer = Quantizer::new(QuantConfig {
bits: 4,
per_channel: false,
calibration_method: None,
..Default::default()
});
let quantized = quantizer
.quantize_tensor(&weights[0].data, weights[0].shape.clone())
.unwrap();
assert!(quantized.is_int4());
assert_eq!(quantized.bits(), 4);
let (scales, zero_points) = quantized.get_all_scales_zero_points();
let qdq_data = vec![QdqWeightInput {
original_name: weights[0].name.clone(),
quantized_values: quantized.data(),
scales,
zero_points,
bits: 4,
axis: None,
}];
let output_path = dir.path().join("model_int4.onnx");
model.save_quantized(&qdq_data, &output_path).unwrap();
let reloaded = OnnxModel::load(&output_path).unwrap();
let report = reloaded.validate_connectivity();
assert!(
report.valid,
"Connectivity broken: {:?}",
report.broken_refs
);
let qinfo = reloaded.load_quantized_info();
assert_eq!(qinfo.len(), 1);
assert_eq!(qinfo[0].bits, 4);
}
#[test]
fn test_quantize_per_channel() {
let mut w1_data = Vec::new();
for i in 0..8 {
w1_data.push(i as f32 * 0.01);
} for i in 0..8 {
w1_data.push(i as f32 * 1.0);
}
let mut w2_data = Vec::new();
for i in 0..8 {
w2_data.push(i as f32 * 0.5);
}
for i in 0..8 {
w2_data.push(i as f32 * 2.0);
}
let model_proto = build_two_weight_model(&w1_data, &[2, 8], &w2_data, &[2, 8]);
let dir = tempfile::tempdir().unwrap();
let model_path = write_model_to_tempfile(&model_proto, &dir, "model.onnx");
let mut model = OnnxModel::load(&model_path).unwrap();
let weights = model.extract_weights();
assert_eq!(weights.len(), 2);
let quantizer = Quantizer::new(QuantConfig {
bits: 8,
per_channel: true,
calibration_method: None,
..Default::default()
});
let mut qdq_data = Vec::new();
for w in &weights {
let quantized = quantizer.quantize_tensor(&w.data, w.shape.clone()).unwrap();
assert!(quantized.is_per_channel());
let (scales, zero_points) = quantized.get_all_scales_zero_points();
assert_eq!(scales.len(), 2);
assert_eq!(zero_points.len(), 2);
qdq_data.push(QdqWeightInput {
original_name: w.name.clone(),
quantized_values: quantized.data(),
scales,
zero_points,
bits: 8,
axis: Some(0),
});
}
let output_path = dir.path().join("model_pc.onnx");
model.save_quantized(&qdq_data, &output_path).unwrap();
let reloaded = OnnxModel::load(&output_path).unwrap();
let report = reloaded.validate_connectivity();
assert!(
report.valid,
"Connectivity broken: {:?}",
report.broken_refs
);
}
#[test]
fn test_round_trip_quantization_accuracy() {
let data: Vec<f32> = (0..1000).map(|i| (i as f32 / 999.0) * 2.0 - 1.0).collect();
let shape = vec![1000];
let q8 = QuantizedTensor::from_f32(&data, shape.clone()).unwrap();
let mse_8 = q8.quantization_error(&data);
assert!(mse_8 < 1e-4, "INT8 MSE too high: {}", mse_8);
let q4 = QuantizedTensorInt4::from_f32(&data, shape).unwrap();
let mse_4 = q4.quantization_error(&data);
assert!(mse_4 < 0.01, "INT4 MSE too high: {}", mse_4);
assert!(mse_4 > mse_8, "INT4 error should exceed INT8 error");
assert_eq!(q8.to_f32().len(), 1000);
assert_eq!(q4.to_f32().len(), 1000);
}
#[test]
fn test_error_variants_are_correct() {
let result = QuantizedTensor::from_f32(&[], vec![0]);
assert!(result.is_err());
assert!(
matches!(result.unwrap_err(), QuantizeError::InvalidTensor { .. }),
"expected InvalidTensor for empty tensor"
);
let result = QuantizedTensor::from_f32(&[1.0, 2.0], vec![3]);
assert!(matches!(
result.unwrap_err(),
QuantizeError::InvalidTensor { .. }
));
let result = QuantizedTensor::from_f32_per_channel(&[1.0], vec![]);
assert!(matches!(
result.unwrap_err(),
QuantizeError::InvalidTensor { .. }
));
let quantizer = Quantizer::new(QuantConfig {
bits: 3,
per_channel: false,
calibration_method: None,
..Default::default()
});
let result = quantizer.quantize_tensor(&[1.0, 2.0], vec![2]);
assert!(matches!(
result.unwrap_err(),
QuantizeError::UnsupportedConfig { .. }
));
let result = OnnxModel::load("/nonexistent/path/model.onnx");
assert!(matches!(
result.unwrap_err(),
QuantizeError::ModelLoad { .. }
));
let result: Result<quantize_rs::calibration::methods::CalibrationMethod, _> = "invalid".parse();
assert!(matches!(result.unwrap_err(), QuantizeError::Config { .. }));
let cfg = Config::from_yaml("bits: 3").unwrap();
assert!(matches!(
cfg.validate().unwrap_err(),
QuantizeError::Config { .. }
));
}
#[test]
fn test_mixed_precision_quantization() {
let w1_data: Vec<f32> = (0..16).map(|i| (i as f32 - 8.0) * 0.1).collect();
let w2_data: Vec<f32> = (0..16).map(|i| (i as f32 - 8.0) * 0.5).collect();
let model_proto = build_two_weight_model(&w1_data, &[4, 4], &w2_data, &[4, 4]);
let dir = tempfile::tempdir().unwrap();
let model_path = write_model_to_tempfile(&model_proto, &dir, "model.onnx");
let mut model = OnnxModel::load(&model_path).unwrap();
let weights = model.extract_weights();
assert_eq!(weights.len(), 2);
let mut layer_bits = HashMap::new();
layer_bits.insert("w1".to_string(), 4u8);
let config = QuantConfig {
bits: 8,
per_channel: false,
calibration_method: None,
layer_bits,
..Default::default()
};
let mut qdq_data = Vec::new();
for w in &weights {
let layer_config = QuantConfig {
bits: config.bits_for_layer(&w.name),
..config.clone()
};
let quantized = Quantizer::new(layer_config)
.quantize_tensor(&w.data, w.shape.clone())
.unwrap();
let (scales, zero_points) = quantized.get_all_scales_zero_points();
let bits_used = quantized.bits();
if w.name == "w1" {
assert!(quantized.is_int4(), "w1 should be INT4");
assert_eq!(bits_used, 4);
} else {
assert!(quantized.is_int8(), "w2 should be INT8");
assert_eq!(bits_used, 8);
}
qdq_data.push(QdqWeightInput {
original_name: w.name.clone(),
quantized_values: quantized.data(),
scales,
zero_points,
bits: bits_used,
axis: None,
});
}
let output_path = dir.path().join("model_mixed.onnx");
model.save_quantized(&qdq_data, &output_path).unwrap();
let reloaded = OnnxModel::load(&output_path).unwrap();
let report = reloaded.validate_connectivity();
assert!(
report.valid,
"Connectivity broken: {:?}",
report.broken_refs
);
let qinfo = reloaded.load_quantized_info();
assert_eq!(qinfo.len(), 2);
let w1_info = qinfo.iter().find(|q| q.name == "w1").expect("w1 not found");
let w2_info = qinfo.iter().find(|q| q.name == "w2").expect("w2 not found");
assert_eq!(w1_info.bits, 4, "w1 bits should be 4 in metadata");
assert_eq!(w2_info.bits, 8, "w2 bits should be 8 in metadata");
assert!(w1_info.scale > 0.0);
assert!(w2_info.scale > 0.0);
}
#[test]
fn test_config_layer_bits() {
let yaml = r#"
bits: 8
models:
- input: model.onnx
output: model_mixed.onnx
layer_bits:
conv1.weight: 4
head.weight: 8
"#;
let config = Config::from_yaml(yaml).unwrap();
config.validate().unwrap();
let model_cfg = &config.models[0];
let lb = config.get_layer_bits(model_cfg);
assert_eq!(lb.get("conv1.weight"), Some(&4u8));
assert_eq!(lb.get("head.weight"), Some(&8u8));
let toml_str = r#"
bits = 8
[[models]]
input = "model.onnx"
output = "model_mixed.onnx"
[models.layer_bits]
"conv1.weight" = 4
"head.weight" = 8
"#;
let config = Config::from_toml(toml_str).unwrap();
config.validate().unwrap();
let lb = config.get_layer_bits(&config.models[0]);
assert_eq!(lb.get("conv1.weight"), Some(&4u8));
let yaml_bad = r#"
bits: 8
models:
- input: model.onnx
output: model_mixed.onnx
layer_bits:
conv1.weight: 3
"#;
let cfg = Config::from_yaml(yaml_bad).unwrap();
assert!(matches!(
cfg.validate().unwrap_err(),
QuantizeError::Config { .. }
));
}
fn build_multilayer_model() -> ModelProto {
let make_tensor = |name: &str, dims: &[i64], n: usize| TensorProto {
name: name.to_string(),
data_type: tensor_proto::DataType::Float as i32,
dims: dims.to_vec(),
float_data: (0..n).map(|i| (i as f32) / (n as f32) - 0.5).collect(),
..Default::default()
};
ModelProto {
opset_import: vec![OperatorSetIdProto {
domain: String::new(),
version: 13,
}],
graph: Some(GraphProto {
name: "multilayer".to_string(),
input: vec![ValueInfoProto {
name: "input".to_string(),
..Default::default()
}],
output: vec![ValueInfoProto {
name: "output".to_string(),
..Default::default()
}],
initializer: vec![
make_tensor("conv1.weight", &[8, 3, 3, 3], 216),
make_tensor("conv1.bias", &[8], 8),
make_tensor("conv2.weight", &[16, 8, 3, 3], 1152),
make_tensor("conv2.bias", &[16], 16),
make_tensor("fc.weight", &[10, 144], 1440),
make_tensor("fc.bias", &[10], 10),
],
node: vec![
NodeProto {
op_type: "Conv".to_string(),
name: "conv1".to_string(),
input: vec![
"input".to_string(),
"conv1.weight".to_string(),
"conv1.bias".to_string(),
],
output: vec!["conv1_out".to_string()],
..Default::default()
},
NodeProto {
op_type: "Conv".to_string(),
name: "conv2".to_string(),
input: vec![
"conv1_out".to_string(),
"conv2.weight".to_string(),
"conv2.bias".to_string(),
],
output: vec!["conv2_out".to_string()],
..Default::default()
},
NodeProto {
op_type: "Gemm".to_string(),
name: "fc".to_string(),
input: vec![
"conv2_out".to_string(),
"fc.weight".to_string(),
"fc.bias".to_string(),
],
output: vec!["output".to_string()],
..Default::default()
},
],
..Default::default()
}),
..Default::default()
}
}
fn quantize_weights(
config: &QuantConfig,
weights: &[quantize_rs::WeightTensor],
) -> Vec<QdqWeightInput> {
weights
.iter()
.filter(|w| config.should_quantize(&w.name, w.data.len()))
.map(|w| {
let layer_config = QuantConfig {
bits: config.bits_for_layer(&w.name),
..config.clone()
};
let quantized = Quantizer::new(layer_config)
.quantize_tensor(&w.data, w.shape.clone())
.unwrap();
let (scales, zero_points) = quantized.get_all_scales_zero_points();
let bits = quantized.bits();
let is_pc = quantized.is_per_channel();
QdqWeightInput {
original_name: w.name.clone(),
quantized_values: quantized.data(),
scales,
zero_points,
bits,
axis: if is_pc { Some(0) } else { None },
}
})
.collect()
}
#[test]
fn test_multilayer_min_elements() {
let model_proto = build_multilayer_model();
let dir = tempfile::tempdir().unwrap();
let model_path = write_model_to_tempfile(&model_proto, &dir, "model.onnx");
let mut model = OnnxModel::load(&model_path).unwrap();
let weights = model.extract_weights();
assert_eq!(weights.len(), 6);
let config = QuantConfig {
bits: 8,
min_elements: 100,
..Default::default()
};
let qdq_data = quantize_weights(&config, &weights);
assert_eq!(qdq_data.len(), 3, "expected 3 large weights quantized");
let names: Vec<&str> = qdq_data.iter().map(|q| q.original_name.as_str()).collect();
assert!(names.contains(&"conv1.weight"));
assert!(names.contains(&"conv2.weight"));
assert!(names.contains(&"fc.weight"));
assert!(!names.contains(&"conv1.bias"));
assert!(!names.contains(&"conv2.bias"));
assert!(!names.contains(&"fc.bias"));
let output_path = dir.path().join("model_min_elements.onnx");
model.save_quantized(&qdq_data, &output_path).unwrap();
let reloaded = OnnxModel::load(&output_path).unwrap();
assert!(reloaded.validate_connectivity().valid);
assert_eq!(reloaded.load_quantized_info().len(), 3);
}
#[test]
fn test_multilayer_excluded_layers() {
let model_proto = build_multilayer_model();
let dir = tempfile::tempdir().unwrap();
let model_path = write_model_to_tempfile(&model_proto, &dir, "model.onnx");
let mut model = OnnxModel::load(&model_path).unwrap();
let weights = model.extract_weights();
let config = QuantConfig {
bits: 8,
excluded_layers: vec!["conv1.weight".to_string(), "fc.bias".to_string()],
..Default::default()
};
let qdq_data = quantize_weights(&config, &weights);
assert_eq!(qdq_data.len(), 4, "expected 4 weights after exclusions");
let names: Vec<&str> = qdq_data.iter().map(|q| q.original_name.as_str()).collect();
assert!(
!names.contains(&"conv1.weight"),
"conv1.weight should be excluded"
);
assert!(!names.contains(&"fc.bias"), "fc.bias should be excluded");
let output_path = dir.path().join("model_excluded.onnx");
model.save_quantized(&qdq_data, &output_path).unwrap();
let reloaded = OnnxModel::load(&output_path).unwrap();
assert!(reloaded.validate_connectivity().valid);
assert_eq!(reloaded.load_quantized_info().len(), 4);
}
#[test]
fn test_multilayer_full_round_trip() {
let model_proto = build_multilayer_model();
let dir = tempfile::tempdir().unwrap();
let model_path = write_model_to_tempfile(&model_proto, &dir, "model.onnx");
let mut model = OnnxModel::load(&model_path).unwrap();
let weights = model.extract_weights();
assert_eq!(weights.len(), 6);
let config = QuantConfig {
bits: 8,
per_channel: true,
..Default::default()
};
let qdq_data = quantize_weights(&config, &weights);
assert_eq!(qdq_data.len(), 6, "all 6 weights should be quantized");
let output_path = dir.path().join("model_full.onnx");
model.save_quantized(&qdq_data, &output_path).unwrap();
let reloaded = OnnxModel::load(&output_path).unwrap();
let report = reloaded.validate_connectivity();
assert!(
report.valid,
"Connectivity broken: {:?}",
report.broken_refs
);
let qinfo = reloaded.load_quantized_info();
assert_eq!(qinfo.len(), 6, "all 6 weights should appear in metadata");
for info in &qinfo {
assert!(info.scale > 0.0, "scale must be positive for {}", info.name);
assert_eq!(info.bits, 8);
}
}
#[test]
fn test_multilayer_compression_ratio() {
let model_proto = build_multilayer_model();
let dir = tempfile::tempdir().unwrap();
let model_path = write_model_to_tempfile(&model_proto, &dir, "model.onnx");
let weights = {
let m = OnnxModel::load(&model_path).unwrap();
m.extract_weights()
};
let original_bytes: usize = weights.iter().map(|w| w.data.len() * 4).sum();
let cfg8 = QuantConfig {
bits: 8,
..Default::default()
};
let bytes_int8: usize = weights
.iter()
.map(|w| {
Quantizer::new(cfg8.clone())
.quantize_tensor(&w.data, w.shape.clone())
.unwrap()
.size_bytes()
})
.sum();
let cfg4 = QuantConfig {
bits: 4,
..Default::default()
};
let bytes_int4: usize = weights
.iter()
.map(|w| {
Quantizer::new(cfg4.clone())
.quantize_tensor(&w.data, w.shape.clone())
.unwrap()
.size_bytes()
})
.sum();
assert!(
bytes_int8 < original_bytes,
"INT8 ({bytes_int8} B) should be smaller than FP32 ({original_bytes} B)"
);
assert!(
bytes_int4 < bytes_int8,
"INT4 ({bytes_int4} B) should be smaller than INT8 ({bytes_int8} B)"
);
let ratio8 = bytes_int8 as f64 / original_bytes as f64;
let ratio4 = bytes_int4 as f64 / original_bytes as f64;
assert!(ratio8 < 0.5, "INT8 ratio {ratio8:.2} should be < 0.5");
assert!(ratio4 < 0.3, "INT4 ratio {ratio4:.2} should be < 0.3");
}
#[test]
fn test_dual_input_initializer_model() {
let weight_data: Vec<f32> = (0..16).map(|i| (i as f32 - 8.0) * 0.1).collect();
let model_proto = ModelProto {
opset_import: vec![OperatorSetIdProto {
domain: String::new(),
version: 13,
}],
graph: Some(GraphProto {
name: "dual_input".to_string(),
input: vec![
ValueInfoProto {
name: "input".to_string(),
..Default::default()
},
ValueInfoProto {
name: "weight".to_string(),
..Default::default()
},
],
output: vec![ValueInfoProto {
name: "output".to_string(),
..Default::default()
}],
initializer: vec![TensorProto {
name: "weight".to_string(),
data_type: tensor_proto::DataType::Float as i32,
dims: vec![4, 4],
float_data: weight_data.clone(),
..Default::default()
}],
node: vec![NodeProto {
op_type: "Conv".to_string(),
name: "conv0".to_string(),
input: vec!["input".to_string(), "weight".to_string()],
output: vec!["output".to_string()],
..Default::default()
}],
..Default::default()
}),
..Default::default()
};
let dir = tempfile::tempdir().unwrap();
let model_path = write_model_to_tempfile(&model_proto, &dir, "dual.onnx");
let mut model = OnnxModel::load(&model_path).unwrap();
let info = model.info();
assert_eq!(
info.inputs.len(),
2,
"model should have 2 inputs (data + weight)"
);
let weights = model.extract_weights();
assert_eq!(weights.len(), 1);
let quantizer = Quantizer::new(QuantConfig {
bits: 8,
..Default::default()
});
let quantized = quantizer
.quantize_tensor(&weights[0].data, weights[0].shape.clone())
.unwrap();
let (scales, zero_points) = quantized.get_all_scales_zero_points();
let qdq_data = vec![QdqWeightInput {
original_name: weights[0].name.clone(),
quantized_values: quantized.data(),
scales,
zero_points,
bits: 8,
axis: None,
}];
let output_path = dir.path().join("dual_int8.onnx");
model.save_quantized(&qdq_data, &output_path).unwrap();
let reloaded = OnnxModel::load(&output_path).unwrap();
let report = reloaded.validate_connectivity();
assert!(
report.valid,
"Connectivity broken: {:?}",
report.broken_refs
);
let reloaded_info = reloaded.info();
assert_eq!(
reloaded_info.inputs.len(),
1,
"QDQ transform should remove weight from graph.input; got {:?}",
reloaded_info.inputs
);
assert_eq!(reloaded_info.inputs[0], "input");
}
#[test]
#[ignore = "set QUANTIZE_RS_TEST_MODEL=/path/to/model.onnx to enable"]
fn test_real_model_int8() {
let path = std::env::var("QUANTIZE_RS_TEST_MODEL")
.expect("QUANTIZE_RS_TEST_MODEL must point to an ONNX file");
let mut model = OnnxModel::load(&path).expect("failed to load model");
let weights = model.extract_weights();
assert!(!weights.is_empty(), "model has no extractable weights");
let config = QuantConfig {
bits: 8,
per_channel: false,
min_elements: 128,
..Default::default()
};
let qdq_data = quantize_weights(&config, &weights);
assert!(
!qdq_data.is_empty(),
"no weights passed the min_elements=128 filter"
);
let dir = tempfile::tempdir().unwrap();
let output_path = dir.path().join("model_int8.onnx");
model
.save_quantized(&qdq_data, &output_path)
.expect("save failed");
let reloaded = OnnxModel::load(&output_path).unwrap();
let report = reloaded.validate_connectivity();
assert!(
report.valid,
"Connectivity broken after INT8 quantization: {:?}",
report.broken_refs
);
let qinfo = reloaded.load_quantized_info();
assert_eq!(
qinfo.len(),
qdq_data.len(),
"metadata weight count mismatch"
);
for info in &qinfo {
assert_eq!(info.bits, 8);
assert!(info.scale > 0.0);
}
}
#[test]
#[ignore = "set QUANTIZE_RS_TEST_MODEL=/path/to/model.onnx to enable"]
fn test_real_model_int4() {
let path = std::env::var("QUANTIZE_RS_TEST_MODEL")
.expect("QUANTIZE_RS_TEST_MODEL must point to an ONNX file");
let mut model = OnnxModel::load(&path).expect("failed to load model");
let weights = model.extract_weights();
let config = QuantConfig {
bits: 4,
per_channel: false,
min_elements: 128,
..Default::default()
};
let qdq_data = quantize_weights(&config, &weights);
let dir = tempfile::tempdir().unwrap();
let output_path = dir.path().join("model_int4.onnx");
model
.save_quantized(&qdq_data, &output_path)
.expect("save failed");
let reloaded = OnnxModel::load(&output_path).unwrap();
let report = reloaded.validate_connectivity();
assert!(
report.valid,
"Connectivity broken after INT4 quantization: {:?}",
report.broken_refs
);
let qinfo = reloaded.load_quantized_info();
assert_eq!(qinfo.len(), qdq_data.len());
for info in &qinfo {
assert_eq!(info.bits, 4);
assert!(info.scale > 0.0);
}
}
#[test]
fn test_config_round_trip() {
let yaml = r#"
bits: 8
per_channel: true
models:
- input: model.onnx
output: model_int8.onnx
"#;
let config = Config::from_yaml(yaml).unwrap();
assert_eq!(config.bits, 8);
assert!(config.per_channel);
config.validate().unwrap();
let toml_str = r#"
bits = 4
per_channel = false
[[models]]
input = "a.onnx"
output = "b.onnx"
"#;
let config = Config::from_toml(toml_str).unwrap();
assert_eq!(config.bits, 4);
config.validate().unwrap();
let result = Config::from_yaml("bits: [invalid");
assert!(matches!(result.unwrap_err(), QuantizeError::Config { .. }));
let result = Config::from_toml("bits = [invalid");
assert!(matches!(result.unwrap_err(), QuantizeError::Config { .. }));
let cfg = Config::from_yaml("bits: 16").unwrap();
assert!(matches!(
cfg.validate().unwrap_err(),
QuantizeError::Config { .. }
));
let yaml_bad = r#"
bits: 8
models:
- input: ""
output: "out.onnx"
"#;
let cfg = Config::from_yaml(yaml_bad).unwrap();
assert!(matches!(
cfg.validate().unwrap_err(),
QuantizeError::Config { .. }
));
}
#[test]
fn test_calibrated_quantization_uses_stats() {
let weight_data: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) / 32.0).collect();
let shape = vec![4, 16];
let config_uncalibrated = QuantConfig {
bits: 8,
..Default::default()
};
let q_uncalibrated = Quantizer::new(config_uncalibrated);
let result_uncalibrated = q_uncalibrated
.quantize_tensor_with_name("layer0.weight", &weight_data, shape.clone())
.unwrap();
let activation_data: Vec<f32> = (0..1000).map(|i| (i as f32 - 500.0) / 1000.0).collect();
let stats = ActivationStats::from_data(&activation_data);
let mut stats_map = HashMap::new();
stats_map.insert("layer0.weight".to_string(), stats);
let config_calibrated = QuantConfig {
bits: 8,
..Default::default()
};
let q_calibrated = Quantizer::with_calibration(config_calibrated, stats_map);
let result_calibrated = q_calibrated
.quantize_tensor_with_name("layer0.weight", &weight_data, shape.clone())
.unwrap();
let (scales_uncal, _) = result_uncalibrated.get_all_scales_zero_points();
let (scales_cal, _) = result_calibrated.get_all_scales_zero_points();
assert_eq!(scales_uncal.len(), 1);
assert_eq!(scales_cal.len(), 1);
assert!(
scales_cal[0] < scales_uncal[0],
"Calibrated scale ({}) should be < uncalibrated scale ({})",
scales_cal[0],
scales_uncal[0],
);
}
#[test]
fn test_calibrated_quantization_with_method() {
let weight_data: Vec<f32> = (0..100).map(|i| (i as f32 - 50.0) * 0.1).collect();
let shape = vec![10, 10];
let mut activation_data: Vec<f32> = (0..500).map(|i| (i as f32 - 250.0) / 250.0).collect();
activation_data.push(10.0); activation_data.push(-10.0); let stats = ActivationStats::from_data(&activation_data);
let mut stats_map = HashMap::new();
stats_map.insert("conv.weight".to_string(), stats);
let config = QuantConfig {
bits: 8,
calibration_method: Some(CalibrationMethod::MinMax),
..Default::default()
};
let quantizer = Quantizer::with_calibration(config, stats_map);
let result = quantizer
.quantize_tensor_with_name("conv.weight", &weight_data, shape)
.unwrap();
assert!(result.is_int8());
let error = result.quantization_error(&weight_data);
assert!(error.is_finite());
assert!(
error < 10.0,
"Quantization error unexpectedly high: {}",
error
);
}
#[test]
fn test_calibrated_quantization_fallback_no_stats() {
let weight_data: Vec<f32> = (0..32).map(|i| (i as f32 - 16.0) * 0.5).collect();
let shape = vec![4, 8];
let stats = ActivationStats::from_data(&[0.0, 1.0]);
let mut stats_map = HashMap::new();
stats_map.insert("other_layer.weight".to_string(), stats);
let config = QuantConfig {
bits: 8,
..Default::default()
};
let q_calibrated = Quantizer::with_calibration(config.clone(), stats_map);
let result_calibrated = q_calibrated
.quantize_tensor_with_name("conv.weight", &weight_data, shape.clone())
.unwrap();
let q_uncalibrated = Quantizer::new(config);
let result_uncalibrated = q_uncalibrated
.quantize_tensor_with_name("conv.weight", &weight_data, shape)
.unwrap();
let (scales_cal, zps_cal) = result_calibrated.get_all_scales_zero_points();
let (scales_uncal, zps_uncal) = result_uncalibrated.get_all_scales_zero_points();
assert_eq!(
scales_cal, scales_uncal,
"Fallback should match uncalibrated"
);
assert_eq!(zps_cal, zps_uncal, "Fallback should match uncalibrated");
}
#[test]
fn test_calibrated_quantization_int4() {
let weight_data: Vec<f32> = (0..48).map(|i| (i as f32 - 24.0) / 24.0).collect();
let shape = vec![6, 8];
let activation_data: Vec<f32> = (0..200).map(|i| (i as f32 - 100.0) / 100.0).collect();
let stats = ActivationStats::from_data(&activation_data);
let mut stats_map = HashMap::new();
stats_map.insert("fc.weight".to_string(), stats);
let config = QuantConfig {
bits: 4,
..Default::default()
};
let quantizer = Quantizer::with_calibration(config, stats_map);
let result = quantizer
.quantize_tensor_with_name("fc.weight", &weight_data, shape)
.unwrap();
assert!(result.is_int4());
assert_eq!(result.bits(), 4);
let error = result.quantization_error(&weight_data);
assert!(error.is_finite());
}
#[test]
fn test_calibrated_full_pipeline() {
let weight_data: Vec<f32> = (0..16).map(|i| (i as f32 - 8.0) * 0.1).collect();
let model = build_minimal_model(&weight_data, &[4, 4]);
let dir = tempfile::tempdir().unwrap();
let input_path = write_model_to_tempfile(&model, &dir, "calib_input.onnx");
let output_path = dir.path().join("calib_output.onnx");
let mut loaded = OnnxModel::load(&input_path).unwrap();
let weights = loaded.extract_weights();
assert_eq!(weights.len(), 1);
let activation_data: Vec<f32> = (0..500).map(|i| (i as f32 - 250.0) / 500.0).collect();
let stats = ActivationStats::from_data(&activation_data);
let mut stats_map = HashMap::new();
stats_map.insert("weight".to_string(), stats);
let config = QuantConfig {
bits: 8,
..Default::default()
};
let quantizer = Quantizer::with_calibration(config, stats_map);
let mut quantized_data = Vec::new();
for weight in &weights {
let quantized = quantizer
.quantize_tensor_with_name(&weight.name, &weight.data, weight.shape.clone())
.unwrap();
let (scales, zero_points) = quantized.get_all_scales_zero_points();
let is_per_channel = quantized.is_per_channel();
quantized_data.push(QdqWeightInput {
original_name: weight.name.clone(),
quantized_values: quantized.data(),
scales,
zero_points,
bits: quantized.bits(),
axis: if is_per_channel { Some(0) } else { None },
});
}
loaded
.save_quantized(&quantized_data, &output_path)
.unwrap();
let reloaded = OnnxModel::load(&output_path).unwrap();
let report = reloaded.validate_connectivity();
assert!(
report.valid,
"Connectivity broken: {:?}",
report.broken_refs
);
let qdq_info = reloaded.load_quantized_info();
assert_eq!(qdq_info.len(), 1);
assert_eq!(qdq_info[0].name, "weight");
assert!(qdq_info[0].scale > 0.0);
}