#[derive(Debug, Clone)]
pub struct QuantConfig {
pub bits: u8,
pub group_size: usize,
pub double_quant: bool,
}
impl QuantConfig {
pub fn int4() -> Self {
Self {
bits: 4,
group_size: 64,
double_quant: true,
}
}
pub fn int8() -> Self {
Self {
bits: 8,
group_size: 128,
double_quant: false,
}
}
}
pub fn quantize_tensor(data: &[f32], config: &QuantConfig) -> (Vec<u8>, Vec<f32>, Vec<f32>) {
let num_groups = data.len().div_ceil(config.group_size);
let mut quantized = Vec::new();
let mut scales = Vec::with_capacity(num_groups);
let mut zero_points = Vec::with_capacity(num_groups);
let max_val = ((1u32 << config.bits) - 1) as f32;
for group in data.chunks(config.group_size) {
let min = group.iter().cloned().fold(f32::INFINITY, f32::min);
let max = group.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let range = max - min;
let scale = if range > 0.0 { range / max_val } else { 1.0 };
let zero_point = min;
scales.push(scale);
zero_points.push(zero_point);
for &val in group {
let quantized_val = ((val - zero_point) / scale).round().clamp(0.0, max_val) as u8;
quantized.push(quantized_val);
}
}
(quantized, scales, zero_points)
}
pub fn dequantize_tensor(
quantized: &[u8],
scales: &[f32],
zero_points: &[f32],
group_size: usize,
) -> Vec<f32> {
let mut result = Vec::with_capacity(quantized.len());
for (group_idx, chunk) in quantized.chunks(group_size).enumerate() {
let scale = scales[group_idx];
let zero_point = zero_points[group_idx];
for &q in chunk {
result.push(q as f32 * scale + zero_point);
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_quantize_dequantize_int8() {
let data: Vec<f32> = (0..256).map(|i| i as f32 / 255.0).collect();
let config = QuantConfig::int8();
let (quantized, scales, zero_points) = quantize_tensor(&data, &config);
let recovered = dequantize_tensor(&quantized, &scales, &zero_points, config.group_size);
for (original, &recovered_val) in data.iter().zip(recovered.iter()) {
assert!(
(original - recovered_val).abs() < 0.02,
"Original: {}, Recovered: {}",
original,
recovered_val
);
}
}
#[test]
fn test_quantize_int4() {
let data: Vec<f32> = (0..64).map(|i| i as f32 * 0.1).collect();
let config = QuantConfig::int4();
let (quantized, scales, _) = quantize_tensor(&data, &config);
for &q in &quantized {
assert!(q <= 15, "INT4 value out of range: {}", q);
}
assert_eq!(scales.len(), 1); }
}