use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum QuantScheme {
Symmetric,
Asymmetric,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum QuantBits {
Int8,
Int4,
}
#[derive(Debug, Clone)]
pub struct QuantParams {
pub scale: f32,
pub zero_point: i32,
pub bits: QuantBits,
pub scheme: QuantScheme,
pub per_channel_scales: Option<Vec<f32>>,
}
impl QuantParams {
pub fn from_range(min_val: f32, max_val: f32, bits: QuantBits, scheme: QuantScheme) -> Self {
let (qmin, qmax) = match bits {
QuantBits::Int8 => (-128i32, 127i32),
QuantBits::Int4 => (-8i32, 7i32),
};
match scheme {
QuantScheme::Symmetric => {
let abs_max = min_val.abs().max(max_val.abs());
let scale = abs_max / qmax as f32;
Self {
scale: if scale == 0.0 { 1.0 } else { scale },
zero_point: 0,
bits,
scheme,
per_channel_scales: None,
}
}
QuantScheme::Asymmetric => {
let range = max_val - min_val;
let scale = range / (qmax - qmin) as f32;
let zero_point = (qmin as f32 - min_val / scale).round() as i32;
Self {
scale: if scale == 0.0 { 1.0 } else { scale },
zero_point: zero_point.clamp(qmin, qmax),
bits,
scheme,
per_channel_scales: None,
}
}
}
}
pub fn calibrate(data: &[f32], bits: QuantBits, scheme: QuantScheme) -> Self {
if data.is_empty() {
return Self { scale: 1.0, zero_point: 0, bits, scheme, per_channel_scales: None };
}
let min_val = data.iter().cloned().fold(f32::INFINITY, f32::min);
let max_val = data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
Self::from_range(min_val, max_val, bits, scheme)
}
}
pub fn quantize_int8(data: &[f32], params: &QuantParams) -> Vec<i8> {
data.iter().map(|&x| {
let q = (x / params.scale).round() as i32 + params.zero_point;
q.clamp(-128, 127) as i8
}).collect()
}
pub fn dequantize_int8(data: &[i8], params: &QuantParams) -> Vec<f32> {
data.iter().map(|&q| {
(q as i32 - params.zero_point) as f32 * params.scale
}).collect()
}
pub fn quantize_int4(data: &[f32], params: &QuantParams) -> Vec<u8> {
let mut packed = Vec::with_capacity((data.len() + 1) / 2);
for chunk in data.chunks(2) {
let lo = {
let q = (chunk[0] / params.scale).round() as i32 + params.zero_point;
(q.clamp(-8, 7) & 0x0F) as u8
};
let hi = if chunk.len() > 1 {
let q = (chunk[1] / params.scale).round() as i32 + params.zero_point;
((q.clamp(-8, 7) & 0x0F) as u8) << 4
} else {
0
};
packed.push(lo | hi);
}
packed
}
pub fn dequantize_int4(data: &[u8], count: usize, params: &QuantParams) -> Vec<f32> {
let mut result = Vec::with_capacity(count);
for &byte in data {
if result.len() >= count { break; }
let lo = (byte & 0x0F) as i8;
let lo = if lo & 0x08 != 0 { lo | !0x0F_u8 as i8 } else { lo }; result.push((lo as i32 - params.zero_point) as f32 * params.scale);
if result.len() >= count { break; }
let hi = ((byte >> 4) & 0x0F) as i8;
let hi = if hi & 0x08 != 0 { hi | !0x0F_u8 as i8 } else { hi };
result.push((hi as i32 - params.zero_point) as f32 * params.scale);
}
result
}
pub fn quantized_gemm_int8(
a: &[i8], b: &[i8],
m: usize, k: usize, n: usize,
a_params: &QuantParams, b_params: &QuantParams,
) -> Vec<f32> {
let mut c = vec![0i32; m * n];
for i in 0..m {
for p in 0..k {
let a_val = a[i * k + p] as i32 - a_params.zero_point;
for j in 0..n {
let b_val = b[p * n + j] as i32 - b_params.zero_point;
c[i * n + j] += a_val * b_val;
}
}
}
let output_scale = a_params.scale * b_params.scale;
c.iter().map(|&v| v as f32 * output_scale).collect()
}
pub fn quantization_error(original: &[f32], params: &QuantParams) -> QuantError {
let quantized = quantize_int8(original, params);
let dequantized = dequantize_int8(&quantized, params);
let mse: f64 = original.iter().zip(dequantized.iter())
.map(|(&o, &d)| ((o - d) as f64).powi(2))
.sum::<f64>() / original.len() as f64;
let max_error = original.iter().zip(dequantized.iter())
.map(|(&o, &d)| (o - d).abs())
.fold(0.0f32, f32::max);
let signal_power: f64 = original.iter().map(|&x| (x as f64).powi(2)).sum::<f64>() / original.len() as f64;
let snr = if mse > 0.0 { 10.0 * (signal_power / mse).log10() } else { f64::INFINITY };
QuantError {
mse: mse as f32,
max_error,
snr_db: snr as f32,
compression_ratio: match params.bits {
QuantBits::Int8 => 4.0, QuantBits::Int4 => 8.0, },
}
}
#[derive(Debug, Clone)]
pub struct QuantError {
pub mse: f32,
pub max_error: f32,
pub snr_db: f32,
pub compression_ratio: f32,
}
impl fmt::Display for QuantError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "QuantError: MSE={:.6}, MaxErr={:.4}, SNR={:.1}dB, {}x compression",
self.mse, self.max_error, self.snr_db, self.compression_ratio)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_symmetric_int8_roundtrip() {
let data = vec![-1.0, -0.5, 0.0, 0.5, 1.0];
let params = QuantParams::calibrate(&data, QuantBits::Int8, QuantScheme::Symmetric);
let quantized = quantize_int8(&data, ¶ms);
let dequantized = dequantize_int8(&quantized, ¶ms);
for i in 0..data.len() {
assert!((data[i] - dequantized[i]).abs() < 0.02,
"Mismatch at {}: original={}, dequantized={}", i, data[i], dequantized[i]);
}
}
#[test]
fn test_asymmetric_int8() {
let data = vec![0.0, 0.25, 0.5, 0.75, 1.0];
let params = QuantParams::calibrate(&data, QuantBits::Int8, QuantScheme::Asymmetric);
let quantized = quantize_int8(&data, ¶ms);
let dequantized = dequantize_int8(&quantized, ¶ms);
for i in 0..data.len() {
assert!((data[i] - dequantized[i]).abs() < 0.02,
"Asymmetric mismatch at {}: {} vs {}", i, data[i], dequantized[i]);
}
}
#[test]
fn test_int4_quantization() {
let data = vec![-1.0, -0.5, 0.0, 0.5, 1.0, 1.5];
let params = QuantParams::calibrate(&data, QuantBits::Int4, QuantScheme::Symmetric);
let packed = quantize_int4(&data, ¶ms);
let dequantized = dequantize_int4(&packed, data.len(), ¶ms);
assert_eq!(dequantized.len(), data.len());
for i in 0..data.len() {
assert!((data[i] - dequantized[i]).abs() < 0.5,
"INT4 mismatch at {}: {} vs {}", i, data[i], dequantized[i]);
}
}
#[test]
fn test_int4_packing() {
let data = vec![0.0, 0.0, 0.0, 0.0]; let params = QuantParams::from_range(-1.0, 1.0, QuantBits::Int4, QuantScheme::Symmetric);
let packed = quantize_int4(&data, ¶ms);
assert_eq!(packed.len(), 2);
}
#[test]
fn test_quantized_gemm() {
let a_f32 = vec![1.0f32, 2.0, 3.0, 4.0];
let b_f32 = vec![5.0f32, 6.0, 7.0, 8.0];
let a_params = QuantParams::calibrate(&a_f32, QuantBits::Int8, QuantScheme::Symmetric);
let b_params = QuantParams::calibrate(&b_f32, QuantBits::Int8, QuantScheme::Symmetric);
let a_q = quantize_int8(&a_f32, &a_params);
let b_q = quantize_int8(&b_f32, &b_params);
let c = quantized_gemm_int8(&a_q, &b_q, 2, 2, 2, &a_params, &b_params);
assert!((c[0] - 19.0).abs() < 1.0, "Got {}", c[0]);
assert!((c[1] - 22.0).abs() < 1.0, "Got {}", c[1]);
assert!((c[2] - 43.0).abs() < 1.5, "Got {}", c[2]);
assert!((c[3] - 50.0).abs() < 1.5, "Got {}", c[3]);
}
#[test]
fn test_quantization_error() {
let data: Vec<f32> = (0..100).map(|i| (i as f32 - 50.0) / 50.0).collect();
let params = QuantParams::calibrate(&data, QuantBits::Int8, QuantScheme::Symmetric);
let error = quantization_error(&data, ¶ms);
assert!(error.mse < 0.001, "MSE too high: {}", error.mse);
assert!(error.snr_db > 30.0, "SNR too low: {}dB", error.snr_db);
assert_eq!(error.compression_ratio, 4.0);
}
#[test]
fn test_quantization_error_int4() {
let data: Vec<f32> = (0..100).map(|i| (i as f32 - 50.0) / 50.0).collect();
let params = QuantParams::calibrate(&data, QuantBits::Int4, QuantScheme::Symmetric);
let error = quantization_error(&data, ¶ms);
assert_eq!(error.compression_ratio, 8.0);
}
#[test]
fn test_zero_range_calibration() {
let data = vec![0.0, 0.0, 0.0];
let params = QuantParams::calibrate(&data, QuantBits::Int8, QuantScheme::Symmetric);
assert_eq!(params.scale, 1.0); }
#[test]
fn test_empty_calibration() {
let params = QuantParams::calibrate(&[], QuantBits::Int8, QuantScheme::Symmetric);
assert_eq!(params.scale, 1.0);
}
#[test]
fn test_quant_error_display() {
let error = QuantError { mse: 0.001, max_error: 0.01, snr_db: 40.0, compression_ratio: 4.0 };
let s = format!("{}", error);
assert!(s.contains("MSE"));
assert!(s.contains("4x"));
}
}