use crate::accumulator::BinnedAccumulatorF64;
#[derive(Debug, Clone, Copy)]
pub struct QuantParamsI8 {
pub scale: f64,
pub zero_point: i8,
}
impl QuantParamsI8 {
pub fn new(scale: f64, zero_point: i8) -> Self {
QuantParamsI8 { scale, zero_point }
}
#[inline]
pub fn dequantize(&self, v: i8) -> f64 {
self.scale * (v as i64 - self.zero_point as i64) as f64
}
pub fn dequantize_slice(&self, src: &[i8]) -> Vec<f64> {
src.iter().map(|&v| self.dequantize(v)).collect()
}
}
#[derive(Debug, Clone, Copy)]
pub struct QuantParamsI4 {
pub scale: f64,
pub zero_point: i8,
}
impl QuantParamsI4 {
pub fn new(scale: f64, zero_point: i8) -> Self {
assert!(zero_point >= -8 && zero_point <= 7, "i4 zero_point must be in [-8, 7]");
QuantParamsI4 { scale, zero_point }
}
#[inline]
pub fn unpack_byte(byte: u8) -> (i8, i8) {
let hi = (((byte >> 4) & 0x0F) as i8) << 4 >> 4;
let lo = ((byte & 0x0F) as i8) << 4 >> 4;
(hi, lo)
}
#[inline]
pub fn dequantize(&self, v: i8) -> f64 {
self.scale * (v as i64 - self.zero_point as i64) as f64
}
}
#[inline]
pub fn saturating_mul_i8(a: i8, b: i8) -> i32 {
(a as i32) * (b as i32)
}
#[inline]
pub fn saturating_dot_i8(a: &[i8], b: &[i8]) -> i32 {
debug_assert_eq!(a.len(), b.len());
let mut sum: i32 = 0;
for i in 0..a.len() {
let prod = (a[i] as i32) * (b[i] as i32);
sum = sum.saturating_add(prod);
}
sum
}
pub fn quantized_matmul_i8(
a: &[i8], b: &[i8], out: &mut [f64],
m: usize, k: usize, n: usize,
params_a: &QuantParamsI8, params_b: &QuantParamsI8,
) {
debug_assert_eq!(a.len(), m * k);
debug_assert_eq!(b.len(), k * n);
debug_assert_eq!(out.len(), m * n);
let combined_scale = params_a.scale * params_b.scale;
for i in 0..m {
for j in 0..n {
let mut acc = BinnedAccumulatorF64::new();
for p in 0..k {
let int_prod = (a[i * k + p] as i64 - params_a.zero_point as i64)
* (b[p * n + j] as i64 - params_b.zero_point as i64);
acc.add(combined_scale * int_prod as f64);
}
out[i * n + j] = acc.finalize();
}
}
}
pub fn quantized_dot_i8(
a: &[i8], b: &[i8],
params_a: &QuantParamsI8, params_b: &QuantParamsI8,
) -> f64 {
debug_assert_eq!(a.len(), b.len());
let combined_scale = params_a.scale * params_b.scale;
let mut acc = BinnedAccumulatorF64::new();
for i in 0..a.len() {
let int_prod = (a[i] as i64 - params_a.zero_point as i64)
* (b[i] as i64 - params_b.zero_point as i64);
acc.add(combined_scale * int_prod as f64);
}
acc.finalize()
}
pub fn quantized_sum_i8(values: &[i8], params: &QuantParamsI8) -> f64 {
let mut acc = BinnedAccumulatorF64::new();
for &v in values {
acc.add(params.dequantize(v));
}
acc.finalize()
}
pub fn quantized_sum_i4(packed: &[u8], count: usize, params: &QuantParamsI4) -> f64 {
let mut acc = BinnedAccumulatorF64::new();
let mut remaining = count;
for &byte in packed {
if remaining == 0 { break; }
let (hi, lo) = QuantParamsI4::unpack_byte(byte);
acc.add(params.dequantize(hi));
remaining -= 1;
if remaining == 0 { break; }
acc.add(params.dequantize(lo));
remaining -= 1;
}
acc.finalize()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dequantize_i8_basic() {
let params = QuantParamsI8::new(0.1, 0);
assert_eq!(params.dequantize(10), 1.0);
assert_eq!(params.dequantize(-10), -1.0);
assert_eq!(params.dequantize(0), 0.0);
}
#[test]
fn test_dequantize_i8_with_zero_point() {
let params = QuantParamsI8::new(0.5, 10);
assert_eq!(params.dequantize(20), 5.0);
assert_eq!(params.dequantize(10), 0.0);
}
#[test]
fn test_saturating_dot_i8() {
let a = vec![1i8, 2, 3, 4];
let b = vec![5i8, 6, 7, 8];
assert_eq!(saturating_dot_i8(&a, &b), 70); }
#[test]
fn test_saturating_dot_overflow() {
let a = vec![127i8; 1000];
let b = vec![127i8; 1000];
let result = saturating_dot_i8(&a, &b);
assert_eq!(result, 16_129_000);
}
#[test]
fn test_quantized_matmul_identity() {
let params = QuantParamsI8::new(1.0, 0);
let a = vec![1i8, 0, 0, 1]; let b = vec![3i8, 4, 5, 6];
let mut out = vec![0.0f64; 4];
quantized_matmul_i8(&a, &b, &mut out, 2, 2, 2, ¶ms, ¶ms);
assert_eq!(out, vec![3.0, 4.0, 5.0, 6.0]);
}
#[test]
fn test_quantized_matmul_scaling() {
let params_a = QuantParamsI8::new(0.5, 0);
let params_b = QuantParamsI8::new(2.0, 0);
let a = vec![2i8, 3];
let b = vec![4i8, 5];
let mut out = vec![0.0f64; 1];
quantized_matmul_i8(&a, &b, &mut out, 1, 2, 1, ¶ms_a, ¶ms_b);
assert_eq!(out[0], 23.0);
}
#[test]
fn test_quantized_dot_deterministic() {
let params = QuantParamsI8::new(0.001, 0);
let a: Vec<i8> = (0..100).map(|i| (i % 127) as i8).collect();
let b: Vec<i8> = (0..100).map(|i| ((100 - i) % 127) as i8).collect();
let r1 = quantized_dot_i8(&a, &b, ¶ms, ¶ms);
let r2 = quantized_dot_i8(&a, &b, ¶ms, ¶ms);
assert_eq!(r1.to_bits(), r2.to_bits());
}
#[test]
fn test_i4_unpack() {
let (hi, lo) = QuantParamsI4::unpack_byte(0x3E);
assert_eq!(hi, 3);
assert_eq!(lo, -2);
}
#[test]
fn test_i4_unpack_negatives() {
let (hi, lo) = QuantParamsI4::unpack_byte(0xF8);
assert_eq!(hi, -1);
assert_eq!(lo, -8);
}
#[test]
fn test_quantized_sum_i4() {
let params = QuantParamsI4::new(1.0, 0);
let packed = vec![0x23u8, 0x45];
let result = quantized_sum_i4(&packed, 4, ¶ms);
assert_eq!(result, 14.0); }
#[test]
fn test_quantized_sum_i8_near_order_invariant() {
let params = QuantParamsI8::new(0.001, 0);
let values: Vec<i8> = (0..200).map(|i| ((i as i16 - 100) % 128) as i8).collect();
let r1 = quantized_sum_i8(&values, ¶ms);
let mut rev = values.clone();
rev.reverse();
let r2 = quantized_sum_i8(&rev, ¶ms);
let ulps = (r1.to_bits() as i64 - r2.to_bits() as i64).unsigned_abs();
assert!(ulps < 10,
"Quantized sum should be near-order-invariant: {r1} vs {r2} ({ulps} ULPs)");
}
#[test]
fn test_quantized_sum_i8_merge_order_invariant() {
let params = QuantParamsI8::new(0.001, 0);
let values: Vec<i8> = (0..200).map(|i| ((i as i16 - 100) % 128) as i8).collect();
let mut fwd = BinnedAccumulatorF64::new();
for chunk in values.chunks(20) {
let mut c = BinnedAccumulatorF64::new();
for &v in chunk {
c.add(params.dequantize(v));
}
fwd.merge(&c);
}
let chunks: Vec<Vec<i8>> = values.chunks(20).map(|c| c.to_vec()).collect();
let mut rev = BinnedAccumulatorF64::new();
for chunk in chunks.iter().rev() {
let mut c = BinnedAccumulatorF64::new();
for &v in chunk.iter() {
c.add(params.dequantize(v));
}
rev.merge(&c);
}
assert_eq!(fwd.finalize().to_bits(), rev.finalize().to_bits(),
"Merge-based quantized sum must be order-invariant");
}
}