use std::arch::x86_64::*;
#[target_feature(enable = "avx")]
fn hsum256_ps(v: __m256) -> f32 {
let lo = _mm256_castps256_ps128(v);
let hi = _mm256_extractf128_ps(v, 1);
let sum = _mm_add_ps(lo, hi);
let shuf = _mm_movehdup_ps(sum);
let sums = _mm_add_ps(sum, shuf);
let shuf2 = _mm_movehl_ps(shuf, sums);
let sums2 = _mm_add_ss(sums, shuf2);
_mm_cvtss_f32(sums2)
}
#[target_feature(enable = "avx2")]
fn hsum256_epi32(v: __m256i) -> i32 {
let lo = _mm256_castsi256_si128(v);
let hi = _mm256_extracti128_si256(v, 1);
let sum = _mm_add_epi32(lo, hi);
let hi64 = _mm_unpackhi_epi64(sum, sum);
let sum64 = _mm_add_epi32(sum, hi64);
let hi32 = _mm_shuffle_epi32(sum64, 0b01);
let sum32 = _mm_add_epi32(sum64, hi32);
_mm_cvtsi128_si32(sum32)
}
#[target_feature(enable = "avx,fma")]
pub(crate) unsafe fn dot_f32(a: &[f32], b: &[f32]) -> f32 {
let n = a.len();
let mut i = 0usize;
let mut acc = _mm256_setzero_ps();
while i + 8 <= n {
unsafe {
let va = _mm256_loadu_ps(a.as_ptr().add(i));
let vb = _mm256_loadu_ps(b.as_ptr().add(i));
acc = _mm256_fmadd_ps(va, vb, acc);
}
i += 8;
}
let mut sum = hsum256_ps(acc);
while i < n {
sum += a[i] * b[i];
i += 1;
}
sum
}
#[target_feature(enable = "avx,fma")]
pub(crate) unsafe fn l2_sq_f32(a: &[f32], b: &[f32]) -> f32 {
let n = a.len();
let mut i = 0usize;
let mut acc = _mm256_setzero_ps();
while i + 8 <= n {
unsafe {
let va = _mm256_loadu_ps(a.as_ptr().add(i));
let vb = _mm256_loadu_ps(b.as_ptr().add(i));
let d = _mm256_sub_ps(va, vb);
acc = _mm256_fmadd_ps(d, d, acc);
}
i += 8;
}
let mut sum = hsum256_ps(acc);
while i < n {
let d = a[i] - b[i];
sum += d * d;
i += 1;
}
sum
}
#[target_feature(enable = "avx,fma")]
pub(crate) unsafe fn cosine_f32(a: &[f32], b: &[f32]) -> f32 {
let n = a.len();
let mut i = 0usize;
let mut dot = _mm256_setzero_ps();
let mut na = _mm256_setzero_ps();
let mut nb = _mm256_setzero_ps();
while i + 8 <= n {
unsafe {
let va = _mm256_loadu_ps(a.as_ptr().add(i));
let vb = _mm256_loadu_ps(b.as_ptr().add(i));
dot = _mm256_fmadd_ps(va, vb, dot);
na = _mm256_fmadd_ps(va, va, na);
nb = _mm256_fmadd_ps(vb, vb, nb);
}
i += 8;
}
let mut d = hsum256_ps(dot);
let mut sa = hsum256_ps(na);
let mut sb = hsum256_ps(nb);
while i < n {
d += a[i] * b[i];
sa += a[i] * a[i];
sb += b[i] * b[i];
i += 1;
}
let denom = sa.sqrt() * sb.sqrt();
if denom > 0.0 { d / denom } else { 0.0 }
}
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn dot_i8(a: &[i8], b: &[i8]) -> i32 {
let n = a.len();
let mut i = 0usize;
let mut acc = _mm256_setzero_si256();
while i + 16 <= n {
unsafe {
let va = _mm_loadu_si128(a.as_ptr().add(i).cast());
let vb = _mm_loadu_si128(b.as_ptr().add(i).cast());
let va16 = _mm256_cvtepi8_epi16(va);
let vb16 = _mm256_cvtepi8_epi16(vb);
let prod = _mm256_madd_epi16(va16, vb16);
acc = _mm256_add_epi32(acc, prod);
}
i += 16;
}
let mut sum = hsum256_epi32(acc);
while i < n {
sum += i32::from(a[i]) * i32::from(b[i]);
i += 1;
}
sum
}
#[target_feature(enable = "avx2")]
fn popcnt256(v: __m256i) -> __m256i {
let lookup = _mm256_setr_epi8(
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3,
3, 4,
);
let low_mask = _mm256_set1_epi8(0x0f);
let lo = _mm256_and_si256(v, low_mask);
let hi = _mm256_and_si256(_mm256_srli_epi16(v, 4), low_mask);
let popcnt_lo = _mm256_shuffle_epi8(lookup, lo);
let popcnt_hi = _mm256_shuffle_epi8(lookup, hi);
let bytes = _mm256_add_epi8(popcnt_lo, popcnt_hi);
_mm256_sad_epu8(bytes, _mm256_setzero_si256())
}
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn hamming_u64(a: &[u64], b: &[u64]) -> u32 {
let n = a.len();
let mut i = 0usize;
let mut acc = _mm256_setzero_si256();
while i + 4 <= n {
unsafe {
let va = _mm256_loadu_si256(a.as_ptr().add(i).cast());
let vb = _mm256_loadu_si256(b.as_ptr().add(i).cast());
let x = _mm256_xor_si256(va, vb);
acc = _mm256_add_epi64(acc, popcnt256(x));
}
i += 4;
}
let lo = _mm256_castsi256_si128(acc);
let hi = _mm256_extracti128_si256(acc, 1);
let sum128 = _mm_add_epi64(lo, hi);
let hi64 = _mm_unpackhi_epi64(sum128, sum128);
let total = _mm_add_epi64(sum128, hi64);
#[allow(clippy::cast_sign_loss)]
let mut sum = _mm_cvtsi128_si64(total) as u64 as u32;
while i < n {
sum += (a[i] ^ b[i]).count_ones();
i += 1;
}
sum
}
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn l2_sq_i8(a: &[i8], b: &[i8]) -> i32 {
let n = a.len();
let mut i = 0usize;
let mut acc = _mm256_setzero_si256();
while i + 16 <= n {
unsafe {
let va = _mm_loadu_si128(a.as_ptr().add(i).cast());
let vb = _mm_loadu_si128(b.as_ptr().add(i).cast());
let va16 = _mm256_cvtepi8_epi16(va);
let vb16 = _mm256_cvtepi8_epi16(vb);
let diff = _mm256_sub_epi16(va16, vb16);
let sq = _mm256_madd_epi16(diff, diff);
acc = _mm256_add_epi32(acc, sq);
}
i += 16;
}
let mut sum = hsum256_epi32(acc);
while i < n {
let d = i32::from(a[i]) - i32::from(b[i]);
sum += d * d;
i += 1;
}
sum
}