use crate::VectorType;
use bytemuck::{Pod, Zeroable};
#[repr(C)]
#[derive(Copy, Clone, Debug, PartialEq, Eq, Pod, Zeroable, Default)]
pub struct BqSignature {
pub data: [u64; 32],
}
impl BqSignature {
pub fn empty() -> Self {
Self::default()
}
pub fn from_vector<T: VectorType>(vec: &[T]) -> Self {
let mut data = [0u64; 32];
for i in 0..32 {
let mut chunk_bits = 0u64;
for j in 0..64 {
let idx = i * 64 + j;
if idx < vec.len() && vec[idx].to_f32() > 0.0 {
chunk_bits |= 1u64 << j;
}
}
data[i] = chunk_bits;
}
Self { data }
}
#[inline]
pub fn hamming_distance(&self, other: &Self) -> u32 {
self.data
.iter()
.zip(other.data.iter())
.map(|(a, b)| (a ^ b).count_ones())
.sum()
}
}
#[repr(C)]
#[derive(Copy, Clone, Debug, PartialEq, Eq, Pod, Zeroable, Default)]
pub struct Bq2Signature {
pub pos: [u64; 32],
pub strong: [u64; 32],
}
impl Bq2Signature {
pub fn empty() -> Self {
Self::default()
}
pub fn from_vector<T: crate::VectorType>(vec: &[T]) -> Self {
let mut pos = [0u64; 32];
let mut strong = [0u64; 32];
let mut sum_abs = 0.0;
for v in vec {
sum_abs += v.to_f32().abs();
}
let alpha = if vec.is_empty() {
0.0
} else {
sum_abs / vec.len() as f32
};
for i in 0..32 {
let mut chunk_pos = 0u64;
let mut chunk_strong = 0u64;
for j in 0..64 {
let idx = i * 64 + j;
if idx < vec.len() {
let val = vec[idx].to_f32();
if val > 0.0 {
chunk_pos |= 1u64 << j;
}
if val.abs() > alpha {
chunk_strong |= 1u64 << j;
}
}
}
pos[i] = chunk_pos;
strong[i] = chunk_strong;
}
Self { pos, strong }
}
#[inline]
pub fn distance(&self, other: &Self, dim: usize) -> u32 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512vpopcntdq") && is_x86_feature_detected!("avx512f") {
return unsafe { self.distance_avx512(other, dim) };
}
}
self.distance_scalar(other, dim)
}
#[inline]
fn distance_scalar(&self, other: &Self, dim: usize) -> u32 {
let chunks = dim.div_ceil(64);
let valid_bits_last = if dim.is_multiple_of(64) {
!0u64
} else {
(1u64 << (dim % 64)) - 1
};
let mut dot = 0i32;
for i in 0..chunks {
let mask = if i == chunks - 1 {
valid_bits_last
} else {
!0u64
};
let a1 = self.pos[i] & mask;
let b1 = self.strong[i] & mask;
let a2 = other.pos[i] & mask;
let b2 = other.strong[i] & mask;
let same = !(a1 ^ a2) & mask;
let diff = (a1 ^ a2) & mask;
let both_strong = b1 & b2 & mask;
let one_strong = (b1 ^ b2) & mask;
let both_weak = !(b1 | b2) & mask;
let w4_pos = same & both_strong;
let w4_neg = diff & both_strong;
let w2_pos = same & one_strong;
let w2_neg = diff & one_strong;
let w1_pos = same & both_weak;
let w1_neg = diff & both_weak;
dot += 4 * w4_pos.count_ones() as i32;
dot -= 4 * w4_neg.count_ones() as i32;
dot += 2 * w2_pos.count_ones() as i32;
dot -= 2 * w2_neg.count_ones() as i32;
dot += w1_pos.count_ones() as i32;
dot -= w1_neg.count_ones() as i32;
}
let max_dot = 4 * dim as i32;
(max_dot - dot) as u32
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f,avx512vpopcntdq")]
unsafe fn distance_avx512(&self, other: &Self, dim: usize) -> u32 {
unsafe {
use std::arch::x86_64::*;
let chunks = dim.div_ceil(64);
let full_rounds = chunks / 8;
let remainder = chunks % 8;
let mut acc_w4_pos = _mm512_setzero_si512();
let mut acc_w4_neg = _mm512_setzero_si512();
let mut acc_w2_pos = _mm512_setzero_si512();
let mut acc_w2_neg = _mm512_setzero_si512();
let mut acc_w1_pos = _mm512_setzero_si512();
let mut acc_w1_neg = _mm512_setzero_si512();
let ones_mask = _mm512_set1_epi64(-1i64);
for r in 0..full_rounds {
let off = r * 8;
let a1 = _mm512_loadu_si512(self.pos.as_ptr().add(off) as *const _);
let b1 = _mm512_loadu_si512(self.strong.as_ptr().add(off) as *const _);
let a2 = _mm512_loadu_si512(other.pos.as_ptr().add(off) as *const _);
let b2 = _mm512_loadu_si512(other.strong.as_ptr().add(off) as *const _);
let xor_ab = _mm512_xor_si512(a1, a2);
let same = _mm512_xor_si512(xor_ab, ones_mask);
let diff = xor_ab;
let both_strong = _mm512_and_si512(b1, b2);
let one_strong = _mm512_xor_si512(b1, b2);
let b_or = _mm512_or_si512(b1, b2);
let both_weak = _mm512_xor_si512(b_or, ones_mask);
acc_w4_pos = _mm512_add_epi64(
acc_w4_pos,
_mm512_popcnt_epi64(_mm512_and_si512(same, both_strong)),
);
acc_w4_neg = _mm512_add_epi64(
acc_w4_neg,
_mm512_popcnt_epi64(_mm512_and_si512(diff, both_strong)),
);
acc_w2_pos = _mm512_add_epi64(
acc_w2_pos,
_mm512_popcnt_epi64(_mm512_and_si512(same, one_strong)),
);
acc_w2_neg = _mm512_add_epi64(
acc_w2_neg,
_mm512_popcnt_epi64(_mm512_and_si512(diff, one_strong)),
);
acc_w1_pos = _mm512_add_epi64(
acc_w1_pos,
_mm512_popcnt_epi64(_mm512_and_si512(same, both_weak)),
);
acc_w1_neg = _mm512_add_epi64(
acc_w1_neg,
_mm512_popcnt_epi64(_mm512_and_si512(diff, both_weak)),
);
}
let hsum = |v: __m512i| -> i64 { _mm512_reduce_add_epi64(v) };
let mut dot = 0i64;
dot += 4 * hsum(acc_w4_pos);
dot -= 4 * hsum(acc_w4_neg);
dot += 2 * hsum(acc_w2_pos);
dot -= 2 * hsum(acc_w2_neg);
dot += hsum(acc_w1_pos);
dot -= hsum(acc_w1_neg);
if remainder > 0 {
let start = full_rounds * 8;
let valid_bits_last = if dim.is_multiple_of(64) {
!0u64
} else {
(1u64 << (dim % 64)) - 1
};
for i in start..chunks {
let mask = if i == chunks - 1 {
valid_bits_last
} else {
!0u64
};
let a1 = self.pos[i] & mask;
let b1 = self.strong[i] & mask;
let a2 = other.pos[i] & mask;
let b2 = other.strong[i] & mask;
let same = !(a1 ^ a2) & mask;
let diff = (a1 ^ a2) & mask;
let both_strong = b1 & b2 & mask;
let one_strong = (b1 ^ b2) & mask;
let both_weak = !(b1 | b2) & mask;
dot += 4 * (same & both_strong).count_ones() as i64;
dot -= 4 * (diff & both_strong).count_ones() as i64;
dot += 2 * (same & one_strong).count_ones() as i64;
dot -= 2 * (diff & one_strong).count_ones() as i64;
dot += (same & both_weak).count_ones() as i64;
dot -= (diff & both_weak).count_ones() as i64;
}
}
let max_dot = 4 * dim as i64;
(max_dot - dot) as u32
}
}
}