turboquant 0.1.1

Implementation of Google's TurboQuant algorithm for vector quantization
Documentation
use crate::backend::{sum_squared_error, ExecutionBackend};
use serde::{Deserialize, Serialize};

use crate::codebook::{generate_codebook, Codebook};
use crate::error::{Result, TurboQuantError};
use crate::rotation::RandomRotation;
use crate::scalar_quant::ScalarQuantizer;
use crate::utils::validate_unit_vector;

/// A quantized vector holding per-coordinate indices and metadata.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[must_use]
pub struct QuantizedVector {
    /// Codebook index for each coordinate after rotation.
    pub indices: Vec<u8>,
    /// Bit width used for quantization.
    pub bit_width: u8,
    /// Original (pre-quantization) dimension.
    pub dim: usize,
}

impl QuantizedVector {
    /// Storage size in bytes (bit-packed).
    pub fn bytes(&self) -> f64 {
        (self.indices.len() as f64 * self.bit_width as f64) / 8.0
    }

    /// Compression ratio relative to f32 storage.
    pub fn compression_ratio(&self) -> f64 {
        (self.dim as f64 * 4.0) / self.bytes()
    }
}

/// TurboQuant MSE-optimal quantizer.
///
/// This is Stage 1 of TurboQuant. It achieves the optimal rate-distortion
/// trade-off for MSE by:
///   1. Applying a random orthogonal rotation Π to decorrelate the vector.
///   2. Applying a Lloyd-Max optimal scalar quantizer to each rotated coordinate.
///
/// The rotation ensures that the coordinates of unit-sphere vectors are i.i.d.
/// under the Beta distribution, making scalar quantization near-optimal.
///
/// Theoretical MSE bound: Δ_MSE(b) ≈ √3·π/2 · 1/4^b
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TurboQuantMSE {
    rotation: RandomRotation,
    quantizer: ScalarQuantizer,
    pub dim: usize,
    pub bit_width: u8,
}

impl TurboQuantMSE {
    fn validate_quantized(&self, q: &QuantizedVector) -> Result<()> {
        if q.dim != self.dim {
            return Err(TurboQuantError::DimensionMismatch {
                expected: self.dim,
                got: q.dim,
            });
        }
        if q.bit_width != self.bit_width {
            return Err(TurboQuantError::BitWidthMismatch {
                expected: self.bit_width,
                got: q.bit_width,
            });
        }
        if q.indices.len() != self.dim {
            return Err(TurboQuantError::LengthMismatch {
                context: "TurboQuantMSE quantized vector indices".into(),
                expected: self.dim,
                got: q.indices.len(),
            });
        }

        self.quantizer.validate_indices(&q.indices)
    }

    pub(crate) fn dequantize_parts(
        &self,
        indices: &[u8],
        bit_width: u8,
        dim: usize,
    ) -> Result<Vec<f64>> {
        let q = QuantizedVector {
            indices: indices.to_vec(),
            bit_width,
            dim,
        };
        self.validate_quantized(&q)?;

        let y_approx = self.quantizer.dequantize_batch(indices);
        self.rotation.rotate_inverse(&y_approx)
    }

    /// Create a new TurboQuant MSE quantizer.
    ///
    /// # Arguments
    /// * `dim` - Vector dimension
    /// * `bit_width` - Bits per coordinate (1-8)
    /// * `seed` - Random seed for the rotation matrix
    pub fn new(dim: usize, bit_width: u8, seed: u64) -> Result<Self> {
        if dim == 0 {
            return Err(TurboQuantError::InvalidDimension(dim));
        }
        if !(1..=8).contains(&bit_width) {
            return Err(TurboQuantError::InvalidBitWidth(bit_width));
        }

        let rotation = RandomRotation::new(dim, seed)?;
        let codebook = generate_codebook(dim, bit_width, 100)?;
        let quantizer = ScalarQuantizer::from_codebook(codebook);

        Ok(Self {
            rotation,
            quantizer,
            dim,
            bit_width,
        })
    }

    /// Create from a pre-computed codebook (useful when sharing codebooks).
    pub fn with_codebook(dim: usize, codebook: Codebook, seed: u64) -> Result<Self> {
        let bit_width = codebook.bit_width;
        let rotation = RandomRotation::new(dim, seed)?;
        let quantizer = ScalarQuantizer::from_codebook(codebook);
        Ok(Self {
            rotation,
            quantizer,
            dim,
            bit_width,
        })
    }

    /// Quantize a vector to its compressed representation.
    ///
    /// Algorithm:
    ///   1. y = Π · x  (random orthogonal rotation)
    ///   2. For each y_j: idx_j = argmin_k |y_j - c_k|  (scalar quantize)
    pub fn quantize(&self, x: &[f64]) -> Result<QuantizedVector> {
        if x.len() != self.dim {
            return Err(TurboQuantError::DimensionMismatch {
                expected: self.dim,
                got: x.len(),
            });
        }
        validate_unit_vector(x, "TurboQuantMSE input")?;

        // Step 1: Rotate
        let y = self.rotation.rotate(x)?;

        // Step 2: Scalar quantize each coordinate
        let indices = self.quantizer.quantize_batch(&y);

        Ok(QuantizedVector {
            indices,
            bit_width: self.bit_width,
            dim: self.dim,
        })
    }

    /// Reconstruct an approximate vector from its quantized representation.
    ///
    /// Algorithm:
    ///   1. For each idx_j: ỹ_j = c_{idx_j}  (centroid lookup)
    ///   2. x̃ = Πᵀ · ỹ  (inverse rotation)
    pub fn dequantize(&self, q: &QuantizedVector) -> Result<Vec<f64>> {
        self.dequantize_parts(&q.indices, q.bit_width, q.dim)
    }

    /// Theoretical MSE distortion bound from the TurboQuant paper.
    ///
    /// Δ_MSE(b) = √3 · π/2 · (1/4)^b
    ///
    /// This is the expected ‖x - x̃‖² / d for large dimension d.
    pub fn distortion_bound(&self) -> f64 {
        let b = self.bit_width as f64;
        3.0_f64.sqrt() * std::f64::consts::PI / 2.0 * (0.25_f64).powf(b)
    }

    /// Measure actual MSE on a single vector (for benchmarking).
    pub fn actual_mse(&self, x: &[f64]) -> Result<f64> {
        let q = self.quantize(x)?;
        let x_approx = self.dequantize(&q)?;
        let mse = sum_squared_error(ExecutionBackend::default(), x, &x_approx) / self.dim as f64;
        Ok(mse)
    }

    /// Access the codebook (for sharing with other components).
    pub fn codebook(&self) -> &Codebook {
        &self.quantizer.codebook
    }

    #[cfg_attr(not(feature = "gpu"), allow(dead_code))]
    pub(crate) fn rotation(&self) -> &RandomRotation {
        &self.rotation
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::utils::normalize;

    fn random_unit_vector(dim: usize, seed: u64) -> Vec<f64> {
        use rand::SeedableRng;
        use rand_distr::{Distribution, Normal};
        let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
        let normal = Normal::new(0.0, 1.0).unwrap();
        let x: Vec<f64> = (0..dim).map(|_| normal.sample(&mut rng)).collect();
        normalize(&x).unwrap()
    }

    #[test]
    fn test_quantize_dequantize_shape() {
        let tq = TurboQuantMSE::new(64, 4, 42).unwrap();
        let x = random_unit_vector(64, 1);
        let q = tq.quantize(&x).unwrap();
        assert_eq!(q.indices.len(), 64);
        assert_eq!(q.bit_width, 4);
        let x_approx = tq.dequantize(&q).unwrap();
        assert_eq!(x_approx.len(), 64);
    }

    #[test]
    fn test_mse_decreases_with_bits() {
        let dim = 256;
        let x = random_unit_vector(dim, 999);
        let mse2 = TurboQuantMSE::new(dim, 2, 42)
            .unwrap()
            .actual_mse(&x)
            .unwrap();
        let mse4 = TurboQuantMSE::new(dim, 4, 42)
            .unwrap()
            .actual_mse(&x)
            .unwrap();
        assert!(
            mse4 < mse2,
            "4-bit MSE {} should < 2-bit MSE {}",
            mse4,
            mse2
        );
    }

    #[test]
    fn test_distortion_bound() {
        let tq = TurboQuantMSE::new(512, 4, 42).unwrap();
        let bound = tq.distortion_bound();
        // For b=4: √3·π/2 / 4^4 = 2.72.../256 ≈ 0.0106
        assert!(bound > 0.0 && bound < 0.1, "bound={}", bound);
    }

    #[test]
    fn test_dimension_mismatch() {
        let tq = TurboQuantMSE::new(64, 2, 1).unwrap();
        let x = vec![0.0; 128];
        assert!(tq.quantize(&x).is_err());
    }

    #[test]
    fn test_compression_ratio() {
        let tq = TurboQuantMSE::new(128, 4, 1).unwrap();
        let x = random_unit_vector(128, 2);
        let q = tq.quantize(&x).unwrap();
        // 4-bit quantization → 8x compression over f32
        assert!((q.compression_ratio() - 8.0).abs() < 0.01);
    }

    #[test]
    fn test_with_codebook() {
        let dim = 64;
        let tq_orig = TurboQuantMSE::new(dim, 4, 42).unwrap();
        let codebook = tq_orig.codebook().clone();

        // Create a new quantizer sharing the same codebook but different seed
        let tq_shared = TurboQuantMSE::with_codebook(dim, codebook, 99).unwrap();
        assert_eq!(tq_shared.dim, dim);
        assert_eq!(tq_shared.bit_width, 4);

        // Should produce valid quantization
        let x = random_unit_vector(dim, 1);
        let q = tq_shared.quantize(&x).unwrap();
        assert_eq!(q.indices.len(), dim);
        let recon = tq_shared.dequantize(&q).unwrap();
        assert_eq!(recon.len(), dim);
    }

    #[test]
    fn test_dequantize_dimension_mismatch() {
        let tq = TurboQuantMSE::new(64, 4, 42).unwrap();
        let wrong_q = QuantizedVector {
            indices: vec![0; 128],
            bit_width: 4,
            dim: 128,
        };
        assert!(tq.dequantize(&wrong_q).is_err());
    }

    #[test]
    fn test_invalid_dimension_zero() {
        assert!(TurboQuantMSE::new(0, 4, 1).is_err());
    }

    #[test]
    fn test_invalid_bit_width() {
        assert!(TurboQuantMSE::new(64, 0, 1).is_err());
        assert!(TurboQuantMSE::new(64, 9, 1).is_err());
    }

    #[test]
    fn test_quantize_rejects_non_unit_vector() {
        let tq = TurboQuantMSE::new(8, 4, 1).unwrap();
        let x = vec![1.0; 8];
        assert!(matches!(
            tq.quantize(&x),
            Err(TurboQuantError::NotUnitVector(_))
        ));
    }

    #[test]
    fn test_dequantize_rejects_invalid_index() {
        let tq = TurboQuantMSE::new(8, 2, 1).unwrap();
        let q = QuantizedVector {
            indices: vec![4; 8],
            bit_width: 2,
            dim: 8,
        };
        assert!(matches!(
            tq.dequantize(&q),
            Err(TurboQuantError::InvalidQuantizationIndex { .. })
        ));
    }
}