use rand::rngs::StdRng;
use rand::SeedableRng;
use rand_distr::{Distribution, StandardNormal};
#[derive(Debug, Clone)]
pub struct QjlProjector {
seed: u64,
dim: usize,
}
impl QjlProjector {
pub fn new(seed: u64, dim: usize) -> Self {
Self { seed, dim }
}
pub fn compress(&self, x: &[f32]) -> (Vec<u64>, f32) {
assert_eq!(x.len(), self.dim);
let norm = l2_norm(x);
let words_needed = (self.dim + 63) / 64;
let mut bits = vec![0u64; words_needed];
let mut rng = StdRng::seed_from_u64(self.seed);
for i in 0..self.dim {
let mut dot = 0.0f32;
for j in 0..self.dim {
let s_ij: f32 = StandardNormal.sample(&mut rng);
dot += s_ij * x[j];
}
if dot >= 0.0 {
bits[i / 64] |= 1u64 << (i % 64);
}
}
(bits, norm)
}
pub fn inner_product(&self, query: &[f32], key_bits: &[u64], key_norm: f32) -> f32 {
assert_eq!(query.len(), self.dim);
let coeff = std::f32::consts::FRAC_PI_2.sqrt() / (self.dim as f32);
let mut rng = StdRng::seed_from_u64(self.seed);
let mut sum = 0.0f32;
for i in 0..self.dim {
let mut proj_q = 0.0f32;
for j in 0..self.dim {
let s_ij: f32 = StandardNormal.sample(&mut rng);
proj_q += s_ij * query[j];
}
let sign = if (key_bits[i / 64] >> (i % 64)) & 1 == 1 {
1.0f32
} else {
-1.0f32
};
sum += proj_q * sign;
}
coeff * key_norm * sum
}
pub fn project_query(&self, query: &[f32]) -> Vec<f32> {
assert_eq!(query.len(), self.dim);
let mut projected = vec![0.0f32; self.dim];
let mut rng = StdRng::seed_from_u64(self.seed);
for i in 0..self.dim {
let mut dot = 0.0f32;
for j in 0..self.dim {
let s_ij: f32 = StandardNormal.sample(&mut rng);
dot += s_ij * query[j];
}
projected[i] = dot;
}
projected
}
pub fn inner_product_fast(
&self,
projected_query: &[f32],
key_bits: &[u64],
key_norm: f32,
) -> f32 {
assert_eq!(projected_query.len(), self.dim);
let coeff = std::f32::consts::FRAC_PI_2.sqrt() / (self.dim as f32);
let sum = crate::backend::cpu::simd::dot_with_sign_bits_fast(projected_query, key_bits, self.dim);
coeff * key_norm * sum
}
pub fn packed_words(&self) -> usize {
(self.dim + 63) / 64
}
#[inline]
pub fn dim(&self) -> usize {
self.dim
}
#[inline]
pub fn seed(&self) -> u64 {
self.seed
}
}
#[inline]
pub fn dot_with_sign_bits(values: &[f32], bits: &[u64], count: usize) -> f32 {
let mut sum = 0.0f32;
let full_words = count / 64;
let remainder = count % 64;
for word_idx in 0..full_words {
let word = bits[word_idx];
let base = word_idx * 64;
for bit in 0..64 {
let sign = if (word >> bit) & 1 == 1 { 1.0f32 } else { -1.0f32 };
sum += values[base + bit] * sign;
}
}
if remainder > 0 {
let word = bits[full_words];
let base = full_words * 64;
for bit in 0..remainder {
let sign = if (word >> bit) & 1 == 1 { 1.0f32 } else { -1.0f32 };
sum += values[base + bit] * sign;
}
}
sum
}
fn l2_norm(x: &[f32]) -> f32 {
x.iter().map(|&v| v * v).sum::<f32>().sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compress_produces_correct_shape() {
let proj = QjlProjector::new(42, 128);
let x: Vec<f32> = (0..128).map(|i| (i as f32 - 64.0) * 0.01).collect();
let (bits, norm) = proj.compress(&x);
assert_eq!(bits.len(), 2); assert!(norm > 0.0);
}
#[test]
fn test_inner_product_unbiased() {
let dim = 64;
let proj = QjlProjector::new(42, dim);
let x: Vec<f32> = (0..dim).map(|i| (i as f32 - 32.0) * 0.01).collect();
let y: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.01).collect();
let true_dot: f32 = x.iter().zip(y.iter()).map(|(a, b)| a * b).sum();
let (bits, norm) = proj.compress(&x);
let est = proj.inner_product(&y, &bits, norm);
let rel_err = (est - true_dot).abs() / true_dot.abs().max(1e-6);
assert!(
rel_err < 2.0,
"estimate should be in the right ballpark: est={est}, true={true_dot}"
);
}
#[test]
fn test_fast_matches_slow() {
let dim = 32;
let proj = QjlProjector::new(99, dim);
let key: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.02).collect();
let query: Vec<f32> = (0..dim).map(|i| (i as f32 - 16.0) * 0.01).collect();
let (bits, norm) = proj.compress(&key);
let slow = proj.inner_product(&query, &bits, norm);
let projected_q = proj.project_query(&query);
let fast = proj.inner_product_fast(&projected_q, &bits, norm);
assert!(
(slow - fast).abs() < 1e-4,
"fast and slow should match: {slow} vs {fast}"
);
}
#[test]
fn test_dot_with_sign_bits_all_positive() {
let values = vec![1.0, 2.0, 3.0, 4.0];
let bits = vec![0xF_u64]; let result = dot_with_sign_bits(&values, &bits, 4);
assert!((result - 10.0).abs() < 1e-6);
}
#[test]
fn test_dot_with_sign_bits_alternating() {
let values = vec![1.0, 2.0, 3.0, 4.0];
let bits = vec![0b0101_u64]; let result = dot_with_sign_bits(&values, &bits, 4);
let expected = 1.0 - 2.0 + 3.0 - 4.0;
assert!((result - expected).abs() < 1e-6);
}
}