use serde::{Deserialize, Serialize};
use crate::error::{Result, TurboQuantError};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct BitPackedVector {
data: Vec<u8>,
count: usize,
bit_width: u8,
}
impl BitPackedVector {
pub fn pack(indices: &[u8], bit_width: u8) -> Result<Self> {
if !(1..=8).contains(&bit_width) {
return Err(TurboQuantError::InvalidBitWidth(bit_width));
}
let count = indices.len();
let total_bits = count * bit_width as usize;
let num_bytes = total_bits.div_ceil(8);
let mut data = vec![0u8; num_bytes];
let mask = (1u16 << bit_width) - 1;
for (i, &idx) in indices.iter().enumerate() {
if (idx as u16) > mask {
return Err(TurboQuantError::InvalidQuantizationIndex {
index: idx,
max: mask as u8,
bit_width,
});
}
let val = (idx as u16) & mask;
let bit_offset = i * bit_width as usize;
let byte_idx = bit_offset / 8;
let bit_idx = bit_offset % 8;
let shifted = val << (16 - bit_width as u16 - bit_idx as u16);
let hi = (shifted >> 8) as u8;
let lo = (shifted & 0xFF) as u8;
data[byte_idx] |= hi;
if byte_idx + 1 < num_bytes {
data[byte_idx + 1] |= lo;
}
}
Ok(Self {
data,
count,
bit_width,
})
}
pub fn unpack(&self) -> Vec<u8> {
let mut indices = Vec::with_capacity(self.count);
let mask = (1u16 << self.bit_width) - 1;
for i in 0..self.count {
let bit_offset = i * self.bit_width as usize;
let byte_idx = bit_offset / 8;
let bit_idx = bit_offset % 8;
let hi = self.data[byte_idx] as u16;
let lo = if byte_idx + 1 < self.data.len() {
self.data[byte_idx + 1] as u16
} else {
0
};
let combined = (hi << 8) | lo;
let shift = 16 - self.bit_width as u16 - bit_idx as u16;
let val = ((combined >> shift) & mask) as u8;
indices.push(val);
}
indices
}
pub fn byte_len(&self) -> usize {
self.data.len()
}
pub fn count(&self) -> usize {
self.count
}
pub fn bit_width(&self) -> u8 {
self.bit_width
}
pub fn as_bytes(&self) -> &[u8] {
&self.data
}
pub fn compression_ratio_vs_u8(&self) -> f64 {
if self.data.is_empty() {
return 1.0;
}
self.count as f64 / self.data.len() as f64
}
pub fn compression_ratio_vs_f32(&self) -> f64 {
if self.data.is_empty() {
return 1.0;
}
(self.count as f64 * 4.0) / self.data.len() as f64
}
}
pub fn pack_signs(signs: &[bool]) -> Vec<u8> {
let num_bytes = signs.len().div_ceil(8);
let mut data = vec![0u8; num_bytes];
for (i, &s) in signs.iter().enumerate() {
if s {
let byte_idx = i / 8;
let bit_idx = 7 - (i % 8);
data[byte_idx] |= 1 << bit_idx;
}
}
data
}
pub fn unpack_signs(data: &[u8], count: usize) -> Vec<bool> {
let mut signs = Vec::with_capacity(count);
for i in 0..count {
let byte_idx = i / 8;
let bit_idx = 7 - (i % 8);
let bit = (data[byte_idx] >> bit_idx) & 1;
signs.push(bit == 1);
}
signs
}
pub fn packed_byte_size(count: usize, bit_width: u8) -> usize {
(count * bit_width as usize).div_ceil(8)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pack_unpack_4bit() {
let indices: Vec<u8> = (0..16).collect();
let packed = BitPackedVector::pack(&indices, 4).unwrap();
assert_eq!(packed.byte_len(), 8); let unpacked = packed.unpack();
assert_eq!(unpacked, indices);
}
#[test]
fn test_pack_unpack_1bit() {
let indices = vec![0, 1, 1, 0, 1, 0, 0, 1, 1];
let packed = BitPackedVector::pack(&indices, 1).unwrap();
assert_eq!(packed.byte_len(), 2); let unpacked = packed.unpack();
assert_eq!(unpacked, indices);
}
#[test]
fn test_pack_unpack_2bit() {
let indices = vec![0, 1, 2, 3, 3, 2, 1, 0];
let packed = BitPackedVector::pack(&indices, 2).unwrap();
assert_eq!(packed.byte_len(), 2); let unpacked = packed.unpack();
assert_eq!(unpacked, indices);
}
#[test]
fn test_pack_unpack_3bit() {
let indices = vec![0, 1, 2, 3, 4, 5, 6, 7];
let packed = BitPackedVector::pack(&indices, 3).unwrap();
assert_eq!(packed.byte_len(), 3); let unpacked = packed.unpack();
assert_eq!(unpacked, indices);
}
#[test]
fn test_pack_unpack_8bit() {
let indices: Vec<u8> = (0..=255).collect();
let packed = BitPackedVector::pack(&indices, 8).unwrap();
assert_eq!(packed.byte_len(), 256);
let unpacked = packed.unpack();
assert_eq!(unpacked, indices);
}
#[test]
fn test_pack_unpack_large_4bit() {
let indices: Vec<u8> = (0..512).map(|i| (i % 16) as u8).collect();
let packed = BitPackedVector::pack(&indices, 4).unwrap();
assert_eq!(packed.byte_len(), 256); let unpacked = packed.unpack();
assert_eq!(unpacked, indices);
}
#[test]
fn test_invalid_bit_width() {
assert!(BitPackedVector::pack(&[0, 1], 0).is_err());
assert!(BitPackedVector::pack(&[0, 1], 9).is_err());
}
#[test]
fn test_empty_vector() {
let packed = BitPackedVector::pack(&[], 4).unwrap();
assert_eq!(packed.byte_len(), 0);
assert_eq!(packed.count(), 0);
assert_eq!(packed.unpack(), Vec::<u8>::new());
}
#[test]
fn test_compression_ratio() {
let indices: Vec<u8> = vec![0; 128];
let packed = BitPackedVector::pack(&indices, 4).unwrap();
assert!((packed.compression_ratio_vs_f32() - 8.0).abs() < 0.01);
}
#[test]
fn test_pack_unpack_signs() {
let signs = vec![true, false, true, true, false, false, true, false, true];
let packed = pack_signs(&signs);
assert_eq!(packed.len(), 2); let unpacked = unpack_signs(&packed, signs.len());
assert_eq!(unpacked, signs);
}
#[test]
fn test_pack_signs_empty() {
let signs: Vec<bool> = vec![];
let packed = pack_signs(&signs);
assert!(packed.is_empty());
let unpacked = unpack_signs(&packed, 0);
assert!(unpacked.is_empty());
}
#[test]
fn test_packed_byte_size() {
assert_eq!(packed_byte_size(128, 4), 64);
assert_eq!(packed_byte_size(128, 2), 32);
assert_eq!(packed_byte_size(128, 1), 16);
assert_eq!(packed_byte_size(100, 3), 38); assert_eq!(packed_byte_size(0, 4), 0);
}
#[test]
fn test_pack_unpack_5bit() {
let indices: Vec<u8> = (0..32).collect();
let packed = BitPackedVector::pack(&indices, 5).unwrap();
assert_eq!(packed.byte_len(), 20); let unpacked = packed.unpack();
assert_eq!(unpacked, indices);
}
#[test]
fn test_pack_rejects_overflow() {
let indices = vec![255u8]; assert!(BitPackedVector::pack(&indices, 4).is_err());
}
}