turboquant 0.1.1

Implementation of Google's TurboQuant algorithm for vector quantization
Documentation
use nalgebra::{DMatrix, DVector};
use rand::SeedableRng;
use rand_distr::{Distribution, Normal};
use serde::{Deserialize, Serialize};

use crate::error::{Result, TurboQuantError};

/// A random orthogonal rotation matrix Π, generated via QR decomposition
/// of a random Gaussian matrix. Rotation preserves norms and is the key
/// step in TurboQuant to decorrelate and uniformly distribute coordinates
/// before scalar quantization.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RandomRotation {
    /// The d×d orthogonal matrix Π (stored column-major via nalgebra).
    matrix: DMatrix<f64>,
    /// Transpose Πᵀ for the inverse rotation (cached for efficiency).
    matrix_t: DMatrix<f64>,
    pub dim: usize,
}

impl RandomRotation {
    /// Create a new random orthogonal rotation matrix for dimension `dim`.
    ///
    /// Algorithm:
    /// 1. Fill a d×d matrix G with i.i.d. N(0,1) entries.
    /// 2. Compute QR decomposition: G = QR.
    /// 3. Use Q as the rotation matrix (Q is orthogonal by construction).
    pub fn new(dim: usize, seed: u64) -> Result<Self> {
        if dim == 0 {
            return Err(TurboQuantError::InvalidDimension(dim));
        }

        let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
        let normal = Normal::new(0.0, 1.0).unwrap();

        // Fill d×d Gaussian matrix (column-major: data[col*dim + row])
        let data: Vec<f64> = (0..dim * dim).map(|_| normal.sample(&mut rng)).collect();
        let g = DMatrix::from_vec(dim, dim, data);

        // QR decomposition via nalgebra's Householder QR
        let qr = g.qr();
        let q = qr.q();

        // Make the QR decomposition unique by ensuring diagonal of R is positive.
        // This prevents reflections and ensures a proper rotation.
        let r = qr.r();
        let mut signs = vec![1.0f64; dim];
        for i in 0..dim {
            if r[(i, i)] < 0.0 {
                signs[i] = -1.0;
            }
        }
        // Multiply each column of Q by the sign
        let sign_diag = DMatrix::from_diagonal(&DVector::from_vec(signs));
        let matrix = q * sign_diag;
        let matrix_t = matrix.transpose();

        Ok(Self {
            matrix,
            matrix_t,
            dim,
        })
    }

    /// Apply the rotation: y = Π · x
    pub fn rotate(&self, x: &[f64]) -> Result<Vec<f64>> {
        if x.len() != self.dim {
            return Err(TurboQuantError::DimensionMismatch {
                expected: self.dim,
                got: x.len(),
            });
        }
        let xv = DVector::from_vec(x.to_vec());
        let yv = &self.matrix * xv;
        Ok(yv.data.into())
    }

    /// Apply the inverse rotation: x = Πᵀ · y
    pub fn rotate_inverse(&self, y: &[f64]) -> Result<Vec<f64>> {
        if y.len() != self.dim {
            return Err(TurboQuantError::DimensionMismatch {
                expected: self.dim,
                got: y.len(),
            });
        }
        let yv = DVector::from_vec(y.to_vec());
        let xv = &self.matrix_t * yv;
        Ok(xv.data.into())
    }

    /// Access the raw matrix (for testing / inspection).
    pub fn matrix(&self) -> &DMatrix<f64> {
        &self.matrix
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use approx::assert_abs_diff_eq;

    #[test]
    fn test_rotation_orthogonality() {
        let dim = 8;
        let rot = RandomRotation::new(dim, 42).unwrap();
        // Q · Qᵀ should be identity
        let prod = &rot.matrix * &rot.matrix_t;
        for i in 0..dim {
            for j in 0..dim {
                let expected = if i == j { 1.0 } else { 0.0 };
                assert_abs_diff_eq!(prod[(i, j)], expected, epsilon = 1e-10);
            }
        }
    }

    #[test]
    fn test_rotation_inverse() {
        let dim = 16;
        let rot = RandomRotation::new(dim, 123).unwrap();
        let x: Vec<f64> = (0..dim).map(|i| i as f64 / dim as f64).collect();
        let y = rot.rotate(&x).unwrap();
        let x_recovered = rot.rotate_inverse(&y).unwrap();
        for (a, b) in x.iter().zip(x_recovered.iter()) {
            assert_abs_diff_eq!(a, b, epsilon = 1e-10);
        }
    }

    #[test]
    fn test_rotation_preserves_norm() {
        use crate::utils::norm;
        let dim = 32;
        let rot = RandomRotation::new(dim, 7).unwrap();
        let x: Vec<f64> = (0..dim).map(|i| (i as f64).sin()).collect();
        let nx = norm(&x);
        let y = rot.rotate(&x).unwrap();
        let ny = norm(&y);
        assert_abs_diff_eq!(nx, ny, epsilon = 1e-8);
    }

    #[test]
    fn test_dimension_mismatch() {
        let rot = RandomRotation::new(4, 1).unwrap();
        let x = vec![1.0, 2.0]; // wrong size
        assert!(rot.rotate(&x).is_err());
    }

    #[test]
    fn test_invalid_dimension_zero() {
        assert!(RandomRotation::new(0, 1).is_err());
    }

    #[test]
    fn test_rotate_inverse_dimension_mismatch() {
        let rot = RandomRotation::new(4, 1).unwrap();
        let y = vec![1.0, 2.0, 3.0]; // wrong size
        assert!(rot.rotate_inverse(&y).is_err());
    }
}