#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct BinaryCode {
pub words: Vec<u64>,
pub norm: f32,
pub dim: usize,
}
impl BinaryCode {
pub fn encode(rotated: &[f32], norm: f32) -> Self {
let dim = rotated.len();
let n_words = (dim + 63) / 64;
let mut words = vec![0u64; n_words];
for (i, &v) in rotated.iter().enumerate() {
if v >= 0.0 {
words[i / 64] |= 1u64 << (63 - (i % 64));
}
}
Self { words, norm, dim }
}
#[inline]
pub fn xnor_popcount(&self, other: &Self) -> u32 {
debug_assert_eq!(self.words.len(), other.words.len());
self.words
.iter()
.zip(other.words.iter())
.map(|(&a, &b)| (!(a ^ b)).count_ones())
.sum()
}
#[inline]
pub fn masked_xnor_popcount(&self, other: &Self) -> u32 {
debug_assert_eq!(self.words.len(), other.words.len());
debug_assert_eq!(self.dim, other.dim);
let n_words = self.words.len();
if n_words == 0 {
return 0;
}
let mut sum: u32 = 0;
for i in 0..n_words - 1 {
sum += (!(self.words[i] ^ other.words[i])).count_ones();
}
let valid_bits = self.dim - 64 * (n_words - 1);
let mask: u64 = if valid_bits == 64 {
!0u64
} else {
!0u64 << (64 - valid_bits)
};
let last = !(self.words[n_words - 1] ^ other.words[n_words - 1]) & mask;
sum += last.count_ones();
sum
}
#[inline]
pub fn estimated_sq_distance(&self, query_code: &Self) -> f32 {
use std::f32::consts::PI;
let d = self.dim as f32;
let agreement = self.masked_xnor_popcount(query_code) as f32;
let est_cos = (PI * (1.0 - agreement / d)).cos();
let est_ip = self.norm * query_code.norm * est_cos;
let q_sq = query_code.norm * query_code.norm;
q_sq + self.norm * self.norm - 2.0 * est_ip
}
#[inline]
pub fn estimated_sq_distance_asymmetric(&self, q_rotated_unit: &[f32], q_norm: f32) -> f32 {
debug_assert_eq!(q_rotated_unit.len(), self.dim);
let d = self.dim;
let inv_sqrt_d = 1.0 / (d as f32).sqrt();
let mut ip = 0.0f32;
for (i, &q_i) in q_rotated_unit.iter().enumerate() {
let bit_set = (self.words[i / 64] >> (63 - (i % 64))) & 1 == 1;
ip += if bit_set { q_i } else { -q_i };
}
let unit_ip = ip * inv_sqrt_d;
let est_ip = q_norm * self.norm * unit_ip;
q_norm * q_norm + self.norm * self.norm - 2.0 * est_ip
}
}
pub fn pack_bits(bits: &[bool]) -> Vec<u64> {
let n_words = (bits.len() + 63) / 64;
let mut words = vec![0u64; n_words];
for (i, &b) in bits.iter().enumerate() {
if b {
words[i / 64] |= 1u64 << (63 - (i % 64));
}
}
words
}
pub fn unpack_bits(words: &[u64], dim: usize) -> Vec<bool> {
(0..dim)
.map(|i| words[i / 64] & (1u64 << (63 - (i % 64))) != 0)
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pack_unpack_roundtrip() {
let bits: Vec<bool> = (0..130).map(|i| i % 3 == 0).collect();
let words = pack_bits(&bits);
let unpacked = unpack_bits(&words, 130);
assert_eq!(bits, unpacked);
}
#[test]
fn xnor_self_is_all_ones() {
let v: Vec<f32> = (0..64)
.map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
.collect();
let code = BinaryCode::encode(&v, 1.0);
let agreement = code.xnor_popcount(&code);
assert_eq!(
agreement, 64,
"self-agreement should be D=64, got {agreement}"
);
}
#[test]
fn xnor_opposite_is_zero() {
let v: Vec<f32> = (0..64)
.map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
.collect();
let neg_v: Vec<f32> = v.iter().map(|&x| -x).collect();
let code = BinaryCode::encode(&v, 1.0);
let code_neg = BinaryCode::encode(&neg_v, 1.0);
let agreement = code.xnor_popcount(&code_neg);
assert_eq!(agreement, 0, "opposite vectors should have 0 agreement");
}
#[test]
fn masked_popcount_handles_non_aligned_dim() {
let v: Vec<f32> = (0..100)
.map(|i| if i % 2 == 0 { 1.0 } else { -1.0 })
.collect();
let neg_v: Vec<f32> = v.iter().map(|&x| -x).collect();
let code = BinaryCode::encode(&v, 1.0);
let code_neg = BinaryCode::encode(&neg_v, 1.0);
let raw = code.xnor_popcount(&code_neg);
assert_eq!(
raw, 28,
"raw xnor should count padding as matches (bug demo)"
);
let masked = code.masked_xnor_popcount(&code_neg);
assert_eq!(
masked, 0,
"masked xnor must ignore padding bits; got {masked}"
);
let self_masked = code.masked_xnor_popcount(&code);
assert_eq!(self_masked, 100);
}
#[test]
fn masked_popcount_matches_raw_when_aligned() {
let v: Vec<f32> = (0..128).map(|i| (i as f32 * 0.1).sin()).collect();
let w: Vec<f32> = (0..128).map(|i| (i as f32 * 0.1).cos()).collect();
let ca = BinaryCode::encode(&v, 1.0);
let cb = BinaryCode::encode(&w, 1.0);
assert_eq!(ca.xnor_popcount(&cb), ca.masked_xnor_popcount(&cb));
}
#[test]
fn estimated_distance_self_is_near_zero() {
let v: Vec<f32> = (0..128).map(|i| (i as f32 / 128.0).sin()).collect();
let norm: f32 = v.iter().map(|&x| x * x).sum::<f32>().sqrt();
let unit: Vec<f32> = v.iter().map(|&x| x / norm).collect();
let code = BinaryCode::encode(&unit, 1.0);
let est = code.estimated_sq_distance(&code);
assert!(
est.abs() < 1e-5,
"self sq-distance estimate too large: {est}"
);
}
#[test]
fn asymmetric_matches_symmetric_in_sign() {
use rand::{Rng as _, SeedableRng as _};
let mut rng = rand::rngs::StdRng::seed_from_u64(11);
let d = 128;
let q: Vec<f32> = (0..d).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
let q_norm: f32 = q.iter().map(|&x| x * x).sum::<f32>().sqrt();
let q_unit: Vec<f32> = q.iter().map(|&x| x / q_norm).collect();
let qc = BinaryCode::encode(&q_unit, q_norm);
let near: Vec<f32> = q.iter().map(|&x| x + rng.gen::<f32>() * 0.1).collect();
let far: Vec<f32> = (0..d).map(|_| rng.gen::<f32>() * 2.0 - 1.0).collect();
let encode_one = |v: &[f32]| {
let n: f32 = v.iter().map(|&x| x * x).sum::<f32>().sqrt();
let u: Vec<f32> = v.iter().map(|&x| x / n).collect();
BinaryCode::encode(&u, n)
};
let cn = encode_one(&near);
let cf = encode_one(&far);
let s_near = cn.estimated_sq_distance(&qc);
let s_far = cf.estimated_sq_distance(&qc);
let a_near = cn.estimated_sq_distance_asymmetric(&q_unit, q_norm);
let a_far = cf.estimated_sq_distance_asymmetric(&q_unit, q_norm);
assert!(s_near < s_far, "symmetric ordering wrong");
assert!(a_near < a_far, "asymmetric ordering wrong");
}
}