#[derive(Debug, Clone, PartialEq)]
pub enum QuantError {
IndexOutOfBounds { idx: usize, max: usize },
ValueOutOfRange { val: i64, min: i64, max: i64 },
InvalidGroupSize,
EmptyTensor,
}
impl std::fmt::Display for QuantError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
QuantError::IndexOutOfBounds { idx, max } => {
write!(f, "index {idx} out of bounds (max {max})")
}
QuantError::ValueOutOfRange { val, min, max } => {
write!(f, "value {val} out of range [{min}, {max}]")
}
QuantError::InvalidGroupSize => write!(f, "invalid group size"),
QuantError::EmptyTensor => write!(f, "tensor is empty"),
}
}
}
impl std::error::Error for QuantError {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BitWidth {
Two,
Three,
Four,
Eight,
}
impl BitWidth {
pub fn bits(self) -> u32 {
match self {
BitWidth::Two => 2,
BitWidth::Three => 3,
BitWidth::Four => 4,
BitWidth::Eight => 8,
}
}
pub fn max_val(self) -> i64 {
match self {
BitWidth::Two => 1,
BitWidth::Three => 3,
BitWidth::Four => 7,
BitWidth::Eight => 127,
}
}
pub fn min_val(self) -> i64 {
match self {
BitWidth::Two => -2,
BitWidth::Three => -4,
BitWidth::Four => -8,
BitWidth::Eight => -128,
}
}
fn offset(self) -> i64 {
-self.min_val()
}
fn mask(self) -> u64 {
(1u64 << self.bits()) - 1
}
pub fn values_per_byte(self) -> usize {
match self {
BitWidth::Two => 4,
BitWidth::Three => 2, BitWidth::Four => 2,
BitWidth::Eight => 1,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct PackedBuffer {
pub data: Vec<u8>,
pub bit_width: BitWidth,
pub num_elements: usize,
}
impl PackedBuffer {
pub fn new(num_elements: usize, bit_width: BitWidth) -> Self {
let total_bits = num_elements * bit_width.bits() as usize;
let num_bytes = total_bits.div_ceil(8);
let data = vec![0u8; num_bytes];
Self { data, bit_width, num_elements }
}
pub fn pack(values: &[i64], bit_width: BitWidth) -> Result<Self, QuantError> {
if values.is_empty() {
return Err(QuantError::EmptyTensor);
}
let min = bit_width.min_val();
let max = bit_width.max_val();
for &v in values {
if v < min || v > max {
return Err(QuantError::ValueOutOfRange { val: v, min, max });
}
}
let mut buf = Self::new(values.len(), bit_width);
for (idx, &v) in values.iter().enumerate() {
buf.set_unchecked(idx, v);
}
Ok(buf)
}
pub fn unpack(&self) -> Vec<i64> {
(0..self.num_elements)
.map(|i| self.get_unchecked(i))
.collect()
}
pub fn get(&self, idx: usize) -> Result<i64, QuantError> {
if idx >= self.num_elements {
return Err(QuantError::IndexOutOfBounds { idx, max: self.num_elements.saturating_sub(1) });
}
Ok(self.get_unchecked(idx))
}
pub fn set(&mut self, idx: usize, val: i64) -> Result<(), QuantError> {
if idx >= self.num_elements {
return Err(QuantError::IndexOutOfBounds { idx, max: self.num_elements.saturating_sub(1) });
}
let min = self.bit_width.min_val();
let max = self.bit_width.max_val();
if val < min || val > max {
return Err(QuantError::ValueOutOfRange { val, min, max });
}
self.set_unchecked(idx, val);
Ok(())
}
pub fn size_bytes(&self) -> usize {
self.data.len()
}
pub fn compression_ratio_vs_f32(&self) -> f32 {
self.bit_width.bits() as f32 / 32.0_f32
}
fn get_unchecked(&self, idx: usize) -> i64 {
let offset = self.bit_width.offset();
let mask = self.bit_width.mask();
let bits = self.bit_width.bits() as usize;
let bit_pos = idx * bits;
let byte_idx = bit_pos / 8;
let bit_off = bit_pos % 8;
let raw = if byte_idx + 1 < self.data.len() {
(self.data[byte_idx] as u16) | ((self.data[byte_idx + 1] as u16) << 8)
} else {
self.data[byte_idx] as u16
};
let unsigned = ((raw as u64) >> bit_off) & mask;
unsigned as i64 - offset
}
fn set_unchecked(&mut self, idx: usize, val: i64) {
let offset = self.bit_width.offset();
let mask = self.bit_width.mask();
let bits = self.bit_width.bits() as usize;
let unsigned = (val + offset) as u64 & mask;
let bit_pos = idx * bits;
let byte_idx = bit_pos / 8;
let bit_off = bit_pos % 8;
self.data[byte_idx] &= !((mask << bit_off) as u8);
self.data[byte_idx] |= ((unsigned << bit_off) & 0xFF) as u8;
if bit_off + bits > 8 && byte_idx + 1 < self.data.len() {
let spill_bits = bit_off + bits - 8;
let spill_mask = ((1u64 << spill_bits) - 1) as u8;
self.data[byte_idx + 1] &= !spill_mask;
self.data[byte_idx + 1] |= (unsigned >> (bits - spill_bits)) as u8 & spill_mask;
}
}
}
#[derive(Debug, Clone)]
pub struct Int2QuantConfig {
pub group_size: usize,
pub symmetric: bool,
}
impl Default for Int2QuantConfig {
fn default() -> Self {
Self { group_size: 128, symmetric: true }
}
}
pub fn quantize_int2(
tensor: &[f32],
config: &Int2QuantConfig,
) -> Result<(PackedBuffer, Vec<f32>, Vec<f32>), QuantError> {
if tensor.is_empty() {
return Err(QuantError::EmptyTensor);
}
if config.group_size == 0 {
return Err(QuantError::InvalidGroupSize);
}
let num_groups = tensor.len().div_ceil(config.group_size);
let mut scales = Vec::with_capacity(num_groups);
let mut zero_points = Vec::with_capacity(num_groups);
let mut quantized = Vec::with_capacity(tensor.len());
let min_q = BitWidth::Two.min_val(); let max_q = BitWidth::Two.max_val();
for group_idx in 0..num_groups {
let start = group_idx * config.group_size;
let end = (start + config.group_size).min(tensor.len());
let group = &tensor[start..end];
let (scale, zero_point) = compute_scale_zero_point(group, min_q, max_q, config.symmetric);
scales.push(scale);
zero_points.push(zero_point);
for &w in group {
let q = quantize_single(w, scale, zero_point, min_q, max_q);
quantized.push(q);
}
}
let packed = PackedBuffer::pack(&quantized, BitWidth::Two)?;
Ok((packed, scales, zero_points))
}
pub fn dequantize_int2(
packed: &PackedBuffer,
scales: &[f32],
zero_points: &[f32],
config: &Int2QuantConfig,
) -> Result<Vec<f32>, QuantError> {
if config.group_size == 0 {
return Err(QuantError::InvalidGroupSize);
}
let values = packed.unpack();
let mut output = Vec::with_capacity(values.len());
for (i, &q) in values.iter().enumerate() {
let group_idx = i / config.group_size;
if group_idx >= scales.len() {
return Err(QuantError::IndexOutOfBounds { idx: group_idx, max: scales.len().saturating_sub(1) });
}
let scale = scales[group_idx];
let zp = zero_points[group_idx];
output.push(dequantize_single(q, scale, zp));
}
Ok(output)
}
pub fn quantize_int3(
tensor: &[f32],
group_size: usize,
) -> Result<(PackedBuffer, Vec<f32>, Vec<f32>), QuantError> {
if tensor.is_empty() {
return Err(QuantError::EmptyTensor);
}
if group_size == 0 {
return Err(QuantError::InvalidGroupSize);
}
let num_groups = tensor.len().div_ceil(group_size);
let mut scales = Vec::with_capacity(num_groups);
let mut zero_points = Vec::with_capacity(num_groups);
let mut quantized = Vec::with_capacity(tensor.len());
let min_q = BitWidth::Three.min_val(); let max_q = BitWidth::Three.max_val();
for group_idx in 0..num_groups {
let start = group_idx * group_size;
let end = (start + group_size).min(tensor.len());
let group = &tensor[start..end];
let (scale, zero_point) = compute_scale_zero_point(group, min_q, max_q, true);
scales.push(scale);
zero_points.push(zero_point);
for &w in group {
let q = quantize_single(w, scale, zero_point, min_q, max_q);
quantized.push(q);
}
}
let packed = PackedBuffer::pack(&quantized, BitWidth::Three)?;
Ok((packed, scales, zero_points))
}
pub fn dequantize_int3(
packed: &PackedBuffer,
scales: &[f32],
zero_points: &[f32],
group_size: usize,
) -> Result<Vec<f32>, QuantError> {
if group_size == 0 {
return Err(QuantError::InvalidGroupSize);
}
let values = packed.unpack();
let mut output = Vec::with_capacity(values.len());
for (i, &q) in values.iter().enumerate() {
let group_idx = i / group_size;
if group_idx >= scales.len() {
return Err(QuantError::IndexOutOfBounds { idx: group_idx, max: scales.len().saturating_sub(1) });
}
let scale = scales[group_idx];
let zp = zero_points[group_idx];
output.push(dequantize_single(q, scale, zp));
}
Ok(output)
}
fn compute_scale_zero_point(
group: &[f32],
min_q: i64,
max_q: i64,
symmetric: bool,
) -> (f32, f32) {
let fmin = group.iter().cloned().fold(f32::INFINITY, f32::min);
let fmax = group.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
if symmetric {
let max_abs = fmin.abs().max(fmax.abs());
if max_abs < f32::EPSILON {
return (1.0_f32, 0.0_f32);
}
let scale = max_abs / (max_q as f32);
(scale, 0.0_f32)
} else {
let q_range = (max_q - min_q) as f32;
let f_range = fmax - fmin;
if f_range < f32::EPSILON {
return (1.0_f32, fmin);
}
let scale = f_range / q_range;
let zero_point = fmin;
(scale, zero_point)
}
}
fn quantize_single(val: f32, scale: f32, zero_point: f32, min_q: i64, max_q: i64) -> i64 {
if scale.abs() < f32::EPSILON {
return 0;
}
let q = ((val - zero_point) / scale).round() as i64;
q.clamp(min_q, max_q)
}
fn dequantize_single(q: i64, scale: f32, zero_point: f32) -> f32 {
q as f32 * scale + zero_point
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bitwidth_properties() {
assert_eq!(BitWidth::Two.bits(), 2);
assert_eq!(BitWidth::Three.bits(), 3);
assert_eq!(BitWidth::Four.bits(), 4);
assert_eq!(BitWidth::Eight.bits(), 8);
assert_eq!(BitWidth::Two.max_val(), 1);
assert_eq!(BitWidth::Three.max_val(), 3);
assert_eq!(BitWidth::Four.max_val(), 7);
assert_eq!(BitWidth::Eight.max_val(), 127);
assert_eq!(BitWidth::Two.min_val(), -2);
assert_eq!(BitWidth::Three.min_val(), -4);
assert_eq!(BitWidth::Four.min_val(), -8);
assert_eq!(BitWidth::Eight.min_val(), -128);
}
#[test]
fn test_values_per_byte() {
assert_eq!(BitWidth::Two.values_per_byte(), 4);
assert_eq!(BitWidth::Four.values_per_byte(), 2);
assert_eq!(BitWidth::Eight.values_per_byte(), 1);
}
#[test]
fn test_2bit_pack_unpack_round_trip() {
let values: Vec<i64> = vec![-2, -1, 0, 1, -2, 1, 0, -1];
let buf = PackedBuffer::pack(&values, BitWidth::Two).expect("pack failed");
let unpacked = buf.unpack();
assert_eq!(unpacked, values, "2-bit round-trip mismatch");
}
#[test]
fn test_2bit_boundary_values() {
let values = vec![-2, 1];
let buf = PackedBuffer::pack(&values, BitWidth::Two).expect("pack failed");
assert_eq!(buf.get(0).expect("get 0 failed"), -2);
assert_eq!(buf.get(1).expect("get 1 failed"), 1);
}
#[test]
fn test_3bit_pack_unpack_round_trip() {
let values: Vec<i64> = vec![-4, -3, -2, -1, 0, 1, 2, 3];
let buf = PackedBuffer::pack(&values, BitWidth::Three).expect("pack failed");
let unpacked = buf.unpack();
assert_eq!(unpacked, values, "3-bit round-trip mismatch");
}
#[test]
fn test_3bit_cross_byte_boundary() {
let values: Vec<i64> = vec![-4, 3, -3, 2, -2, 1, -1, 0];
let buf = PackedBuffer::pack(&values, BitWidth::Three).expect("pack failed");
let unpacked = buf.unpack();
assert_eq!(unpacked, values, "3-bit cross-byte round-trip mismatch");
}
#[test]
fn test_4bit_pack_unpack_round_trip() {
let values: Vec<i64> = vec![-8, -5, 0, 3, 7, -1, -8, 7];
let buf = PackedBuffer::pack(&values, BitWidth::Four).expect("pack failed");
let unpacked = buf.unpack();
assert_eq!(unpacked, values, "4-bit round-trip mismatch");
}
#[test]
fn test_get_set_single_element() {
let mut buf = PackedBuffer::new(16, BitWidth::Four);
buf.set(5, 7).expect("set failed");
assert_eq!(buf.get(5).expect("get failed"), 7);
buf.set(5, -8).expect("set failed");
assert_eq!(buf.get(5).expect("get failed"), -8);
}
#[test]
fn test_set_does_not_corrupt_neighbours() {
let values: Vec<i64> = vec![1, 2, 3, 4, 5, 6, 7, -8];
let mut buf = PackedBuffer::pack(&values, BitWidth::Four).expect("pack failed");
buf.set(3, -1).expect("set failed");
assert_eq!(buf.get(2).expect("get failed"), 3, "neighbour corrupted");
assert_eq!(buf.get(3).expect("get failed"), -1, "value not set");
assert_eq!(buf.get(4).expect("get failed"), 5, "neighbour corrupted");
}
#[test]
fn test_compression_ratio() {
let buf2 = PackedBuffer::new(1, BitWidth::Two);
let buf4 = PackedBuffer::new(1, BitWidth::Four);
let buf8 = PackedBuffer::new(1, BitWidth::Eight);
assert!((buf2.compression_ratio_vs_f32() - 0.0625_f32).abs() < 1e-6);
assert!((buf4.compression_ratio_vs_f32() - 0.125_f32).abs() < 1e-6);
assert!((buf8.compression_ratio_vs_f32() - 0.25_f32).abs() < 1e-6);
}
#[test]
fn test_error_index_out_of_bounds() {
let buf = PackedBuffer::new(4, BitWidth::Four);
assert!(matches!(
buf.get(4),
Err(QuantError::IndexOutOfBounds { idx: 4, .. })
));
}
#[test]
fn test_error_value_out_of_range() {
let result = PackedBuffer::pack(&[8], BitWidth::Four);
assert!(matches!(
result,
Err(QuantError::ValueOutOfRange { val: 8, .. })
));
}
#[test]
fn test_error_empty_tensor() {
let result = PackedBuffer::pack(&[], BitWidth::Two);
assert_eq!(result, Err(QuantError::EmptyTensor));
}
#[test]
fn test_int2_quantization_round_trip() {
let tensor: Vec<f32> = (0..256).map(|i| (i as f32 - 128.0) / 128.0).collect();
let config = Int2QuantConfig { group_size: 64, symmetric: true };
let (packed, scales, zero_points) =
quantize_int2(&tensor, &config).expect("quantize_int2 failed");
let reconstructed =
dequantize_int2(&packed, &scales, &zero_points, &config).expect("dequantize_int2 failed");
assert_eq!(reconstructed.len(), tensor.len());
for v in &reconstructed {
assert!(v.is_finite(), "non-finite value in reconstruction");
}
}
#[test]
fn test_int3_quantization_round_trip() {
let tensor: Vec<f32> = (0..128).map(|i| (i as f32 - 64.0) / 64.0).collect();
let group_size = 32;
let (packed, scales, zero_points) =
quantize_int3(&tensor, group_size).expect("quantize_int3 failed");
let reconstructed =
dequantize_int3(&packed, &scales, &zero_points, group_size).expect("dequantize_int3 failed");
assert_eq!(reconstructed.len(), tensor.len());
let max_err = tensor
.iter()
.zip(reconstructed.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0_f32, f32::max);
assert!(max_err < 0.3_f32, "INT3 max error {max_err} exceeds threshold");
}
#[test]
fn test_size_bytes_2bit() {
let buf = PackedBuffer::new(8, BitWidth::Two);
assert_eq!(buf.size_bytes(), 2);
}
#[test]
fn test_size_bytes_3bit_cross_boundary() {
let buf = PackedBuffer::new(8, BitWidth::Three);
assert_eq!(buf.size_bytes(), 3);
let buf9 = PackedBuffer::new(9, BitWidth::Three);
assert_eq!(buf9.size_bytes(), 4);
}
#[test]
fn test_int2_invalid_group_size() {
let tensor = vec![0.5_f32; 16];
let config = Int2QuantConfig { group_size: 0, symmetric: true };
assert_eq!(quantize_int2(&tensor, &config), Err(QuantError::InvalidGroupSize));
}
}