use proptest::prelude::*;
use quantize_rs::quantization::{
QuantParams, QuantParamsInt4, QuantizedTensor, QuantizedTensorInt4,
};
use quantize_rs::{pack_int4, unpack_int4};
fn spanning_range() -> impl Strategy<Value = (f32, f32)> {
(
(-1e6_f32..=-1e-3_f32), (1e-3_f32..=1e6_f32), )
}
fn int4_value() -> impl Strategy<Value = i8> {
-8_i8..=7_i8
}
fn int4_vec() -> impl Strategy<Value = Vec<i8>> {
prop::collection::vec(int4_value(), 1..=512)
}
fn finite_f32_vec() -> impl Strategy<Value = Vec<f32>> {
prop::collection::vec(-1e6_f32..=1e6_f32, 1..=512)
}
proptest! {
#[test]
fn prop_int8_quantize_always_in_range(
(min, max) in spanning_range(),
v in -1e6_f32..=1e6_f32,
) {
let params = QuantParams::from_range(min, max);
let q = params.quantize(v);
prop_assert!(q >= -128, "INT8 result {} < -128 (min={}, max={}, v={})", q, min, max, v);
let _ = q; }
#[test]
fn prop_int4_quantize_always_in_range(
(min, max) in spanning_range(),
v in -1e6_f32..=1e6_f32,
) {
let params = QuantParamsInt4::from_range(min, max);
let q = params.quantize(v);
prop_assert!(q >= -8, "INT4 result {} < -8 (min={}, max={}, v={})", q, min, max, v);
prop_assert!(q <= 7, "INT4 result {} > 7 (min={}, max={}, v={})", q, min, max, v);
}
}
proptest! {
#[test]
fn prop_int8_round_trip_error_bounded(
(min, max) in spanning_range(),
t in 0.0_f32..=1.0_f32,
) {
let v = min + t * (max - min);
let params = QuantParams::from_range(min, max);
let q = params.quantize(v);
let dq = params.dequantize(q);
let error = (v - dq).abs();
let bound = params.scale() + 1e-4;
prop_assert!(
error <= bound,
"INT8 round-trip error {:.9} > bound {:.9} \
(min={}, max={}, v={}, q={}, dq={}, scale={})",
error, bound, min, max, v, q, dq, params.scale()
);
}
#[test]
fn prop_int4_round_trip_error_bounded(
(min, max) in spanning_range(),
t in 0.0_f32..=1.0_f32,
) {
let v = min + t * (max - min);
let params = QuantParamsInt4::from_range(min, max);
let q = params.quantize(v);
let dq = params.dequantize(q);
let error = (v - dq).abs();
let bound = params.scale() + 1e-4;
prop_assert!(
error <= bound,
"INT4 round-trip error {:.9} > bound {:.9} \
(min={}, max={}, v={}, q={}, dq={}, scale={})",
error, bound, min, max, v, q, dq, params.scale()
);
}
#[test]
fn prop_int4_scale_larger_than_int8(
(min, max) in spanning_range(),
) {
let p8 = QuantParams::from_range(min, max);
let p4 = QuantParamsInt4::from_range(min, max);
prop_assert!(
p4.scale() >= p8.scale(),
"INT4 scale {} < INT8 scale {} for range [{}, {}]",
p4.scale(), p8.scale(), min, max
);
}
}
proptest! {
#[test]
fn prop_pack_unpack_identity(data in int4_vec()) {
let packed = pack_int4(&data);
let unpacked = unpack_int4(&packed, data.len());
prop_assert_eq!(
&data, &unpacked,
"pack/unpack mismatch for {} values", data.len()
);
}
#[test]
fn prop_packed_size_is_ceil_half(data in int4_vec()) {
let packed = pack_int4(&data);
let expected = data.len().div_ceil(2);
prop_assert_eq!(
packed.len(), expected,
"packed size {} != ceil({}/2) = {}",
packed.len(), data.len(), expected
);
}
#[test]
fn prop_single_value_packs_to_one_byte(v in int4_value()) {
let packed = pack_int4(&[v]);
prop_assert_eq!(packed.len(), 1);
let unpacked = unpack_int4(&packed, 1);
prop_assert_eq!(unpacked, vec![v]);
}
}
proptest! {
#[test]
fn prop_int8_from_f32_no_panic(data in finite_f32_vec()) {
let len = data.len();
let result = QuantizedTensor::from_f32(&data, vec![len]);
prop_assert!(result.is_ok(), "from_f32 unexpectedly failed: {:?}", result.err());
}
#[test]
fn prop_int4_from_f32_no_panic(data in finite_f32_vec()) {
let len = data.len();
let result = QuantizedTensorInt4::from_f32(&data, vec![len]);
prop_assert!(result.is_ok(), "INT4 from_f32 unexpectedly failed: {:?}", result.err());
}
#[test]
fn prop_int8_from_f32_with_range_no_panic(
data in finite_f32_vec(),
(min, max) in spanning_range(),
) {
let len = data.len();
let result = QuantizedTensor::from_f32_with_range(&data, vec![len], min, max);
prop_assert!(result.is_ok(), "from_f32_with_range failed: {:?}", result.err());
}
#[test]
fn prop_int4_from_f32_with_range_no_panic(
data in finite_f32_vec(),
(min, max) in spanning_range(),
) {
let len = data.len();
let result = QuantizedTensorInt4::from_f32_with_range(&data, vec![len], min, max);
prop_assert!(result.is_ok(), "INT4 from_f32_with_range failed: {:?}", result.err());
}
#[test]
fn prop_int8_per_channel_no_panic(
channels in 1_usize..=8,
per_channel in 1_usize..=64,
) {
let data: Vec<f32> = (0..channels * per_channel)
.map(|i| (i as f32) / ((channels * per_channel) as f32))
.collect();
let result = QuantizedTensor::from_f32_per_channel(&data, vec![channels, per_channel]);
prop_assert!(result.is_ok(), "INT8 per_channel failed: {:?}", result.err());
}
#[test]
fn prop_int4_per_channel_no_panic(
channels in 1_usize..=8,
per_channel in 1_usize..=64,
) {
let data: Vec<f32> = (0..channels * per_channel)
.map(|i| (i as f32) / ((channels * per_channel) as f32))
.collect();
let result = QuantizedTensorInt4::from_f32_per_channel(&data, vec![channels, per_channel]);
prop_assert!(result.is_ok(), "INT4 per_channel failed: {:?}", result.err());
}
#[test]
fn prop_int8_per_channel_round_trip(
channels in 1_usize..=8,
per_channel in 1_usize..=64,
) {
let data: Vec<f32> = (0..channels * per_channel)
.map(|i| ((i as f32) - (channels * per_channel / 2) as f32) * 0.01)
.collect();
let shape = vec![channels, per_channel];
if let Ok(tensor) = QuantizedTensor::from_f32_per_channel(&data, shape) {
let dq = tensor.to_f32();
for ch in 0..channels {
let start = ch * per_channel;
let end = start + per_channel;
let ch_min = data[start..end].iter().copied().fold(f32::INFINITY, f32::min);
let ch_max = data[start..end].iter().copied().fold(f32::NEG_INFINITY, f32::max);
let adj_min = ch_min.min(0.0);
let adj_max = ch_max.max(0.0);
let range = (adj_max - adj_min).max(1e-10);
let bound = range / 255.0 + 1e-4;
for j in 0..per_channel {
let idx = start + j;
let error = (data[idx] - dq[idx]).abs();
prop_assert!(
error <= bound,
"per-channel INT8 error {:.9} > bound {:.9} (ch={}, idx={})",
error, bound, ch, idx
);
}
}
}
}
#[test]
fn prop_int4_per_channel_round_trip(
channels in 1_usize..=8,
per_channel in 1_usize..=64,
) {
let data: Vec<f32> = (0..channels * per_channel)
.map(|i| ((i as f32) - (channels * per_channel / 2) as f32) * 0.01)
.collect();
let shape = vec![channels, per_channel];
if let Ok(tensor) = QuantizedTensorInt4::from_f32_per_channel(&data, shape) {
let dq = tensor.to_f32();
for ch in 0..channels {
let start = ch * per_channel;
let end = start + per_channel;
let ch_min = data[start..end].iter().copied().fold(f32::INFINITY, f32::min);
let ch_max = data[start..end].iter().copied().fold(f32::NEG_INFINITY, f32::max);
let adj_min = ch_min.min(0.0);
let adj_max = ch_max.max(0.0);
let range = (adj_max - adj_min).max(1e-10);
let bound = range / 15.0 + 1e-4;
for j in 0..per_channel {
let idx = start + j;
let error = (data[idx] - dq[idx]).abs();
prop_assert!(
error <= bound,
"per-channel INT4 error {:.9} > bound {:.9} (ch={}, idx={})",
error, bound, ch, idx
);
}
}
}
}
#[test]
fn prop_to_f32_no_panic(data in finite_f32_vec()) {
let len = data.len();
if let Ok(tensor) = QuantizedTensor::from_f32(&data, vec![len]) {
let dq = tensor.to_f32();
prop_assert_eq!(dq.len(), len, "dequantized length mismatch");
}
if let Ok(tensor) = QuantizedTensorInt4::from_f32(&data, vec![len]) {
let dq = tensor.to_f32();
prop_assert_eq!(dq.len(), len, "INT4 dequantized length mismatch");
}
}
}