use nalgebra::{DMatrix, DVector};
use rand::SeedableRng;
use rand_distr::{Distribution, Normal};
use serde::{Deserialize, Serialize};
use crate::error::{Result, TurboQuantError};
use crate::utils::{norm, validate_finite_vector};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[must_use]
pub struct QJLQuantized {
pub signs: Vec<bool>,
pub residual_norm: f64,
pub dim: usize,
}
impl QJLQuantized {
pub fn bits(&self) -> usize {
self.signs.len() + 32
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct QJL {
projection: DMatrix<f64>,
projection_t: DMatrix<f64>,
pub dim: usize,
}
impl QJL {
fn validate_quantized(&self, q: &QJLQuantized) -> Result<()> {
if q.dim != self.dim {
return Err(TurboQuantError::DimensionMismatch {
expected: self.dim,
got: q.dim,
});
}
if q.signs.len() != self.dim {
return Err(TurboQuantError::LengthMismatch {
context: "QJL sign sketch".into(),
expected: self.dim,
got: q.signs.len(),
});
}
if !q.residual_norm.is_finite() || q.residual_norm < 0.0 {
return Err(TurboQuantError::InvalidValue {
context: "QJL residual norm".into(),
value: q.residual_norm,
});
}
Ok(())
}
pub(crate) fn dequantize_parts(
&self,
signs: &[bool],
residual_norm: f64,
dim: usize,
) -> Result<Vec<f64>> {
let q = QJLQuantized {
signs: signs.to_vec(),
residual_norm,
dim,
};
self.validate_quantized(&q)?;
let zv = DVector::from_iterator(
self.dim,
signs.iter().copied().map(|s| if s { 1.0 } else { -1.0 }),
);
let recon = &self.projection_t * zv;
let scale = residual_norm * (std::f64::consts::PI / 2.0).sqrt() / self.dim as f64;
Ok(recon.iter().map(|&v| v * scale).collect())
}
pub(crate) fn estimate_inner_product_parts(
&self,
signs: &[bool],
residual_norm: f64,
dim: usize,
y: &[f64],
) -> Result<f64> {
if y.len() != self.dim {
return Err(TurboQuantError::DimensionMismatch {
expected: self.dim,
got: y.len(),
});
}
validate_finite_vector(y, "QJL query")?;
let q = QJLQuantized {
signs: signs.to_vec(),
residual_norm,
dim,
};
self.validate_quantized(&q)?;
let yv = DVector::from_vec(y.to_vec());
let sy = &self.projection * yv;
let dot: f64 = signs
.iter()
.zip(sy.iter())
.map(|(&sign, syi)| if sign { *syi } else { -*syi })
.sum();
let scale = residual_norm * (std::f64::consts::PI / 2.0).sqrt() / self.dim as f64;
Ok(scale * dot)
}
pub fn new(dim: usize, seed: u64) -> Result<Self> {
if dim == 0 {
return Err(TurboQuantError::InvalidDimension(dim));
}
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let normal = Normal::new(0.0, 1.0).unwrap();
let data: Vec<f64> = (0..dim * dim).map(|_| normal.sample(&mut rng)).collect();
let projection = DMatrix::from_vec(dim, dim, data);
let projection_t = projection.transpose();
Ok(Self {
projection,
projection_t,
dim,
})
}
pub fn quantize(&self, x: &[f64]) -> Result<QJLQuantized> {
if x.len() != self.dim {
return Err(TurboQuantError::DimensionMismatch {
expected: self.dim,
got: x.len(),
});
}
validate_finite_vector(x, "QJL input")?;
let xv = DVector::from_vec(x.to_vec());
let sx = &self.projection * xv;
let signs: Vec<bool> = sx.iter().map(|&v| v >= 0.0).collect();
let residual_norm = norm(x);
Ok(QJLQuantized {
signs,
residual_norm,
dim: self.dim,
})
}
pub fn dequantize(&self, q: &QJLQuantized) -> Result<Vec<f64>> {
self.dequantize_parts(&q.signs, q.residual_norm, q.dim)
}
pub fn estimate_inner_product(&self, q: &QJLQuantized, y: &[f64]) -> Result<f64> {
self.estimate_inner_product_parts(&q.signs, q.residual_norm, q.dim, y)
}
pub fn projection(&self) -> &DMatrix<f64> {
&self.projection
}
pub fn projection_t(&self) -> &DMatrix<f64> {
&self.projection_t
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::utils::{inner_product, normalize};
fn random_unit_vector(dim: usize, seed: u64) -> Vec<f64> {
use rand::SeedableRng;
use rand_distr::{Distribution, Normal};
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let normal = Normal::new(0.0, 1.0).unwrap();
let x: Vec<f64> = (0..dim).map(|_| normal.sample(&mut rng)).collect();
normalize(&x).unwrap()
}
#[test]
fn test_quantize_output_shape() {
let qjl = QJL::new(32, 42).unwrap();
let x = random_unit_vector(32, 1);
let q = qjl.quantize(&x).unwrap();
assert_eq!(q.signs.len(), 32);
assert!((q.residual_norm - 1.0).abs() < 1e-10);
}
#[test]
fn test_inner_product_estimation_unbiased() {
let dim = 512;
let qjl = QJL::new(dim, 0).unwrap();
let mut total_err = 0.0;
let n_trials = 10;
for seed in 0..n_trials {
let x = random_unit_vector(dim, seed);
let y = random_unit_vector(dim, seed + 100);
let true_ip = inner_product(&x, &y);
let q = qjl.quantize(&x).unwrap();
let est_ip = qjl.estimate_inner_product(&q, &y).unwrap();
total_err += (true_ip - est_ip).abs();
}
let avg_err = total_err / n_trials as f64;
assert!(
avg_err < 0.3,
"average inner product error {} too large",
avg_err
);
}
#[test]
fn test_dequantize_shape() {
let dim = 16;
let qjl = QJL::new(dim, 1).unwrap();
let x = random_unit_vector(dim, 5);
let q = qjl.quantize(&x).unwrap();
let recon = qjl.dequantize(&q).unwrap();
assert_eq!(recon.len(), dim);
}
#[test]
fn test_dimension_mismatch() {
let qjl = QJL::new(16, 1).unwrap();
let x = vec![0.0; 32]; assert!(qjl.quantize(&x).is_err());
}
#[test]
fn test_invalid_dimension_zero() {
assert!(QJL::new(0, 1).is_err());
}
#[test]
fn test_dequantize_dimension_mismatch() {
let qjl = QJL::new(16, 1).unwrap();
let bad_q = QJLQuantized {
signs: vec![true; 32],
residual_norm: 1.0,
dim: 32,
};
assert!(qjl.dequantize(&bad_q).is_err());
}
#[test]
fn test_estimate_inner_product_dimension_mismatch() {
let qjl = QJL::new(16, 1).unwrap();
let x = random_unit_vector(16, 1);
let q = qjl.quantize(&x).unwrap();
let bad_y = vec![0.0; 32];
assert!(qjl.estimate_inner_product(&q, &bad_y).is_err());
let bad_q = QJLQuantized {
signs: vec![true; 32],
residual_norm: 1.0,
dim: 32,
};
let y = random_unit_vector(16, 2);
assert!(qjl.estimate_inner_product(&bad_q, &y).is_err());
}
#[test]
fn test_qjl_bits_count() {
let qjl = QJL::new(64, 1).unwrap();
let x = random_unit_vector(64, 1);
let q = qjl.quantize(&x).unwrap();
assert_eq!(q.bits(), 96);
}
#[test]
fn test_dequantize_rejects_wrong_sign_count() {
let qjl = QJL::new(8, 1).unwrap();
let q = QJLQuantized {
signs: vec![true; 7],
residual_norm: 1.0,
dim: 8,
};
assert!(matches!(
qjl.dequantize(&q),
Err(TurboQuantError::LengthMismatch { .. })
));
}
}