#![allow(clippy::cast_precision_loss)]
#![allow(clippy::cast_possible_truncation)]
#![allow(clippy::cast_sign_loss)]
#![allow(clippy::incompatible_msrv)]
#![allow(clippy::wildcard_imports)]
#![allow(clippy::missing_panics_doc)]
#![allow(clippy::similar_names)]
use super::reduction::hsum_avx256;
use super::scalar::cosine_finish_fast;
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
unsafe fn cosine_avx2_remainder(
mut a_ptr: *const f32,
mut b_ptr: *const f32,
end_ptr: *const f32,
mut dot_acc: std::arch::x86_64::__m256,
mut na_acc: std::arch::x86_64::__m256,
mut nb_acc: std::arch::x86_64::__m256,
) -> f32 {
use std::arch::x86_64::*;
while a_ptr.add(8) <= end_ptr {
let va = _mm256_loadu_ps(a_ptr);
let vb = _mm256_loadu_ps(b_ptr);
dot_acc = _mm256_fmadd_ps(va, vb, dot_acc);
na_acc = _mm256_fmadd_ps(va, va, na_acc);
nb_acc = _mm256_fmadd_ps(vb, vb, nb_acc);
a_ptr = a_ptr.add(8);
b_ptr = b_ptr.add(8);
}
let mut dot = hsum_avx256(dot_acc);
let mut norm_a_sq = hsum_avx256(na_acc);
let mut norm_b_sq = hsum_avx256(nb_acc);
while a_ptr < end_ptr {
let x = *a_ptr;
let y = *b_ptr;
dot += x * y;
norm_a_sq += x * x;
norm_b_sq += y * y;
a_ptr = a_ptr.add(1);
b_ptr = b_ptr.add(1);
}
cosine_finish_fast(dot, norm_a_sq, norm_b_sq)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
pub(crate) unsafe fn cosine_fused_avx2_2acc(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let len = a.len();
let mut a_ptr = a.as_ptr();
let mut b_ptr = b.as_ptr();
let end_main = a.as_ptr().add(len / 16 * 16);
let end_ptr = a.as_ptr().add(len);
let mut dot0 = _mm256_setzero_ps();
let mut dot1 = _mm256_setzero_ps();
let mut na0 = _mm256_setzero_ps();
let mut na1 = _mm256_setzero_ps();
let mut nb0 = _mm256_setzero_ps();
let mut nb1 = _mm256_setzero_ps();
while a_ptr < end_main {
let va0 = _mm256_loadu_ps(a_ptr);
let vb0 = _mm256_loadu_ps(b_ptr);
dot0 = _mm256_fmadd_ps(va0, vb0, dot0);
na0 = _mm256_fmadd_ps(va0, va0, na0);
nb0 = _mm256_fmadd_ps(vb0, vb0, nb0);
let va1 = _mm256_loadu_ps(a_ptr.add(8));
let vb1 = _mm256_loadu_ps(b_ptr.add(8));
dot1 = _mm256_fmadd_ps(va1, vb1, dot1);
na1 = _mm256_fmadd_ps(va1, va1, na1);
nb1 = _mm256_fmadd_ps(vb1, vb1, nb1);
a_ptr = a_ptr.add(16);
b_ptr = b_ptr.add(16);
}
let dot_acc = _mm256_add_ps(dot0, dot1);
let na_acc = _mm256_add_ps(na0, na1);
let nb_acc = _mm256_add_ps(nb0, nb1);
cosine_avx2_remainder(a_ptr, b_ptr, end_ptr, dot_acc, na_acc, nb_acc)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
unsafe fn cosine_4acc_main_loop(
mut a_ptr: *const f32,
mut b_ptr: *const f32,
end_main: *const f32,
) -> (
std::arch::x86_64::__m256,
std::arch::x86_64::__m256,
std::arch::x86_64::__m256,
*const f32,
*const f32,
) {
use std::arch::x86_64::*;
let z = _mm256_setzero_ps();
let (mut dot0, mut dot1, mut dot2, mut dot3) = (z, z, z, z);
let (mut na0, mut na1, mut na2, mut na3) = (z, z, z, z);
let (mut nb0, mut nb1, mut nb2, mut nb3) = (z, z, z, z);
while a_ptr < end_main {
let (va0, vb0) = (_mm256_loadu_ps(a_ptr), _mm256_loadu_ps(b_ptr));
dot0 = _mm256_fmadd_ps(va0, vb0, dot0);
na0 = _mm256_fmadd_ps(va0, va0, na0);
nb0 = _mm256_fmadd_ps(vb0, vb0, nb0);
let (va1, vb1) = (_mm256_loadu_ps(a_ptr.add(8)), _mm256_loadu_ps(b_ptr.add(8)));
dot1 = _mm256_fmadd_ps(va1, vb1, dot1);
na1 = _mm256_fmadd_ps(va1, va1, na1);
nb1 = _mm256_fmadd_ps(vb1, vb1, nb1);
let (va2, vb2) = (
_mm256_loadu_ps(a_ptr.add(16)),
_mm256_loadu_ps(b_ptr.add(16)),
);
dot2 = _mm256_fmadd_ps(va2, vb2, dot2);
na2 = _mm256_fmadd_ps(va2, va2, na2);
nb2 = _mm256_fmadd_ps(vb2, vb2, nb2);
let (va3, vb3) = (
_mm256_loadu_ps(a_ptr.add(24)),
_mm256_loadu_ps(b_ptr.add(24)),
);
dot3 = _mm256_fmadd_ps(va3, vb3, dot3);
na3 = _mm256_fmadd_ps(va3, va3, na3);
nb3 = _mm256_fmadd_ps(vb3, vb3, nb3);
a_ptr = a_ptr.add(32);
b_ptr = b_ptr.add(32);
}
let dot_acc = _mm256_add_ps(_mm256_add_ps(dot0, dot1), _mm256_add_ps(dot2, dot3));
let na_acc = _mm256_add_ps(_mm256_add_ps(na0, na1), _mm256_add_ps(na2, na3));
let nb_acc = _mm256_add_ps(_mm256_add_ps(nb0, nb1), _mm256_add_ps(nb2, nb3));
(dot_acc, na_acc, nb_acc, a_ptr, b_ptr)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
pub(crate) unsafe fn cosine_fused_avx2(a: &[f32], b: &[f32]) -> f32 {
let len = a.len();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let end_main = a.as_ptr().add(len / 32 * 32);
let end_ptr = a.as_ptr().add(len);
let (dot_acc, na_acc, nb_acc, a_p, b_p) = cosine_4acc_main_loop(a_ptr, b_ptr, end_main);
cosine_avx2_remainder(a_p, b_p, end_ptr, dot_acc, na_acc, nb_acc)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn hamming_avx2(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let len = a.len();
let mut a_ptr = a.as_ptr();
let mut b_ptr = b.as_ptr();
let end_main = a.as_ptr().add(len / 32 * 32);
let end_ptr = a.as_ptr().add(len);
let threshold = _mm256_set1_ps(0.5);
let one_vec = _mm256_set1_ps(1.0);
let mut acc0 = _mm256_setzero_ps();
let mut acc1 = _mm256_setzero_ps();
let mut acc2 = _mm256_setzero_ps();
let mut acc3 = _mm256_setzero_ps();
while a_ptr < end_main {
acc0 = hamming_avx2_fp_acc(a_ptr, b_ptr, threshold, one_vec, acc0);
acc1 = hamming_avx2_fp_acc(a_ptr.add(8), b_ptr.add(8), threshold, one_vec, acc1);
acc2 = hamming_avx2_fp_acc(a_ptr.add(16), b_ptr.add(16), threshold, one_vec, acc2);
acc3 = hamming_avx2_fp_acc(a_ptr.add(24), b_ptr.add(24), threshold, one_vec, acc3);
a_ptr = a_ptr.add(32);
b_ptr = b_ptr.add(32);
}
let acc01 = _mm256_add_ps(acc0, acc1);
let acc23 = _mm256_add_ps(acc2, acc3);
let mut acc = _mm256_add_ps(acc01, acc23);
while a_ptr.add(8) <= end_ptr {
acc = hamming_avx2_fp_acc(a_ptr, b_ptr, threshold, one_vec, acc);
a_ptr = a_ptr.add(8);
b_ptr = b_ptr.add(8);
}
let mut diff_count = hsum_avx256(acc);
while a_ptr < end_ptr {
if (*a_ptr > 0.5) != (*b_ptr > 0.5) {
diff_count += 1.0;
}
a_ptr = a_ptr.add(1);
b_ptr = b_ptr.add(1);
}
diff_count
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn hamming_avx2_fp_acc(
a_ptr: *const f32,
b_ptr: *const f32,
threshold: std::arch::x86_64::__m256,
one_vec: std::arch::x86_64::__m256,
acc: std::arch::x86_64::__m256,
) -> std::arch::x86_64::__m256 {
use std::arch::x86_64::*;
let mask_a = _mm256_cmp_ps(_mm256_loadu_ps(a_ptr), threshold, _CMP_GT_OQ);
let mask_b = _mm256_cmp_ps(_mm256_loadu_ps(b_ptr), threshold, _CMP_GT_OQ);
let diff = _mm256_xor_ps(mask_a, mask_b);
let ones = _mm256_and_ps(diff, one_vec);
_mm256_add_ps(acc, ones)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn hamming_binary_avx2(a: &[u64], b: &[u64]) -> u32 {
use std::arch::x86_64::*;
let len = a.len();
let mut total: u32 = 0;
let mut i = 0;
while i + 4 <= len {
let va = _mm256_loadu_si256(a.as_ptr().add(i).cast());
let vb = _mm256_loadu_si256(b.as_ptr().add(i).cast());
let xor = _mm256_xor_si256(va, vb);
let x0 = _mm256_extract_epi64(xor, 0) as u64;
let x1 = _mm256_extract_epi64(xor, 1) as u64;
let x2 = _mm256_extract_epi64(xor, 2) as u64;
let x3 = _mm256_extract_epi64(xor, 3) as u64;
total += x0.count_ones() + x1.count_ones() + x2.count_ones() + x3.count_ones();
i += 4;
}
while i < len {
total += (a[i] ^ b[i]).count_ones();
i += 1;
}
total
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn jaccard_avx2_remainder(
mut a_ptr: *const f32,
mut b_ptr: *const f32,
end_ptr: *const f32,
mut acc_inter: std::arch::x86_64::__m256,
mut acc_union: std::arch::x86_64::__m256,
) -> f32 {
use std::arch::x86_64::*;
while a_ptr.add(8) <= end_ptr {
let va = _mm256_loadu_ps(a_ptr);
let vb = _mm256_loadu_ps(b_ptr);
acc_inter = _mm256_add_ps(acc_inter, _mm256_min_ps(va, vb));
acc_union = _mm256_add_ps(acc_union, _mm256_max_ps(va, vb));
a_ptr = a_ptr.add(8);
b_ptr = b_ptr.add(8);
}
let mut inter_sum = hsum_avx256(acc_inter);
let mut union_sum = hsum_avx256(acc_union);
while a_ptr < end_ptr {
let x = *a_ptr;
let y = *b_ptr;
inter_sum += x.min(y);
union_sum += x.max(y);
a_ptr = a_ptr.add(1);
b_ptr = b_ptr.add(1);
}
if union_sum == 0.0 {
1.0
} else {
inter_sum / union_sum
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn jaccard_avx2(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let len = a.len();
let mut a_ptr = a.as_ptr();
let mut b_ptr = b.as_ptr();
let end_main = a.as_ptr().add(len / 32 * 32);
let end_ptr = a.as_ptr().add(len);
let mut inter0 = _mm256_setzero_ps();
let mut inter1 = _mm256_setzero_ps();
let mut inter2 = _mm256_setzero_ps();
let mut inter3 = _mm256_setzero_ps();
let mut union0 = _mm256_setzero_ps();
let mut union1 = _mm256_setzero_ps();
let mut union2 = _mm256_setzero_ps();
let mut union3 = _mm256_setzero_ps();
while a_ptr < end_main {
let va0 = _mm256_loadu_ps(a_ptr);
let vb0 = _mm256_loadu_ps(b_ptr);
inter0 = _mm256_add_ps(inter0, _mm256_min_ps(va0, vb0));
union0 = _mm256_add_ps(union0, _mm256_max_ps(va0, vb0));
let va1 = _mm256_loadu_ps(a_ptr.add(8));
let vb1 = _mm256_loadu_ps(b_ptr.add(8));
inter1 = _mm256_add_ps(inter1, _mm256_min_ps(va1, vb1));
union1 = _mm256_add_ps(union1, _mm256_max_ps(va1, vb1));
let va2 = _mm256_loadu_ps(a_ptr.add(16));
let vb2 = _mm256_loadu_ps(b_ptr.add(16));
inter2 = _mm256_add_ps(inter2, _mm256_min_ps(va2, vb2));
union2 = _mm256_add_ps(union2, _mm256_max_ps(va2, vb2));
let va3 = _mm256_loadu_ps(a_ptr.add(24));
let vb3 = _mm256_loadu_ps(b_ptr.add(24));
inter3 = _mm256_add_ps(inter3, _mm256_min_ps(va3, vb3));
union3 = _mm256_add_ps(union3, _mm256_max_ps(va3, vb3));
a_ptr = a_ptr.add(32);
b_ptr = b_ptr.add(32);
}
let inter01 = _mm256_add_ps(inter0, inter1);
let inter23 = _mm256_add_ps(inter2, inter3);
let acc_inter = _mm256_add_ps(inter01, inter23);
let union01 = _mm256_add_ps(union0, union1);
let union23 = _mm256_add_ps(union2, union3);
let acc_union = _mm256_add_ps(union01, union23);
jaccard_avx2_remainder(a_ptr, b_ptr, end_ptr, acc_inter, acc_union)
}