#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
use super::simd_config;
#[cfg(target_arch = "x86_64")]
use super::dot_product::{horizontal_sum_avx2, horizontal_sum_avx512};
#[cfg(target_arch = "aarch64")]
use super::dot_product::horizontal_sum_neon;
#[inline(always)]
fn dispatch_squared(a: &[f32], b: &[f32]) -> f32 {
let config = simd_config();
#[cfg(target_arch = "x86_64")]
{
if config.avx512f_enabled {
return unsafe { squared_euclidean_distance_avx512_unrolled(a, b) };
}
if config.avx2_enabled && config.fma_enabled {
return unsafe { squared_euclidean_distance_avx2_unrolled(a, b) };
}
}
#[cfg(target_arch = "aarch64")]
{
if config.neon_enabled {
return unsafe { squared_euclidean_distance_neon_unrolled(a, b) };
}
}
squared_euclidean_distance_scalar(a, b)
}
#[inline]
pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return f32::MAX;
}
debug_assert_eq!(a.len(), b.len());
dispatch_squared(a, b).sqrt()
}
#[inline]
pub fn squared_euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return f32::MAX;
}
debug_assert_eq!(a.len(), b.len());
dispatch_squared(a, b)
}
pub(crate) fn squared_euclidean_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| {
let d = x - y;
d * d
})
.sum::<f32>()
}
#[cfg(test)]
pub(crate) fn euclidean_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
squared_euclidean_distance_scalar(a, b).sqrt()
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
unsafe fn squared_euclidean_distance_avx512_unrolled(a: &[f32], b: &[f32]) -> f32 {
const SIMD_WIDTH: usize = 16;
const UNROLL: usize = 4;
const CHUNK_SIZE: usize = SIMD_WIDTH * UNROLL;
let n = a.len();
debug_assert_eq!(n, b.len());
let chunks = n / CHUNK_SIZE;
let mut sum0 = _mm512_setzero_ps();
let mut sum1 = _mm512_setzero_ps();
let mut sum2 = _mm512_setzero_ps();
let mut sum3 = _mm512_setzero_ps();
for i in 0..chunks {
let base = i * CHUNK_SIZE;
let a0 = _mm512_loadu_ps(a.as_ptr().add(base));
let b0 = _mm512_loadu_ps(b.as_ptr().add(base));
let diff0 = _mm512_sub_ps(a0, b0);
sum0 = _mm512_fmadd_ps(diff0, diff0, sum0);
let a1 = _mm512_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH));
let b1 = _mm512_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH));
let diff1 = _mm512_sub_ps(a1, b1);
sum1 = _mm512_fmadd_ps(diff1, diff1, sum1);
let a2 = _mm512_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 2));
let b2 = _mm512_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 2));
let diff2 = _mm512_sub_ps(a2, b2);
sum2 = _mm512_fmadd_ps(diff2, diff2, sum2);
let a3 = _mm512_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 3));
let b3 = _mm512_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 3));
let diff3 = _mm512_sub_ps(a3, b3);
sum3 = _mm512_fmadd_ps(diff3, diff3, sum3);
}
let sum_vec = _mm512_add_ps(_mm512_add_ps(sum0, sum1), _mm512_add_ps(sum2, sum3));
let main_sum = horizontal_sum_avx512(sum_vec);
let main_processed = chunks * CHUNK_SIZE;
let remaining = n - main_processed;
let remaining_chunks = remaining / SIMD_WIDTH;
let mut remainder_sum = _mm512_setzero_ps();
for i in 0..remaining_chunks {
let offset = main_processed + i * SIMD_WIDTH;
let a_vec = _mm512_loadu_ps(a.as_ptr().add(offset));
let b_vec = _mm512_loadu_ps(b.as_ptr().add(offset));
let diff = _mm512_sub_ps(a_vec, b_vec);
remainder_sum = _mm512_fmadd_ps(diff, diff, remainder_sum);
}
let mut sum = main_sum + horizontal_sum_avx512(remainder_sum);
let scalar_start = main_processed + remaining_chunks * SIMD_WIDTH;
for i in scalar_start..n {
let diff = a[i] - b[i];
sum += diff * diff;
}
sum
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
unsafe fn squared_euclidean_distance_avx2_unrolled(a: &[f32], b: &[f32]) -> f32 {
const SIMD_WIDTH: usize = 8;
const UNROLL: usize = 4;
const CHUNK_SIZE: usize = SIMD_WIDTH * UNROLL;
let n = a.len();
debug_assert_eq!(n, b.len());
let chunks = n / CHUNK_SIZE;
let mut sum0 = _mm256_setzero_ps();
let mut sum1 = _mm256_setzero_ps();
let mut sum2 = _mm256_setzero_ps();
let mut sum3 = _mm256_setzero_ps();
for i in 0..chunks {
let base = i * CHUNK_SIZE;
let a0 = _mm256_loadu_ps(a.as_ptr().add(base));
let b0 = _mm256_loadu_ps(b.as_ptr().add(base));
let diff0 = _mm256_sub_ps(a0, b0);
sum0 = _mm256_fmadd_ps(diff0, diff0, sum0);
let a1 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH));
let b1 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH));
let diff1 = _mm256_sub_ps(a1, b1);
sum1 = _mm256_fmadd_ps(diff1, diff1, sum1);
let a2 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 2));
let b2 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 2));
let diff2 = _mm256_sub_ps(a2, b2);
sum2 = _mm256_fmadd_ps(diff2, diff2, sum2);
let a3 = _mm256_loadu_ps(a.as_ptr().add(base + SIMD_WIDTH * 3));
let b3 = _mm256_loadu_ps(b.as_ptr().add(base + SIMD_WIDTH * 3));
let diff3 = _mm256_sub_ps(a3, b3);
sum3 = _mm256_fmadd_ps(diff3, diff3, sum3);
}
let sum_vec = _mm256_add_ps(_mm256_add_ps(sum0, sum1), _mm256_add_ps(sum2, sum3));
let mut sum = horizontal_sum_avx2(sum_vec);
for i in (chunks * CHUNK_SIZE)..n {
let diff = a[i] - b[i];
sum += diff * diff;
}
sum
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn squared_euclidean_distance_neon_unrolled(a: &[f32], b: &[f32]) -> f32 {
const SIMD_WIDTH: usize = 4;
const UNROLL: usize = 4;
const CHUNK_SIZE: usize = SIMD_WIDTH * UNROLL;
let n = a.len();
debug_assert_eq!(n, b.len());
let chunks = n / CHUNK_SIZE;
let mut sum0 = vdupq_n_f32(0.0);
let mut sum1 = vdupq_n_f32(0.0);
let mut sum2 = vdupq_n_f32(0.0);
let mut sum3 = vdupq_n_f32(0.0);
for i in 0..chunks {
let base = i * CHUNK_SIZE;
let a0 = vld1q_f32(a.as_ptr().add(base));
let b0 = vld1q_f32(b.as_ptr().add(base));
let diff0 = vsubq_f32(a0, b0);
sum0 = vfmaq_f32(sum0, diff0, diff0);
let a1 = vld1q_f32(a.as_ptr().add(base + SIMD_WIDTH));
let b1 = vld1q_f32(b.as_ptr().add(base + SIMD_WIDTH));
let diff1 = vsubq_f32(a1, b1);
sum1 = vfmaq_f32(sum1, diff1, diff1);
let a2 = vld1q_f32(a.as_ptr().add(base + SIMD_WIDTH * 2));
let b2 = vld1q_f32(b.as_ptr().add(base + SIMD_WIDTH * 2));
let diff2 = vsubq_f32(a2, b2);
sum2 = vfmaq_f32(sum2, diff2, diff2);
let a3 = vld1q_f32(a.as_ptr().add(base + SIMD_WIDTH * 3));
let b3 = vld1q_f32(b.as_ptr().add(base + SIMD_WIDTH * 3));
let diff3 = vsubq_f32(a3, b3);
sum3 = vfmaq_f32(sum3, diff3, diff3);
}
let sum_vec = vaddq_f32(vaddq_f32(sum0, sum1), vaddq_f32(sum2, sum3));
let mut sum = horizontal_sum_neon(sum_vec);
for i in (chunks * CHUNK_SIZE)..n {
let diff = a[i] - b[i];
sum += diff * diff;
}
sum
}