turboquant 0.1.1

Implementation of Google's TurboQuant algorithm for vector quantization
Documentation
use serde::{Deserialize, Serialize};

use crate::error::{Result, TurboQuantError};

/// A bit-packed representation of quantization indices.
///
/// Stores `count` indices of `bit_width` bits each in a compact byte array.
/// For example, 128 indices at 4 bits = 64 bytes instead of 128 bytes.
///
/// Packing order: indices are packed MSB-first within each byte, with the
/// first index occupying the highest bits of the first byte.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct BitPackedVector {
    /// Packed byte storage.
    data: Vec<u8>,
    /// Number of indices stored.
    count: usize,
    /// Bits per index (1-8).
    bit_width: u8,
}

impl BitPackedVector {
    /// Pack a slice of indices into a compact bit-packed representation.
    ///
    /// # Arguments
    /// * `indices` - Quantization indices (each must fit in `bit_width` bits)
    /// * `bit_width` - Bits per index (1-8)
    ///
    /// # Errors
    /// Returns `InvalidBitWidth` if bit_width is not in 1..=8.
    pub fn pack(indices: &[u8], bit_width: u8) -> Result<Self> {
        if !(1..=8).contains(&bit_width) {
            return Err(TurboQuantError::InvalidBitWidth(bit_width));
        }

        let count = indices.len();
        let total_bits = count * bit_width as usize;
        let num_bytes = total_bits.div_ceil(8);
        let mut data = vec![0u8; num_bytes];

        let mask = (1u16 << bit_width) - 1;

        for (i, &idx) in indices.iter().enumerate() {
            if (idx as u16) > mask {
                return Err(TurboQuantError::InvalidQuantizationIndex {
                    index: idx,
                    max: mask as u8,
                    bit_width,
                });
            }

            let val = (idx as u16) & mask;
            let bit_offset = i * bit_width as usize;
            let byte_idx = bit_offset / 8;
            let bit_idx = bit_offset % 8;

            // The value may span two bytes
            let shifted = val << (16 - bit_width as u16 - bit_idx as u16);
            let hi = (shifted >> 8) as u8;
            let lo = (shifted & 0xFF) as u8;

            data[byte_idx] |= hi;
            if byte_idx + 1 < num_bytes {
                data[byte_idx + 1] |= lo;
            }
        }

        Ok(Self {
            data,
            count,
            bit_width,
        })
    }

    /// Unpack indices from the bit-packed representation.
    pub fn unpack(&self) -> Vec<u8> {
        let mut indices = Vec::with_capacity(self.count);
        let mask = (1u16 << self.bit_width) - 1;

        for i in 0..self.count {
            let bit_offset = i * self.bit_width as usize;
            let byte_idx = bit_offset / 8;
            let bit_idx = bit_offset % 8;

            // Read two bytes (handling end-of-buffer)
            let hi = self.data[byte_idx] as u16;
            let lo = if byte_idx + 1 < self.data.len() {
                self.data[byte_idx + 1] as u16
            } else {
                0
            };
            let combined = (hi << 8) | lo;

            let shift = 16 - self.bit_width as u16 - bit_idx as u16;
            let val = ((combined >> shift) & mask) as u8;
            indices.push(val);
        }

        indices
    }

    /// Number of bytes used for storage.
    pub fn byte_len(&self) -> usize {
        self.data.len()
    }

    /// Number of indices stored.
    pub fn count(&self) -> usize {
        self.count
    }

    /// Bits per index.
    pub fn bit_width(&self) -> u8 {
        self.bit_width
    }

    /// Raw packed bytes (for serialization or I/O).
    pub fn as_bytes(&self) -> &[u8] {
        &self.data
    }

    /// Compression ratio vs storing each index as a full byte.
    pub fn compression_ratio_vs_u8(&self) -> f64 {
        if self.data.is_empty() {
            return 1.0;
        }
        self.count as f64 / self.data.len() as f64
    }

    /// Compression ratio vs storing each value as f32.
    pub fn compression_ratio_vs_f32(&self) -> f64 {
        if self.data.is_empty() {
            return 1.0;
        }
        (self.count as f64 * 4.0) / self.data.len() as f64
    }
}

/// Pack a vector of sign bits (booleans) into bytes.
///
/// Each bool becomes 1 bit; true = 1, false = 0. Packed MSB-first.
pub fn pack_signs(signs: &[bool]) -> Vec<u8> {
    let num_bytes = signs.len().div_ceil(8);
    let mut data = vec![0u8; num_bytes];

    for (i, &s) in signs.iter().enumerate() {
        if s {
            let byte_idx = i / 8;
            let bit_idx = 7 - (i % 8);
            data[byte_idx] |= 1 << bit_idx;
        }
    }

    data
}

/// Unpack sign bits from packed bytes.
pub fn unpack_signs(data: &[u8], count: usize) -> Vec<bool> {
    let mut signs = Vec::with_capacity(count);
    for i in 0..count {
        let byte_idx = i / 8;
        let bit_idx = 7 - (i % 8);
        let bit = (data[byte_idx] >> bit_idx) & 1;
        signs.push(bit == 1);
    }
    signs
}

/// Compute the packed byte size for `count` indices at `bit_width` bits each.
pub fn packed_byte_size(count: usize, bit_width: u8) -> usize {
    (count * bit_width as usize).div_ceil(8)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_pack_unpack_4bit() {
        let indices: Vec<u8> = (0..16).collect();
        let packed = BitPackedVector::pack(&indices, 4).unwrap();
        assert_eq!(packed.byte_len(), 8); // 16 * 4 bits = 64 bits = 8 bytes
        let unpacked = packed.unpack();
        assert_eq!(unpacked, indices);
    }

    #[test]
    fn test_pack_unpack_1bit() {
        let indices = vec![0, 1, 1, 0, 1, 0, 0, 1, 1];
        let packed = BitPackedVector::pack(&indices, 1).unwrap();
        assert_eq!(packed.byte_len(), 2); // 9 bits → 2 bytes
        let unpacked = packed.unpack();
        assert_eq!(unpacked, indices);
    }

    #[test]
    fn test_pack_unpack_2bit() {
        let indices = vec![0, 1, 2, 3, 3, 2, 1, 0];
        let packed = BitPackedVector::pack(&indices, 2).unwrap();
        assert_eq!(packed.byte_len(), 2); // 8 * 2 bits = 16 bits = 2 bytes
        let unpacked = packed.unpack();
        assert_eq!(unpacked, indices);
    }

    #[test]
    fn test_pack_unpack_3bit() {
        let indices = vec![0, 1, 2, 3, 4, 5, 6, 7];
        let packed = BitPackedVector::pack(&indices, 3).unwrap();
        assert_eq!(packed.byte_len(), 3); // 8 * 3 bits = 24 bits = 3 bytes
        let unpacked = packed.unpack();
        assert_eq!(unpacked, indices);
    }

    #[test]
    fn test_pack_unpack_8bit() {
        let indices: Vec<u8> = (0..=255).collect();
        let packed = BitPackedVector::pack(&indices, 8).unwrap();
        assert_eq!(packed.byte_len(), 256);
        let unpacked = packed.unpack();
        assert_eq!(unpacked, indices);
    }

    #[test]
    fn test_pack_unpack_large_4bit() {
        // Simulate a typical quantized vector: 512 dims at 4 bits
        let indices: Vec<u8> = (0..512).map(|i| (i % 16) as u8).collect();
        let packed = BitPackedVector::pack(&indices, 4).unwrap();
        assert_eq!(packed.byte_len(), 256); // 512 * 4 / 8
        let unpacked = packed.unpack();
        assert_eq!(unpacked, indices);
    }

    #[test]
    fn test_invalid_bit_width() {
        assert!(BitPackedVector::pack(&[0, 1], 0).is_err());
        assert!(BitPackedVector::pack(&[0, 1], 9).is_err());
    }

    #[test]
    fn test_empty_vector() {
        let packed = BitPackedVector::pack(&[], 4).unwrap();
        assert_eq!(packed.byte_len(), 0);
        assert_eq!(packed.count(), 0);
        assert_eq!(packed.unpack(), Vec::<u8>::new());
    }

    #[test]
    fn test_compression_ratio() {
        let indices: Vec<u8> = vec![0; 128];
        let packed = BitPackedVector::pack(&indices, 4).unwrap();
        // 128 indices * 4 bits = 64 bytes; vs 128 * 4 = 512 bytes f32
        assert!((packed.compression_ratio_vs_f32() - 8.0).abs() < 0.01);
    }

    #[test]
    fn test_pack_unpack_signs() {
        let signs = vec![true, false, true, true, false, false, true, false, true];
        let packed = pack_signs(&signs);
        assert_eq!(packed.len(), 2); // 9 bits → 2 bytes
        let unpacked = unpack_signs(&packed, signs.len());
        assert_eq!(unpacked, signs);
    }

    #[test]
    fn test_pack_signs_empty() {
        let signs: Vec<bool> = vec![];
        let packed = pack_signs(&signs);
        assert!(packed.is_empty());
        let unpacked = unpack_signs(&packed, 0);
        assert!(unpacked.is_empty());
    }

    #[test]
    fn test_packed_byte_size() {
        assert_eq!(packed_byte_size(128, 4), 64);
        assert_eq!(packed_byte_size(128, 2), 32);
        assert_eq!(packed_byte_size(128, 1), 16);
        assert_eq!(packed_byte_size(100, 3), 38); // 300 bits → 38 bytes
        assert_eq!(packed_byte_size(0, 4), 0);
    }

    #[test]
    fn test_pack_unpack_5bit() {
        // 5-bit: values 0..31
        let indices: Vec<u8> = (0..32).collect();
        let packed = BitPackedVector::pack(&indices, 5).unwrap();
        assert_eq!(packed.byte_len(), 20); // 32 * 5 = 160 bits = 20 bytes
        let unpacked = packed.unpack();
        assert_eq!(unpacked, indices);
    }

    #[test]
    fn test_pack_rejects_overflow() {
        let indices = vec![255u8]; // 8-bit max, but packing at 4 bits
        assert!(BitPackedVector::pack(&indices, 4).is_err());
    }
}