turboquant 0.1.1

Implementation of Google's TurboQuant algorithm for vector quantization
Documentation
use serde::{Deserialize, Serialize};

use crate::error::{Result, TurboQuantError};
use crate::qjl::QJL;
use crate::turboquant_mse::TurboQuantMSE;
use crate::utils::{inner_product, norm, validate_unit_vector};

/// Two-stage quantized representation for inner-product-optimal TurboQuant.
///
/// Stage 1: TurboQuantMSE with (b-1) bits → x̃_mse, residual r = x - x̃_mse
/// Stage 2: QJL (1 bit) on the residual r → γ·sign(S·r)
#[derive(Debug, Clone, Serialize, Deserialize)]
#[must_use]
pub struct ProdQuantized {
    /// Codebook indices from Stage 1 (MSE quantization with b-1 bits).
    pub mse_indices: Vec<u8>,
    /// Sign bits from Stage 2 (QJL quantization of residual).
    pub qjl_signs: Vec<bool>,
    /// L2 norm of the residual (γ = ‖r‖₂).
    pub residual_norm: f64,
    /// Total bit width (Stage 1 uses bit_width-1 bits, Stage 2 uses 1 bit).
    pub bit_width: u8,
    /// Original vector dimension.
    pub dim: usize,
}

impl ProdQuantized {
    /// Total storage in bytes.
    pub fn bytes(&self) -> f64 {
        // Stage 1: (bit_width-1) bits per coordinate
        let stage1_bits = self.mse_indices.len() as f64 * (self.bit_width - 1) as f64;
        // Stage 2: 1 bit per coordinate
        let stage2_bits = self.qjl_signs.len() as f64;
        // Residual norm: 32 bits (f32)
        (stage1_bits + stage2_bits + 32.0) / 8.0
    }

    /// Compression ratio vs f32 storage.
    pub fn compression_ratio(&self) -> f64 {
        (self.dim as f64 * 4.0) / self.bytes()
    }
}

/// TurboQuant inner-product-optimal quantizer (two-stage).
///
/// Combines TurboQuantMSE (Stage 1, b-1 bits) with QJL (Stage 2, 1 bit)
/// to achieve near-optimal inner product estimation.
///
/// Theoretical distortion bound for inner product estimation:
///   Δ_IP(b, ‖y‖) = √3 · π² · ‖y‖² / (d · 4^b)
///
/// This is tighter than pure MSE quantization for inner product tasks.
#[derive(Debug, Serialize, Deserialize)]
pub struct TurboQuantProd {
    /// Stage 1: MSE-optimal quantizer using (b-1) bits.
    mse_stage: TurboQuantMSE,
    /// Stage 2: QJL 1-bit quantizer for the residual.
    qjl_stage: QJL,
    pub dim: usize,
    /// Total bits (Stage 1: bit_width-1, Stage 2: 1).
    pub bit_width: u8,
}

impl TurboQuantProd {
    fn validate_quantized(&self, q: &ProdQuantized) -> Result<()> {
        if q.dim != self.dim {
            return Err(TurboQuantError::DimensionMismatch {
                expected: self.dim,
                got: q.dim,
            });
        }
        if q.bit_width != self.bit_width {
            return Err(TurboQuantError::BitWidthMismatch {
                expected: self.bit_width,
                got: q.bit_width,
            });
        }
        if q.mse_indices.len() != self.dim {
            return Err(TurboQuantError::LengthMismatch {
                context: "TurboQuantProd MSE indices".into(),
                expected: self.dim,
                got: q.mse_indices.len(),
            });
        }
        if q.qjl_signs.len() != self.dim {
            return Err(TurboQuantError::LengthMismatch {
                context: "TurboQuantProd QJL signs".into(),
                expected: self.dim,
                got: q.qjl_signs.len(),
            });
        }
        if !q.residual_norm.is_finite() || q.residual_norm < 0.0 {
            return Err(TurboQuantError::InvalidValue {
                context: "TurboQuantProd residual norm".into(),
                value: q.residual_norm,
            });
        }

        Ok(())
    }

    /// Create a new TurboQuant inner-product-optimal quantizer.
    ///
    /// # Arguments
    /// * `dim` - Vector dimension
    /// * `bit_width` - Total bits per coordinate (must be ≥ 2; Stage 1 uses bit_width-1, Stage 2 uses 1)
    /// * `seed` - Random seed (Stage 1 uses seed, Stage 2 uses seed+1)
    pub fn new(dim: usize, bit_width: u8, seed: u64) -> Result<Self> {
        if dim == 0 {
            return Err(TurboQuantError::InvalidDimension(dim));
        }
        if !(2..=8).contains(&bit_width) {
            return Err(TurboQuantError::InvalidBitWidth(bit_width));
        }

        let mse_bits = bit_width - 1;
        let mse_stage = TurboQuantMSE::new(dim, mse_bits, seed)?;
        let qjl_stage = QJL::new(dim, seed + 1)?;

        Ok(Self {
            mse_stage,
            qjl_stage,
            dim,
            bit_width,
        })
    }

    /// Quantize a vector using the two-stage approach.
    ///
    /// Algorithm:
    ///   1. q_mse = TurboQuantMSE.quantize(x)  [with b-1 bits]
    ///   2. x̃_mse = TurboQuantMSE.dequantize(q_mse)
    ///   3. r = x - x̃_mse  (residual)
    ///   4. γ = ‖r‖₂
    ///   5. q_qjl = sign(S · r)  [QJL 1-bit sketch of residual]
    ///   6. Return (q_mse, q_qjl, γ)
    pub fn quantize(&self, x: &[f64]) -> Result<ProdQuantized> {
        if x.len() != self.dim {
            return Err(TurboQuantError::DimensionMismatch {
                expected: self.dim,
                got: x.len(),
            });
        }
        validate_unit_vector(x, "TurboQuantProd input")?;

        // Stage 1: MSE quantization
        let q_mse = self.mse_stage.quantize(x)?;
        let x_approx = self.mse_stage.dequantize(&q_mse)?;

        // Residual
        let residual: Vec<f64> = x
            .iter()
            .zip(x_approx.iter())
            .map(|(xi, xi_a)| xi - xi_a)
            .collect();
        let residual_norm = norm(&residual);

        // Stage 2: QJL on residual
        let q_qjl = self.qjl_stage.quantize(&residual)?;

        Ok(ProdQuantized {
            mse_indices: q_mse.indices,
            qjl_signs: q_qjl.signs,
            residual_norm,
            bit_width: self.bit_width,
            dim: self.dim,
        })
    }

    /// Reconstruct an approximate vector from the two-stage quantized representation.
    ///
    /// Algorithm:
    ///   1. x̃_mse = TurboQuantMSE.dequantize(mse_indices)
    ///   2. x̃_qjl = γ · √(π/2) / d · Sᵀ · z  (QJL reconstruction of residual)
    ///   3. Return x̃_mse + x̃_qjl
    pub fn dequantize(&self, q: &ProdQuantized) -> Result<Vec<f64>> {
        self.validate_quantized(q)?;

        let x_mse = self
            .mse_stage
            .dequantize_parts(&q.mse_indices, q.bit_width - 1, q.dim)?;
        let x_qjl = self
            .qjl_stage
            .dequantize_parts(&q.qjl_signs, q.residual_norm, q.dim)?;

        // Combine
        let result: Vec<f64> = x_mse.iter().zip(x_qjl.iter()).map(|(a, b)| a + b).collect();
        Ok(result)
    }

    /// Estimate ⟨x, y⟩ from the quantized x and exact query y.
    ///
    /// Estimate = ⟨x̃_mse, y⟩ + γ · √(π/2) / d · ⟨z, S·y⟩
    ///
    /// The first term handles the bulk of the signal; the second term
    /// corrects for the residual, improving inner product accuracy.
    pub fn estimate_inner_product(&self, q: &ProdQuantized, y: &[f64]) -> Result<f64> {
        if y.len() != self.dim {
            return Err(TurboQuantError::DimensionMismatch {
                expected: self.dim,
                got: y.len(),
            });
        }
        self.validate_quantized(q)?;

        let x_mse = self
            .mse_stage
            .dequantize_parts(&q.mse_indices, q.bit_width - 1, q.dim)?;
        let ip_mse = inner_product(&x_mse, y);

        let ip_qjl =
            self.qjl_stage
                .estimate_inner_product_parts(&q.qjl_signs, q.residual_norm, q.dim, y)?;

        Ok(ip_mse + ip_qjl)
    }

    /// Theoretical inner product distortion bound.
    ///
    /// E[|⟨x, y⟩ - ⟨x̃, y⟩|²] ≤ √3 · π² · ‖y‖² / (d · 4^b)
    pub fn distortion_bound(&self, y_norm: f64) -> f64 {
        let b = self.bit_width as f64;
        let d = self.dim as f64;
        3.0_f64.sqrt() * std::f64::consts::PI * std::f64::consts::PI * y_norm * y_norm
            / (d * (4.0_f64).powf(b))
    }

    /// Access the stage-1 MSE quantizer used inside the product quantizer.
    pub fn mse_stage(&self) -> &TurboQuantMSE {
        &self.mse_stage
    }

    /// Access the stage-2 QJL sketch used for residual correction.
    pub fn qjl_stage(&self) -> &QJL {
        &self.qjl_stage
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::utils::{inner_product, normalize};

    fn random_unit_vector(dim: usize, seed: u64) -> Vec<f64> {
        use rand::SeedableRng;
        use rand_distr::{Distribution, Normal};
        let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
        let normal = Normal::new(0.0, 1.0).unwrap();
        let x: Vec<f64> = (0..dim).map(|_| normal.sample(&mut rng)).collect();
        normalize(&x).unwrap()
    }

    #[test]
    fn test_quantize_dequantize_shape() {
        let tq = TurboQuantProd::new(64, 3, 42).unwrap();
        let x = random_unit_vector(64, 1);
        let q = tq.quantize(&x).unwrap();
        assert_eq!(q.mse_indices.len(), 64);
        assert_eq!(q.qjl_signs.len(), 64);
        let recon = tq.dequantize(&q).unwrap();
        assert_eq!(recon.len(), 64);
    }

    #[test]
    fn test_inner_product_estimate_quality() {
        let dim = 256;
        let tq = TurboQuantProd::new(dim, 4, 7).unwrap();

        let mut total_err = 0.0;
        let n = 10;
        for seed in 0..n {
            let x = random_unit_vector(dim, seed);
            let y = random_unit_vector(dim, seed + 50);
            let true_ip = inner_product(&x, &y);
            let q = tq.quantize(&x).unwrap();
            let est_ip = tq.estimate_inner_product(&q, &y).unwrap();
            total_err += (true_ip - est_ip).abs();
        }
        let avg_err = total_err / n as f64;
        assert!(
            avg_err < 0.15,
            "avg inner product error {} too large",
            avg_err
        );
    }

    #[test]
    fn test_invalid_bit_width() {
        assert!(TurboQuantProd::new(64, 1, 1).is_err()); // must be >= 2
        assert!(TurboQuantProd::new(64, 9, 1).is_err()); // must be <= 8
    }

    #[test]
    fn test_dimension_mismatch_quantize() {
        let tq = TurboQuantProd::new(64, 3, 1).unwrap();
        let x = vec![0.0; 128]; // wrong dim
        assert!(tq.quantize(&x).is_err());
    }

    #[test]
    fn test_dimension_mismatch_estimate_ip() {
        let tq = TurboQuantProd::new(64, 3, 1).unwrap();
        let x = random_unit_vector(64, 1);
        let q = tq.quantize(&x).unwrap();
        let bad_y = vec![0.0; 128];
        assert!(tq.estimate_inner_product(&q, &bad_y).is_err());
    }

    #[test]
    fn test_distortion_bound() {
        let tq = TurboQuantProd::new(128, 4, 1).unwrap();
        let bound = tq.distortion_bound(1.0);
        assert!(bound > 0.0 && bound < 1.0, "bound={}", bound);
        // Bound should scale with y_norm²
        let bound2 = tq.distortion_bound(2.0);
        assert!(
            (bound2 / bound - 4.0).abs() < 1e-10,
            "bound should scale as y_norm²"
        );
    }

    #[test]
    fn test_dequantize_dimension_mismatch() {
        let tq = TurboQuantProd::new(64, 3, 1).unwrap();
        let bad_q = ProdQuantized {
            mse_indices: vec![0; 128],
            qjl_signs: vec![true; 128],
            residual_norm: 0.1,
            bit_width: 3,
            dim: 128,
        };
        assert!(tq.dequantize(&bad_q).is_err());
    }

    #[test]
    fn test_prod_improves_inner_product_estimate() {
        // TurboQuantProd should have better inner product estimates than MSE-only
        let dim = 128;
        let n_trials = 20;
        let tq_mse = TurboQuantMSE::new(dim, 3, 1).unwrap();
        let tq_prod = TurboQuantProd::new(dim, 3, 1).unwrap();

        let mut mse_ip_err = 0.0;
        let mut prod_ip_err = 0.0;

        for seed in 0..n_trials {
            let x = random_unit_vector(dim, seed);
            let y = random_unit_vector(dim, seed + 100);
            let true_ip = inner_product(&x, &y);

            // MSE-only inner product via reconstruction
            let q_mse = tq_mse.quantize(&x).unwrap();
            let x_mse = tq_mse.dequantize(&q_mse).unwrap();
            mse_ip_err += (true_ip - inner_product(&x_mse, &y)).abs();

            // Prod two-stage inner product estimate
            let q_prod = tq_prod.quantize(&x).unwrap();
            let est_ip = tq_prod.estimate_inner_product(&q_prod, &y).unwrap();
            prod_ip_err += (true_ip - est_ip).abs();
        }

        let avg_mse_err = mse_ip_err / n_trials as f64;
        let avg_prod_err = prod_ip_err / n_trials as f64;
        // Both should be reasonable; the key check is they are finite and small
        assert!(avg_mse_err < 0.5, "MSE IP error too large: {}", avg_mse_err);
        assert!(
            avg_prod_err < 0.5,
            "Prod IP error too large: {}",
            avg_prod_err
        );
    }

    #[test]
    fn test_quantize_rejects_non_unit_vector() {
        let tq = TurboQuantProd::new(8, 3, 1).unwrap();
        let x = vec![1.0; 8];
        assert!(matches!(
            tq.quantize(&x),
            Err(TurboQuantError::NotUnitVector(_))
        ));
    }

    #[test]
    fn test_dequantize_rejects_wrong_sign_count() {
        let tq = TurboQuantProd::new(8, 3, 1).unwrap();
        let q = ProdQuantized {
            mse_indices: vec![0; 8],
            qjl_signs: vec![true; 7],
            residual_norm: 0.1,
            bit_width: 3,
            dim: 8,
        };
        assert!(matches!(
            tq.dequantize(&q),
            Err(TurboQuantError::LengthMismatch { .. })
        ));
    }
}