use half::f16;
use crate::error::{require, Result, TurboQuantError};
pub(crate) const BITS_TQ2: u8 = 2;
pub(crate) const BITS_TQ3: u8 = 3;
pub(crate) const BITS_TQ4: u8 = 4;
const PACK_2BIT_GROUP_SIZE: usize = 4;
const PACK_3BIT_GROUP_SIZE: usize = 8;
const PACK_3BIT_BYTES: usize = 3;
const PACK_4BIT_GROUP_SIZE: usize = 2;
const MASK_3BIT: u8 = 0x7;
const MASK_2BIT: u8 = 0x3;
const MASK_1BIT: u8 = 0x1;
const MASK_4BIT: u8 = 0xF;
const SHIFT_3: u32 = 3;
const SHIFT_4: u32 = 4;
const SHIFT_5: u32 = 5;
const SHIFT_6: u32 = 6;
const SHIFT_7: u32 = 7;
const SHIFT_1: u32 = 1;
const SHIFT_2: u32 = 2;
const SCALE_SIZE_BYTES: usize = 2;
#[derive(Clone, Copy)]
pub struct TurboQuantConfig {
pub(crate) bits: u8,
pub(crate) dim: usize,
pub(crate) rotation_seed: u64,
}
pub(crate) fn is_valid_bits(bits: u8) -> bool {
bits == BITS_TQ2 || bits == BITS_TQ3 || bits == BITS_TQ4
}
pub(crate) fn is_valid_dim(dim: usize) -> bool {
dim > 0 && dim.is_power_of_two()
}
impl TurboQuantConfig {
pub fn new(bits: u8, dim: usize) -> Result<Self> {
require(is_valid_bits(bits), TurboQuantError::UnsupportedBits(bits))?;
require(is_valid_dim(dim), TurboQuantError::InvalidDimension(dim))?;
Ok(Self {
bits,
dim,
rotation_seed: 0,
})
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.rotation_seed = seed;
self
}
}
pub struct PackedBlock {
bits: u8,
scale: f16,
packed_indices: Vec<u8>,
}
impl PackedBlock {
pub fn new(bits: u8, scale: f16, indices: &[u8]) -> Self {
let pack = |indices: &[u8]| -> Vec<u8> {
match bits {
BITS_TQ2 => pack_indices_2bit(indices),
BITS_TQ3 => pack_indices_3bit(indices),
BITS_TQ4 => pack_indices_4bit(indices),
_ => unreachable!("bits validated to be 2, 3, or 4"),
}
};
Self {
bits,
scale,
packed_indices: pack(indices),
}
}
pub fn scale(&self) -> f16 {
self.scale
}
pub fn bits(&self) -> u8 {
self.bits
}
pub fn size_bytes(&self) -> usize {
SCALE_SIZE_BYTES + self.packed_indices.len()
}
pub fn unpack_into(&self, count: usize, buf: &mut Vec<u8>) {
buf.clear();
let do_unpack = |packed: &[u8], out: &mut Vec<u8>| match self.bits {
BITS_TQ2 => out.extend_from_slice(&unpack_indices_2bit(packed, count)),
BITS_TQ3 => out.extend_from_slice(&unpack_indices_3bit(packed, count)),
BITS_TQ4 => out.extend_from_slice(&unpack_indices_4bit(packed, count)),
_ => unreachable!("bits validated"),
};
do_unpack(&self.packed_indices, buf);
buf.truncate(count);
}
pub fn unpack(&self, count: usize) -> Vec<u8> {
let do_unpack = |packed: &[u8]| match self.bits {
BITS_TQ2 => unpack_indices_2bit(packed, count),
BITS_TQ3 => unpack_indices_3bit(packed, count),
BITS_TQ4 => unpack_indices_4bit(packed, count),
_ => unreachable!("bits validated"),
};
do_unpack(&self.packed_indices)
}
}
pub fn pack_2bit(values: &[u8; PACK_2BIT_GROUP_SIZE]) -> u8 {
(values[0] & MASK_2BIT)
| ((values[1] & MASK_2BIT) << SHIFT_2)
| ((values[2] & MASK_2BIT) << SHIFT_4)
| ((values[3] & MASK_2BIT) << SHIFT_6)
}
pub fn unpack_2bit(packed: u8) -> [u8; PACK_2BIT_GROUP_SIZE] {
[
packed & MASK_2BIT,
(packed >> SHIFT_2) & MASK_2BIT,
(packed >> SHIFT_4) & MASK_2BIT,
(packed >> SHIFT_6) & MASK_2BIT,
]
}
fn num_2bit_groups(len: usize) -> usize {
len / PACK_2BIT_GROUP_SIZE
}
fn has_2bit_remainder(len: usize) -> bool {
len % PACK_2BIT_GROUP_SIZE != 0
}
fn packed_2bit_capacity(num_groups: usize, has_remainder: bool) -> usize {
num_groups + usize::from(has_remainder)
}
fn chunk_to_2bit_array(chunk: &[u8]) -> [u8; PACK_2BIT_GROUP_SIZE] {
chunk.try_into().expect("chunk size matches group size")
}
fn pad_remainder_2bit(tail: &[u8]) -> [u8; PACK_2BIT_GROUP_SIZE] {
let mut padded = [0u8; PACK_2BIT_GROUP_SIZE];
padded[..tail.len()].copy_from_slice(tail);
padded
}
pub fn pack_indices_2bit(indices: &[u8]) -> Vec<u8> {
pack_indices_chunked(
indices,
PACK_2BIT_GROUP_SIZE,
packed_2bit_capacity(
num_2bit_groups(indices.len()),
has_2bit_remainder(indices.len()),
),
|chunk, out| out.push(pack_2bit(&chunk_to_2bit_array(chunk))),
|tail, out| out.push(pack_2bit(&pad_remainder_2bit(tail))),
)
}
pub fn unpack_indices_2bit(packed: &[u8], count: usize) -> Vec<u8> {
let mut result = Vec::with_capacity(count);
for &byte in packed {
let vals = unpack_2bit(byte);
result.extend_from_slice(&vals);
}
result.truncate(count);
result
}
pub fn pack_3bit(values: &[u8; PACK_3BIT_GROUP_SIZE]) -> [u8; PACK_3BIT_BYTES] {
let mut packed = [0u8; PACK_3BIT_BYTES];
packed[0] = (values[0] & MASK_3BIT)
| ((values[1] & MASK_3BIT) << SHIFT_3)
| ((values[2] & MASK_2BIT) << SHIFT_6);
packed[1] = ((values[2] >> SHIFT_2) & MASK_1BIT)
| ((values[3] & MASK_3BIT) << SHIFT_1)
| ((values[4] & MASK_3BIT) << SHIFT_4)
| ((values[5] & MASK_1BIT) << SHIFT_7);
packed[2] = ((values[5] >> SHIFT_1) & MASK_2BIT)
| ((values[6] & MASK_3BIT) << SHIFT_2)
| ((values[7] & MASK_3BIT) << SHIFT_5);
packed
}
pub fn unpack_3bit(packed: &[u8; PACK_3BIT_BYTES]) -> [u8; PACK_3BIT_GROUP_SIZE] {
let mut values = [0u8; PACK_3BIT_GROUP_SIZE];
values[0] = packed[0] & MASK_3BIT;
values[1] = (packed[0] >> SHIFT_3) & MASK_3BIT;
values[2] = ((packed[0] >> SHIFT_6) & MASK_2BIT) | ((packed[1] & MASK_1BIT) << SHIFT_2);
values[3] = (packed[1] >> SHIFT_1) & MASK_3BIT;
values[4] = (packed[1] >> SHIFT_4) & MASK_3BIT;
values[5] = ((packed[1] >> SHIFT_7) & MASK_1BIT) | ((packed[2] & MASK_2BIT) << SHIFT_1);
values[6] = (packed[2] >> SHIFT_2) & MASK_3BIT;
values[7] = (packed[2] >> SHIFT_5) & MASK_3BIT;
values
}
pub fn pack_4bit(values: &[u8; 2]) -> u8 {
(values[0] & MASK_4BIT) | ((values[1] & MASK_4BIT) << SHIFT_4)
}
pub fn unpack_4bit(packed: u8) -> [u8; 2] {
[packed & MASK_4BIT, (packed >> SHIFT_4) & MASK_4BIT]
}
fn num_3bit_groups(len: usize) -> usize {
len / PACK_3BIT_GROUP_SIZE
}
fn has_3bit_remainder(len: usize) -> bool {
len % PACK_3BIT_GROUP_SIZE != 0
}
fn packed_3bit_capacity(num_groups: usize, has_remainder: bool) -> usize {
let remainder_bytes = if has_remainder { PACK_3BIT_BYTES } else { 0 };
num_groups * PACK_3BIT_BYTES + remainder_bytes
}
fn chunk_to_3bit_array(chunk: &[u8]) -> [u8; PACK_3BIT_GROUP_SIZE] {
chunk.try_into().expect("chunk size matches group size")
}
fn pad_remainder_3bit(tail: &[u8]) -> [u8; PACK_3BIT_GROUP_SIZE] {
let mut padded = [0u8; PACK_3BIT_GROUP_SIZE];
padded[..tail.len()].copy_from_slice(tail);
padded
}
fn chunk_to_packed_3bit_array(chunk: &[u8]) -> [u8; PACK_3BIT_BYTES] {
chunk.try_into().expect("chunk size matches group size")
}
fn num_4bit_pairs(len: usize) -> usize {
len / PACK_4BIT_GROUP_SIZE
}
fn has_4bit_remainder(len: usize) -> bool {
len % PACK_4BIT_GROUP_SIZE != 0
}
fn packed_4bit_capacity(num_pairs: usize, has_remainder: bool) -> usize {
num_pairs + usize::from(has_remainder)
}
fn chunk_to_4bit_array(pair: &[u8]) -> [u8; PACK_4BIT_GROUP_SIZE] {
pair.try_into().expect("chunk size matches group size")
}
fn trailing_4bit_pair(last: u8) -> [u8; PACK_4BIT_GROUP_SIZE] {
[last, 0]
}
fn pack_indices_chunked<F, R>(
indices: &[u8],
group_size: usize,
capacity: usize,
mut pack_group: F,
mut pack_remainder: R,
) -> Vec<u8>
where
F: FnMut(&[u8], &mut Vec<u8>),
R: FnMut(&[u8], &mut Vec<u8>),
{
let mut packed = Vec::with_capacity(capacity);
for chunk in indices.chunks_exact(group_size) {
pack_group(chunk, &mut packed);
}
let mut handle_tail = || {
let tail = indices.chunks_exact(group_size).remainder();
if tail.is_empty() {
return;
}
pack_remainder(tail, &mut packed);
};
handle_tail();
packed
}
pub fn pack_indices_3bit(indices: &[u8]) -> Vec<u8> {
pack_indices_chunked(
indices,
PACK_3BIT_GROUP_SIZE,
packed_3bit_capacity(
num_3bit_groups(indices.len()),
has_3bit_remainder(indices.len()),
),
|chunk, out| out.extend_from_slice(&pack_3bit(&chunk_to_3bit_array(chunk))),
|tail, out| out.extend_from_slice(&pack_3bit(&pad_remainder_3bit(tail))),
)
}
pub fn unpack_indices_3bit(packed: &[u8], count: usize) -> Vec<u8> {
let mut result = Vec::with_capacity(count);
for chunk in packed.chunks_exact(PACK_3BIT_BYTES) {
let arr = chunk_to_packed_3bit_array(chunk);
let vals = unpack_3bit(&arr);
result.extend_from_slice(&vals);
}
result.truncate(count);
result
}
pub fn pack_indices_4bit(indices: &[u8]) -> Vec<u8> {
pack_indices_chunked(
indices,
PACK_4BIT_GROUP_SIZE,
packed_4bit_capacity(
num_4bit_pairs(indices.len()),
has_4bit_remainder(indices.len()),
),
|chunk, out| out.push(pack_4bit(&chunk_to_4bit_array(chunk))),
|tail, out| out.push(pack_4bit(&trailing_4bit_pair(tail[0]))),
)
}
pub fn unpack_indices_4bit(packed: &[u8], count: usize) -> Vec<u8> {
let mut result = Vec::with_capacity(count);
for &byte in packed {
let vals = unpack_4bit(byte);
result.extend_from_slice(&vals);
}
result.truncate(count);
result
}
#[cfg(test)]
mod tests {
use super::*;
const TEST_BLOCK_SIZE: usize = 32;
const TEST_DIM_128: usize = 128;
const TEST_3BIT_GROUPS: usize = 4;
const TEST_4BIT_PAIRS: usize = 5;
const MAX_3BIT_VALUE: u8 = 7;
const MAX_4BIT_VALUE: u8 = 15;
const TEST_TRAILING_VALUE: u8 = 9;
const TEST_3BIT_EXACT_COUNT: usize = 16;
const TEST_3BIT_REMAINDER_COUNT: usize = 11;
const TEST_4BIT_EVEN_COUNT: usize = 10;
const TEST_4BIT_ODD_COUNT: usize = 7;
const TEST_4BIT_LEVELS: u8 = 16;
const MAX_2BIT_VALUE: u8 = 3;
const TEST_2BIT_EXACT_COUNT: usize = 12;
const TEST_2BIT_REMAINDER_COUNT: usize = 7;
#[test]
fn is_valid_bits_accepts_2_3_and_4() {
assert!(is_valid_bits(BITS_TQ2));
assert!(is_valid_bits(BITS_TQ3));
assert!(is_valid_bits(BITS_TQ4));
}
#[test]
fn is_valid_bits_rejects_others() {
assert!(!is_valid_bits(0));
assert!(!is_valid_bits(1));
assert!(!is_valid_bits(5));
}
#[test]
fn is_valid_dim_accepts_powers_of_two() {
assert!(is_valid_dim(TEST_DIM_128 / 2));
assert!(is_valid_dim(TEST_DIM_128));
}
#[test]
fn is_valid_dim_rejects_invalid() {
assert!(!is_valid_dim(0));
assert!(!is_valid_dim(3));
assert!(!is_valid_dim(100));
}
#[test]
fn packed_3bit_capacity_no_remainder() {
assert_eq!(
packed_3bit_capacity(TEST_3BIT_GROUPS, false),
TEST_3BIT_GROUPS * PACK_3BIT_BYTES
);
}
#[test]
fn packed_3bit_capacity_with_remainder() {
assert_eq!(
packed_3bit_capacity(TEST_3BIT_GROUPS, true),
TEST_3BIT_GROUPS * PACK_3BIT_BYTES + PACK_3BIT_BYTES
);
}
#[test]
fn packed_3bit_capacity_zero_groups() {
assert_eq!(packed_3bit_capacity(0, false), 0);
assert_eq!(packed_3bit_capacity(0, true), 3);
}
#[test]
fn packed_4bit_capacity_no_remainder() {
assert_eq!(
packed_4bit_capacity(TEST_4BIT_PAIRS, false),
TEST_4BIT_PAIRS
);
}
#[test]
fn packed_4bit_capacity_with_remainder() {
assert_eq!(
packed_4bit_capacity(TEST_4BIT_PAIRS, true),
TEST_4BIT_PAIRS + 1
);
}
#[test]
fn chunk_to_3bit_array_preserves_values() {
let input: Vec<u8> = vec![0, 1, 2, 3, 4, 5, 6, 7];
let arr = chunk_to_3bit_array(&input);
assert_eq!(arr, [0, 1, 2, 3, 4, 5, 6, 7]);
}
#[test]
fn chunk_to_4bit_array_preserves_values() {
let input: Vec<u8> = vec![10, 15];
let arr = chunk_to_4bit_array(&input);
assert_eq!(arr, [10, 15]);
}
#[test]
fn pad_remainder_3bit_pads_correctly() {
let tail: Vec<u8> = vec![1, 2, 3];
let padded = pad_remainder_3bit(&tail);
assert_eq!(padded, [1, 2, 3, 0, 0, 0, 0, 0]);
}
#[test]
fn pad_remainder_3bit_single_element() {
let tail: Vec<u8> = vec![5];
let padded = pad_remainder_3bit(&tail);
assert_eq!(padded, [5, 0, 0, 0, 0, 0, 0, 0]);
}
#[test]
fn trailing_4bit_pair_handles_single_element() {
let pair = trailing_4bit_pair(TEST_TRAILING_VALUE);
assert_eq!(pair, [TEST_TRAILING_VALUE, 0]);
}
#[test]
fn chunk_to_packed_3bit_array_preserves_values() {
let input: Vec<u8> = vec![0xAB, 0xCD, 0xEF];
let arr = chunk_to_packed_3bit_array(&input);
assert_eq!(arr, [0xAB, 0xCD, 0xEF]);
}
#[test]
fn pack_unpack_3bit_identity() {
let values: [u8; PACK_3BIT_GROUP_SIZE] = [0, 1, 2, 3, 4, 5, 6, MAX_3BIT_VALUE];
let packed = pack_3bit(&values);
let unpacked = unpack_3bit(&packed);
assert_eq!(values, unpacked);
}
#[test]
fn pack_unpack_3bit_zeros() {
let values = [0u8; PACK_3BIT_GROUP_SIZE];
assert_eq!(unpack_3bit(&pack_3bit(&values)), values);
}
#[test]
fn pack_unpack_3bit_max() {
let values = [MAX_3BIT_VALUE; PACK_3BIT_GROUP_SIZE];
assert_eq!(unpack_3bit(&pack_3bit(&values)), values);
}
#[test]
fn pack_unpack_4bit_identity() {
let values: [u8; PACK_4BIT_GROUP_SIZE] = [0, MAX_4BIT_VALUE];
let packed = pack_4bit(&values);
let unpacked = unpack_4bit(packed);
assert_eq!(values, unpacked);
}
#[test]
fn pack_unpack_4bit_zeros() {
let values = [0u8; PACK_4BIT_GROUP_SIZE];
assert_eq!(unpack_4bit(pack_4bit(&values)), values);
}
#[test]
fn pack_unpack_4bit_max() {
let values = [MAX_4BIT_VALUE; PACK_4BIT_GROUP_SIZE];
assert_eq!(unpack_4bit(pack_4bit(&values)), values);
}
#[test]
fn roundtrip_3bit_exact_multiple() {
let indices: Vec<u8> = (0..TEST_3BIT_EXACT_COUNT as u8)
.map(|i| i % (MAX_3BIT_VALUE + 1))
.collect();
let packed = pack_indices_3bit(&indices);
let unpacked = unpack_indices_3bit(&packed, indices.len());
assert_eq!(indices, unpacked);
}
#[test]
fn roundtrip_3bit_with_remainder() {
let indices: Vec<u8> = (0..TEST_3BIT_REMAINDER_COUNT as u8)
.map(|i| i % (MAX_3BIT_VALUE + 1))
.collect();
let packed = pack_indices_3bit(&indices);
let unpacked = unpack_indices_3bit(&packed, indices.len());
assert_eq!(indices, unpacked);
}
#[test]
fn roundtrip_4bit_even_count() {
let indices: Vec<u8> = (0..TEST_4BIT_EVEN_COUNT as u8)
.map(|i| i % TEST_4BIT_LEVELS)
.collect();
let packed = pack_indices_4bit(&indices);
let unpacked = unpack_indices_4bit(&packed, indices.len());
assert_eq!(indices, unpacked);
}
#[test]
fn roundtrip_4bit_odd_count() {
let indices: Vec<u8> = (0..TEST_4BIT_ODD_COUNT as u8)
.map(|i| i % TEST_4BIT_LEVELS)
.collect();
let packed = pack_indices_4bit(&indices);
let unpacked = unpack_indices_4bit(&packed, indices.len());
assert_eq!(indices, unpacked);
}
#[test]
fn config_rejects_invalid_bits() {
assert!(TurboQuantConfig::new(1, TEST_BLOCK_SIZE).is_err());
assert!(TurboQuantConfig::new(5, TEST_BLOCK_SIZE).is_err());
}
#[test]
fn config_rejects_non_power_of_two() {
assert!(TurboQuantConfig::new(BITS_TQ3, 33).is_err());
assert!(TurboQuantConfig::new(BITS_TQ4, 0).is_err());
}
#[test]
fn config_accepts_valid() {
assert!(TurboQuantConfig::new(BITS_TQ2, TEST_BLOCK_SIZE).is_ok());
assert!(TurboQuantConfig::new(BITS_TQ3, TEST_BLOCK_SIZE).is_ok());
assert!(TurboQuantConfig::new(BITS_TQ4, TEST_DIM_128).is_ok());
}
const TQ3_D32_EXPECTED_SIZE: usize = SCALE_SIZE_BYTES + 12;
const TQ4_D32_EXPECTED_SIZE: usize = SCALE_SIZE_BYTES + 16;
#[test]
fn packed_block_tq3_size_bytes() {
let indices = vec![0u8; TEST_BLOCK_SIZE];
let block = PackedBlock::new(BITS_TQ3, f16::from_f32(1.0), &indices);
assert_eq!(block.size_bytes(), TQ3_D32_EXPECTED_SIZE);
}
#[test]
fn packed_block_tq4_size_bytes() {
let indices = vec![0u8; TEST_BLOCK_SIZE];
let block = PackedBlock::new(BITS_TQ4, f16::from_f32(1.0), &indices);
assert_eq!(block.size_bytes(), TQ4_D32_EXPECTED_SIZE);
}
#[test]
fn pack_unpack_2bit_identity() {
let values: [u8; PACK_2BIT_GROUP_SIZE] = [0, 1, 2, MAX_2BIT_VALUE];
let packed = pack_2bit(&values);
let unpacked = unpack_2bit(packed);
assert_eq!(values, unpacked);
}
#[test]
fn pack_unpack_2bit_zeros() {
let values = [0u8; PACK_2BIT_GROUP_SIZE];
assert_eq!(unpack_2bit(pack_2bit(&values)), values);
}
#[test]
fn pack_unpack_2bit_max() {
let values = [MAX_2BIT_VALUE; PACK_2BIT_GROUP_SIZE];
assert_eq!(unpack_2bit(pack_2bit(&values)), values);
}
#[test]
fn roundtrip_2bit_exact_multiple() {
let indices: Vec<u8> = (0..TEST_2BIT_EXACT_COUNT as u8)
.map(|i| i % (MAX_2BIT_VALUE + 1))
.collect();
let packed = pack_indices_2bit(&indices);
let unpacked = unpack_indices_2bit(&packed, indices.len());
assert_eq!(indices, unpacked);
}
#[test]
fn roundtrip_2bit_with_remainder() {
let indices: Vec<u8> = (0..TEST_2BIT_REMAINDER_COUNT as u8)
.map(|i| i % (MAX_2BIT_VALUE + 1))
.collect();
let packed = pack_indices_2bit(&indices);
let unpacked = unpack_indices_2bit(&packed, indices.len());
assert_eq!(indices, unpacked);
}
#[test]
fn packed_block_tq2_size_bytes() {
let indices = vec![0u8; TEST_BLOCK_SIZE];
let block = PackedBlock::new(BITS_TQ2, f16::from_f32(1.0), &indices);
assert_eq!(block.size_bytes(), 10);
}
}