#![allow(clippy::cast_precision_loss)]
#![allow(clippy::cast_possible_truncation)]
#![allow(clippy::cast_sign_loss)]
#![allow(clippy::incompatible_msrv)]
#![allow(clippy::wildcard_imports)]
#![allow(clippy::missing_panics_doc)]
#![allow(clippy::similar_names)]
use crate::simd_4acc_l2_loop;
use crate::simd_native::reduction::hsum_avx256;
use crate::sum_squared_remainder_unrolled_8;
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
#[allow(clippy::too_many_lines)] pub(crate) unsafe fn squared_l2_avx2(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let len = a.len();
let simd_len = len / 16;
let mut sum0 = _mm256_setzero_ps();
let mut sum1 = _mm256_setzero_ps();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
for i in 0..simd_len {
let offset = i * 16;
let va0 = _mm256_loadu_ps(a_ptr.add(offset));
let vb0 = _mm256_loadu_ps(b_ptr.add(offset));
let diff0 = _mm256_sub_ps(va0, vb0);
sum0 = _mm256_fmadd_ps(diff0, diff0, sum0);
let va1 = _mm256_loadu_ps(a_ptr.add(offset + 8));
let vb1 = _mm256_loadu_ps(b_ptr.add(offset + 8));
let diff1 = _mm256_sub_ps(va1, vb1);
sum1 = _mm256_fmadd_ps(diff1, diff1, sum1);
}
let combined = _mm256_add_ps(sum0, sum1);
let mut result = hsum_avx256(combined);
let base = simd_len * 16;
let remainder = len - base;
if remainder >= 8 {
let va = _mm256_loadu_ps(a_ptr.add(base));
let vb = _mm256_loadu_ps(b_ptr.add(base));
let diff = _mm256_sub_ps(va, vb);
let tmp_sum = _mm256_fmadd_ps(diff, diff, _mm256_setzero_ps());
result += hsum_avx256(tmp_sum);
if remainder > 8 {
let rbase = base + 8;
let r = remainder - 8;
sum_squared_remainder_unrolled_8!(a, b, rbase, r, result);
}
} else if remainder > 0 {
sum_squared_remainder_unrolled_8!(a, b, base, remainder, result);
}
result
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
pub(crate) unsafe fn squared_l2_avx2_1acc(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let len = a.len();
let simd_len = len / 8;
let mut sum = _mm256_setzero_ps();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
for i in 0..simd_len {
let offset = i * 8;
let va = _mm256_loadu_ps(a_ptr.add(offset));
let vb = _mm256_loadu_ps(b_ptr.add(offset));
let diff = _mm256_sub_ps(va, vb);
sum = _mm256_fmadd_ps(diff, diff, sum);
}
let mut result = hsum_avx256(sum);
let base = simd_len * 8;
let remainder = len - base;
sum_squared_remainder_unrolled_8!(a, b, base, remainder, result);
result
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
pub(crate) unsafe fn squared_l2_avx2_4acc(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let len = a.len();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let end_main = a_ptr.add(len / 32 * 32);
let end_ptr = a_ptr.add(len);
let (combined, mut a_p, mut b_p) = simd_4acc_l2_loop!(
a_ptr,
b_ptr,
end_main,
_mm256_setzero_ps(),
_mm256_loadu_ps,
_mm256_sub_ps,
_mm256_fmadd_ps,
_mm256_add_ps,
8
);
let mut result = hsum_avx256(combined);
while a_p < end_ptr {
let d = *a_p - *b_p;
result += d * d;
a_p = a_p.add(1);
b_p = b_p.add(1);
}
result
}