turboquant 0.1.1

Implementation of Google's TurboQuant algorithm for vector quantization
Documentation
use nalgebra::{DMatrix, DVector};
use rand::SeedableRng;
use rand_distr::{Distribution, Normal};
use serde::{Deserialize, Serialize};

use crate::error::{Result, TurboQuantError};
use crate::utils::{norm, validate_finite_vector};

/// Quantized Johnson-Lindenstrauss (QJL) transform output.
///
/// Stores:
///   - `signs`: the sign bits of S·x, one per projection dimension
///   - `residual_norm`: ‖x‖₂ (or ‖r‖₂ for residuals), needed for reconstruction
///   - `dim`: original vector dimension
#[derive(Debug, Clone, Serialize, Deserialize)]
#[must_use]
pub struct QJLQuantized {
    /// Sign of each projection: true = +1, false = -1.
    pub signs: Vec<bool>,
    /// L2 norm of the quantized vector (γ = ‖x‖₂).
    pub residual_norm: f64,
    /// Original vector dimension.
    pub dim: usize,
}

impl QJLQuantized {
    /// Storage in bits: one bit per sign + 32 bits for the norm.
    pub fn bits(&self) -> usize {
        self.signs.len() + 32
    }
}

/// Quantized Johnson-Lindenstrauss (QJL) transform.
///
/// A 1-bit sketch for inner product estimation. Given a Gaussian random
/// matrix S ∈ ℝ^{d×d}, the QJL quantization of x is:
///
///   q(x) = sign(S · x)  ∈ {-1, +1}^d
///
/// Inner products can be estimated as:
///
///   <x, y> ≈ (π/2) · <q(x), S·y> / d  ·  ‖x‖₂
///
/// This provides a 1-bit-per-coordinate sketch that is optimal for
/// inner product estimation on the unit sphere.
#[derive(Debug, Serialize, Deserialize)]
pub struct QJL {
    /// The d×d Gaussian projection matrix S.
    projection: DMatrix<f64>,
    /// Transpose Sᵀ (cached for reconstruction efficiency).
    projection_t: DMatrix<f64>,
    pub dim: usize,
}

impl QJL {
    fn validate_quantized(&self, q: &QJLQuantized) -> Result<()> {
        if q.dim != self.dim {
            return Err(TurboQuantError::DimensionMismatch {
                expected: self.dim,
                got: q.dim,
            });
        }
        if q.signs.len() != self.dim {
            return Err(TurboQuantError::LengthMismatch {
                context: "QJL sign sketch".into(),
                expected: self.dim,
                got: q.signs.len(),
            });
        }
        if !q.residual_norm.is_finite() || q.residual_norm < 0.0 {
            return Err(TurboQuantError::InvalidValue {
                context: "QJL residual norm".into(),
                value: q.residual_norm,
            });
        }

        Ok(())
    }

    pub(crate) fn dequantize_parts(
        &self,
        signs: &[bool],
        residual_norm: f64,
        dim: usize,
    ) -> Result<Vec<f64>> {
        let q = QJLQuantized {
            signs: signs.to_vec(),
            residual_norm,
            dim,
        };
        self.validate_quantized(&q)?;

        let zv = DVector::from_iterator(
            self.dim,
            signs.iter().copied().map(|s| if s { 1.0 } else { -1.0 }),
        );
        let recon = &self.projection_t * zv;

        let scale = residual_norm * (std::f64::consts::PI / 2.0).sqrt() / self.dim as f64;
        Ok(recon.iter().map(|&v| v * scale).collect())
    }

    pub(crate) fn estimate_inner_product_parts(
        &self,
        signs: &[bool],
        residual_norm: f64,
        dim: usize,
        y: &[f64],
    ) -> Result<f64> {
        if y.len() != self.dim {
            return Err(TurboQuantError::DimensionMismatch {
                expected: self.dim,
                got: y.len(),
            });
        }
        validate_finite_vector(y, "QJL query")?;

        let q = QJLQuantized {
            signs: signs.to_vec(),
            residual_norm,
            dim,
        };
        self.validate_quantized(&q)?;

        let yv = DVector::from_vec(y.to_vec());
        let sy = &self.projection * yv;
        let dot: f64 = signs
            .iter()
            .zip(sy.iter())
            .map(|(&sign, syi)| if sign { *syi } else { -*syi })
            .sum();

        let scale = residual_norm * (std::f64::consts::PI / 2.0).sqrt() / self.dim as f64;
        Ok(scale * dot)
    }

    /// Create a new QJL sketcher for dimension `dim`.
    ///
    /// The projection matrix S has i.i.d. N(0,1) entries (not normalized;
    /// the normalization is absorbed into the estimation formula).
    pub fn new(dim: usize, seed: u64) -> Result<Self> {
        if dim == 0 {
            return Err(TurboQuantError::InvalidDimension(dim));
        }

        let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
        let normal = Normal::new(0.0, 1.0).unwrap();

        let data: Vec<f64> = (0..dim * dim).map(|_| normal.sample(&mut rng)).collect();
        let projection = DMatrix::from_vec(dim, dim, data);
        let projection_t = projection.transpose();

        Ok(Self {
            projection,
            projection_t,
            dim,
        })
    }

    /// Compute the QJL sketch: q = sign(S · x).
    ///
    /// # Returns
    /// `QJLQuantized` containing sign bits and ‖x‖₂.
    pub fn quantize(&self, x: &[f64]) -> Result<QJLQuantized> {
        if x.len() != self.dim {
            return Err(TurboQuantError::DimensionMismatch {
                expected: self.dim,
                got: x.len(),
            });
        }
        validate_finite_vector(x, "QJL input")?;

        let xv = DVector::from_vec(x.to_vec());
        let sx = &self.projection * xv;

        // sign(S·x): true = +1, false = -1
        let signs: Vec<bool> = sx.iter().map(|&v| v >= 0.0).collect();
        let residual_norm = norm(x);

        Ok(QJLQuantized {
            signs,
            residual_norm,
            dim: self.dim,
        })
    }

    /// Approximate reconstruction from the QJL sketch.
    ///
    /// x̃ = γ · √(π/2) / d · Sᵀ · z
    ///
    /// where z_j = +1 if signs[j] else -1, and γ = ‖x‖₂.
    ///
    /// This is not a perfect reconstruction but provides an approximation
    /// useful for inner product estimation.
    pub fn dequantize(&self, q: &QJLQuantized) -> Result<Vec<f64>> {
        self.dequantize_parts(&q.signs, q.residual_norm, q.dim)
    }

    /// Estimate ⟨x, y⟩ from the QJL sketch of x and query vector y.
    ///
    /// Estimate = γ · √(π/2) / d · ⟨z, S·y⟩
    ///          = γ · √(π/2) / d · Σ_j z_j · (S·y)_j
    ///
    /// For unit vectors this is unbiased: E[estimate] = ⟨x, y⟩.
    pub fn estimate_inner_product(&self, q: &QJLQuantized, y: &[f64]) -> Result<f64> {
        self.estimate_inner_product_parts(&q.signs, q.residual_norm, q.dim, y)
    }

    /// Access the projection matrix (for internal use by TurboQuantProd).
    pub fn projection(&self) -> &DMatrix<f64> {
        &self.projection
    }

    pub fn projection_t(&self) -> &DMatrix<f64> {
        &self.projection_t
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::utils::{inner_product, normalize};

    fn random_unit_vector(dim: usize, seed: u64) -> Vec<f64> {
        use rand::SeedableRng;
        use rand_distr::{Distribution, Normal};
        let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
        let normal = Normal::new(0.0, 1.0).unwrap();
        let x: Vec<f64> = (0..dim).map(|_| normal.sample(&mut rng)).collect();
        normalize(&x).unwrap()
    }

    #[test]
    fn test_quantize_output_shape() {
        let qjl = QJL::new(32, 42).unwrap();
        let x = random_unit_vector(32, 1);
        let q = qjl.quantize(&x).unwrap();
        assert_eq!(q.signs.len(), 32);
        assert!((q.residual_norm - 1.0).abs() < 1e-10);
    }

    #[test]
    fn test_inner_product_estimation_unbiased() {
        // For large dim, the estimate should be close to true inner product
        let dim = 512;
        let qjl = QJL::new(dim, 0).unwrap();

        let mut total_err = 0.0;
        let n_trials = 10;

        for seed in 0..n_trials {
            let x = random_unit_vector(dim, seed);
            let y = random_unit_vector(dim, seed + 100);
            let true_ip = inner_product(&x, &y);
            let q = qjl.quantize(&x).unwrap();
            let est_ip = qjl.estimate_inner_product(&q, &y).unwrap();
            total_err += (true_ip - est_ip).abs();
        }

        let avg_err = total_err / n_trials as f64;
        // Average error should be well under 0.3 for dim=512
        assert!(
            avg_err < 0.3,
            "average inner product error {} too large",
            avg_err
        );
    }

    #[test]
    fn test_dequantize_shape() {
        let dim = 16;
        let qjl = QJL::new(dim, 1).unwrap();
        let x = random_unit_vector(dim, 5);
        let q = qjl.quantize(&x).unwrap();
        let recon = qjl.dequantize(&q).unwrap();
        assert_eq!(recon.len(), dim);
    }

    #[test]
    fn test_dimension_mismatch() {
        let qjl = QJL::new(16, 1).unwrap();
        let x = vec![0.0; 32]; // wrong dim
        assert!(qjl.quantize(&x).is_err());
    }

    #[test]
    fn test_invalid_dimension_zero() {
        assert!(QJL::new(0, 1).is_err());
    }

    #[test]
    fn test_dequantize_dimension_mismatch() {
        let qjl = QJL::new(16, 1).unwrap();
        let bad_q = QJLQuantized {
            signs: vec![true; 32],
            residual_norm: 1.0,
            dim: 32,
        };
        assert!(qjl.dequantize(&bad_q).is_err());
    }

    #[test]
    fn test_estimate_inner_product_dimension_mismatch() {
        let qjl = QJL::new(16, 1).unwrap();
        let x = random_unit_vector(16, 1);
        let q = qjl.quantize(&x).unwrap();

        // Wrong query dimension
        let bad_y = vec![0.0; 32];
        assert!(qjl.estimate_inner_product(&q, &bad_y).is_err());

        // Wrong quantized dimension
        let bad_q = QJLQuantized {
            signs: vec![true; 32],
            residual_norm: 1.0,
            dim: 32,
        };
        let y = random_unit_vector(16, 2);
        assert!(qjl.estimate_inner_product(&bad_q, &y).is_err());
    }

    #[test]
    fn test_qjl_bits_count() {
        let qjl = QJL::new(64, 1).unwrap();
        let x = random_unit_vector(64, 1);
        let q = qjl.quantize(&x).unwrap();
        // 64 sign bits + 32 bits for norm
        assert_eq!(q.bits(), 96);
    }

    #[test]
    fn test_dequantize_rejects_wrong_sign_count() {
        let qjl = QJL::new(8, 1).unwrap();
        let q = QJLQuantized {
            signs: vec![true; 7],
            residual_norm: 1.0,
            dim: 8,
        };
        assert!(matches!(
            qjl.dequantize(&q),
            Err(TurboQuantError::LengthMismatch { .. })
        ));
    }
}