use ndarray::Array2;
use tensorlogic_scirs_backend::quantization::{
QuantizationGranularity, QuantizationParams, QuantizationScheme, QuantizationType,
};
use crate::quantization::{calibrate_linear, QuantizedLinear};
fn sample_weight_4x8() -> Array2<f64> {
Array2::from_shape_fn((4, 8), |(i, j)| (i * 8 + j) as f64 * 0.5 - 14.0)
}
#[test]
fn test_roundtrip_per_tensor() {
let weight = sample_weight_4x8();
let params = calibrate_linear(
&weight,
QuantizationType::Int8,
QuantizationGranularity::PerTensor,
);
let qlinear = QuantizedLinear::from_fp(&weight, ¶ms).expect("from_fp per-tensor");
let deq = qlinear.dequantize();
let max_abs_err = weight
.iter()
.zip(deq.iter())
.map(|(o, d)| (o - d).abs())
.fold(0.0_f64, f64::max);
assert!(
max_abs_err < 2.0,
"per-tensor round-trip max error={max_abs_err} >= 2.0"
);
}
#[test]
fn test_roundtrip_per_channel() {
let weight = sample_weight_4x8();
let params = calibrate_linear(
&weight,
QuantizationType::Int8,
QuantizationGranularity::PerChannel,
);
let qlinear = QuantizedLinear::from_fp(&weight, ¶ms).expect("from_fp per-channel");
let deq = qlinear.dequantize();
let max_abs_err = weight
.iter()
.zip(deq.iter())
.map(|(o, d)| (o - d).abs())
.fold(0.0_f64, f64::max);
assert!(
max_abs_err < 2.0,
"per-channel round-trip max error={max_abs_err} >= 2.0"
);
}
#[test]
fn test_forward_matches_fp() {
let weight = sample_weight_4x8();
let params = calibrate_linear(
&weight,
QuantizationType::Int8,
QuantizationGranularity::PerChannel,
);
let qlinear = QuantizedLinear::from_fp(&weight, ¶ms).expect("from_fp forward test");
let x = Array2::from_shape_fn((2, 8), |(i, j)| (i + j) as f64 * 0.1);
let out_q = qlinear.forward(&x);
let weight_fp = qlinear.dequantize();
let out_ref = x.dot(&weight_fp.t());
assert_eq!(out_q.shape(), &[2, 4]);
for (a, b) in out_q.iter().zip(out_ref.iter()) {
assert!((a - b).abs() < 1e-12, "forward mismatch: {a} vs {b}");
}
}
#[test]
fn test_calibration_sanity() {
let weight = sample_weight_4x8();
let params = calibrate_linear(
&weight,
QuantizationType::Int8,
QuantizationGranularity::PerTensor,
);
assert!(params.scale[0] > 0.0, "scale[0]={}", params.scale[0]);
assert!(params.scale[0].is_finite(), "scale not finite");
assert_eq!(
params.zero_point[0], 0,
"symmetric should have zero_point==0"
);
assert!(
params.scale[0] < 1.0,
"scale[0]={} unreasonably large",
params.scale[0]
);
}
#[test]
fn test_per_channel_uses_different_scales() {
let data = vec![100.0_f64, -100.0, 50.0, -50.0, 1.0_f64, -1.0, 0.5, -0.5];
let weight = Array2::from_shape_vec((2, 4), data).expect("build test weight");
let params = calibrate_linear(
&weight,
QuantizationType::Int8,
QuantizationGranularity::PerChannel,
);
assert_eq!(params.scale.len(), 2, "PerChannel needs 2 scales");
let ratio = params.scale[0] / params.scale[1];
assert!(
ratio > 10.0,
"scale[0]={} scale[1]={} ratio={} (expected >10)",
params.scale[0],
params.scale[1],
ratio
);
}
#[test]
fn test_bias_shapes_checked() {
let weight = sample_weight_4x8();
let params = calibrate_linear(
&weight,
QuantizationType::Int8,
QuantizationGranularity::PerTensor,
);
let qlinear = QuantizedLinear::from_fp(&weight, ¶ms).expect("from_fp bias test");
let bad_bias = ndarray::Array1::zeros(3_usize); assert!(qlinear.with_bias(bad_bias).is_err());
}
#[test]
fn test_bias_forward_correct() {
let weight = sample_weight_4x8();
let params = calibrate_linear(
&weight,
QuantizationType::Int8,
QuantizationGranularity::PerTensor,
);
let qlinear = QuantizedLinear::from_fp(&weight, ¶ms).expect("from_fp bias forward");
let bias = ndarray::Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let qlinear_b = qlinear.with_bias(bias.clone()).expect("with_bias");
let x = Array2::from_shape_fn((2, 8), |(i, j)| (i + j) as f64 * 0.1);
let out_no_bias = QuantizedLinear::from_fp(&weight, ¶ms)
.expect("from_fp no-bias")
.forward(&x);
let out_with_bias = qlinear_b.forward(&x);
for batch in 0..2 {
for ch in 0..4 {
let expected = out_no_bias[[batch, ch]] + bias[ch];
let got = out_with_bias[[batch, ch]];
assert!(
(expected - got).abs() < 1e-12,
"bias mismatch at [{batch},{ch}]: expected={expected} got={got}"
);
}
}
}
#[test]
fn test_invalid_qtype_returns_error() {
let weight = sample_weight_4x8();
let params = QuantizationParams {
qtype: QuantizationType::Int4,
scheme: QuantizationScheme::Symmetric,
granularity: QuantizationGranularity::PerTensor,
scale: vec![1.0],
zero_point: vec![0],
min_val: vec![-1.0],
max_val: vec![1.0],
};
assert!(
QuantizedLinear::from_fp(&weight, ¶ms).is_err(),
"Int4 should be rejected"
);
}