use std::arch::x86_64::{
__m256, _mm256_add_ps, _mm256_andnot_ps, _mm256_castps_si256, _mm256_castsi256_ps,
_mm256_cmpeq_epi32, _mm256_loadu_ps, _mm256_movemask_ps, _mm256_mul_ps, _mm256_set1_ps,
_mm256_setzero_ps, _mm256_storeu_ps, _mm256_sub_ps,
};
const LANES: usize = 8;
#[inline]
#[target_feature(enable = "avx2")]
fn horizontal_sum(vec: __m256) -> f32 {
let mut buf = [0.0_f32; LANES];
unsafe { _mm256_storeu_ps(buf.as_mut_ptr(), vec) };
buf.iter().sum()
}
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
let mut acc = _mm256_setzero_ps();
let chunks_a = a.chunks_exact(LANES);
let chunks_b = b.chunks_exact(LANES);
let tail_a = chunks_a.remainder();
let tail_b = chunks_b.remainder();
for (ca, cb) in chunks_a.zip(chunks_b) {
let va = unsafe { _mm256_loadu_ps(ca.as_ptr()) };
let vb = unsafe { _mm256_loadu_ps(cb.as_ptr()) };
acc = _mm256_add_ps(acc, _mm256_mul_ps(va, vb));
}
let mut sum = horizontal_sum(acc);
for (x, y) in tail_a.iter().zip(tail_b.iter()) {
sum += x * y;
}
sum
}
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn cosine(a: &[f32], b: &[f32]) -> f32 {
let mut dot_acc = _mm256_setzero_ps();
let mut na_acc = _mm256_setzero_ps();
let mut nb_acc = _mm256_setzero_ps();
let chunks_a = a.chunks_exact(LANES);
let chunks_b = b.chunks_exact(LANES);
let tail_a = chunks_a.remainder();
let tail_b = chunks_b.remainder();
for (ca, cb) in chunks_a.zip(chunks_b) {
let va = unsafe { _mm256_loadu_ps(ca.as_ptr()) };
let vb = unsafe { _mm256_loadu_ps(cb.as_ptr()) };
dot_acc = _mm256_add_ps(dot_acc, _mm256_mul_ps(va, vb));
na_acc = _mm256_add_ps(na_acc, _mm256_mul_ps(va, va));
nb_acc = _mm256_add_ps(nb_acc, _mm256_mul_ps(vb, vb));
}
let mut dot = horizontal_sum(dot_acc);
let mut na = horizontal_sum(na_acc);
let mut nb = horizontal_sum(nb_acc);
for (x, y) in tail_a.iter().zip(tail_b.iter()) {
dot += x * y;
na += x * x;
nb += y * y;
}
let denom = na.sqrt() * nb.sqrt();
if denom <= f32::MIN_POSITIVE {
return 1.0;
}
1.0 - dot / denom
}
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn euclidean(a: &[f32], b: &[f32]) -> f32 {
let mut acc = _mm256_setzero_ps();
let chunks_a = a.chunks_exact(LANES);
let chunks_b = b.chunks_exact(LANES);
let tail_a = chunks_a.remainder();
let tail_b = chunks_b.remainder();
for (ca, cb) in chunks_a.zip(chunks_b) {
let va = unsafe { _mm256_loadu_ps(ca.as_ptr()) };
let vb = unsafe { _mm256_loadu_ps(cb.as_ptr()) };
let diff = _mm256_sub_ps(va, vb);
acc = _mm256_add_ps(acc, _mm256_mul_ps(diff, diff));
}
let mut sum = horizontal_sum(acc);
for (x, y) in tail_a.iter().zip(tail_b.iter()) {
let d = x - y;
sum += d * d;
}
sum.sqrt()
}
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn manhattan(a: &[f32], b: &[f32]) -> f32 {
let sign_mask = _mm256_set1_ps(-0.0_f32);
let mut acc = _mm256_setzero_ps();
let chunks_a = a.chunks_exact(LANES);
let chunks_b = b.chunks_exact(LANES);
let tail_a = chunks_a.remainder();
let tail_b = chunks_b.remainder();
for (ca, cb) in chunks_a.zip(chunks_b) {
let va = unsafe { _mm256_loadu_ps(ca.as_ptr()) };
let vb = unsafe { _mm256_loadu_ps(cb.as_ptr()) };
let diff = _mm256_sub_ps(va, vb);
let abs = _mm256_andnot_ps(sign_mask, diff);
acc = _mm256_add_ps(acc, abs);
}
let mut sum = horizontal_sum(acc);
for (x, y) in tail_a.iter().zip(tail_b.iter()) {
sum += (x - y).abs();
}
sum
}
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn hamming(a: &[f32], b: &[f32]) -> f32 {
let mut diff_count: u64 = 0;
let chunks_a = a.chunks_exact(LANES);
let chunks_b = b.chunks_exact(LANES);
let tail_a = chunks_a.remainder();
let tail_b = chunks_b.remainder();
for (ca, cb) in chunks_a.zip(chunks_b) {
let va = unsafe { _mm256_loadu_ps(ca.as_ptr()) };
let vb = unsafe { _mm256_loadu_ps(cb.as_ptr()) };
let ia = _mm256_castps_si256(va);
let ib = _mm256_castps_si256(vb);
let eq = _mm256_cmpeq_epi32(ia, ib);
let eq_mask = _mm256_movemask_ps(_mm256_castsi256_ps(eq)) as u32;
let equal = eq_mask.count_ones() as u64;
diff_count += LANES as u64 - equal;
}
for (x, y) in tail_a.iter().zip(tail_b.iter()) {
if x.to_bits() != y.to_bits() {
diff_count += 1;
}
}
diff_count as f32
}