use crate::error::{validate_finite, Result, TurboQuantError};
use crate::traits::SerializableCode;
#[derive(Debug, Clone)]
#[cfg_attr(
feature = "serde-support",
derive(serde::Serialize, serde::Deserialize)
)]
pub struct QjlSketch {
pub(crate) signs: Vec<u8>,
pub(crate) num_projections: usize,
pub(crate) norm: f32,
}
impl QjlSketch {
pub fn size_bytes(&self) -> usize {
self.signs.len() + 2 + 4
}
pub fn num_projections(&self) -> usize {
self.num_projections
}
pub fn norm(&self) -> f32 {
self.norm
}
#[inline]
fn get_sign(&self, p: usize) -> f32 {
let byte_idx = p / 8;
let bit_idx = p % 8;
if (self.signs[byte_idx] >> bit_idx) & 1 == 1 {
1.0
} else {
-1.0
}
}
}
impl QjlSketch {
pub fn to_compact_bytes(&self) -> Vec<u8> {
let n_sign_bytes = self.num_projections.div_ceil(8);
let mut out = Vec::with_capacity(1 + 2 + 4 + n_sign_bytes);
out.push(crate::COMPACT_FORMAT_VERSION);
let np: u16 = self.num_projections.try_into().expect(
"QjlSketch num_projections exceeds u16::MAX; too many projections for compact format",
);
out.extend_from_slice(&np.to_le_bytes());
out.extend_from_slice(&self.norm.to_le_bytes());
out.extend_from_slice(&self.signs);
out
}
pub fn from_compact_bytes(bytes: &[u8]) -> Result<Self> {
let err = |reason: &str| TurboQuantError::DeserializationError {
reason: reason.to_string(),
};
if bytes.len() < 7 {
return Err(err("buffer too short: need at least 7 bytes"));
}
if bytes[0] != crate::COMPACT_FORMAT_VERSION {
return Err(err(&format!(
"unsupported version 0x{:02X}, expected 0x{:02X}",
bytes[0],
crate::COMPACT_FORMAT_VERSION
)));
}
let num_projections = u16::from_le_bytes([bytes[1], bytes[2]]) as usize;
if num_projections == 0 {
return Err(err("num_projections must be > 0"));
}
let norm = f32::from_le_bytes([bytes[3], bytes[4], bytes[5], bytes[6]]);
if norm.is_nan() {
return Err(err("norm is NaN"));
}
if norm.is_infinite() {
return Err(err("norm is infinite"));
}
if norm < 0.0 {
return Err(err("norm is negative"));
}
let n_sign_bytes = num_projections.div_ceil(8);
let expected_len = 7 + n_sign_bytes;
if bytes.len() < expected_len {
return Err(err(&format!(
"buffer too short: need {expected_len}, got {}",
bytes.len()
)));
}
let signs = bytes[7..7 + n_sign_bytes].to_vec();
Ok(Self { signs, num_projections, norm })
}
}
impl SerializableCode for QjlSketch {
#[inline]
fn to_compact_bytes(&self) -> Vec<u8> {
QjlSketch::to_compact_bytes(self)
}
#[inline]
fn from_compact_bytes(bytes: &[u8]) -> Result<Self> {
QjlSketch::from_compact_bytes(bytes)
}
}
#[derive(Debug, Clone)]
pub struct QjlQuantizer {
dim: usize,
num_projections: usize,
seed: u64,
projection_matrix: Vec<f32>,
}
impl QjlQuantizer {
pub fn new(dim: usize, projections: usize, seed: u64) -> Result<Self> {
if dim == 0 {
return Err(TurboQuantError::ZeroDimension);
}
if projections == 0 {
return Err(TurboQuantError::ZeroProjections);
}
if projections > u16::MAX as usize {
return Err(TurboQuantError::DimensionTooLarge(projections, u16::MAX as usize));
}
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use rand_distr::Distribution;
use rand_distr::StandardNormal;
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let projection_matrix: Vec<f32> = (0..projections * dim)
.map(|_| {
<StandardNormal as Distribution<f64>>::sample(&StandardNormal, &mut rng) as f32
})
.collect();
Ok(Self {
dim,
num_projections: projections,
seed,
projection_matrix,
})
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn projections(&self) -> usize {
self.num_projections
}
pub fn seed(&self) -> u64 {
self.seed
}
pub fn matrix_size_bytes(&self) -> usize {
self.projection_matrix.len() * core::mem::size_of::<f32>()
}
#[inline]
#[cfg_attr(
feature = "tracing-support",
tracing::instrument(
name = "bitpolar::qjl::sketch",
skip(self, vector),
fields(dim = self.dim, projections = self.num_projections)
)
)]
pub fn sketch(&self, vector: &[f32]) -> Result<QjlSketch> {
if vector.len() != self.dim {
return Err(TurboQuantError::DimensionMismatch {
expected: self.dim,
actual: vector.len(),
});
}
validate_finite(vector)?;
let norm_sq: f32 = vector.iter().map(|x| x * x).sum::<f32>();
let norm = crate::compat::math::sqrtf(norm_sq);
let n_bytes = self.num_projections.div_ceil(8);
let mut signs = vec![0u8; n_bytes];
for p in 0..self.num_projections {
let row_start = p * self.dim;
let dot: f32 = self.projection_matrix[row_start..row_start + self.dim]
.iter()
.zip(vector.iter())
.map(|(m, v)| m * v)
.sum();
if dot >= 0.0 {
let byte_idx = p / 8;
let bit_idx = p % 8;
signs[byte_idx] |= 1 << bit_idx;
}
}
Ok(QjlSketch {
signs,
num_projections: self.num_projections,
norm,
})
}
#[inline]
#[cfg_attr(
feature = "tracing-support",
tracing::instrument(
name = "bitpolar::qjl::ip_estimate",
skip(self, sketch, query),
fields(dim = self.dim, projections = self.num_projections)
)
)]
pub fn inner_product_estimate(
&self,
sketch: &QjlSketch,
query: &[f32],
) -> Result<f32> {
if query.len() != self.dim {
return Err(TurboQuantError::DimensionMismatch {
expected: self.dim,
actual: query.len(),
});
}
validate_finite(query)?;
if sketch.norm < 1e-30 {
return Ok(0.0);
}
let scale =
sketch.norm * crate::compat::math::sqrtf(core::f32::consts::FRAC_PI_2) / self.num_projections as f32;
let mut sum = 0.0f32;
for p in 0..self.num_projections {
let sign = sketch.get_sign(p);
let row_start = p * self.dim;
let dot: f32 = self.projection_matrix[row_start..row_start + self.dim]
.iter()
.zip(query.iter())
.map(|(m, q)| m * q)
.sum();
sum += sign * dot;
}
Ok(scale * sum)
}
pub fn inner_product_estimate_with_norm(
&self,
sketch: &QjlSketch,
query: &[f32],
norm: f32,
) -> Result<f32> {
if query.len() != self.dim {
return Err(TurboQuantError::DimensionMismatch {
expected: self.dim,
actual: query.len(),
});
}
validate_finite(query)?;
if !norm.is_finite() || norm < 0.0 {
return Ok(0.0);
}
if norm < 1e-30 {
return Ok(0.0);
}
let scale = norm * crate::compat::math::sqrtf(core::f32::consts::FRAC_PI_2) / self.num_projections as f32;
let mut sum = 0.0f32;
for p in 0..self.num_projections {
let sign = sketch.get_sign(p);
let row_start = p * self.dim;
let dot: f32 = self.projection_matrix[row_start..row_start + self.dim]
.iter()
.zip(query.iter())
.map(|(m, q)| m * q)
.sum();
sum += sign * dot;
}
Ok(scale * sum)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_zero_dimension_error() {
assert!(matches!(
QjlQuantizer::new(0, 16, 42),
Err(TurboQuantError::ZeroDimension)
));
}
#[test]
fn test_zero_projections_error() {
assert!(matches!(
QjlQuantizer::new(8, 0, 42),
Err(TurboQuantError::ZeroProjections)
));
}
#[test]
fn test_sketch_shape() {
let q = QjlQuantizer::new(8, 16, 42).unwrap();
let v = vec![0.5_f32; 8];
let sketch = q.sketch(&v).unwrap();
assert_eq!(sketch.signs.len(), 2); assert_eq!(sketch.num_projections(), 16);
}
#[test]
fn test_sketch_odd_projections() {
let q = QjlQuantizer::new(8, 13, 42).unwrap();
let v = vec![0.5_f32; 8];
let sketch = q.sketch(&v).unwrap();
assert_eq!(sketch.signs.len(), 2);
}
#[test]
fn test_dimension_mismatch() {
let q = QjlQuantizer::new(8, 16, 42).unwrap();
let v = vec![0.0_f32; 4];
assert!(matches!(
q.sketch(&v),
Err(TurboQuantError::DimensionMismatch { .. })
));
}
#[test]
fn test_non_finite_error() {
let q = QjlQuantizer::new(4, 8, 42).unwrap();
let v = vec![1.0_f32, f32::INFINITY, 0.0, 0.0];
assert!(matches!(
q.sketch(&v),
Err(TurboQuantError::NonFiniteInput { .. })
));
}
#[test]
fn test_inner_product_self_positive() {
let q = QjlQuantizer::new(64, 128, 42).unwrap();
let v: Vec<f32> = (0..64).map(|i| (i as f32 + 1.0) * 0.1).collect();
let exact: f32 = v.iter().map(|x| x * x).sum();
let sketch = q.sketch(&v).unwrap();
let est = q.inner_product_estimate(&sketch, &v).unwrap();
assert!(
est > 0.0,
"Self IP should be positive: estimate={}, exact={}",
est,
exact
);
}
#[test]
fn test_inner_product_unbiased() {
let dim = 64;
let q = QjlQuantizer::new(dim, 128, 42).unwrap();
let mut total_error = 0.0f64;
let trials = 200;
for i in 0..trials {
let v: Vec<f32> = (0..dim)
.map(|j| ((i * dim + j) as f32 * 0.7).sin())
.collect();
let query: Vec<f32> = (0..dim)
.map(|j| ((i * dim + j) as f32 * 1.3).cos())
.collect();
let exact: f32 = v.iter().zip(query.iter()).map(|(a, b)| a * b).sum();
let sketch = q.sketch(&v).unwrap();
let estimated = q.inner_product_estimate(&sketch, &query).unwrap();
total_error += (estimated - exact) as f64;
}
let mean_error = total_error / trials as f64;
assert!(
mean_error.abs() < 2.0,
"QJL should be approximately unbiased, mean error = {:.4}",
mean_error
);
}
#[test]
fn test_zero_vector() {
let q = QjlQuantizer::new(8, 16, 42).unwrap();
let v = vec![0.0_f32; 8];
let sketch = q.sketch(&v).unwrap();
assert!(sketch.norm < 1e-10);
let query = vec![1.0_f32; 8];
let est = q.inner_product_estimate(&sketch, &query).unwrap();
assert!(est.abs() < 1e-10, "Zero vector IP should be ~0, got {}", est);
}
#[test]
fn test_deterministic() {
let q = QjlQuantizer::new(16, 32, 99).unwrap();
let v = vec![1.0_f32; 16];
let s1 = q.sketch(&v).unwrap();
let s2 = q.sketch(&v).unwrap();
assert_eq!(s1.signs, s2.signs);
assert_eq!(s1.norm, s2.norm);
}
#[test]
fn test_different_seeds() {
let q1 = QjlQuantizer::new(16, 32, 1).unwrap();
let q2 = QjlQuantizer::new(16, 32, 2).unwrap();
let v = vec![1.0_f32; 16];
let s1 = q1.sketch(&v).unwrap();
let s2 = q2.sketch(&v).unwrap();
assert_ne!(s1.signs, s2.signs);
}
#[test]
fn test_sketch_size_bytes() {
let q = QjlQuantizer::new(1536, 384, 42).unwrap();
let v = vec![0.1_f32; 1536];
let sketch = q.sketch(&v).unwrap();
assert_eq!(sketch.size_bytes(), 54);
}
#[test]
fn test_matrix_size_bytes() {
let q = QjlQuantizer::new(1536, 384, 42).unwrap();
assert_eq!(q.matrix_size_bytes(), 384 * 1536 * 4);
}
#[test]
fn test_qjl_sketch_roundtrip() {
let q = QjlQuantizer::new(16, 32, 42).unwrap();
let v: Vec<f32> = (0..16).map(|i| (i as f32 * 0.3).sin()).collect();
let sketch = q.sketch(&v).unwrap();
let bytes = sketch.to_compact_bytes();
let decoded = QjlSketch::from_compact_bytes(&bytes).unwrap();
assert_eq!(decoded.signs, sketch.signs);
assert_eq!(decoded.num_projections, sketch.num_projections);
assert_eq!(decoded.norm, sketch.norm);
}
#[test]
fn test_qjl_sketch_odd_projections_roundtrip() {
let q = QjlQuantizer::new(8, 13, 99).unwrap();
let v = vec![1.0_f32; 8];
let sketch = q.sketch(&v).unwrap();
let bytes = sketch.to_compact_bytes();
let back = QjlSketch::from_compact_bytes(&bytes).unwrap();
assert_eq!(back.num_projections, 13);
assert_eq!(back.signs.len(), 2); }
#[test]
fn test_qjl_sketch_wrong_version() {
let q = QjlQuantizer::new(8, 16, 42).unwrap();
let v = vec![0.5_f32; 8];
let sketch = q.sketch(&v).unwrap();
let mut bytes = sketch.to_compact_bytes();
bytes[0] = 0xBB;
assert!(matches!(
QjlSketch::from_compact_bytes(&bytes),
Err(TurboQuantError::DeserializationError { .. })
));
}
#[test]
fn test_qjl_sketch_truncated() {
let q = QjlQuantizer::new(8, 16, 42).unwrap();
let v = vec![0.5_f32; 8];
let sketch = q.sketch(&v).unwrap();
let bytes = sketch.to_compact_bytes();
let truncated = &bytes[..7];
assert!(matches!(
QjlSketch::from_compact_bytes(truncated),
Err(TurboQuantError::DeserializationError { .. })
));
}
#[test]
fn test_qjl_sketch_empty_buffer() {
assert!(matches!(
QjlSketch::from_compact_bytes(&[]),
Err(TurboQuantError::DeserializationError { .. })
));
}
#[test]
fn test_serializable_code_trait_qjl() {
use crate::traits::SerializableCode;
let q = QjlQuantizer::new(16, 32, 42).unwrap();
let v = vec![0.4_f32; 16];
let sketch = q.sketch(&v).unwrap();
let bytes = <QjlSketch as SerializableCode>::to_compact_bytes(&sketch);
let back = <QjlSketch as SerializableCode>::from_compact_bytes(&bytes).unwrap();
assert_eq!(back.norm, sketch.norm);
}
}