#[allow(unused_imports)] use super::{dot::dot_product_native, simd_level, SimdLevel};
#[allow(clippy::inline_always)]
#[inline(always)]
#[must_use]
pub fn squared_l2_native(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len(), "Vector dimensions must match");
match simd_level() {
#[cfg(target_arch = "x86_64")]
SimdLevel::Avx512 if a.len() >= 1024 => {
unsafe { crate::simd_native::squared_l2_avx512_8acc(a, b) }
}
#[cfg(target_arch = "x86_64")]
SimdLevel::Avx512 if a.len() >= 512 => {
unsafe { crate::simd_native::squared_l2_avx512_4acc(a, b) }
}
#[cfg(target_arch = "x86_64")]
SimdLevel::Avx512 => {
unsafe { crate::simd_native::squared_l2_avx512(a, b) }
}
#[cfg(target_arch = "x86_64")]
SimdLevel::Avx2 if a.len() >= 256 => {
unsafe { crate::simd_native::squared_l2_avx2_4acc(a, b) }
}
#[cfg(target_arch = "x86_64")]
SimdLevel::Avx2 if a.len() >= 64 => {
unsafe { crate::simd_native::squared_l2_avx2(a, b) }
}
#[cfg(target_arch = "x86_64")]
SimdLevel::Avx2 if a.len() >= 8 => {
unsafe { crate::simd_native::squared_l2_avx2_1acc(a, b) }
}
#[cfg(target_arch = "aarch64")]
SimdLevel::Neon if a.len() >= 4 => crate::simd_native::squared_l2_neon(a, b),
_ => super::squared_l2_scalar(a, b),
}
}
#[allow(clippy::inline_always)]
#[inline(always)]
#[must_use]
pub fn euclidean_native(a: &[f32], b: &[f32]) -> f32 {
squared_l2_native(a, b).sqrt()
}
#[allow(clippy::inline_always)]
#[inline(always)]
#[must_use]
pub fn norm_native(v: &[f32]) -> f32 {
dot_product_native(v, v).sqrt()
}
#[allow(clippy::inline_always)]
#[inline(always)]
pub fn normalize_inplace_native(v: &mut [f32]) {
let n = norm_native(v);
if n > 0.0 {
let inv_norm = 1.0 / n;
scale_inplace_native(v, inv_norm);
}
}
#[inline]
fn scale_inplace_native(v: &mut [f32], factor: f32) {
match simd_level() {
#[cfg(target_arch = "x86_64")]
SimdLevel::Avx512 | SimdLevel::Avx2 if v.len() >= 8 => {
unsafe { scale_inplace_avx2(v, factor) };
}
_ => {
for x in v.iter_mut() {
*x *= factor;
}
}
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn scale_inplace_avx2(v: &mut [f32], factor: f32) {
use std::arch::x86_64::{_mm256_loadu_ps, _mm256_mul_ps, _mm256_set1_ps, _mm256_storeu_ps};
let len = v.len();
let simd_len = len / 8;
let ptr = v.as_mut_ptr();
let scale = _mm256_set1_ps(factor);
for i in 0..simd_len {
let offset = i * 8;
let val = _mm256_loadu_ps(ptr.add(offset));
let scaled = _mm256_mul_ps(val, scale);
_mm256_storeu_ps(ptr.add(offset), scaled);
}
let base = simd_len * 8;
for i in base..len {
*v.get_unchecked_mut(i) *= factor;
}
}
#[inline]
#[must_use]
pub fn batch_squared_l2_native(candidates: &[&[f32]], query: &[f32]) -> Vec<f32> {
super::batch_with_prefetch(candidates, query, squared_l2_native)
}
#[inline]
#[must_use]
pub fn batch_euclidean_native(candidates: &[&[f32]], query: &[f32]) -> Vec<f32> {
super::batch_with_prefetch(candidates, query, euclidean_native)
}
#[allow(unused_variables)] pub(super) fn resolve_squared_l2(level: SimdLevel, dim: usize) -> fn(&[f32], &[f32]) -> f32 {
match level {
#[cfg(target_arch = "x86_64")]
SimdLevel::Avx512 if dim >= 1024 => {
|a, b| {
unsafe { crate::simd_native::squared_l2_avx512_8acc(a, b) }
}
}
#[cfg(target_arch = "x86_64")]
SimdLevel::Avx512 if dim >= 512 => {
|a, b| {
unsafe { crate::simd_native::squared_l2_avx512_4acc(a, b) }
}
}
#[cfg(target_arch = "x86_64")]
SimdLevel::Avx512 => |a, b| {
unsafe { crate::simd_native::squared_l2_avx512(a, b) }
},
#[cfg(target_arch = "x86_64")]
SimdLevel::Avx2 if dim >= 256 => {
|a, b| {
unsafe { crate::simd_native::squared_l2_avx2_4acc(a, b) }
}
}
#[cfg(target_arch = "x86_64")]
SimdLevel::Avx2 if dim >= 64 => |a, b| {
unsafe { crate::simd_native::squared_l2_avx2(a, b) }
},
#[cfg(target_arch = "x86_64")]
SimdLevel::Avx2 if dim >= 8 => {
|a, b| {
unsafe { crate::simd_native::squared_l2_avx2_1acc(a, b) }
}
}
#[cfg(target_arch = "aarch64")]
SimdLevel::Neon if dim >= 4 => |a, b| crate::simd_native::squared_l2_neon(a, b),
_ => super::squared_l2_scalar,
}
}