#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
use super::simd_config;
#[derive(Debug, Clone)]
pub struct BinaryVector {
pub data: Vec<u8>,
pub dims: usize,
pub norm: f32,
}
impl BinaryVector {
pub fn from_f32(vector: &[f32]) -> Self {
Self::from_f32_with_threshold(vector, 0.0)
}
pub fn from_f32_with_threshold(vector: &[f32], threshold: f32) -> Self {
let dims = vector.len();
let mut norm_sq = 0.0f32;
for &v in vector {
if v.is_finite() {
norm_sq += v * v;
}
}
let norm = norm_sq.sqrt();
let packed_len = dims.div_ceil(8);
let mut data = vec![0u8; packed_len];
for (i, &v) in vector.iter().enumerate() {
let val = if v.is_finite() { v } else { 0.0 };
if val >= threshold {
let byte_idx = i / 8;
let bit_idx = 7 - (i % 8); data[byte_idx] |= 1 << bit_idx;
}
}
Self { data, dims, norm }
}
pub fn to_f32(&self) -> Vec<f32> {
let mut result = Vec::with_capacity(self.dims);
for i in 0..self.dims {
let byte_idx = i / 8;
let bit_idx = 7 - (i % 8);
let bit = (self.data[byte_idx] >> bit_idx) & 1;
result.push(if bit == 1 { 1.0 } else { -1.0 });
}
result
}
#[inline]
pub fn hamming_distance(&self, other: &BinaryVector) -> u32 {
hamming_distance_binary(self, other)
}
#[inline]
pub fn cosine_distance_approx(&self, other: &BinaryVector) -> f32 {
if self.dims == 0 {
return 0.0;
}
let hamming = self.hamming_distance(other) as f32;
2.0 * hamming / self.dims as f32
}
#[inline]
pub fn cosine_similarity_approx(&self, other: &BinaryVector) -> f32 {
1.0 - self.cosine_distance_approx(other)
}
}
#[inline]
pub fn hamming_distance_binary(a: &BinaryVector, b: &BinaryVector) -> u32 {
if a.dims != b.dims {
return u32::MAX;
}
let config = simd_config();
#[cfg(target_arch = "aarch64")]
{
if config.neon_enabled {
debug_assert_eq!(a.data.len(), b.data.len());
return unsafe { hamming_distance_neon(&a.data, &b.data) };
}
}
#[cfg(not(target_arch = "aarch64"))]
{
let _ = config;
}
hamming_distance_scalar(&a.data, &b.data)
}
fn hamming_distance_scalar(a: &[u8], b: &[u8]) -> u32 {
let mut total: u32 = 0;
let chunks = a.len() / 8;
for c in 0..chunks {
let offset = c * 8;
let a_u64 = u64::from_ne_bytes([
a[offset],
a[offset + 1],
a[offset + 2],
a[offset + 3],
a[offset + 4],
a[offset + 5],
a[offset + 6],
a[offset + 7],
]);
let b_u64 = u64::from_ne_bytes([
b[offset],
b[offset + 1],
b[offset + 2],
b[offset + 3],
b[offset + 4],
b[offset + 5],
b[offset + 6],
b[offset + 7],
]);
total += (a_u64 ^ b_u64).count_ones();
}
let remainder_start = chunks * 8;
for i in remainder_start..a.len() {
total += (a[i] ^ b[i]).count_ones();
}
total
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn hamming_distance_neon(a: &[u8], b: &[u8]) -> u32 {
debug_assert_eq!(
a.len(),
b.len(),
"hamming_distance_neon: slice lengths differ ({} vs {})",
a.len(),
b.len()
);
let len = a.len();
const SIMD_WIDTH: usize = 16;
let chunks = len / SIMD_WIDTH;
let mut sum_u64 = vdupq_n_u64(0);
for c in 0..chunks {
let base = c * SIMD_WIDTH;
let va = vld1q_u8(a.as_ptr().add(base));
let vb = vld1q_u8(b.as_ptr().add(base));
let xor = veorq_u8(va, vb);
let popcnt = vcntq_u8(xor);
let sum_u16 = vpaddlq_u8(popcnt);
let sum_u32 = vpaddlq_u16(sum_u16);
sum_u64 = vaddq_u64(sum_u64, vpaddlq_u32(sum_u32));
}
let total = vgetq_lane_u64(sum_u64, 0) + vgetq_lane_u64(sum_u64, 1);
let mut result = total as u32;
let remainder_start = chunks * SIMD_WIDTH;
for i in remainder_start..len {
result += (a[i] ^ b[i]).count_ones();
}
result
}
#[cfg(test)]
mod tests {
use super::*;
fn generate_vector(dim: usize, seed: u64) -> Vec<f32> {
let mut state = seed ^ ((dim as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15));
(0..dim)
.map(|i| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407)
.wrapping_add(i as u64);
let unit = ((state >> 32) as u32) as f32 / u32::MAX as f32;
unit * 2.0 - 1.0
})
.collect()
}
#[test]
fn test_binary_quantize_basic() {
let v = vec![0.5, -0.3, 0.0, -1.0, 1.0, 0.1, -0.1, 0.9];
let bv = BinaryVector::from_f32(&v);
assert_eq!(bv.data.len(), 1); assert_eq!(bv.dims, 8);
assert_eq!(bv.data[0], 0xAD, "packed bits: {:08b}", bv.data[0]);
}
#[test]
fn test_binary_roundtrip() {
let v = vec![0.5, -0.3, 0.0, -1.0, 1.0, 0.1, -0.1, 0.9];
let bv = BinaryVector::from_f32(&v);
let deq = bv.to_f32();
assert_eq!(deq, vec![1.0, -1.0, 1.0, -1.0, 1.0, 1.0, -1.0, 1.0]);
}
#[test]
fn test_binary_hamming_distance() {
let v = generate_vector(384, 42);
let bv = BinaryVector::from_f32(&v);
assert_eq!(bv.hamming_distance(&bv), 0);
let neg_v: Vec<f32> = v.iter().map(|x| -x).collect();
let neg_bv = BinaryVector::from_f32(&neg_v);
let hamming = bv.hamming_distance(&neg_bv);
assert!(hamming > 350, "hamming={hamming}, expected close to 384");
}
#[test]
fn test_binary_cosine_approx_identical() {
let v = generate_vector(384, 55);
let bv = BinaryVector::from_f32(&v);
let cos_dist = bv.cosine_distance_approx(&bv);
assert!(
cos_dist.abs() < 1e-5,
"Identical binary vectors should have 0 cosine distance, got {cos_dist}"
);
}
#[test]
fn test_binary_cosine_approx_quality() {
let a = generate_vector(384, 101);
let b = generate_vector(384, 202);
let dot: f32 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
let f32_cos = dot / (norm_a * norm_b);
let ba = BinaryVector::from_f32(&a);
let bb = BinaryVector::from_f32(&b);
let bin_cos = ba.cosine_similarity_approx(&bb);
assert!(
(f32_cos - bin_cos).abs() < 0.35,
"Binary cosine too far from f32: f32={f32_cos}, binary={bin_cos}"
);
}
#[test]
fn test_binary_memory_savings() {
let v = generate_vector(384, 999);
let bv = BinaryVector::from_f32(&v);
assert_eq!(bv.data.len(), 48);
}
#[test]
fn test_binary_non_multiple_of_8_dims() {
let v = generate_vector(385, 77);
let bv = BinaryVector::from_f32(&v);
assert_eq!(bv.data.len(), 49);
assert_eq!(bv.dims, 385);
let deq = bv.to_f32();
assert_eq!(deq.len(), 385);
}
#[test]
fn test_binary_with_threshold() {
let v = vec![0.5, 0.3, 0.1, -0.1, -0.3, -0.5, 0.7, 0.2];
let bv = BinaryVector::from_f32_with_threshold(&v, 0.25);
let deq = bv.to_f32();
assert_eq!(deq, vec![1.0, 1.0, -1.0, -1.0, -1.0, -1.0, 1.0, -1.0]);
}
#[test]
fn test_binary_nan_inf_handling() {
let v = vec![
f32::NAN,
f32::INFINITY,
f32::NEG_INFINITY,
1.0,
-1.0,
0.0,
0.5,
-0.5,
];
let bv = BinaryVector::from_f32(&v);
let deq = bv.to_f32();
assert_eq!(deq.len(), 8);
for &val in &deq {
assert!(val == 1.0 || val == -1.0, "Binary should produce +/-1.0");
}
}
#[test]
fn test_hamming_scalar_vs_neon_parity() {
let a = generate_vector(384, 111);
let b = generate_vector(384, 222);
let ba = BinaryVector::from_f32(&a);
let bb = BinaryVector::from_f32(&b);
let scalar_result = hamming_distance_scalar(&ba.data, &bb.data);
let dispatch_result = ba.hamming_distance(&bb);
assert_eq!(
scalar_result, dispatch_result,
"Scalar and dispatched Hamming should match"
);
}
}