use crate::QuantError;
pub fn pack_indices(indices: &[u8], bits: u8) -> Result<Vec<u8>, QuantError> {
if bits == 0 || bits > 4 {
return Err(QuantError::UnsupportedBitWidth(bits));
}
let total_bits = indices.len() * bits as usize;
let byte_len = total_bits.div_ceil(8);
let mut packed = vec![0u8; byte_len];
let mut bit_offset = 0usize;
for &idx in indices {
debug_assert!(idx < (1 << bits), "index {idx} exceeds {bits}-bit range");
let byte_pos = bit_offset / 8;
let bit_pos = bit_offset % 8;
packed[byte_pos] |= idx << bit_pos;
if bit_pos + bits as usize > 8 {
packed[byte_pos + 1] |= idx >> (8 - bit_pos);
}
bit_offset += bits as usize;
}
Ok(packed)
}
pub fn unpack_indices(packed: &[u8], bits: u8, count: usize) -> Result<Vec<u8>, QuantError> {
if bits == 0 || bits > 4 {
return Err(QuantError::UnsupportedBitWidth(bits));
}
let mask = (1u8 << bits) - 1;
let mut indices = Vec::with_capacity(count);
let mut bit_offset = 0usize;
for _ in 0..count {
let byte_pos = bit_offset / 8;
let bit_pos = bit_offset % 8;
let mut val = packed[byte_pos] >> bit_pos;
if bit_pos + bits as usize > 8 {
val |= packed[byte_pos + 1] << (8 - bit_pos);
}
indices.push(val & mask);
bit_offset += bits as usize;
}
Ok(indices)
}
pub fn packed_byte_size(count: usize, bits: u8) -> usize {
(count * bits as usize).div_ceil(8)
}
pub fn pack_signs(signs: &[i8]) -> Vec<u8> {
let byte_len = signs.len().div_ceil(8);
let mut packed = vec![0u8; byte_len];
for (i, &s) in signs.iter().enumerate() {
if s > 0 {
packed[i / 8] |= 1 << (i % 8);
}
}
packed
}
pub fn unpack_signs(packed: &[u8], count: usize) -> Vec<i8> {
let mut signs = Vec::with_capacity(count);
for i in 0..count {
let bit = (packed[i / 8] >> (i % 8)) & 1;
signs.push(if bit == 1 { 1 } else { -1 });
}
signs
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pack_unpack_1bit() {
let indices: Vec<u8> = vec![0, 1, 1, 0, 1, 0, 0, 1];
let packed = pack_indices(&indices, 1).unwrap();
let unpacked = unpack_indices(&packed, 1, indices.len()).unwrap();
assert_eq!(indices, unpacked);
}
#[test]
fn pack_unpack_2bit() {
let indices: Vec<u8> = vec![0, 1, 2, 3, 3, 2, 1, 0];
let packed = pack_indices(&indices, 2).unwrap();
assert_eq!(packed.len(), 2); let unpacked = unpack_indices(&packed, 2, indices.len()).unwrap();
assert_eq!(indices, unpacked);
}
#[test]
fn pack_unpack_3bit() {
let indices: Vec<u8> = vec![0, 1, 2, 3, 4, 5, 6, 7];
let packed = pack_indices(&indices, 3).unwrap();
assert_eq!(packed.len(), 3); let unpacked = unpack_indices(&packed, 3, indices.len()).unwrap();
assert_eq!(indices, unpacked);
}
#[test]
fn pack_unpack_4bit() {
let indices: Vec<u8> = vec![0, 5, 10, 15, 3, 7, 12, 1];
let packed = pack_indices(&indices, 4).unwrap();
assert_eq!(packed.len(), 4); let unpacked = unpack_indices(&packed, 4, indices.len()).unwrap();
assert_eq!(indices, unpacked);
}
#[test]
fn pack_unpack_odd_count() {
let indices: Vec<u8> = vec![7, 0, 3, 5, 2];
let packed = pack_indices(&indices, 3).unwrap();
assert_eq!(packed.len(), 2);
let unpacked = unpack_indices(&packed, 3, indices.len()).unwrap();
assert_eq!(indices, unpacked);
}
#[test]
fn packed_byte_size_correct() {
assert_eq!(packed_byte_size(8, 1), 1);
assert_eq!(packed_byte_size(8, 2), 2);
assert_eq!(packed_byte_size(8, 3), 3);
assert_eq!(packed_byte_size(8, 4), 4);
assert_eq!(packed_byte_size(5, 3), 2);
assert_eq!(packed_byte_size(384, 2), 96);
}
#[test]
fn sign_pack_unpack() {
let signs: Vec<i8> = vec![1, -1, -1, 1, 1, 1, -1, -1, 1];
let packed = pack_signs(&signs);
let unpacked = unpack_signs(&packed, signs.len());
assert_eq!(signs, unpacked);
}
}