use crate::VQuantError;
pub struct BinaryQuantizer {
rotation: Vec<f32>,
dim: usize,
projected_dim: usize,
}
impl BinaryQuantizer {
#[must_use]
pub fn new(dim: usize, projected_dim: usize, seed: u64) -> Self {
let projected_dim = projected_dim.min(dim);
let rotation = generate_rotation(dim, projected_dim, seed);
Self {
rotation,
dim,
projected_dim,
}
}
pub fn try_new(dim: usize, projected_dim: usize, seed: u64) -> crate::Result<Self> {
if dim == 0 {
return Err(VQuantError::InvalidConfig {
field: "dim",
reason: "must be > 0",
});
}
if projected_dim == 0 {
return Err(VQuantError::InvalidConfig {
field: "projected_dim",
reason: "must be > 0",
});
}
Ok(Self::new(dim, projected_dim, seed))
}
#[must_use]
pub fn code_len(&self) -> usize {
self.projected_dim.div_ceil(8)
}
#[must_use]
pub fn dim(&self) -> usize {
self.dim
}
#[must_use]
pub fn projected_dim(&self) -> usize {
self.projected_dim
}
pub fn quantize(&self, vector: &[f32]) -> crate::Result<Vec<u8>> {
if vector.len() != self.dim {
return Err(VQuantError::DimensionMismatch {
expected: self.dim,
got: vector.len(),
});
}
Ok(rotate_and_pack(
&self.rotation,
vector,
self.dim,
self.projected_dim,
))
}
pub fn quantize_batch(&self, vectors: &[&[f32]]) -> crate::Result<Vec<Vec<u8>>> {
vectors.iter().map(|v| self.quantize(v)).collect()
}
pub fn asymmetric_distance(&self, query: &[f32], code: &[u8]) -> crate::Result<f32> {
if query.len() != self.dim {
return Err(VQuantError::DimensionMismatch {
expected: self.dim,
got: query.len(),
});
}
let required = self.code_len();
if code.len() < required {
return Err(VQuantError::DimensionMismatch {
expected: required,
got: code.len(),
});
}
let rotated = apply_rotation_rect(&self.rotation, query, self.dim, self.projected_dim);
let mut ip = 0.0f32;
for (i, &rq) in rotated.iter().enumerate() {
let byte_idx = i / 8;
let bit_idx = i % 8;
let bit = (code[byte_idx] >> bit_idx) & 1;
let sign = if bit == 1 { 1.0f32 } else { -1.0f32 };
ip += rq * sign;
}
Ok(-ip / self.projected_dim as f32)
}
#[must_use]
pub fn symmetric_distance(code_a: &[u8], code_b: &[u8]) -> u32 {
crate::simd_ops::hamming_distance(code_a, code_b)
}
}
fn generate_rotation(dim: usize, projected_dim: usize, seed: u64) -> Vec<f32> {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut state = seed;
let mut next_gaussian = || -> f32 {
let mut hasher = DefaultHasher::new();
state.hash(&mut hasher);
state = hasher.finish();
let u1 = (state as f64) / (u64::MAX as f64);
let mut hasher2 = DefaultHasher::new();
state.hash(&mut hasher2);
state = hasher2.finish();
let u2 = (state as f64) / (u64::MAX as f64);
((-2.0 * u1.max(f64::EPSILON).ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()) as f32
};
let mut basis: Vec<Vec<f32>> = Vec::with_capacity(projected_dim);
for i in 0..projected_dim {
let mut v: Vec<f32> = (0..dim).map(|_| next_gaussian()).collect();
for b in &basis {
let dot: f32 = v.iter().zip(b.iter()).map(|(a, b)| a * b).sum();
for (vi, bi) in v.iter_mut().zip(b.iter()) {
*vi -= dot * bi;
}
}
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
for vi in &mut v {
*vi /= norm;
}
basis.push(v);
} else {
let mut e = vec![0.0f32; dim];
let fallback_idx = i % dim;
e[fallback_idx] = 1.0;
basis.push(e);
}
}
let mut rotation = vec![0.0f32; projected_dim * dim];
for (row_idx, row) in basis.iter().enumerate() {
let offset = row_idx * dim;
rotation[offset..offset + dim].copy_from_slice(row);
}
rotation
}
fn apply_rotation_rect(
rotation: &[f32],
vector: &[f32],
dim: usize,
projected_dim: usize,
) -> Vec<f32> {
let mut result = vec![0.0f32; projected_dim];
for (i, out) in result.iter_mut().enumerate() {
let row_start = i * dim;
let mut sum = 0.0f32;
for j in 0..dim {
sum += rotation[row_start + j] * vector[j];
}
*out = sum;
}
result
}
fn rotate_and_pack(rotation: &[f32], vector: &[f32], dim: usize, projected_dim: usize) -> Vec<u8> {
let bytes_needed = projected_dim.div_ceil(8);
let mut packed = vec![0u8; bytes_needed];
for i in 0..projected_dim {
let row_start = i * dim;
let mut sum = 0.0f32;
for j in 0..dim {
sum += rotation[row_start + j] * vector[j];
}
if sum >= 0.0 {
packed[i / 8] |= 1 << (i % 8);
}
}
packed
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn try_new_zero_dim_rejected() {
assert!(BinaryQuantizer::try_new(0, 4, 42).is_err());
}
#[test]
fn try_new_zero_projected_rejected() {
assert!(BinaryQuantizer::try_new(8, 0, 42).is_err());
}
#[test]
fn projected_dim_clamped_to_dim() {
let q = BinaryQuantizer::new(8, 100, 42);
assert_eq!(q.projected_dim(), 8);
}
#[test]
fn code_len_exact_multiple() {
let q = BinaryQuantizer::new(16, 16, 0);
assert_eq!(q.code_len(), 2);
}
#[test]
fn code_len_non_multiple() {
let q = BinaryQuantizer::new(16, 9, 0);
assert_eq!(q.code_len(), 2); }
#[test]
fn quantize_output_length() {
let dim = 32;
let q = BinaryQuantizer::new(dim, dim, 42);
let v: Vec<f32> = (0..dim).map(|i| i as f32).collect();
let code = q.quantize(&v).unwrap();
assert_eq!(code.len(), dim.div_ceil(8));
}
#[test]
fn quantize_projected_output_length() {
let dim = 32;
let proj = 12;
let q = BinaryQuantizer::new(dim, proj, 7);
let v: Vec<f32> = (0..dim).map(|i| (i as f32).sin()).collect();
let code = q.quantize(&v).unwrap();
assert_eq!(code.len(), proj.div_ceil(8));
}
#[test]
fn quantize_dimension_mismatch() {
let q = BinaryQuantizer::new(16, 16, 0);
assert!(q.quantize(&[1.0f32; 8]).is_err());
}
#[test]
fn quantize_batch_matches_individual() {
let dim = 16;
let q = BinaryQuantizer::new(dim, dim, 99);
let v1: Vec<f32> = (0..dim).map(|i| i as f32).collect();
let v2: Vec<f32> = (0..dim).map(|i| -(i as f32)).collect();
let batch = q.quantize_batch(&[&v1, &v2]).unwrap();
let single1 = q.quantize(&v1).unwrap();
let single2 = q.quantize(&v2).unwrap();
assert_eq!(batch[0], single1);
assert_eq!(batch[1], single2);
}
#[test]
fn symmetric_distance_identical_codes_zero() {
let code = vec![0b10101010u8, 0b11001100];
assert_eq!(BinaryQuantizer::symmetric_distance(&code, &code), 0);
}
#[test]
fn symmetric_distance_all_flipped() {
let a = vec![0u8; 2];
let b = vec![0xFFu8; 2];
assert_eq!(BinaryQuantizer::symmetric_distance(&a, &b), 16);
}
#[test]
fn symmetric_matches_manual_popcount() {
let a = vec![0b00001111u8];
let b = vec![0b11110000u8];
assert_eq!(BinaryQuantizer::symmetric_distance(&a, &b), 8);
}
#[test]
fn asymmetric_distance_dimension_mismatch() {
let q = BinaryQuantizer::new(16, 16, 0);
let code = vec![0u8; q.code_len()];
assert!(q.asymmetric_distance(&[1.0f32; 8], &code).is_err());
}
#[test]
fn asymmetric_distance_code_too_short() {
let q = BinaryQuantizer::new(16, 16, 0);
let query = vec![0.0f32; 16];
assert!(q.asymmetric_distance(&query, &[0u8; 1]).is_err());
}
#[test]
fn asymmetric_distance_finite() {
let dim = 32;
let q = BinaryQuantizer::new(dim, dim, 42);
let v: Vec<f32> = (0..dim).map(|i| (i as f32).sin()).collect();
let code = q.quantize(&v).unwrap();
let dist = q.asymmetric_distance(&v, &code).unwrap();
assert!(dist.is_finite());
}
#[test]
fn asymmetric_distance_preserves_relative_ordering() {
let dim = 64;
let q = BinaryQuantizer::new(dim, dim, 1337);
let query: Vec<f32> = vec![1.0f32 / (dim as f32).sqrt(); dim];
let close: Vec<f32> = (0..dim)
.map(|i| 1.0f32 / (dim as f32).sqrt() + (i as f32) * 1e-4)
.collect();
let far: Vec<f32> = vec![-1.0f32 / (dim as f32).sqrt(); dim];
let code_close = q.quantize(&close).unwrap();
let code_far = q.quantize(&far).unwrap();
let dist_close = q.asymmetric_distance(&query, &code_close).unwrap();
let dist_far = q.asymmetric_distance(&query, &code_far).unwrap();
assert!(
dist_close < dist_far,
"close distance {dist_close} should be less than far distance {dist_far}"
);
}
#[test]
fn symmetric_distance_ordering_consistent_with_asymmetric() {
let dim = 64;
let q = BinaryQuantizer::new(dim, dim, 2024);
let query: Vec<f32> = vec![1.0f32 / (dim as f32).sqrt(); dim];
let close: Vec<f32> = vec![0.9f32 / (dim as f32).sqrt(); dim];
let far: Vec<f32> = vec![-1.0f32 / (dim as f32).sqrt(); dim];
let code_q = q.quantize(&query).unwrap();
let code_close = q.quantize(&close).unwrap();
let code_far = q.quantize(&far).unwrap();
let sym_close = BinaryQuantizer::symmetric_distance(&code_q, &code_close);
let sym_far = BinaryQuantizer::symmetric_distance(&code_q, &code_far);
assert!(
sym_close < sym_far,
"symmetric close {sym_close} should be less than far {sym_far}"
);
}
#[test]
fn rotation_rows_are_unit_vectors() {
let dim = 8;
let q = BinaryQuantizer::new(dim, dim, 0);
for row in 0..dim {
let start = row * dim;
let norm_sq: f32 = q.rotation[start..start + dim].iter().map(|x| x * x).sum();
assert!(
(norm_sq - 1.0).abs() < 1e-5,
"row {row} norm^2 = {norm_sq}, expected ~1.0"
);
}
}
#[test]
fn rotation_rows_are_orthogonal() {
let dim = 8;
let q = BinaryQuantizer::new(dim, dim, 0);
for i in 0..dim {
for j in (i + 1)..dim {
let ri = &q.rotation[i * dim..(i + 1) * dim];
let rj = &q.rotation[j * dim..(j + 1) * dim];
let dot: f32 = ri.iter().zip(rj.iter()).map(|(a, b)| a * b).sum();
assert!(
dot.abs() < 1e-5,
"rows {i} and {j} dot product = {dot}, expected ~0"
);
}
}
}
}