#![allow(clippy::cast_precision_loss)]
#![allow(clippy::cast_possible_truncation)]
#![allow(clippy::cast_sign_loss)]
#![allow(clippy::incompatible_msrv)]
#![allow(clippy::wildcard_imports)]
#![allow(clippy::missing_panics_doc)]
#![allow(clippy::similar_names)]
use crate::simd_4acc_dot_loop;
use crate::simd_4acc_l2_loop;
use super::scalar;
use super::scalar::cosine_finish_fast;
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[inline]
pub(crate) unsafe fn dot_product_avx512(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let len = a.len();
let simd_len = len / 16;
let remainder = len % 16;
let mut sum = _mm512_setzero_ps();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
for i in 0..simd_len {
let offset = i * 16;
let va = _mm512_loadu_ps(a_ptr.add(offset));
let vb = _mm512_loadu_ps(b_ptr.add(offset));
sum = _mm512_fmadd_ps(va, vb, sum);
}
if remainder > 0 {
let base = simd_len * 16;
let mask: __mmask16 = if remainder == 16 {
!0
} else {
((1u32 << remainder) - 1) as u16
};
let va = _mm512_maskz_loadu_ps(mask, a_ptr.add(base));
let vb = _mm512_maskz_loadu_ps(mask, b_ptr.add(base));
sum = _mm512_fmadd_ps(va, vb, sum);
}
_mm512_reduce_add_ps(sum)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[inline]
pub(crate) unsafe fn dot_product_avx512_4acc(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let len = a.len();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let end_main = a_ptr.add(len / 64 * 64);
let end_ptr = a_ptr.add(len);
let (mut acc, mut a_p, mut b_p) = simd_4acc_dot_loop!(
a_ptr,
b_ptr,
end_main,
_mm512_setzero_ps(),
_mm512_loadu_ps,
_mm512_fmadd_ps,
_mm512_add_ps,
16
);
while a_p.add(16) <= end_ptr {
let va = _mm512_loadu_ps(a_p);
let vb = _mm512_loadu_ps(b_p);
acc = _mm512_fmadd_ps(va, vb, acc);
a_p = a_p.add(16);
b_p = b_p.add(16);
}
let remaining = end_ptr.offset_from(a_p) as usize;
if remaining > 0 {
let mask: __mmask16 = if remaining == 16 {
!0
} else {
((1u32 << remaining) - 1) as u16
};
let va = _mm512_maskz_loadu_ps(mask, a_p);
let vb = _mm512_maskz_loadu_ps(mask, b_p);
acc = _mm512_fmadd_ps(va, vb, acc);
}
_mm512_reduce_add_ps(acc)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[inline]
pub(crate) unsafe fn dot_product_avx512_8acc(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let len = a.len();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let end_main = a_ptr.add(len / 128 * 128);
let end_ptr = a_ptr.add(len);
let (mut acc, mut a_p, mut b_p) = dot_8acc_main_loop(a_ptr, b_ptr, end_main);
while a_p.add(16) <= end_ptr {
acc = _mm512_fmadd_ps(_mm512_loadu_ps(a_p), _mm512_loadu_ps(b_p), acc);
a_p = a_p.add(16);
b_p = b_p.add(16);
}
let remaining = end_ptr.offset_from(a_p) as usize;
if remaining > 0 {
let mask: __mmask16 = ((1u32 << remaining) - 1) as u16;
let va = _mm512_maskz_loadu_ps(mask, a_p);
let vb = _mm512_maskz_loadu_ps(mask, b_p);
acc = _mm512_fmadd_ps(va, vb, acc);
}
_mm512_reduce_add_ps(acc)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[inline]
unsafe fn dot_8acc_main_loop(
a_ptr: *const f32,
b_ptr: *const f32,
end_main: *const f32,
) -> (std::arch::x86_64::__m512, *const f32, *const f32) {
use std::arch::x86_64::*;
let mut s0 = _mm512_setzero_ps();
let mut s1 = _mm512_setzero_ps();
let mut s2 = _mm512_setzero_ps();
let mut s3 = _mm512_setzero_ps();
let mut s4 = _mm512_setzero_ps();
let mut s5 = _mm512_setzero_ps();
let mut s6 = _mm512_setzero_ps();
let mut s7 = _mm512_setzero_ps();
let mut pa = a_ptr;
let mut pb = b_ptr;
while pa < end_main {
s0 = _mm512_fmadd_ps(_mm512_loadu_ps(pa), _mm512_loadu_ps(pb), s0);
s1 = _mm512_fmadd_ps(_mm512_loadu_ps(pa.add(16)), _mm512_loadu_ps(pb.add(16)), s1);
s2 = _mm512_fmadd_ps(_mm512_loadu_ps(pa.add(32)), _mm512_loadu_ps(pb.add(32)), s2);
s3 = _mm512_fmadd_ps(_mm512_loadu_ps(pa.add(48)), _mm512_loadu_ps(pb.add(48)), s3);
s4 = _mm512_fmadd_ps(_mm512_loadu_ps(pa.add(64)), _mm512_loadu_ps(pb.add(64)), s4);
s5 = _mm512_fmadd_ps(_mm512_loadu_ps(pa.add(80)), _mm512_loadu_ps(pb.add(80)), s5);
s6 = _mm512_fmadd_ps(_mm512_loadu_ps(pa.add(96)), _mm512_loadu_ps(pb.add(96)), s6);
s7 = _mm512_fmadd_ps(
_mm512_loadu_ps(pa.add(112)),
_mm512_loadu_ps(pb.add(112)),
s7,
);
pa = pa.add(128);
pb = pb.add(128);
}
s0 = _mm512_add_ps(s0, s4);
s1 = _mm512_add_ps(s1, s5);
s2 = _mm512_add_ps(s2, s6);
s3 = _mm512_add_ps(s3, s7);
let sum01 = _mm512_add_ps(s0, s1);
let sum23 = _mm512_add_ps(s2, s3);
(_mm512_add_ps(sum01, sum23), pa, pb)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[inline]
pub(crate) unsafe fn squared_l2_avx512_4acc(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let len = a.len();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let end_main = a_ptr.add(len / 64 * 64);
let end_ptr = a_ptr.add(len);
let (mut acc, mut a_p, mut b_p) = simd_4acc_l2_loop!(
a_ptr,
b_ptr,
end_main,
_mm512_setzero_ps(),
_mm512_loadu_ps,
_mm512_sub_ps,
_mm512_fmadd_ps,
_mm512_add_ps,
16
);
while a_p.add(16) <= end_ptr {
let va = _mm512_loadu_ps(a_p);
let vb = _mm512_loadu_ps(b_p);
let diff = _mm512_sub_ps(va, vb);
acc = _mm512_fmadd_ps(diff, diff, acc);
a_p = a_p.add(16);
b_p = b_p.add(16);
}
let remaining = end_ptr.offset_from(a_p) as usize;
if remaining > 0 {
let mask: __mmask16 = if remaining == 16 {
!0
} else {
((1u32 << remaining) - 1) as u16
};
let va = _mm512_maskz_loadu_ps(mask, a_p);
let vb = _mm512_maskz_loadu_ps(mask, b_p);
let diff = _mm512_sub_ps(va, vb);
acc = _mm512_fmadd_ps(diff, diff, acc);
}
_mm512_reduce_add_ps(acc)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[inline]
pub(crate) unsafe fn squared_l2_avx512(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let len = a.len();
let simd_len = len / 16;
let remainder = len % 16;
let mut sum = _mm512_setzero_ps();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
for i in 0..simd_len {
let offset = i * 16;
let va = _mm512_loadu_ps(a_ptr.add(offset));
let vb = _mm512_loadu_ps(b_ptr.add(offset));
let diff = _mm512_sub_ps(va, vb);
sum = _mm512_fmadd_ps(diff, diff, sum);
}
if remainder > 0 {
let base = simd_len * 16;
let mask: __mmask16 = if remainder == 16 {
!0
} else {
((1u32 << remainder) - 1) as u16
};
let va = _mm512_maskz_loadu_ps(mask, a_ptr.add(base));
let vb = _mm512_maskz_loadu_ps(mask, b_ptr.add(base));
let diff = _mm512_sub_ps(va, vb);
sum = _mm512_fmadd_ps(diff, diff, sum);
}
_mm512_reduce_add_ps(sum)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[inline]
pub(crate) unsafe fn squared_l2_avx512_8acc(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let len = a.len();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let end_main = a_ptr.add(len / 128 * 128);
let end_ptr = a_ptr.add(len);
let (mut acc, mut a_p, mut b_p) = l2_8acc_main_loop(a_ptr, b_ptr, end_main);
while a_p.add(16) <= end_ptr {
let diff = _mm512_sub_ps(_mm512_loadu_ps(a_p), _mm512_loadu_ps(b_p));
acc = _mm512_fmadd_ps(diff, diff, acc);
a_p = a_p.add(16);
b_p = b_p.add(16);
}
let remaining = end_ptr.offset_from(a_p) as usize;
if remaining > 0 {
let mask: __mmask16 = ((1u32 << remaining) - 1) as u16;
let va = _mm512_maskz_loadu_ps(mask, a_p);
let vb = _mm512_maskz_loadu_ps(mask, b_p);
let diff = _mm512_sub_ps(va, vb);
acc = _mm512_fmadd_ps(diff, diff, acc);
}
_mm512_reduce_add_ps(acc)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[inline]
unsafe fn l2_8acc_main_loop(
a_ptr: *const f32,
b_ptr: *const f32,
end_main: *const f32,
) -> (std::arch::x86_64::__m512, *const f32, *const f32) {
use std::arch::x86_64::*;
let mut s0 = _mm512_setzero_ps();
let mut s1 = _mm512_setzero_ps();
let mut s2 = _mm512_setzero_ps();
let mut s3 = _mm512_setzero_ps();
let mut s4 = _mm512_setzero_ps();
let mut s5 = _mm512_setzero_ps();
let mut s6 = _mm512_setzero_ps();
let mut s7 = _mm512_setzero_ps();
let mut pa = a_ptr;
let mut pb = b_ptr;
while pa < end_main {
let d0 = _mm512_sub_ps(_mm512_loadu_ps(pa), _mm512_loadu_ps(pb));
s0 = _mm512_fmadd_ps(d0, d0, s0);
let d1 = _mm512_sub_ps(_mm512_loadu_ps(pa.add(16)), _mm512_loadu_ps(pb.add(16)));
s1 = _mm512_fmadd_ps(d1, d1, s1);
let d2 = _mm512_sub_ps(_mm512_loadu_ps(pa.add(32)), _mm512_loadu_ps(pb.add(32)));
s2 = _mm512_fmadd_ps(d2, d2, s2);
let d3 = _mm512_sub_ps(_mm512_loadu_ps(pa.add(48)), _mm512_loadu_ps(pb.add(48)));
s3 = _mm512_fmadd_ps(d3, d3, s3);
let d4 = _mm512_sub_ps(_mm512_loadu_ps(pa.add(64)), _mm512_loadu_ps(pb.add(64)));
s4 = _mm512_fmadd_ps(d4, d4, s4);
let d5 = _mm512_sub_ps(_mm512_loadu_ps(pa.add(80)), _mm512_loadu_ps(pb.add(80)));
s5 = _mm512_fmadd_ps(d5, d5, s5);
let d6 = _mm512_sub_ps(_mm512_loadu_ps(pa.add(96)), _mm512_loadu_ps(pb.add(96)));
s6 = _mm512_fmadd_ps(d6, d6, s6);
let d7 = _mm512_sub_ps(_mm512_loadu_ps(pa.add(112)), _mm512_loadu_ps(pb.add(112)));
s7 = _mm512_fmadd_ps(d7, d7, s7);
pa = pa.add(128);
pb = pb.add(128);
}
s0 = _mm512_add_ps(s0, s4);
s1 = _mm512_add_ps(s1, s5);
s2 = _mm512_add_ps(s2, s6);
s3 = _mm512_add_ps(s3, s7);
let sum01 = _mm512_add_ps(s0, s1);
let sum23 = _mm512_add_ps(s2, s3);
(_mm512_add_ps(sum01, sum23), pa, pb)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[inline]
pub(crate) unsafe fn cosine_fused_avx512(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let len = a.len();
let simd_chunks = len / 32; let remainder = len % 32;
let mut dot0 = _mm512_setzero_ps();
let mut dot1 = _mm512_setzero_ps();
let mut na0 = _mm512_setzero_ps();
let mut na1 = _mm512_setzero_ps();
let mut nb0 = _mm512_setzero_ps();
let mut nb1 = _mm512_setzero_ps();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
for i in 0..simd_chunks {
let base = i * 32;
let va0 = _mm512_loadu_ps(a_ptr.add(base));
let vb0 = _mm512_loadu_ps(b_ptr.add(base));
dot0 = _mm512_fmadd_ps(va0, vb0, dot0);
na0 = _mm512_fmadd_ps(va0, va0, na0);
nb0 = _mm512_fmadd_ps(vb0, vb0, nb0);
let va1 = _mm512_loadu_ps(a_ptr.add(base + 16));
let vb1 = _mm512_loadu_ps(b_ptr.add(base + 16));
dot1 = _mm512_fmadd_ps(va1, vb1, dot1);
na1 = _mm512_fmadd_ps(va1, va1, na1);
nb1 = _mm512_fmadd_ps(vb1, vb1, nb1);
}
if remainder > 0 {
let base = simd_chunks * 32;
let rem0 = remainder.min(16);
if rem0 > 0 {
let mask0: __mmask16 = if rem0 == 16 {
!0
} else {
((1u32 << rem0) - 1) as u16
};
let va = _mm512_maskz_loadu_ps(mask0, a_ptr.add(base));
let vb = _mm512_maskz_loadu_ps(mask0, b_ptr.add(base));
dot0 = _mm512_fmadd_ps(va, vb, dot0);
na0 = _mm512_fmadd_ps(va, va, na0);
nb0 = _mm512_fmadd_ps(vb, vb, nb0);
}
let rem1 = remainder.saturating_sub(16);
if rem1 > 0 {
let mask1: __mmask16 = if rem1 == 16 {
!0
} else {
((1u32 << rem1) - 1) as u16
};
let va = _mm512_maskz_loadu_ps(mask1, a_ptr.add(base + 16));
let vb = _mm512_maskz_loadu_ps(mask1, b_ptr.add(base + 16));
dot1 = _mm512_fmadd_ps(va, vb, dot1);
na1 = _mm512_fmadd_ps(va, va, na1);
nb1 = _mm512_fmadd_ps(vb, vb, nb1);
}
}
let dot = _mm512_reduce_add_ps(_mm512_add_ps(dot0, dot1));
let norm_a_sq = _mm512_reduce_add_ps(_mm512_add_ps(na0, na1));
let norm_b_sq = _mm512_reduce_add_ps(_mm512_add_ps(nb0, nb1));
cosine_finish_fast(dot, norm_a_sq, norm_b_sq)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[inline]
pub(crate) unsafe fn cosine_fused_avx512_4acc(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let len = a.len();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let end_main = a_ptr.add(len / 64 * 64);
let end_ptr = a_ptr.add(len);
let mut dot0 = _mm512_setzero_ps();
let mut dot1 = _mm512_setzero_ps();
let mut dot2 = _mm512_setzero_ps();
let mut dot3 = _mm512_setzero_ps();
let mut na0 = _mm512_setzero_ps();
let mut na1 = _mm512_setzero_ps();
let mut na2 = _mm512_setzero_ps();
let mut na3 = _mm512_setzero_ps();
let mut nb0 = _mm512_setzero_ps();
let mut nb1 = _mm512_setzero_ps();
let mut nb2 = _mm512_setzero_ps();
let mut nb3 = _mm512_setzero_ps();
let mut cur_a = a_ptr;
let mut cur_b = b_ptr;
while cur_a < end_main {
let va0 = _mm512_loadu_ps(cur_a);
let vb0 = _mm512_loadu_ps(cur_b);
dot0 = _mm512_fmadd_ps(va0, vb0, dot0);
na0 = _mm512_fmadd_ps(va0, va0, na0);
nb0 = _mm512_fmadd_ps(vb0, vb0, nb0);
let va1 = _mm512_loadu_ps(cur_a.add(16));
let vb1 = _mm512_loadu_ps(cur_b.add(16));
dot1 = _mm512_fmadd_ps(va1, vb1, dot1);
na1 = _mm512_fmadd_ps(va1, va1, na1);
nb1 = _mm512_fmadd_ps(vb1, vb1, nb1);
let va2 = _mm512_loadu_ps(cur_a.add(32));
let vb2 = _mm512_loadu_ps(cur_b.add(32));
dot2 = _mm512_fmadd_ps(va2, vb2, dot2);
na2 = _mm512_fmadd_ps(va2, va2, na2);
nb2 = _mm512_fmadd_ps(vb2, vb2, nb2);
let va3 = _mm512_loadu_ps(cur_a.add(48));
let vb3 = _mm512_loadu_ps(cur_b.add(48));
dot3 = _mm512_fmadd_ps(va3, vb3, dot3);
na3 = _mm512_fmadd_ps(va3, va3, na3);
nb3 = _mm512_fmadd_ps(vb3, vb3, nb3);
cur_a = cur_a.add(64);
cur_b = cur_b.add(64);
}
let dot_acc = _mm512_add_ps(_mm512_add_ps(dot0, dot1), _mm512_add_ps(dot2, dot3));
let mut na_acc = _mm512_add_ps(_mm512_add_ps(na0, na1), _mm512_add_ps(na2, na3));
let mut nb_acc = _mm512_add_ps(_mm512_add_ps(nb0, nb1), _mm512_add_ps(nb2, nb3));
let mut rem_dot = dot_acc;
while cur_a.add(16) <= end_ptr {
let va = _mm512_loadu_ps(cur_a);
let vb = _mm512_loadu_ps(cur_b);
rem_dot = _mm512_fmadd_ps(va, vb, rem_dot);
na_acc = _mm512_fmadd_ps(va, va, na_acc);
nb_acc = _mm512_fmadd_ps(vb, vb, nb_acc);
cur_a = cur_a.add(16);
cur_b = cur_b.add(16);
}
let remaining = end_ptr.offset_from(cur_a) as usize;
if remaining > 0 {
let mask: __mmask16 = ((1u32 << remaining) - 1) as u16;
let va = _mm512_maskz_loadu_ps(mask, cur_a);
let vb = _mm512_maskz_loadu_ps(mask, cur_b);
rem_dot = _mm512_fmadd_ps(va, vb, rem_dot);
na_acc = _mm512_fmadd_ps(va, va, na_acc);
nb_acc = _mm512_fmadd_ps(vb, vb, nb_acc);
}
let dot = _mm512_reduce_add_ps(rem_dot);
let norm_a_sq = _mm512_reduce_add_ps(na_acc);
let norm_b_sq = _mm512_reduce_add_ps(nb_acc);
cosine_finish_fast(dot, norm_a_sq, norm_b_sq)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[inline]
pub(crate) unsafe fn cosine_fused_avx512_8acc(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let len = a.len();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let end_main = a_ptr.add(len / 128 * 128);
let end_ptr = a_ptr.add(len);
let (mut dot_acc, mut na_acc, mut nb_acc, mut cur_a, mut cur_b) =
cosine_8acc_main_loop(a_ptr, b_ptr, end_main);
cosine_8acc_remainder(
&mut cur_a,
&mut cur_b,
end_ptr,
&mut dot_acc,
&mut na_acc,
&mut nb_acc,
);
let dot = _mm512_reduce_add_ps(dot_acc);
let norm_a_sq = _mm512_reduce_add_ps(na_acc);
let norm_b_sq = _mm512_reduce_add_ps(nb_acc);
cosine_finish_fast(dot, norm_a_sq, norm_b_sq)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[inline]
#[allow(clippy::too_many_lines)] unsafe fn cosine_8acc_main_loop(
a_ptr: *const f32,
b_ptr: *const f32,
end_main: *const f32,
) -> (
std::arch::x86_64::__m512,
std::arch::x86_64::__m512,
std::arch::x86_64::__m512,
*const f32,
*const f32,
) {
use std::arch::x86_64::*;
let z = _mm512_setzero_ps();
let (mut d0, mut d1, mut d2, mut d3) = (z, z, z, z);
let (mut d4, mut d5, mut d6, mut d7) = (z, z, z, z);
let (mut a0, mut a1, mut a2, mut a3) = (z, z, z, z);
let (mut a4, mut a5, mut a6, mut a7) = (z, z, z, z);
let (mut b0, mut b1, mut b2, mut b3) = (z, z, z, z);
let (mut b4, mut b5, mut b6, mut b7) = (z, z, z, z);
let mut pa = a_ptr;
let mut pb = b_ptr;
while pa < end_main {
cosine_8acc_body_lo(
pa, pb, &mut d0, &mut d1, &mut d2, &mut d3, &mut a0, &mut a1, &mut a2, &mut a3,
&mut b0, &mut b1, &mut b2, &mut b3,
);
cosine_8acc_body_hi(
pa, pb, &mut d4, &mut d5, &mut d6, &mut d7, &mut a4, &mut a5, &mut a6, &mut a7,
&mut b4, &mut b5, &mut b6, &mut b7,
);
pa = pa.add(128);
pb = pb.add(128);
}
d0 = _mm512_add_ps(d0, d4);
d1 = _mm512_add_ps(d1, d5);
d2 = _mm512_add_ps(d2, d6);
d3 = _mm512_add_ps(d3, d7);
let dot_acc = _mm512_add_ps(_mm512_add_ps(d0, d1), _mm512_add_ps(d2, d3));
a0 = _mm512_add_ps(a0, a4);
a1 = _mm512_add_ps(a1, a5);
a2 = _mm512_add_ps(a2, a6);
a3 = _mm512_add_ps(a3, a7);
let na_acc = _mm512_add_ps(_mm512_add_ps(a0, a1), _mm512_add_ps(a2, a3));
b0 = _mm512_add_ps(b0, b4);
b1 = _mm512_add_ps(b1, b5);
b2 = _mm512_add_ps(b2, b6);
b3 = _mm512_add_ps(b3, b7);
let nb_acc = _mm512_add_ps(_mm512_add_ps(b0, b1), _mm512_add_ps(b2, b3));
(dot_acc, na_acc, nb_acc, pa, pb)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[inline]
#[allow(clippy::too_many_arguments)] unsafe fn cosine_8acc_body_lo(
pa: *const f32,
pb: *const f32,
d0: &mut std::arch::x86_64::__m512,
d1: &mut std::arch::x86_64::__m512,
d2: &mut std::arch::x86_64::__m512,
d3: &mut std::arch::x86_64::__m512,
a0: &mut std::arch::x86_64::__m512,
a1: &mut std::arch::x86_64::__m512,
a2: &mut std::arch::x86_64::__m512,
a3: &mut std::arch::x86_64::__m512,
b0: &mut std::arch::x86_64::__m512,
b1: &mut std::arch::x86_64::__m512,
b2: &mut std::arch::x86_64::__m512,
b3: &mut std::arch::x86_64::__m512,
) {
use std::arch::x86_64::*;
let va0 = _mm512_loadu_ps(pa);
let vb0 = _mm512_loadu_ps(pb);
*d0 = _mm512_fmadd_ps(va0, vb0, *d0);
*a0 = _mm512_fmadd_ps(va0, va0, *a0);
*b0 = _mm512_fmadd_ps(vb0, vb0, *b0);
let va1 = _mm512_loadu_ps(pa.add(16));
let vb1 = _mm512_loadu_ps(pb.add(16));
*d1 = _mm512_fmadd_ps(va1, vb1, *d1);
*a1 = _mm512_fmadd_ps(va1, va1, *a1);
*b1 = _mm512_fmadd_ps(vb1, vb1, *b1);
let va2 = _mm512_loadu_ps(pa.add(32));
let vb2 = _mm512_loadu_ps(pb.add(32));
*d2 = _mm512_fmadd_ps(va2, vb2, *d2);
*a2 = _mm512_fmadd_ps(va2, va2, *a2);
*b2 = _mm512_fmadd_ps(vb2, vb2, *b2);
let va3 = _mm512_loadu_ps(pa.add(48));
let vb3 = _mm512_loadu_ps(pb.add(48));
*d3 = _mm512_fmadd_ps(va3, vb3, *d3);
*a3 = _mm512_fmadd_ps(va3, va3, *a3);
*b3 = _mm512_fmadd_ps(vb3, vb3, *b3);
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[inline]
#[allow(clippy::too_many_arguments)] unsafe fn cosine_8acc_body_hi(
pa: *const f32,
pb: *const f32,
d4: &mut std::arch::x86_64::__m512,
d5: &mut std::arch::x86_64::__m512,
d6: &mut std::arch::x86_64::__m512,
d7: &mut std::arch::x86_64::__m512,
a4: &mut std::arch::x86_64::__m512,
a5: &mut std::arch::x86_64::__m512,
a6: &mut std::arch::x86_64::__m512,
a7: &mut std::arch::x86_64::__m512,
b4: &mut std::arch::x86_64::__m512,
b5: &mut std::arch::x86_64::__m512,
b6: &mut std::arch::x86_64::__m512,
b7: &mut std::arch::x86_64::__m512,
) {
use std::arch::x86_64::*;
let va4 = _mm512_loadu_ps(pa.add(64));
let vb4 = _mm512_loadu_ps(pb.add(64));
*d4 = _mm512_fmadd_ps(va4, vb4, *d4);
*a4 = _mm512_fmadd_ps(va4, va4, *a4);
*b4 = _mm512_fmadd_ps(vb4, vb4, *b4);
let va5 = _mm512_loadu_ps(pa.add(80));
let vb5 = _mm512_loadu_ps(pb.add(80));
*d5 = _mm512_fmadd_ps(va5, vb5, *d5);
*a5 = _mm512_fmadd_ps(va5, va5, *a5);
*b5 = _mm512_fmadd_ps(vb5, vb5, *b5);
let va6 = _mm512_loadu_ps(pa.add(96));
let vb6 = _mm512_loadu_ps(pb.add(96));
*d6 = _mm512_fmadd_ps(va6, vb6, *d6);
*a6 = _mm512_fmadd_ps(va6, va6, *a6);
*b6 = _mm512_fmadd_ps(vb6, vb6, *b6);
let va7 = _mm512_loadu_ps(pa.add(112));
let vb7 = _mm512_loadu_ps(pb.add(112));
*d7 = _mm512_fmadd_ps(va7, vb7, *d7);
*a7 = _mm512_fmadd_ps(va7, va7, *a7);
*b7 = _mm512_fmadd_ps(vb7, vb7, *b7);
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[inline]
unsafe fn cosine_8acc_remainder(
cur_a: &mut *const f32,
cur_b: &mut *const f32,
end_ptr: *const f32,
dot_acc: &mut std::arch::x86_64::__m512,
na_acc: &mut std::arch::x86_64::__m512,
nb_acc: &mut std::arch::x86_64::__m512,
) {
use std::arch::x86_64::*;
while (*cur_a).add(16) <= end_ptr {
let va = _mm512_loadu_ps(*cur_a);
let vb = _mm512_loadu_ps(*cur_b);
*dot_acc = _mm512_fmadd_ps(va, vb, *dot_acc);
*na_acc = _mm512_fmadd_ps(va, va, *na_acc);
*nb_acc = _mm512_fmadd_ps(vb, vb, *nb_acc);
*cur_a = (*cur_a).add(16);
*cur_b = (*cur_b).add(16);
}
let remaining = end_ptr.offset_from(*cur_a) as usize;
if remaining > 0 {
let mask: __mmask16 = ((1u32 << remaining) - 1) as u16;
let va = _mm512_maskz_loadu_ps(mask, *cur_a);
let vb = _mm512_maskz_loadu_ps(mask, *cur_b);
*dot_acc = _mm512_fmadd_ps(va, vb, *dot_acc);
*na_acc = _mm512_fmadd_ps(va, va, *na_acc);
*nb_acc = _mm512_fmadd_ps(vb, vb, *nb_acc);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
pub(crate) unsafe fn hamming_avx512(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let len = a.len();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let mut diff_count: u64 = 0;
let mut i = 0;
let threshold = _mm512_set1_ps(0.5);
while i + 16 <= len {
let va = _mm512_loadu_ps(a_ptr.add(i));
let vb = _mm512_loadu_ps(b_ptr.add(i));
let mask_a = _mm512_cmp_ps_mask(va, threshold, _CMP_GT_OQ);
let mask_b = _mm512_cmp_ps_mask(vb, threshold, _CMP_GT_OQ);
let diff_mask = mask_a ^ mask_b;
diff_count += diff_mask.count_ones() as u64;
i += 16;
}
diff_count as f32 + scalar::hamming_scalar(&a[i..], &b[i..])
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
pub(crate) unsafe fn jaccard_avx512(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let len = a.len();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let mut acc_inter = _mm512_setzero_ps();
let mut acc_union = _mm512_setzero_ps();
let mut i = 0;
while i + 16 <= len {
let va = _mm512_loadu_ps(a_ptr.add(i));
let vb = _mm512_loadu_ps(b_ptr.add(i));
acc_inter = _mm512_add_ps(acc_inter, _mm512_min_ps(va, vb));
acc_union = _mm512_add_ps(acc_union, _mm512_max_ps(va, vb));
i += 16;
}
let inter_sum = _mm512_reduce_add_ps(acc_inter);
let union_sum = _mm512_reduce_add_ps(acc_union);
let (scalar_inter, scalar_union) = scalar::jaccard_scalar_accum(&a[i..], &b[i..]);
let total_inter = inter_sum + scalar_inter;
let total_union = union_sum + scalar_union;
if total_union == 0.0 {
1.0
} else {
total_inter / total_union
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[inline]
pub(crate) unsafe fn hamming_avx512_4acc(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let len = a.len();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let mut diff0: u64 = 0;
let mut diff1: u64 = 0;
let mut diff2: u64 = 0;
let mut diff3: u64 = 0;
let threshold = _mm512_set1_ps(0.5);
let mut i = 0;
while i + 64 <= len {
diff0 += hamming_xor_popcount(a_ptr.add(i), b_ptr.add(i), threshold);
diff1 += hamming_xor_popcount(a_ptr.add(i + 16), b_ptr.add(i + 16), threshold);
diff2 += hamming_xor_popcount(a_ptr.add(i + 32), b_ptr.add(i + 32), threshold);
diff3 += hamming_xor_popcount(a_ptr.add(i + 48), b_ptr.add(i + 48), threshold);
i += 64;
}
while i + 16 <= len {
diff0 += hamming_xor_popcount(a_ptr.add(i), b_ptr.add(i), threshold);
i += 16;
}
let simd_total = (diff0 + diff1 + diff2 + diff3) as f32;
simd_total + scalar::hamming_scalar(&a[i..], &b[i..])
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[inline]
unsafe fn hamming_xor_popcount(
a_ptr: *const f32,
b_ptr: *const f32,
threshold: std::arch::x86_64::__m512,
) -> u64 {
use std::arch::x86_64::*;
let va = _mm512_loadu_ps(a_ptr);
let vb = _mm512_loadu_ps(b_ptr);
let mask_a = _mm512_cmp_ps_mask(va, threshold, _CMP_GT_OQ);
let mask_b = _mm512_cmp_ps_mask(vb, threshold, _CMP_GT_OQ);
(mask_a ^ mask_b).count_ones() as u64
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[inline]
#[allow(clippy::too_many_lines)] pub(crate) unsafe fn jaccard_avx512_4acc(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let len = a.len();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let end_main = a_ptr.add(len / 64 * 64);
let end_ptr = a_ptr.add(len);
let zero = _mm512_setzero_ps();
let mut inter0 = zero;
let mut inter1 = zero;
let mut inter2 = zero;
let mut inter3 = zero;
let mut union0 = zero;
let mut union1 = zero;
let mut union2 = zero;
let mut union3 = zero;
let mut cur_a = a_ptr;
let mut cur_b = b_ptr;
while cur_a < end_main {
let va0 = _mm512_loadu_ps(cur_a);
let vb0 = _mm512_loadu_ps(cur_b);
inter0 = _mm512_add_ps(inter0, _mm512_min_ps(va0, vb0));
union0 = _mm512_add_ps(union0, _mm512_max_ps(va0, vb0));
let va1 = _mm512_loadu_ps(cur_a.add(16));
let vb1 = _mm512_loadu_ps(cur_b.add(16));
inter1 = _mm512_add_ps(inter1, _mm512_min_ps(va1, vb1));
union1 = _mm512_add_ps(union1, _mm512_max_ps(va1, vb1));
let va2 = _mm512_loadu_ps(cur_a.add(32));
let vb2 = _mm512_loadu_ps(cur_b.add(32));
inter2 = _mm512_add_ps(inter2, _mm512_min_ps(va2, vb2));
union2 = _mm512_add_ps(union2, _mm512_max_ps(va2, vb2));
let va3 = _mm512_loadu_ps(cur_a.add(48));
let vb3 = _mm512_loadu_ps(cur_b.add(48));
inter3 = _mm512_add_ps(inter3, _mm512_min_ps(va3, vb3));
union3 = _mm512_add_ps(union3, _mm512_max_ps(va3, vb3));
cur_a = cur_a.add(64);
cur_b = cur_b.add(64);
}
let mut inter_acc = _mm512_add_ps(_mm512_add_ps(inter0, inter1), _mm512_add_ps(inter2, inter3));
let mut union_acc = _mm512_add_ps(_mm512_add_ps(union0, union1), _mm512_add_ps(union2, union3));
jaccard_4acc_remainder(cur_a, cur_b, end_ptr, &mut inter_acc, &mut union_acc);
let total_inter = _mm512_reduce_add_ps(inter_acc);
let total_union = _mm512_reduce_add_ps(union_acc);
if total_union == 0.0 {
1.0
} else {
total_inter / total_union
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[inline]
pub(crate) unsafe fn jaccard_avx512_8acc(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let len = a.len();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let end_main = a_ptr.add(len / 128 * 128);
let end_ptr = a_ptr.add(len);
let (mut inter_acc, mut union_acc, cur_a, cur_b) =
jaccard_8acc_main_loop(a_ptr, b_ptr, end_main);
jaccard_8acc_remainder(cur_a, cur_b, end_ptr, &mut inter_acc, &mut union_acc);
let total_inter = _mm512_reduce_add_ps(inter_acc);
let total_union = _mm512_reduce_add_ps(union_acc);
if total_union == 0.0 {
1.0
} else {
total_inter / total_union
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[inline]
#[allow(clippy::too_many_lines)] unsafe fn jaccard_8acc_main_loop(
a_ptr: *const f32,
b_ptr: *const f32,
end_main: *const f32,
) -> (
std::arch::x86_64::__m512,
std::arch::x86_64::__m512,
*const f32,
*const f32,
) {
use std::arch::x86_64::*;
let z = _mm512_setzero_ps();
let (mut i0, mut i1, mut i2, mut i3) = (z, z, z, z);
let (mut i4, mut i5, mut i6, mut i7) = (z, z, z, z);
let (mut u0, mut u1, mut u2, mut u3) = (z, z, z, z);
let (mut u4, mut u5, mut u6, mut u7) = (z, z, z, z);
let mut pa = a_ptr;
let mut pb = b_ptr;
while pa < end_main {
jaccard_8acc_body_lo(
pa, pb, &mut i0, &mut i1, &mut i2, &mut i3, &mut u0, &mut u1, &mut u2, &mut u3,
);
jaccard_8acc_body_hi(
pa, pb, &mut i4, &mut i5, &mut i6, &mut i7, &mut u4, &mut u5, &mut u6, &mut u7,
);
pa = pa.add(128);
pb = pb.add(128);
}
i0 = _mm512_add_ps(i0, i4);
i1 = _mm512_add_ps(i1, i5);
i2 = _mm512_add_ps(i2, i6);
i3 = _mm512_add_ps(i3, i7);
let inter_acc = _mm512_add_ps(_mm512_add_ps(i0, i1), _mm512_add_ps(i2, i3));
u0 = _mm512_add_ps(u0, u4);
u1 = _mm512_add_ps(u1, u5);
u2 = _mm512_add_ps(u2, u6);
u3 = _mm512_add_ps(u3, u7);
let union_acc = _mm512_add_ps(_mm512_add_ps(u0, u1), _mm512_add_ps(u2, u3));
(inter_acc, union_acc, pa, pb)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[inline]
#[allow(clippy::too_many_arguments)] unsafe fn jaccard_8acc_body_lo(
pa: *const f32,
pb: *const f32,
i0: &mut std::arch::x86_64::__m512,
i1: &mut std::arch::x86_64::__m512,
i2: &mut std::arch::x86_64::__m512,
i3: &mut std::arch::x86_64::__m512,
u0: &mut std::arch::x86_64::__m512,
u1: &mut std::arch::x86_64::__m512,
u2: &mut std::arch::x86_64::__m512,
u3: &mut std::arch::x86_64::__m512,
) {
use std::arch::x86_64::*;
let va0 = _mm512_loadu_ps(pa);
let vb0 = _mm512_loadu_ps(pb);
*i0 = _mm512_add_ps(*i0, _mm512_min_ps(va0, vb0));
*u0 = _mm512_add_ps(*u0, _mm512_max_ps(va0, vb0));
let va1 = _mm512_loadu_ps(pa.add(16));
let vb1 = _mm512_loadu_ps(pb.add(16));
*i1 = _mm512_add_ps(*i1, _mm512_min_ps(va1, vb1));
*u1 = _mm512_add_ps(*u1, _mm512_max_ps(va1, vb1));
let va2 = _mm512_loadu_ps(pa.add(32));
let vb2 = _mm512_loadu_ps(pb.add(32));
*i2 = _mm512_add_ps(*i2, _mm512_min_ps(va2, vb2));
*u2 = _mm512_add_ps(*u2, _mm512_max_ps(va2, vb2));
let va3 = _mm512_loadu_ps(pa.add(48));
let vb3 = _mm512_loadu_ps(pb.add(48));
*i3 = _mm512_add_ps(*i3, _mm512_min_ps(va3, vb3));
*u3 = _mm512_add_ps(*u3, _mm512_max_ps(va3, vb3));
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[inline]
#[allow(clippy::too_many_arguments)] unsafe fn jaccard_8acc_body_hi(
pa: *const f32,
pb: *const f32,
i4: &mut std::arch::x86_64::__m512,
i5: &mut std::arch::x86_64::__m512,
i6: &mut std::arch::x86_64::__m512,
i7: &mut std::arch::x86_64::__m512,
u4: &mut std::arch::x86_64::__m512,
u5: &mut std::arch::x86_64::__m512,
u6: &mut std::arch::x86_64::__m512,
u7: &mut std::arch::x86_64::__m512,
) {
use std::arch::x86_64::*;
let va4 = _mm512_loadu_ps(pa.add(64));
let vb4 = _mm512_loadu_ps(pb.add(64));
*i4 = _mm512_add_ps(*i4, _mm512_min_ps(va4, vb4));
*u4 = _mm512_add_ps(*u4, _mm512_max_ps(va4, vb4));
let va5 = _mm512_loadu_ps(pa.add(80));
let vb5 = _mm512_loadu_ps(pb.add(80));
*i5 = _mm512_add_ps(*i5, _mm512_min_ps(va5, vb5));
*u5 = _mm512_add_ps(*u5, _mm512_max_ps(va5, vb5));
let va6 = _mm512_loadu_ps(pa.add(96));
let vb6 = _mm512_loadu_ps(pb.add(96));
*i6 = _mm512_add_ps(*i6, _mm512_min_ps(va6, vb6));
*u6 = _mm512_add_ps(*u6, _mm512_max_ps(va6, vb6));
let va7 = _mm512_loadu_ps(pa.add(112));
let vb7 = _mm512_loadu_ps(pb.add(112));
*i7 = _mm512_add_ps(*i7, _mm512_min_ps(va7, vb7));
*u7 = _mm512_add_ps(*u7, _mm512_max_ps(va7, vb7));
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[inline]
unsafe fn jaccard_8acc_remainder(
mut cur_a: *const f32,
mut cur_b: *const f32,
end_ptr: *const f32,
inter_acc: &mut std::arch::x86_64::__m512,
union_acc: &mut std::arch::x86_64::__m512,
) {
use std::arch::x86_64::*;
while cur_a.add(16) <= end_ptr {
let va = _mm512_loadu_ps(cur_a);
let vb = _mm512_loadu_ps(cur_b);
*inter_acc = _mm512_add_ps(*inter_acc, _mm512_min_ps(va, vb));
*union_acc = _mm512_add_ps(*union_acc, _mm512_max_ps(va, vb));
cur_a = cur_a.add(16);
cur_b = cur_b.add(16);
}
let remaining = end_ptr.offset_from(cur_a) as usize;
if remaining > 0 {
let mask: __mmask16 = ((1u32 << remaining) - 1) as u16;
let va = _mm512_maskz_loadu_ps(mask, cur_a);
let vb = _mm512_maskz_loadu_ps(mask, cur_b);
*inter_acc = _mm512_add_ps(*inter_acc, _mm512_min_ps(va, vb));
*union_acc = _mm512_add_ps(*union_acc, _mm512_max_ps(va, vb));
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[inline]
unsafe fn jaccard_4acc_remainder(
mut cur_a: *const f32,
mut cur_b: *const f32,
end_ptr: *const f32,
inter_acc: &mut std::arch::x86_64::__m512,
union_acc: &mut std::arch::x86_64::__m512,
) {
use std::arch::x86_64::*;
while cur_a.add(16) <= end_ptr {
let va = _mm512_loadu_ps(cur_a);
let vb = _mm512_loadu_ps(cur_b);
*inter_acc = _mm512_add_ps(*inter_acc, _mm512_min_ps(va, vb));
*union_acc = _mm512_add_ps(*union_acc, _mm512_max_ps(va, vb));
cur_a = cur_a.add(16);
cur_b = cur_b.add(16);
}
let remaining = end_ptr.offset_from(cur_a) as usize;
if remaining > 0 {
let mask: __mmask16 = ((1u32 << remaining) - 1) as u16;
let va = _mm512_maskz_loadu_ps(mask, cur_a);
let vb = _mm512_maskz_loadu_ps(mask, cur_b);
*inter_acc = _mm512_add_ps(*inter_acc, _mm512_min_ps(va, vb));
*union_acc = _mm512_add_ps(*union_acc, _mm512_max_ps(va, vb));
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[inline]
pub(crate) unsafe fn hamming_binary_avx512(a: &[u64], b: &[u64]) -> u32 {
use std::arch::x86_64::*;
let len = a.len();
let a_ptr = a.as_ptr().cast::<i64>();
let b_ptr = b.as_ptr().cast::<i64>();
let mut total: u64 = 0;
let mut i = 0;
while i + 8 <= len {
let va = _mm512_loadu_si512(a_ptr.add(i).cast());
let vb = _mm512_loadu_si512(b_ptr.add(i).cast());
let xor = _mm512_xor_si512(va, vb);
let mut xor_arr = [0u64; 8];
_mm512_storeu_si512(xor_arr.as_mut_ptr().cast(), xor);
for val in &xor_arr {
total += u64::from(val.count_ones());
}
i += 8;
}
for j in i..len {
total += u64::from((a[j] ^ b[j]).count_ones());
}
total as u32
}