use half::f16;
use crate::error::{require, Result, TurboQuantError};
pub mod indices;
pub use indices::{
pack_indices_2bit, pack_indices_3bit, pack_indices_4bit, unpack_indices_2bit,
unpack_indices_3bit, unpack_indices_4bit,
};
#[allow(unused_imports)]
use indices::{
chunk_to_2bit_array, chunk_to_3bit_array, chunk_to_4bit_array, chunk_to_packed_3bit_array,
has_2bit_remainder, has_3bit_remainder, has_4bit_remainder, num_2bit_groups, num_3bit_groups,
num_4bit_pairs, packed_2bit_capacity, packed_3bit_capacity, packed_4bit_capacity,
pad_remainder_2bit, pad_remainder_3bit, trailing_4bit_pair,
};
pub(crate) const BITS_TQ2: u8 = 2;
pub(crate) const BITS_TQ3: u8 = 3;
pub(crate) const BITS_TQ4: u8 = 4;
const PACK_2BIT_GROUP_SIZE: usize = 4;
const PACK_3BIT_GROUP_SIZE: usize = 8;
const PACK_3BIT_BYTES: usize = 3;
const PACK_4BIT_GROUP_SIZE: usize = 2;
const MASK_3BIT: u8 = 0x7;
const MASK_2BIT: u8 = 0x3;
const MASK_1BIT: u8 = 0x1;
const MASK_4BIT: u8 = 0xF;
const SHIFT_3: u32 = 3;
const SHIFT_4: u32 = 4;
const SHIFT_5: u32 = 5;
const SHIFT_6: u32 = 6;
const SHIFT_7: u32 = 7;
const SHIFT_1: u32 = 1;
const SHIFT_2: u32 = 2;
const SCALE_SIZE_BYTES: usize = 2;
#[derive(Clone, Copy)]
pub struct TurboQuantConfig {
pub(crate) bits: u8,
pub(crate) dim: usize,
pub(crate) rotation_seed: u64,
}
pub(crate) fn is_valid_bits(bits: u8) -> bool {
bits == BITS_TQ2 || bits == BITS_TQ3 || bits == BITS_TQ4
}
pub(crate) fn is_valid_dim(dim: usize) -> bool {
dim > 0 && dim.is_power_of_two()
}
impl TurboQuantConfig {
pub fn new(bits: u8, dim: usize) -> Result<Self> {
require(is_valid_bits(bits), TurboQuantError::UnsupportedBits(bits))?;
require(is_valid_dim(dim), TurboQuantError::InvalidDimension(dim))?;
Ok(Self {
bits,
dim,
rotation_seed: 0,
})
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.rotation_seed = seed;
self
}
}
pub struct PackedBlock {
pub bits: u8,
pub scale: f16,
pub packed_indices: Vec<u8>,
}
impl PackedBlock {
pub fn new(bits: u8, scale: f16, indices: &[u8]) -> Self {
let pack = |indices: &[u8]| -> Vec<u8> {
match bits {
BITS_TQ2 => pack_indices_2bit(indices),
BITS_TQ3 => pack_indices_3bit(indices),
BITS_TQ4 => pack_indices_4bit(indices),
_ => unreachable!("bits validated to be 2, 3, or 4"),
}
};
Self {
bits,
scale,
packed_indices: pack(indices),
}
}
pub fn size_bytes(&self) -> usize {
SCALE_SIZE_BYTES + self.packed_indices.len()
}
pub fn from_raw(bits: u8, scale: f16, packed_indices: Vec<u8>) -> Self {
Self {
bits,
scale,
packed_indices,
}
}
pub fn unpack_into(&self, count: usize, buf: &mut Vec<u8>) {
buf.clear();
let do_unpack = |packed: &[u8], out: &mut Vec<u8>| match self.bits {
BITS_TQ2 => out.extend_from_slice(&unpack_indices_2bit(packed, count)),
BITS_TQ3 => out.extend_from_slice(&unpack_indices_3bit(packed, count)),
BITS_TQ4 => out.extend_from_slice(&unpack_indices_4bit(packed, count)),
_ => unreachable!("bits validated"),
};
do_unpack(&self.packed_indices, buf);
buf.truncate(count);
}
pub fn unpack(&self, count: usize) -> Vec<u8> {
let do_unpack = |packed: &[u8]| match self.bits {
BITS_TQ2 => unpack_indices_2bit(packed, count),
BITS_TQ3 => unpack_indices_3bit(packed, count),
BITS_TQ4 => unpack_indices_4bit(packed, count),
_ => unreachable!("bits validated"),
};
do_unpack(&self.packed_indices)
}
}
pub fn pack_2bit(values: &[u8; PACK_2BIT_GROUP_SIZE]) -> u8 {
(values[0] & MASK_2BIT)
| ((values[1] & MASK_2BIT) << SHIFT_2)
| ((values[2] & MASK_2BIT) << SHIFT_4)
| ((values[3] & MASK_2BIT) << SHIFT_6)
}
pub fn unpack_2bit(packed: u8) -> [u8; PACK_2BIT_GROUP_SIZE] {
[
packed & MASK_2BIT,
(packed >> SHIFT_2) & MASK_2BIT,
(packed >> SHIFT_4) & MASK_2BIT,
(packed >> SHIFT_6) & MASK_2BIT,
]
}
pub fn pack_3bit(values: &[u8; PACK_3BIT_GROUP_SIZE]) -> [u8; PACK_3BIT_BYTES] {
let mut packed = [0u8; PACK_3BIT_BYTES];
packed[0] = (values[0] & MASK_3BIT)
| ((values[1] & MASK_3BIT) << SHIFT_3)
| ((values[2] & MASK_2BIT) << SHIFT_6);
packed[1] = ((values[2] >> SHIFT_2) & MASK_1BIT)
| ((values[3] & MASK_3BIT) << SHIFT_1)
| ((values[4] & MASK_3BIT) << SHIFT_4)
| ((values[5] & MASK_1BIT) << SHIFT_7);
packed[2] = ((values[5] >> SHIFT_1) & MASK_2BIT)
| ((values[6] & MASK_3BIT) << SHIFT_2)
| ((values[7] & MASK_3BIT) << SHIFT_5);
packed
}
pub fn unpack_3bit(packed: &[u8; PACK_3BIT_BYTES]) -> [u8; PACK_3BIT_GROUP_SIZE] {
let mut values = [0u8; PACK_3BIT_GROUP_SIZE];
values[0] = packed[0] & MASK_3BIT;
values[1] = (packed[0] >> SHIFT_3) & MASK_3BIT;
values[2] = ((packed[0] >> SHIFT_6) & MASK_2BIT) | ((packed[1] & MASK_1BIT) << SHIFT_2);
values[3] = (packed[1] >> SHIFT_1) & MASK_3BIT;
values[4] = (packed[1] >> SHIFT_4) & MASK_3BIT;
values[5] = ((packed[1] >> SHIFT_7) & MASK_1BIT) | ((packed[2] & MASK_2BIT) << SHIFT_1);
values[6] = (packed[2] >> SHIFT_2) & MASK_3BIT;
values[7] = (packed[2] >> SHIFT_5) & MASK_3BIT;
values
}
pub fn pack_4bit(values: &[u8; 2]) -> u8 {
(values[0] & MASK_4BIT) | ((values[1] & MASK_4BIT) << SHIFT_4)
}
pub fn unpack_4bit(packed: u8) -> [u8; 2] {
[packed & MASK_4BIT, (packed >> SHIFT_4) & MASK_4BIT]
}
#[cfg(test)]
mod tests {
use super::*;
const TEST_BLOCK_SIZE: usize = 32;
const TEST_DIM_128: usize = 128;
const TEST_3BIT_GROUPS: usize = 4;
const TEST_4BIT_PAIRS: usize = 5;
const MAX_3BIT_VALUE: u8 = 7;
const MAX_4BIT_VALUE: u8 = 15;
const TEST_TRAILING_VALUE: u8 = 9;
const TEST_3BIT_EXACT_COUNT: usize = 16;
const TEST_3BIT_REMAINDER_COUNT: usize = 11;
const TEST_4BIT_EVEN_COUNT: usize = 10;
const TEST_4BIT_ODD_COUNT: usize = 7;
const TEST_4BIT_LEVELS: u8 = 16;
const TEST_3BIT_LEVELS: usize = 8;
const TEST_SCALE: f32 = 1.5;
const TEST_SCALE_HALF: f32 = 0.5;
const MAX_2BIT_VALUE: u8 = 3;
const TEST_2BIT_EXACT_COUNT: usize = 12;
const TEST_2BIT_REMAINDER_COUNT: usize = 7;
#[test]
fn is_valid_bits_accepts_2_3_and_4() {
assert!(is_valid_bits(BITS_TQ2));
assert!(is_valid_bits(BITS_TQ3));
assert!(is_valid_bits(BITS_TQ4));
}
#[test]
fn is_valid_bits_rejects_others() {
assert!(!is_valid_bits(0));
assert!(!is_valid_bits(1));
assert!(!is_valid_bits(5));
}
#[test]
fn is_valid_dim_accepts_powers_of_two() {
assert!(is_valid_dim(TEST_DIM_128 / 2));
assert!(is_valid_dim(TEST_DIM_128));
}
#[test]
fn is_valid_dim_rejects_invalid() {
assert!(!is_valid_dim(0));
assert!(!is_valid_dim(3));
assert!(!is_valid_dim(100));
}
#[test]
fn packed_3bit_capacity_no_remainder() {
assert_eq!(
packed_3bit_capacity(TEST_3BIT_GROUPS, false),
TEST_3BIT_GROUPS * PACK_3BIT_BYTES
);
}
#[test]
fn packed_3bit_capacity_with_remainder() {
assert_eq!(
packed_3bit_capacity(TEST_3BIT_GROUPS, true),
TEST_3BIT_GROUPS * PACK_3BIT_BYTES + PACK_3BIT_BYTES
);
}
#[test]
fn packed_3bit_capacity_zero_groups() {
assert_eq!(packed_3bit_capacity(0, false), 0);
assert_eq!(packed_3bit_capacity(0, true), 3);
}
#[test]
fn packed_4bit_capacity_no_remainder() {
assert_eq!(
packed_4bit_capacity(TEST_4BIT_PAIRS, false),
TEST_4BIT_PAIRS
);
}
#[test]
fn packed_4bit_capacity_with_remainder() {
assert_eq!(
packed_4bit_capacity(TEST_4BIT_PAIRS, true),
TEST_4BIT_PAIRS + 1
);
}
#[test]
fn chunk_to_3bit_array_preserves_values() {
let input: Vec<u8> = vec![0, 1, 2, 3, 4, 5, 6, 7];
let arr = chunk_to_3bit_array(&input);
assert_eq!(arr, [0, 1, 2, 3, 4, 5, 6, 7]);
}
#[test]
fn chunk_to_4bit_array_preserves_values() {
let input: Vec<u8> = vec![10, 15];
let arr = chunk_to_4bit_array(&input);
assert_eq!(arr, [10, 15]);
}
#[test]
fn pad_remainder_3bit_pads_correctly() {
let tail: Vec<u8> = vec![1, 2, 3];
let padded = pad_remainder_3bit(&tail);
assert_eq!(padded, [1, 2, 3, 0, 0, 0, 0, 0]);
}
#[test]
fn pad_remainder_3bit_single_element() {
let tail: Vec<u8> = vec![5];
let padded = pad_remainder_3bit(&tail);
assert_eq!(padded, [5, 0, 0, 0, 0, 0, 0, 0]);
}
#[test]
fn trailing_4bit_pair_handles_single_element() {
let pair = trailing_4bit_pair(TEST_TRAILING_VALUE);
assert_eq!(pair, [TEST_TRAILING_VALUE, 0]);
}
#[test]
fn chunk_to_packed_3bit_array_preserves_values() {
let input: Vec<u8> = vec![0xAB, 0xCD, 0xEF];
let arr = chunk_to_packed_3bit_array(&input);
assert_eq!(arr, [0xAB, 0xCD, 0xEF]);
}
#[test]
fn pack_unpack_3bit_identity() {
let values: [u8; PACK_3BIT_GROUP_SIZE] = [0, 1, 2, 3, 4, 5, 6, MAX_3BIT_VALUE];
let packed = pack_3bit(&values);
let unpacked = unpack_3bit(&packed);
assert_eq!(values, unpacked);
}
#[test]
fn pack_unpack_3bit_zeros() {
let values = [0u8; PACK_3BIT_GROUP_SIZE];
assert_eq!(unpack_3bit(&pack_3bit(&values)), values);
}
#[test]
fn pack_unpack_3bit_max() {
let values = [MAX_3BIT_VALUE; PACK_3BIT_GROUP_SIZE];
assert_eq!(unpack_3bit(&pack_3bit(&values)), values);
}
#[test]
fn pack_unpack_4bit_identity() {
let values: [u8; PACK_4BIT_GROUP_SIZE] = [0, MAX_4BIT_VALUE];
let packed = pack_4bit(&values);
let unpacked = unpack_4bit(packed);
assert_eq!(values, unpacked);
}
#[test]
fn pack_unpack_4bit_zeros() {
let values = [0u8; PACK_4BIT_GROUP_SIZE];
assert_eq!(unpack_4bit(pack_4bit(&values)), values);
}
#[test]
fn pack_unpack_4bit_max() {
let values = [MAX_4BIT_VALUE; PACK_4BIT_GROUP_SIZE];
assert_eq!(unpack_4bit(pack_4bit(&values)), values);
}
#[test]
fn roundtrip_3bit_exact_multiple() {
let indices: Vec<u8> = (0..TEST_3BIT_EXACT_COUNT as u8)
.map(|i| i % (MAX_3BIT_VALUE + 1))
.collect();
let packed = pack_indices_3bit(&indices);
let unpacked = unpack_indices_3bit(&packed, indices.len());
assert_eq!(indices, unpacked);
}
#[test]
fn roundtrip_3bit_with_remainder() {
let indices: Vec<u8> = (0..TEST_3BIT_REMAINDER_COUNT as u8)
.map(|i| i % (MAX_3BIT_VALUE + 1))
.collect();
let packed = pack_indices_3bit(&indices);
let unpacked = unpack_indices_3bit(&packed, indices.len());
assert_eq!(indices, unpacked);
}
#[test]
fn roundtrip_4bit_even_count() {
let indices: Vec<u8> = (0..TEST_4BIT_EVEN_COUNT as u8)
.map(|i| i % TEST_4BIT_LEVELS)
.collect();
let packed = pack_indices_4bit(&indices);
let unpacked = unpack_indices_4bit(&packed, indices.len());
assert_eq!(indices, unpacked);
}
#[test]
fn roundtrip_4bit_odd_count() {
let indices: Vec<u8> = (0..TEST_4BIT_ODD_COUNT as u8)
.map(|i| i % TEST_4BIT_LEVELS)
.collect();
let packed = pack_indices_4bit(&indices);
let unpacked = unpack_indices_4bit(&packed, indices.len());
assert_eq!(indices, unpacked);
}
#[test]
fn config_rejects_invalid_bits() {
assert!(TurboQuantConfig::new(1, TEST_BLOCK_SIZE).is_err());
assert!(TurboQuantConfig::new(5, TEST_BLOCK_SIZE).is_err());
}
#[test]
fn config_rejects_non_power_of_two() {
assert!(TurboQuantConfig::new(BITS_TQ3, 33).is_err());
assert!(TurboQuantConfig::new(BITS_TQ4, 0).is_err());
}
#[test]
fn config_accepts_valid() {
assert!(TurboQuantConfig::new(BITS_TQ2, TEST_BLOCK_SIZE).is_ok());
assert!(TurboQuantConfig::new(BITS_TQ3, TEST_BLOCK_SIZE).is_ok());
assert!(TurboQuantConfig::new(BITS_TQ4, TEST_DIM_128).is_ok());
}
const TQ3_D32_EXPECTED_SIZE: usize = SCALE_SIZE_BYTES + 12;
const TQ4_D32_EXPECTED_SIZE: usize = SCALE_SIZE_BYTES + 16;
#[test]
fn packed_block_tq3_size_bytes() {
let indices = vec![0u8; TEST_BLOCK_SIZE];
let block = PackedBlock::new(BITS_TQ3, f16::from_f32(1.0), &indices);
assert_eq!(block.size_bytes(), TQ3_D32_EXPECTED_SIZE);
}
#[test]
fn packed_block_tq4_size_bytes() {
let indices = vec![0u8; TEST_BLOCK_SIZE];
let block = PackedBlock::new(BITS_TQ4, f16::from_f32(1.0), &indices);
assert_eq!(block.size_bytes(), TQ4_D32_EXPECTED_SIZE);
}
#[test]
fn pack_unpack_2bit_identity() {
let values: [u8; PACK_2BIT_GROUP_SIZE] = [0, 1, 2, MAX_2BIT_VALUE];
let packed = pack_2bit(&values);
let unpacked = unpack_2bit(packed);
assert_eq!(values, unpacked);
}
#[test]
fn pack_unpack_2bit_zeros() {
let values = [0u8; PACK_2BIT_GROUP_SIZE];
assert_eq!(unpack_2bit(pack_2bit(&values)), values);
}
#[test]
fn pack_unpack_2bit_max() {
let values = [MAX_2BIT_VALUE; PACK_2BIT_GROUP_SIZE];
assert_eq!(unpack_2bit(pack_2bit(&values)), values);
}
#[test]
fn roundtrip_2bit_exact_multiple() {
let indices: Vec<u8> = (0..TEST_2BIT_EXACT_COUNT as u8)
.map(|i| i % (MAX_2BIT_VALUE + 1))
.collect();
let packed = pack_indices_2bit(&indices);
let unpacked = unpack_indices_2bit(&packed, indices.len());
assert_eq!(indices, unpacked);
}
#[test]
fn roundtrip_2bit_with_remainder() {
let indices: Vec<u8> = (0..TEST_2BIT_REMAINDER_COUNT as u8)
.map(|i| i % (MAX_2BIT_VALUE + 1))
.collect();
let packed = pack_indices_2bit(&indices);
let unpacked = unpack_indices_2bit(&packed, indices.len());
assert_eq!(indices, unpacked);
}
#[test]
fn packed_block_tq2_size_bytes() {
let indices = vec![0u8; TEST_BLOCK_SIZE];
let block = PackedBlock::new(BITS_TQ2, f16::from_f32(1.0), &indices);
assert_eq!(block.size_bytes(), 10);
}
#[test]
fn packed_indices_returns_raw_bytes() {
let indices = vec![1u8, 2, 3, 0, 1, 2, 3, 0];
let block = PackedBlock::new(BITS_TQ2, f16::from_f32(TEST_SCALE), &indices);
let raw = block.packed_indices;
assert_eq!(raw.len(), 2);
let block2 = PackedBlock::new(BITS_TQ2, f16::from_f32(TEST_SCALE), &indices);
assert_eq!(raw, block2.packed_indices);
}
#[test]
fn packed_indices_3bit_length() {
let indices = vec![0u8; TEST_DIM_128];
let block = PackedBlock::new(BITS_TQ3, f16::from_f32(1.0), &indices);
assert_eq!(block.packed_indices.len(), 48);
}
#[test]
fn from_raw_roundtrip() {
let indices = vec![3u8, 1, 0, 2, 3, 1, 0, 2];
let original = PackedBlock::new(BITS_TQ2, f16::from_f32(2.0), &indices);
let reconstructed = PackedBlock::from_raw(
original.bits,
original.scale,
original.packed_indices.to_vec(),
);
assert_eq!(reconstructed.bits, original.bits);
assert_eq!(reconstructed.scale, original.scale);
assert_eq!(reconstructed.packed_indices, original.packed_indices);
assert_eq!(reconstructed.unpack(indices.len()), indices);
}
#[test]
fn from_raw_3bit_roundtrip() {
let indices: Vec<u8> = (0..TEST_DIM_128)
.map(|i| (i % TEST_3BIT_LEVELS) as u8)
.collect();
let original = PackedBlock::new(BITS_TQ3, f16::from_f32(TEST_SCALE_HALF), &indices);
let reconstructed =
PackedBlock::from_raw(BITS_TQ3, original.scale, original.packed_indices.to_vec());
assert_eq!(reconstructed.unpack(TEST_DIM_128), indices);
}
}