use ndarray::{Array1, Array2, Axis};
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;
use rand_distr::StandardNormal;
const QJL_CONST: f64 = 1.2533141373155003;
pub struct QJL {
pub d: usize,
s: Array2<f64>,
}
impl QJL {
pub fn new(d: usize, seed: u64) -> Self {
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let s = Array2::from_shape_fn((d, d), |_| rng.sample::<f64, _>(StandardNormal));
Self { d, s }
}
pub fn quantize(&self, r: &Array1<f64>) -> (Array1<i8>, f64) {
let (signs, norms) = self.quantize_batch(&r.clone().insert_axis(Axis(0)));
(signs.row(0).to_owned(), norms[0])
}
pub fn quantize_batch(&self, r: &Array2<f64>) -> (Array2<i8>, Array1<f64>) {
let batch = r.nrows();
let mut norms = Array1::zeros(batch);
for i in 0..batch {
norms[i] = r.row(i).dot(&r.row(i)).sqrt();
}
let projected = r.dot(&self.s.t());
let (b, d) = projected.dim();
let mut signs = Array2::zeros((b, d));
for i in 0..b {
for j in 0..d {
signs[[i, j]] = if projected[[i, j]] >= 0.0 { 1i8 } else { -1i8 };
}
}
(signs, norms)
}
pub fn dequantize(&self, signs: &Array1<i8>, norm: f64) -> Array1<f64> {
let signs_2d = signs.clone().insert_axis(Axis(0));
let norms = Array1::from_vec(vec![norm]);
let result = self.dequantize_batch(&signs_2d, &norms);
result.row(0).to_owned()
}
pub fn dequantize_batch(&self, signs: &Array2<i8>, norms: &Array1<f64>) -> Array2<f64> {
let batch = signs.nrows();
let signs_f64 = signs.mapv(|s| s as f64);
let reconstructed_raw = signs_f64.dot(&self.s);
let mut reconstructed = reconstructed_raw;
for i in 0..batch {
let scale = QJL_CONST / self.d as f64 * norms[i];
for j in 0..self.d {
reconstructed[[i, j]] *= scale;
}
}
reconstructed
}
}
#[cfg(test)]
mod tests {
use super::*;
fn synthetic_batch(batch: usize, d: usize) -> Array2<f64> {
Array2::from_shape_fn((batch, d), |(i, j)| {
let t = (i * 17 + j * 5) as f64;
(t / d as f64).sin() + 0.2 * (t / 9.0).cos()
})
}
#[test]
fn test_sign_quantization() {
let qjl = QJL::new(32, 42);
let r = Array1::from_shape_fn(32, |i| (i as f64 - 16.0) / 10.0);
let (signs, norm) = qjl.quantize(&r);
for &s in signs.iter() {
assert!(s == 1 || s == -1, "Sign should be +1 or -1, got {s}");
}
assert!(norm > 0.0);
}
#[test]
fn test_zero_residual() {
let qjl = QJL::new(16, 42);
let r = Array1::zeros(16);
let (signs, norm) = qjl.quantize(&r);
let r_hat = qjl.dequantize(&signs, norm);
let error: f64 = r_hat.mapv(|v| v * v).sum();
assert!(error < 1e-20, "Zero residual reconstruction error: {error}");
}
#[test]
fn test_inner_product_preservation() {
let d = 256;
let qjl = QJL::new(d, 42);
let batch = synthetic_batch(64, d);
let (signs, norms) = qjl.quantize_batch(&batch);
let batch_hat = qjl.dequantize_batch(&signs, &norms);
let mut total_rel_err = 0.0;
let mut n_pairs = 0usize;
for i in 0..32 {
let j = (i * 7 + 3) % 64;
let x = batch.row(i);
let y = batch.row(j);
let x_hat = batch_hat.row(i);
let y_hat = batch_hat.row(j);
let ip_original = x.dot(&y);
let ip_approx = x_hat.dot(&y_hat);
let denom = ip_original.abs().max(1e-8);
total_rel_err += (ip_original - ip_approx).abs() / denom;
n_pairs += 1;
}
let relative_error = total_rel_err / n_pairs as f64;
assert!(
relative_error < 2.5,
"Mean inner product relative error too large: {relative_error}"
);
}
}