#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
use super::scalar;
use simsimd::SpatialSimilarity;
#[cfg(target_arch = "x86_64")]
#[inline]
pub fn is_avx512_available() -> bool {
#[cfg(feature = "simd-avx512")]
{
is_x86_feature_detected!("avx512f")
}
#[cfg(not(feature = "simd-avx512"))]
{
false
}
}
#[cfg(not(target_arch = "x86_64"))]
#[inline]
pub fn is_avx512_available() -> bool {
false
}
#[cfg(target_arch = "x86_64")]
#[inline]
pub fn is_avx2_available() -> bool {
is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma")
}
#[cfg(not(target_arch = "x86_64"))]
#[inline]
pub fn is_avx2_available() -> bool {
false
}
#[cfg(target_arch = "aarch64")]
#[inline]
pub fn is_neon_available() -> bool {
true }
#[cfg(not(target_arch = "aarch64"))]
#[inline]
pub fn is_neon_available() -> bool {
false
}
pub fn simd_level() -> &'static str {
#[cfg(target_arch = "x86_64")]
{
if is_avx512_available() {
"AVX-512"
} else if is_avx2_available() {
"AVX2"
} else {
"Scalar"
}
}
#[cfg(target_arch = "aarch64")]
{
"NEON"
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
"Scalar"
}
}
#[inline]
fn is_aligned_to(ptr: *const f32, align: usize) -> bool {
(ptr as usize) % align == 0
}
#[inline]
fn is_avx2_aligned(a: *const f32, b: *const f32) -> bool {
is_aligned_to(a, 32) && is_aligned_to(b, 32)
}
#[inline]
#[allow(dead_code)]
fn is_avx512_aligned(a: *const f32, b: *const f32) -> bool {
is_aligned_to(a, 64) && is_aligned_to(b, 64)
}
#[cfg(all(target_arch = "x86_64", feature = "simd-avx512"))]
#[target_feature(enable = "avx512f")]
#[inline]
pub unsafe fn l2_distance_ptr_avx512(a: *const f32, b: *const f32, len: usize) -> f32 {
debug_assert!(!a.is_null() && !b.is_null() && len > 0);
let mut sum = _mm512_setzero_ps();
let chunks = len / 16;
for i in 0..chunks {
let offset = i * 16;
let va = _mm512_loadu_ps(a.add(offset));
let vb = _mm512_loadu_ps(b.add(offset));
let diff = _mm512_sub_ps(va, vb);
sum = _mm512_fmadd_ps(diff, diff, sum);
}
let mut result = _mm512_reduce_add_ps(sum);
for i in (chunks * 16)..len {
let diff = *a.add(i) - *b.add(i);
result += diff * diff;
}
result.sqrt()
}
#[cfg(all(target_arch = "x86_64", feature = "simd-avx512"))]
#[target_feature(enable = "avx512f")]
#[inline]
pub unsafe fn cosine_distance_ptr_avx512(a: *const f32, b: *const f32, len: usize) -> f32 {
debug_assert!(!a.is_null() && !b.is_null() && len > 0);
let mut dot = _mm512_setzero_ps();
let mut norm_a = _mm512_setzero_ps();
let mut norm_b = _mm512_setzero_ps();
let chunks = len / 16;
for i in 0..chunks {
let offset = i * 16;
let va = _mm512_loadu_ps(a.add(offset));
let vb = _mm512_loadu_ps(b.add(offset));
dot = _mm512_fmadd_ps(va, vb, dot);
norm_a = _mm512_fmadd_ps(va, va, norm_a);
norm_b = _mm512_fmadd_ps(vb, vb, norm_b);
}
let mut dot_sum = _mm512_reduce_add_ps(dot);
let mut norm_a_sum = _mm512_reduce_add_ps(norm_a);
let mut norm_b_sum = _mm512_reduce_add_ps(norm_b);
for i in (chunks * 16)..len {
let a_val = *a.add(i);
let b_val = *b.add(i);
dot_sum += a_val * b_val;
norm_a_sum += a_val * a_val;
norm_b_sum += b_val * b_val;
}
let denominator = (norm_a_sum * norm_b_sum).sqrt();
if denominator == 0.0 {
return 1.0;
}
1.0 - (dot_sum / denominator)
}
#[cfg(all(target_arch = "x86_64", feature = "simd-avx512"))]
#[target_feature(enable = "avx512f")]
#[inline]
pub unsafe fn inner_product_ptr_avx512(a: *const f32, b: *const f32, len: usize) -> f32 {
debug_assert!(!a.is_null() && !b.is_null() && len > 0);
let mut sum = _mm512_setzero_ps();
let chunks = len / 16;
for i in 0..chunks {
let offset = i * 16;
let va = _mm512_loadu_ps(a.add(offset));
let vb = _mm512_loadu_ps(b.add(offset));
sum = _mm512_fmadd_ps(va, vb, sum);
}
let mut result = _mm512_reduce_add_ps(sum);
for i in (chunks * 16)..len {
result += *a.add(i) * *b.add(i);
}
-result
}
#[cfg(all(target_arch = "x86_64", feature = "simd-avx512"))]
#[target_feature(enable = "avx512f")]
#[inline]
pub unsafe fn manhattan_distance_ptr_avx512(a: *const f32, b: *const f32, len: usize) -> f32 {
debug_assert!(!a.is_null() && !b.is_null() && len > 0);
let mut sum = _mm512_setzero_ps();
let chunks = len / 16;
for i in 0..chunks {
let offset = i * 16;
let va = _mm512_loadu_ps(a.add(offset));
let vb = _mm512_loadu_ps(b.add(offset));
let diff = _mm512_sub_ps(va, vb);
let abs_diff = _mm512_abs_ps(diff);
sum = _mm512_add_ps(sum, abs_diff);
}
let mut result = _mm512_reduce_add_ps(sum);
for i in (chunks * 16)..len {
result += (*a.add(i) - *b.add(i)).abs();
}
result
}
#[cfg(all(target_arch = "x86_64", feature = "simd-avx512"))]
#[target_feature(enable = "avx512f")]
#[inline]
pub unsafe fn cosine_distance_normalized_avx512(a: *const f32, b: *const f32, len: usize) -> f32 {
debug_assert!(!a.is_null() && !b.is_null() && len > 0);
let mut dot = _mm512_setzero_ps();
let chunks = len / 16;
for i in 0..chunks {
let offset = i * 16;
let va = _mm512_loadu_ps(a.add(offset));
let vb = _mm512_loadu_ps(b.add(offset));
dot = _mm512_fmadd_ps(va, vb, dot);
}
let mut result = _mm512_reduce_add_ps(dot);
for i in (chunks * 16)..len {
result += *a.add(i) * *b.add(i);
}
1.0 - result
}
#[cfg(all(target_arch = "x86_64", feature = "simd-avx512"))]
#[target_feature(enable = "avx512f")]
#[inline]
unsafe fn euclidean_distance_avx512(a: &[f32], b: &[f32]) -> f32 {
l2_distance_ptr_avx512(a.as_ptr(), b.as_ptr(), a.len())
}
#[cfg(all(target_arch = "x86_64", feature = "simd-avx512"))]
#[target_feature(enable = "avx512f")]
#[inline]
unsafe fn cosine_distance_avx512(a: &[f32], b: &[f32]) -> f32 {
cosine_distance_ptr_avx512(a.as_ptr(), b.as_ptr(), a.len())
}
#[cfg(all(target_arch = "x86_64", feature = "simd-avx512"))]
#[target_feature(enable = "avx512f")]
#[inline]
unsafe fn inner_product_avx512(a: &[f32], b: &[f32]) -> f32 {
inner_product_ptr_avx512(a.as_ptr(), b.as_ptr(), a.len())
}
#[cfg(all(target_arch = "x86_64", feature = "simd-avx512"))]
#[target_feature(enable = "avx512f")]
#[inline]
unsafe fn manhattan_distance_avx512(a: &[f32], b: &[f32]) -> f32 {
manhattan_distance_ptr_avx512(a.as_ptr(), b.as_ptr(), a.len())
}
#[cfg(all(target_arch = "x86_64", feature = "simd-avx512"))]
pub fn euclidean_distance_avx512_wrapper(a: &[f32], b: &[f32]) -> f32 {
if is_x86_feature_detected!("avx512f") {
unsafe { euclidean_distance_avx512(a, b) }
} else if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
unsafe { euclidean_distance_avx2(a, b) }
} else {
scalar::euclidean_distance(a, b)
}
}
#[cfg(all(target_arch = "x86_64", not(feature = "simd-avx512")))]
pub fn euclidean_distance_avx512_wrapper(a: &[f32], b: &[f32]) -> f32 {
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
unsafe { euclidean_distance_avx2(a, b) }
} else {
scalar::euclidean_distance(a, b)
}
}
#[cfg(not(target_arch = "x86_64"))]
pub fn euclidean_distance_avx512_wrapper(a: &[f32], b: &[f32]) -> f32 {
scalar::euclidean_distance(a, b)
}
#[cfg(all(target_arch = "x86_64", feature = "simd-avx512"))]
pub fn cosine_distance_avx512_wrapper(a: &[f32], b: &[f32]) -> f32 {
if is_x86_feature_detected!("avx512f") {
unsafe { cosine_distance_avx512(a, b) }
} else if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
unsafe { cosine_distance_avx2(a, b) }
} else {
scalar::cosine_distance(a, b)
}
}
#[cfg(all(target_arch = "x86_64", not(feature = "simd-avx512")))]
pub fn cosine_distance_avx512_wrapper(a: &[f32], b: &[f32]) -> f32 {
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
unsafe { cosine_distance_avx2(a, b) }
} else {
scalar::cosine_distance(a, b)
}
}
#[cfg(not(target_arch = "x86_64"))]
pub fn cosine_distance_avx512_wrapper(a: &[f32], b: &[f32]) -> f32 {
scalar::cosine_distance(a, b)
}
#[cfg(all(target_arch = "x86_64", feature = "simd-avx512"))]
pub fn inner_product_avx512_wrapper(a: &[f32], b: &[f32]) -> f32 {
if is_x86_feature_detected!("avx512f") {
unsafe { inner_product_avx512(a, b) }
} else if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
unsafe { inner_product_avx2(a, b) }
} else {
scalar::inner_product_distance(a, b)
}
}
#[cfg(all(target_arch = "x86_64", not(feature = "simd-avx512")))]
pub fn inner_product_avx512_wrapper(a: &[f32], b: &[f32]) -> f32 {
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
unsafe { inner_product_avx2(a, b) }
} else {
scalar::inner_product_distance(a, b)
}
}
#[cfg(not(target_arch = "x86_64"))]
pub fn inner_product_avx512_wrapper(a: &[f32], b: &[f32]) -> f32 {
scalar::inner_product_distance(a, b)
}
#[cfg(all(target_arch = "x86_64", feature = "simd-avx512"))]
pub fn manhattan_distance_avx512_wrapper(a: &[f32], b: &[f32]) -> f32 {
if is_x86_feature_detected!("avx512f") {
unsafe { manhattan_distance_avx512(a, b) }
} else if is_x86_feature_detected!("avx2") {
unsafe { manhattan_distance_avx2(a, b) }
} else {
scalar::manhattan_distance(a, b)
}
}
#[cfg(all(target_arch = "x86_64", not(feature = "simd-avx512")))]
pub fn manhattan_distance_avx512_wrapper(a: &[f32], b: &[f32]) -> f32 {
if is_x86_feature_detected!("avx2") {
unsafe { manhattan_distance_avx2(a, b) }
} else {
scalar::manhattan_distance(a, b)
}
}
#[cfg(not(target_arch = "x86_64"))]
pub fn manhattan_distance_avx512_wrapper(a: &[f32], b: &[f32]) -> f32 {
scalar::manhattan_distance(a, b)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
pub unsafe fn l2_distance_ptr_avx2(a: *const f32, b: *const f32, len: usize) -> f32 {
debug_assert!(!a.is_null() && !b.is_null() && len > 0);
let mut sum = _mm256_setzero_ps();
let chunks = len / 8;
let use_aligned = is_avx2_aligned(a, b);
if use_aligned {
for i in 0..chunks {
let offset = i * 8;
let va = _mm256_load_ps(a.add(offset));
let vb = _mm256_load_ps(b.add(offset));
let diff = _mm256_sub_ps(va, vb);
sum = _mm256_fmadd_ps(diff, diff, sum);
}
} else {
for i in 0..chunks {
let offset = i * 8;
let va = _mm256_loadu_ps(a.add(offset));
let vb = _mm256_loadu_ps(b.add(offset));
let diff = _mm256_sub_ps(va, vb);
sum = _mm256_fmadd_ps(diff, diff, sum);
}
}
let sum_high = _mm256_extractf128_ps(sum, 1);
let sum_low = _mm256_castps256_ps128(sum);
let sum128 = _mm_add_ps(sum_high, sum_low);
let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
let mut result = _mm_cvtss_f32(sum32);
for i in (chunks * 8)..len {
let diff = *a.add(i) - *b.add(i);
result += diff * diff;
}
result.sqrt()
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
pub unsafe fn cosine_distance_ptr_avx2(a: *const f32, b: *const f32, len: usize) -> f32 {
debug_assert!(!a.is_null() && !b.is_null() && len > 0);
let mut dot = _mm256_setzero_ps();
let mut norm_a = _mm256_setzero_ps();
let mut norm_b = _mm256_setzero_ps();
let chunks = len / 8;
let use_aligned = is_avx2_aligned(a, b);
if use_aligned {
for i in 0..chunks {
let offset = i * 8;
let va = _mm256_load_ps(a.add(offset));
let vb = _mm256_load_ps(b.add(offset));
dot = _mm256_fmadd_ps(va, vb, dot);
norm_a = _mm256_fmadd_ps(va, va, norm_a);
norm_b = _mm256_fmadd_ps(vb, vb, norm_b);
}
} else {
for i in 0..chunks {
let offset = i * 8;
let va = _mm256_loadu_ps(a.add(offset));
let vb = _mm256_loadu_ps(b.add(offset));
dot = _mm256_fmadd_ps(va, vb, dot);
norm_a = _mm256_fmadd_ps(va, va, norm_a);
norm_b = _mm256_fmadd_ps(vb, vb, norm_b);
}
}
let dot_sum = horizontal_sum_256(dot);
let norm_a_sum = horizontal_sum_256(norm_a);
let norm_b_sum = horizontal_sum_256(norm_b);
let mut dot_total = dot_sum;
let mut norm_a_total = norm_a_sum;
let mut norm_b_total = norm_b_sum;
for i in (chunks * 8)..len {
let a_val = *a.add(i);
let b_val = *b.add(i);
dot_total += a_val * b_val;
norm_a_total += a_val * a_val;
norm_b_total += b_val * b_val;
}
let denominator = (norm_a_total * norm_b_total).sqrt();
if denominator == 0.0 {
return 1.0;
}
1.0 - (dot_total / denominator)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
pub unsafe fn inner_product_ptr_avx2(a: *const f32, b: *const f32, len: usize) -> f32 {
debug_assert!(!a.is_null() && !b.is_null() && len > 0);
let mut sum = _mm256_setzero_ps();
let chunks = len / 8;
let use_aligned = is_avx2_aligned(a, b);
if use_aligned {
for i in 0..chunks {
let offset = i * 8;
let va = _mm256_load_ps(a.add(offset));
let vb = _mm256_load_ps(b.add(offset));
sum = _mm256_fmadd_ps(va, vb, sum);
}
} else {
for i in 0..chunks {
let offset = i * 8;
let va = _mm256_loadu_ps(a.add(offset));
let vb = _mm256_loadu_ps(b.add(offset));
sum = _mm256_fmadd_ps(va, vb, sum);
}
}
let mut result = horizontal_sum_256(sum);
for i in (chunks * 8)..len {
result += *a.add(i) * *b.add(i);
}
-result
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
pub unsafe fn manhattan_distance_ptr_avx2(a: *const f32, b: *const f32, len: usize) -> f32 {
debug_assert!(!a.is_null() && !b.is_null() && len > 0);
let sign_mask = _mm256_set1_ps(-0.0);
let mut sum = _mm256_setzero_ps();
let chunks = len / 8;
let use_aligned = is_avx2_aligned(a, b);
if use_aligned {
for i in 0..chunks {
let offset = i * 8;
let va = _mm256_load_ps(a.add(offset));
let vb = _mm256_load_ps(b.add(offset));
let diff = _mm256_sub_ps(va, vb);
let abs_diff = _mm256_andnot_ps(sign_mask, diff);
sum = _mm256_add_ps(sum, abs_diff);
}
} else {
for i in 0..chunks {
let offset = i * 8;
let va = _mm256_loadu_ps(a.add(offset));
let vb = _mm256_loadu_ps(b.add(offset));
let diff = _mm256_sub_ps(va, vb);
let abs_diff = _mm256_andnot_ps(sign_mask, diff);
sum = _mm256_add_ps(sum, abs_diff);
}
}
let mut result = horizontal_sum_256(sum);
for i in (chunks * 8)..len {
result += (*a.add(i) - *b.add(i)).abs();
}
result
}
#[inline]
pub unsafe fn l2_distance_ptr_scalar(a: *const f32, b: *const f32, len: usize) -> f32 {
debug_assert!(!a.is_null() && !b.is_null() && len > 0);
let mut sum = 0.0f32;
for i in 0..len {
let diff = *a.add(i) - *b.add(i);
sum += diff * diff;
}
sum.sqrt()
}
#[inline]
pub unsafe fn cosine_distance_ptr_scalar(a: *const f32, b: *const f32, len: usize) -> f32 {
debug_assert!(!a.is_null() && !b.is_null() && len > 0);
let mut dot = 0.0f32;
let mut norm_a = 0.0f32;
let mut norm_b = 0.0f32;
for i in 0..len {
let a_val = *a.add(i);
let b_val = *b.add(i);
dot += a_val * b_val;
norm_a += a_val * a_val;
norm_b += b_val * b_val;
}
let denominator = (norm_a * norm_b).sqrt();
if denominator == 0.0 {
return 1.0;
}
1.0 - (dot / denominator)
}
#[inline]
pub unsafe fn inner_product_ptr_scalar(a: *const f32, b: *const f32, len: usize) -> f32 {
debug_assert!(!a.is_null() && !b.is_null() && len > 0);
let mut sum = 0.0f32;
for i in 0..len {
sum += *a.add(i) * *b.add(i);
}
-sum
}
#[inline]
pub unsafe fn manhattan_distance_ptr_scalar(a: *const f32, b: *const f32, len: usize) -> f32 {
debug_assert!(!a.is_null() && !b.is_null() && len > 0);
let mut sum = 0.0f32;
for i in 0..len {
sum += (*a.add(i) - *b.add(i)).abs();
}
sum
}
#[inline]
pub unsafe fn l2_distance_ptr(a: *const f32, b: *const f32, len: usize) -> f32 {
#[cfg(target_arch = "x86_64")]
{
#[cfg(feature = "simd-avx512")]
if is_x86_feature_detected!("avx512f") {
return l2_distance_ptr_avx512(a, b, len);
}
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return l2_distance_ptr_avx2(a, b, len);
}
}
l2_distance_ptr_scalar(a, b, len)
}
#[inline]
pub unsafe fn cosine_distance_ptr(a: *const f32, b: *const f32, len: usize) -> f32 {
#[cfg(target_arch = "x86_64")]
{
#[cfg(feature = "simd-avx512")]
if is_x86_feature_detected!("avx512f") {
return cosine_distance_ptr_avx512(a, b, len);
}
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return cosine_distance_ptr_avx2(a, b, len);
}
}
cosine_distance_ptr_scalar(a, b, len)
}
#[inline]
pub unsafe fn inner_product_ptr(a: *const f32, b: *const f32, len: usize) -> f32 {
#[cfg(target_arch = "x86_64")]
{
#[cfg(feature = "simd-avx512")]
if is_x86_feature_detected!("avx512f") {
return inner_product_ptr_avx512(a, b, len);
}
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return inner_product_ptr_avx2(a, b, len);
}
}
inner_product_ptr_scalar(a, b, len)
}
#[inline]
pub unsafe fn manhattan_distance_ptr(a: *const f32, b: *const f32, len: usize) -> f32 {
#[cfg(target_arch = "x86_64")]
{
#[cfg(feature = "simd-avx512")]
if is_x86_feature_detected!("avx512f") {
return manhattan_distance_ptr_avx512(a, b, len);
}
if is_x86_feature_detected!("avx2") {
return manhattan_distance_ptr_avx2(a, b, len);
}
return manhattan_distance_ptr_scalar(a, b, len);
}
#[cfg(target_arch = "aarch64")]
{
return manhattan_distance_ptr_neon(a, b, len);
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
manhattan_distance_ptr_scalar(a, b, len)
}
#[inline]
pub unsafe fn l2_distances_batch(
query: *const f32,
vectors: &[*const f32],
len: usize,
results: &mut [f32],
) {
debug_assert!(results.len() >= vectors.len());
debug_assert!(!query.is_null() && len > 0);
for (i, &vec_ptr) in vectors.iter().enumerate() {
results[i] = l2_distance_ptr(query, vec_ptr, len);
}
}
#[inline]
pub unsafe fn cosine_distances_batch(
query: *const f32,
vectors: &[*const f32],
len: usize,
results: &mut [f32],
) {
debug_assert!(results.len() >= vectors.len());
debug_assert!(!query.is_null() && len > 0);
for (i, &vec_ptr) in vectors.iter().enumerate() {
results[i] = cosine_distance_ptr(query, vec_ptr, len);
}
}
#[inline]
pub unsafe fn inner_product_batch(
query: *const f32,
vectors: &[*const f32],
len: usize,
results: &mut [f32],
) {
debug_assert!(results.len() >= vectors.len());
debug_assert!(!query.is_null() && len > 0);
for (i, &vec_ptr) in vectors.iter().enumerate() {
results[i] = inner_product_ptr(query, vec_ptr, len);
}
}
#[inline]
pub unsafe fn manhattan_distances_batch(
query: *const f32,
vectors: &[*const f32],
len: usize,
results: &mut [f32],
) {
debug_assert!(results.len() >= vectors.len());
debug_assert!(!query.is_null() && len > 0);
for (i, &vec_ptr) in vectors.iter().enumerate() {
results[i] = manhattan_distance_ptr(query, vec_ptr, len);
}
}
#[inline]
pub unsafe fn l2_distances_batch_parallel(
query: *const f32,
vectors: &[*const f32],
len: usize,
results: &mut [f32],
) {
debug_assert!(results.len() >= vectors.len());
debug_assert!(!query.is_null() && len > 0);
for (i, &vec_ptr) in vectors.iter().enumerate() {
results[i] = l2_distance_ptr(query, vec_ptr, len);
}
}
#[inline]
pub unsafe fn cosine_distances_batch_parallel(
query: *const f32,
vectors: &[*const f32],
len: usize,
results: &mut [f32],
) {
debug_assert!(results.len() >= vectors.len());
debug_assert!(!query.is_null() && len > 0);
for (i, &vec_ptr) in vectors.iter().enumerate() {
results[i] = cosine_distance_ptr(query, vec_ptr, len);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
unsafe fn euclidean_distance_avx2(a: &[f32], b: &[f32]) -> f32 {
let n = a.len();
let mut sum = _mm256_setzero_ps();
let chunks = n / 8;
for i in 0..chunks {
let offset = i * 8;
let va = _mm256_loadu_ps(a.as_ptr().add(offset));
let vb = _mm256_loadu_ps(b.as_ptr().add(offset));
let diff = _mm256_sub_ps(va, vb);
sum = _mm256_fmadd_ps(diff, diff, sum);
}
let sum_high = _mm256_extractf128_ps(sum, 1);
let sum_low = _mm256_castps256_ps128(sum);
let sum128 = _mm_add_ps(sum_high, sum_low);
let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
let mut result = _mm_cvtss_f32(sum32);
for i in (chunks * 8)..n {
let diff = a[i] - b[i];
result += diff * diff;
}
result.sqrt()
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
unsafe fn cosine_distance_avx2(a: &[f32], b: &[f32]) -> f32 {
let n = a.len();
let mut dot = _mm256_setzero_ps();
let mut norm_a = _mm256_setzero_ps();
let mut norm_b = _mm256_setzero_ps();
let chunks = n / 8;
for i in 0..chunks {
let offset = i * 8;
let va = _mm256_loadu_ps(a.as_ptr().add(offset));
let vb = _mm256_loadu_ps(b.as_ptr().add(offset));
dot = _mm256_fmadd_ps(va, vb, dot);
norm_a = _mm256_fmadd_ps(va, va, norm_a);
norm_b = _mm256_fmadd_ps(vb, vb, norm_b);
}
let dot_sum = horizontal_sum_256(dot);
let norm_a_sum = horizontal_sum_256(norm_a);
let norm_b_sum = horizontal_sum_256(norm_b);
let mut dot_total = dot_sum;
let mut norm_a_total = norm_a_sum;
let mut norm_b_total = norm_b_sum;
for i in (chunks * 8)..n {
dot_total += a[i] * b[i];
norm_a_total += a[i] * a[i];
norm_b_total += b[i] * b[i];
}
let denominator = (norm_a_total * norm_b_total).sqrt();
if denominator == 0.0 {
return 1.0;
}
1.0 - (dot_total / denominator)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
unsafe fn inner_product_avx2(a: &[f32], b: &[f32]) -> f32 {
let n = a.len();
let mut sum = _mm256_setzero_ps();
let chunks = n / 8;
for i in 0..chunks {
let offset = i * 8;
let va = _mm256_loadu_ps(a.as_ptr().add(offset));
let vb = _mm256_loadu_ps(b.as_ptr().add(offset));
sum = _mm256_fmadd_ps(va, vb, sum);
}
let mut result = horizontal_sum_256(sum);
for i in (chunks * 8)..n {
result += a[i] * b[i];
}
-result
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn manhattan_distance_avx2(a: &[f32], b: &[f32]) -> f32 {
let n = a.len();
let sign_mask = _mm256_set1_ps(-0.0); let mut sum = _mm256_setzero_ps();
let chunks = n / 8;
for i in 0..chunks {
let offset = i * 8;
let va = _mm256_loadu_ps(a.as_ptr().add(offset));
let vb = _mm256_loadu_ps(b.as_ptr().add(offset));
let diff = _mm256_sub_ps(va, vb);
let abs_diff = _mm256_andnot_ps(sign_mask, diff); sum = _mm256_add_ps(sum, abs_diff);
}
let mut result = horizontal_sum_256(sum);
for i in (chunks * 8)..n {
result += (a[i] - b[i]).abs();
}
result
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn horizontal_sum_256(v: __m256) -> f32 {
let sum_high = _mm256_extractf128_ps(v, 1);
let sum_low = _mm256_castps256_ps128(v);
let sum128 = _mm_add_ps(sum_high, sum_low);
let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
_mm_cvtss_f32(sum32)
}
#[inline]
pub fn l2_distance_simsimd(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
if let Some(dist_sq) = f32::sqeuclidean(a, b) {
(dist_sq as f32).sqrt()
} else {
scalar::euclidean_distance(a, b)
}
}
#[inline]
pub fn cosine_distance_simsimd(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
if let Some(dist) = f32::cosine(a, b) {
dist as f32
} else {
scalar::cosine_distance(a, b)
}
}
#[inline]
pub fn dot_product_simsimd(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
if let Some(dot) = f32::dot(a, b) {
dot as f32
} else {
scalar::dot_product(a, b)
}
}
#[inline]
pub fn inner_product_distance_simsimd(a: &[f32], b: &[f32]) -> f32 {
-dot_product_simsimd(a, b)
}
#[inline]
pub fn l2_distance_optimized(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
let len = a.len();
match len {
384 | 768 | 1536 | 3072 => l2_distance_simsimd(a, b),
_ => {
#[cfg(target_arch = "x86_64")]
{
if is_avx2_available() && len >= 32 {
unsafe { l2_distance_avx2_unrolled(a, b) }
} else {
l2_distance_simsimd(a, b)
}
}
#[cfg(not(target_arch = "x86_64"))]
{
l2_distance_simsimd(a, b)
}
}
}
}
#[inline]
pub fn cosine_distance_optimized(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
let len = a.len();
match len {
384 | 768 | 1536 | 3072 => cosine_distance_simsimd(a, b),
_ => {
#[cfg(target_arch = "x86_64")]
{
if is_avx2_available() && len >= 32 {
unsafe { cosine_distance_avx2_unrolled(a, b) }
} else {
cosine_distance_simsimd(a, b)
}
}
#[cfg(not(target_arch = "x86_64"))]
{
cosine_distance_simsimd(a, b)
}
}
}
}
#[inline]
pub fn inner_product_distance_optimized(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
let len = a.len();
match len {
384 | 768 | 1536 | 3072 => inner_product_distance_simsimd(a, b),
_ => {
#[cfg(target_arch = "x86_64")]
{
if is_avx2_available() && len >= 32 {
unsafe { inner_product_avx2_unrolled(a, b) }
} else {
inner_product_distance_simsimd(a, b)
}
}
#[cfg(not(target_arch = "x86_64"))]
{
inner_product_distance_simsimd(a, b)
}
}
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
unsafe fn l2_distance_avx2_unrolled(a: &[f32], b: &[f32]) -> f32 {
let n = a.len();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let mut sum0 = _mm256_setzero_ps();
let mut sum1 = _mm256_setzero_ps();
let mut sum2 = _mm256_setzero_ps();
let mut sum3 = _mm256_setzero_ps();
let chunks_4x = n / 32;
for i in 0..chunks_4x {
let offset = i * 32;
let va0 = _mm256_loadu_ps(a_ptr.add(offset));
let vb0 = _mm256_loadu_ps(b_ptr.add(offset));
let va1 = _mm256_loadu_ps(a_ptr.add(offset + 8));
let vb1 = _mm256_loadu_ps(b_ptr.add(offset + 8));
let va2 = _mm256_loadu_ps(a_ptr.add(offset + 16));
let vb2 = _mm256_loadu_ps(b_ptr.add(offset + 16));
let va3 = _mm256_loadu_ps(a_ptr.add(offset + 24));
let vb3 = _mm256_loadu_ps(b_ptr.add(offset + 24));
let diff0 = _mm256_sub_ps(va0, vb0);
let diff1 = _mm256_sub_ps(va1, vb1);
let diff2 = _mm256_sub_ps(va2, vb2);
let diff3 = _mm256_sub_ps(va3, vb3);
sum0 = _mm256_fmadd_ps(diff0, diff0, sum0);
sum1 = _mm256_fmadd_ps(diff1, diff1, sum1);
sum2 = _mm256_fmadd_ps(diff2, diff2, sum2);
sum3 = _mm256_fmadd_ps(diff3, diff3, sum3);
}
let remaining_start = chunks_4x * 32;
let remaining_chunks = (n - remaining_start) / 8;
for i in 0..remaining_chunks {
let offset = remaining_start + i * 8;
let va = _mm256_loadu_ps(a_ptr.add(offset));
let vb = _mm256_loadu_ps(b_ptr.add(offset));
let diff = _mm256_sub_ps(va, vb);
sum0 = _mm256_fmadd_ps(diff, diff, sum0);
}
let sum01 = _mm256_add_ps(sum0, sum1);
let sum23 = _mm256_add_ps(sum2, sum3);
let sum_all = _mm256_add_ps(sum01, sum23);
let mut result = horizontal_sum_256(sum_all);
let final_start = remaining_start + remaining_chunks * 8;
for i in final_start..n {
let diff = *a_ptr.add(i) - *b_ptr.add(i);
result += diff * diff;
}
result.sqrt()
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
unsafe fn cosine_distance_avx2_unrolled(a: &[f32], b: &[f32]) -> f32 {
let n = a.len();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let mut dot0 = _mm256_setzero_ps();
let mut dot1 = _mm256_setzero_ps();
let mut dot2 = _mm256_setzero_ps();
let mut dot3 = _mm256_setzero_ps();
let mut norm_a0 = _mm256_setzero_ps();
let mut norm_a1 = _mm256_setzero_ps();
let mut norm_a2 = _mm256_setzero_ps();
let mut norm_a3 = _mm256_setzero_ps();
let mut norm_b0 = _mm256_setzero_ps();
let mut norm_b1 = _mm256_setzero_ps();
let mut norm_b2 = _mm256_setzero_ps();
let mut norm_b3 = _mm256_setzero_ps();
let chunks_4x = n / 32;
for i in 0..chunks_4x {
let offset = i * 32;
let va0 = _mm256_loadu_ps(a_ptr.add(offset));
let vb0 = _mm256_loadu_ps(b_ptr.add(offset));
let va1 = _mm256_loadu_ps(a_ptr.add(offset + 8));
let vb1 = _mm256_loadu_ps(b_ptr.add(offset + 8));
let va2 = _mm256_loadu_ps(a_ptr.add(offset + 16));
let vb2 = _mm256_loadu_ps(b_ptr.add(offset + 16));
let va3 = _mm256_loadu_ps(a_ptr.add(offset + 24));
let vb3 = _mm256_loadu_ps(b_ptr.add(offset + 24));
dot0 = _mm256_fmadd_ps(va0, vb0, dot0);
dot1 = _mm256_fmadd_ps(va1, vb1, dot1);
dot2 = _mm256_fmadd_ps(va2, vb2, dot2);
dot3 = _mm256_fmadd_ps(va3, vb3, dot3);
norm_a0 = _mm256_fmadd_ps(va0, va0, norm_a0);
norm_a1 = _mm256_fmadd_ps(va1, va1, norm_a1);
norm_a2 = _mm256_fmadd_ps(va2, va2, norm_a2);
norm_a3 = _mm256_fmadd_ps(va3, va3, norm_a3);
norm_b0 = _mm256_fmadd_ps(vb0, vb0, norm_b0);
norm_b1 = _mm256_fmadd_ps(vb1, vb1, norm_b1);
norm_b2 = _mm256_fmadd_ps(vb2, vb2, norm_b2);
norm_b3 = _mm256_fmadd_ps(vb3, vb3, norm_b3);
}
let remaining_start = chunks_4x * 32;
let remaining_chunks = (n - remaining_start) / 8;
for i in 0..remaining_chunks {
let offset = remaining_start + i * 8;
let va = _mm256_loadu_ps(a_ptr.add(offset));
let vb = _mm256_loadu_ps(b_ptr.add(offset));
dot0 = _mm256_fmadd_ps(va, vb, dot0);
norm_a0 = _mm256_fmadd_ps(va, va, norm_a0);
norm_b0 = _mm256_fmadd_ps(vb, vb, norm_b0);
}
let dot01 = _mm256_add_ps(dot0, dot1);
let dot23 = _mm256_add_ps(dot2, dot3);
let dot_all = _mm256_add_ps(dot01, dot23);
let norm_a01 = _mm256_add_ps(norm_a0, norm_a1);
let norm_a23 = _mm256_add_ps(norm_a2, norm_a3);
let norm_a_all = _mm256_add_ps(norm_a01, norm_a23);
let norm_b01 = _mm256_add_ps(norm_b0, norm_b1);
let norm_b23 = _mm256_add_ps(norm_b2, norm_b3);
let norm_b_all = _mm256_add_ps(norm_b01, norm_b23);
let mut dot_sum = horizontal_sum_256(dot_all);
let mut norm_a_sum = horizontal_sum_256(norm_a_all);
let mut norm_b_sum = horizontal_sum_256(norm_b_all);
let final_start = remaining_start + remaining_chunks * 8;
for i in final_start..n {
let a_val = *a_ptr.add(i);
let b_val = *b_ptr.add(i);
dot_sum += a_val * b_val;
norm_a_sum += a_val * a_val;
norm_b_sum += b_val * b_val;
}
let denominator = (norm_a_sum * norm_b_sum).sqrt();
if denominator == 0.0 {
return 1.0;
}
1.0 - (dot_sum / denominator)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
unsafe fn inner_product_avx2_unrolled(a: &[f32], b: &[f32]) -> f32 {
let n = a.len();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let mut sum0 = _mm256_setzero_ps();
let mut sum1 = _mm256_setzero_ps();
let mut sum2 = _mm256_setzero_ps();
let mut sum3 = _mm256_setzero_ps();
let chunks_4x = n / 32;
for i in 0..chunks_4x {
let offset = i * 32;
let va0 = _mm256_loadu_ps(a_ptr.add(offset));
let vb0 = _mm256_loadu_ps(b_ptr.add(offset));
let va1 = _mm256_loadu_ps(a_ptr.add(offset + 8));
let vb1 = _mm256_loadu_ps(b_ptr.add(offset + 8));
let va2 = _mm256_loadu_ps(a_ptr.add(offset + 16));
let vb2 = _mm256_loadu_ps(b_ptr.add(offset + 16));
let va3 = _mm256_loadu_ps(a_ptr.add(offset + 24));
let vb3 = _mm256_loadu_ps(b_ptr.add(offset + 24));
sum0 = _mm256_fmadd_ps(va0, vb0, sum0);
sum1 = _mm256_fmadd_ps(va1, vb1, sum1);
sum2 = _mm256_fmadd_ps(va2, vb2, sum2);
sum3 = _mm256_fmadd_ps(va3, vb3, sum3);
}
let remaining_start = chunks_4x * 32;
let remaining_chunks = (n - remaining_start) / 8;
for i in 0..remaining_chunks {
let offset = remaining_start + i * 8;
let va = _mm256_loadu_ps(a_ptr.add(offset));
let vb = _mm256_loadu_ps(b_ptr.add(offset));
sum0 = _mm256_fmadd_ps(va, vb, sum0);
}
let sum01 = _mm256_add_ps(sum0, sum1);
let sum23 = _mm256_add_ps(sum2, sum3);
let sum_all = _mm256_add_ps(sum01, sum23);
let mut result = horizontal_sum_256(sum_all);
let final_start = remaining_start + remaining_chunks * 8;
for i in final_start..n {
result += *a_ptr.add(i) * *b_ptr.add(i);
}
-result
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn euclidean_distance_neon(a: &[f32], b: &[f32]) -> f32 {
use std::arch::aarch64::*;
let n = a.len();
let mut sum = vdupq_n_f32(0.0);
let chunks = n / 4;
for i in 0..chunks {
let offset = i * 4;
let va = vld1q_f32(a.as_ptr().add(offset));
let vb = vld1q_f32(b.as_ptr().add(offset));
let diff = vsubq_f32(va, vb);
sum = vfmaq_f32(sum, diff, diff);
}
let mut result = vaddvq_f32(sum);
for i in (chunks * 4)..n {
let diff = a[i] - b[i];
result += diff * diff;
}
result.sqrt()
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn cosine_distance_neon(a: &[f32], b: &[f32]) -> f32 {
use std::arch::aarch64::*;
let n = a.len();
let mut dot = vdupq_n_f32(0.0);
let mut norm_a = vdupq_n_f32(0.0);
let mut norm_b = vdupq_n_f32(0.0);
let chunks = n / 4;
for i in 0..chunks {
let offset = i * 4;
let va = vld1q_f32(a.as_ptr().add(offset));
let vb = vld1q_f32(b.as_ptr().add(offset));
dot = vfmaq_f32(dot, va, vb);
norm_a = vfmaq_f32(norm_a, va, va);
norm_b = vfmaq_f32(norm_b, vb, vb);
}
let mut dot_sum = vaddvq_f32(dot);
let mut norm_a_sum = vaddvq_f32(norm_a);
let mut norm_b_sum = vaddvq_f32(norm_b);
for i in (chunks * 4)..n {
dot_sum += a[i] * b[i];
norm_a_sum += a[i] * a[i];
norm_b_sum += b[i] * b[i];
}
let denominator = (norm_a_sum * norm_b_sum).sqrt();
if denominator == 0.0 {
return 1.0;
}
1.0 - (dot_sum / denominator)
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn inner_product_neon(a: &[f32], b: &[f32]) -> f32 {
use std::arch::aarch64::*;
let n = a.len();
let mut sum = vdupq_n_f32(0.0);
let chunks = n / 4;
for i in 0..chunks {
let offset = i * 4;
let va = vld1q_f32(a.as_ptr().add(offset));
let vb = vld1q_f32(b.as_ptr().add(offset));
sum = vfmaq_f32(sum, va, vb);
}
let mut result = vaddvq_f32(sum);
for i in (chunks * 4)..n {
result += a[i] * b[i];
}
-result
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn manhattan_distance_neon(a: &[f32], b: &[f32]) -> f32 {
use std::arch::aarch64::*;
let n = a.len();
let mut sum = vdupq_n_f32(0.0);
let chunks = n / 4;
for i in 0..chunks {
let offset = i * 4;
let va = vld1q_f32(a.as_ptr().add(offset));
let vb = vld1q_f32(b.as_ptr().add(offset));
let diff = vsubq_f32(va, vb);
let abs_diff = vabsq_f32(diff);
sum = vaddq_f32(sum, abs_diff);
}
let mut result = vaddvq_f32(sum);
for i in (chunks * 4)..n {
result += (a[i] - b[i]).abs();
}
result
}
#[cfg(target_arch = "aarch64")]
#[inline]
pub unsafe fn manhattan_distance_ptr_neon(a: *const f32, b: *const f32, len: usize) -> f32 {
use std::arch::aarch64::*;
debug_assert!(!a.is_null() && !b.is_null() && len > 0);
let mut sum = vdupq_n_f32(0.0);
let chunks = len / 4;
for i in 0..chunks {
let offset = i * 4;
let va = vld1q_f32(a.add(offset));
let vb = vld1q_f32(b.add(offset));
let diff = vsubq_f32(va, vb);
let abs_diff = vabsq_f32(diff);
sum = vaddq_f32(sum, abs_diff);
}
let mut result = vaddvq_f32(sum);
for i in (chunks * 4)..len {
result += (*a.add(i) - *b.add(i)).abs();
}
result
}
#[cfg(target_arch = "x86_64")]
pub fn euclidean_distance_avx2_wrapper(a: &[f32], b: &[f32]) -> f32 {
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
unsafe { euclidean_distance_avx2(a, b) }
} else {
scalar::euclidean_distance(a, b)
}
}
#[cfg(not(target_arch = "x86_64"))]
pub fn euclidean_distance_avx2_wrapper(a: &[f32], b: &[f32]) -> f32 {
scalar::euclidean_distance(a, b)
}
#[cfg(target_arch = "x86_64")]
pub fn cosine_distance_avx2_wrapper(a: &[f32], b: &[f32]) -> f32 {
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
unsafe { cosine_distance_avx2(a, b) }
} else {
scalar::cosine_distance(a, b)
}
}
#[cfg(not(target_arch = "x86_64"))]
pub fn cosine_distance_avx2_wrapper(a: &[f32], b: &[f32]) -> f32 {
scalar::cosine_distance(a, b)
}
#[cfg(target_arch = "x86_64")]
pub fn inner_product_avx2_wrapper(a: &[f32], b: &[f32]) -> f32 {
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
unsafe { inner_product_avx2(a, b) }
} else {
scalar::inner_product_distance(a, b)
}
}
#[cfg(not(target_arch = "x86_64"))]
pub fn inner_product_avx2_wrapper(a: &[f32], b: &[f32]) -> f32 {
scalar::inner_product_distance(a, b)
}
#[cfg(target_arch = "x86_64")]
pub fn manhattan_distance_avx2_wrapper(a: &[f32], b: &[f32]) -> f32 {
if is_x86_feature_detected!("avx2") {
unsafe { manhattan_distance_avx2(a, b) }
} else {
scalar::manhattan_distance(a, b)
}
}
#[cfg(not(target_arch = "x86_64"))]
pub fn manhattan_distance_avx2_wrapper(a: &[f32], b: &[f32]) -> f32 {
scalar::manhattan_distance(a, b)
}
#[cfg(target_arch = "aarch64")]
pub fn euclidean_distance_neon_wrapper(a: &[f32], b: &[f32]) -> f32 {
unsafe { euclidean_distance_neon(a, b) }
}
#[cfg(not(target_arch = "aarch64"))]
pub fn euclidean_distance_neon_wrapper(a: &[f32], b: &[f32]) -> f32 {
scalar::euclidean_distance(a, b)
}
#[cfg(target_arch = "aarch64")]
pub fn cosine_distance_neon_wrapper(a: &[f32], b: &[f32]) -> f32 {
unsafe { cosine_distance_neon(a, b) }
}
#[cfg(not(target_arch = "aarch64"))]
pub fn cosine_distance_neon_wrapper(a: &[f32], b: &[f32]) -> f32 {
scalar::cosine_distance(a, b)
}
#[cfg(target_arch = "aarch64")]
pub fn inner_product_neon_wrapper(a: &[f32], b: &[f32]) -> f32 {
unsafe { inner_product_neon(a, b) }
}
#[cfg(not(target_arch = "aarch64"))]
pub fn inner_product_neon_wrapper(a: &[f32], b: &[f32]) -> f32 {
scalar::inner_product_distance(a, b)
}
#[cfg(target_arch = "aarch64")]
pub fn manhattan_distance_neon_wrapper(a: &[f32], b: &[f32]) -> f32 {
unsafe { manhattan_distance_neon(a, b) }
}
#[cfg(not(target_arch = "aarch64"))]
pub fn manhattan_distance_neon_wrapper(a: &[f32], b: &[f32]) -> f32 {
scalar::manhattan_distance(a, b)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
pub unsafe fn cosine_distance_normalized_avx2(a: *const f32, b: *const f32, len: usize) -> f32 {
debug_assert!(!a.is_null() && !b.is_null() && len > 0);
let mut dot = _mm256_setzero_ps();
let chunks = len / 8;
for i in 0..chunks {
let offset = i * 8;
let va = _mm256_loadu_ps(a.add(offset));
let vb = _mm256_loadu_ps(b.add(offset));
dot = _mm256_fmadd_ps(va, vb, dot);
}
let mut result = horizontal_sum_256(dot);
for i in (chunks * 8)..len {
result += *a.add(i) * *b.add(i);
}
1.0 - result
}
#[inline]
pub unsafe fn cosine_distance_normalized_scalar(a: *const f32, b: *const f32, len: usize) -> f32 {
debug_assert!(!a.is_null() && !b.is_null() && len > 0);
let mut dot = 0.0f32;
for i in 0..len {
dot += *a.add(i) * *b.add(i);
}
1.0 - dot
}
#[inline]
pub unsafe fn cosine_distance_normalized_ptr(a: *const f32, b: *const f32, len: usize) -> f32 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return cosine_distance_normalized_avx2(a, b, len);
}
}
cosine_distance_normalized_scalar(a, b, len)
}
pub fn cosine_distance_normalized(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
unsafe { cosine_distance_normalized_ptr(a.as_ptr(), b.as_ptr(), a.len()) }
}
#[inline]
pub unsafe fn l2_topk_batch(
query: *const f32,
vectors: &[*const f32],
len: usize,
k: usize,
) -> Vec<(usize, f32)> {
let mut results: Vec<(usize, f32)> = vectors
.iter()
.enumerate()
.map(|(i, &ptr)| (i, l2_distance_ptr(query, ptr, len)))
.collect();
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(k);
results
}
#[inline]
pub unsafe fn cosine_topk_normalized_batch(
query: *const f32,
vectors: &[*const f32],
len: usize,
k: usize,
) -> Vec<(usize, f32)> {
let mut results: Vec<(usize, f32)> = vectors
.iter()
.enumerate()
.map(|(i, &ptr)| (i, cosine_distance_normalized_ptr(query, ptr, len)))
.collect();
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(k);
results
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_avx2_euclidean() {
let a: Vec<f32> = (0..128).map(|i| i as f32).collect();
let b: Vec<f32> = (0..128).map(|i| (i + 1) as f32).collect();
let scalar = scalar::euclidean_distance(&a, &b);
let simd = euclidean_distance_avx2_wrapper(&a, &b);
assert!(
(scalar - simd).abs() < 1e-4,
"scalar={}, simd={}",
scalar,
simd
);
}
#[test]
fn test_avx2_cosine() {
let a: Vec<f32> = (0..128).map(|i| i as f32 * 0.01).collect();
let b: Vec<f32> = (0..128).map(|i| (128 - i) as f32 * 0.01).collect();
let scalar = scalar::cosine_distance(&a, &b);
let simd = cosine_distance_avx2_wrapper(&a, &b);
assert!(
(scalar - simd).abs() < 1e-4,
"scalar={}, simd={}",
scalar,
simd
);
}
#[test]
fn test_avx2_inner_product() {
let a: Vec<f32> = (0..128).map(|i| i as f32 * 0.01).collect();
let b: Vec<f32> = (0..128).map(|i| (128 - i) as f32 * 0.01).collect();
let scalar = scalar::inner_product_distance(&a, &b);
let simd = inner_product_avx2_wrapper(&a, &b);
assert!(
(scalar - simd).abs() < 1e-3,
"scalar={}, simd={}",
scalar,
simd
);
}
#[test]
fn test_avx2_manhattan() {
let a: Vec<f32> = (0..128).map(|i| i as f32).collect();
let b: Vec<f32> = (0..128).map(|i| (i + 1) as f32).collect();
let scalar = scalar::manhattan_distance(&a, &b);
let simd = manhattan_distance_avx2_wrapper(&a, &b);
assert!(
(scalar - simd).abs() < 1e-4,
"scalar={}, simd={}",
scalar,
simd
);
}
#[test]
fn test_remainder_handling() {
for size in [1, 3, 5, 7, 9, 15, 17, 31, 33, 63, 65, 127, 129] {
let a: Vec<f32> = (0..size).map(|i| i as f32).collect();
let b: Vec<f32> = (0..size).map(|i| (size - i) as f32).collect();
let scalar = scalar::euclidean_distance(&a, &b);
let simd = euclidean_distance_avx2_wrapper(&a, &b);
assert!(
(scalar - simd).abs() < 1e-3,
"size={}, scalar={}, simd={}",
size,
scalar,
simd
);
}
}
#[test]
fn test_ptr_l2_distance() {
let a: Vec<f32> = vec![0.0, 0.0, 0.0];
let b: Vec<f32> = vec![3.0, 4.0, 0.0];
let dist = unsafe { l2_distance_ptr(a.as_ptr(), b.as_ptr(), a.len()) };
assert!((dist - 5.0).abs() < 1e-5, "Expected 5.0, got {}", dist);
}
#[test]
fn test_ptr_cosine_distance() {
let a: Vec<f32> = vec![1.0, 0.0, 0.0];
let b: Vec<f32> = vec![1.0, 0.0, 0.0];
let dist = unsafe { cosine_distance_ptr(a.as_ptr(), b.as_ptr(), a.len()) };
assert!(dist.abs() < 1e-5, "Expected ~0.0, got {}", dist);
}
#[test]
fn test_ptr_inner_product() {
let a: Vec<f32> = vec![1.0, 2.0, 3.0];
let b: Vec<f32> = vec![4.0, 5.0, 6.0];
let dist = unsafe { inner_product_ptr(a.as_ptr(), b.as_ptr(), a.len()) };
assert!(
(dist - (-32.0)).abs() < 1e-5,
"Expected -32.0, got {}",
dist
);
}
#[test]
fn test_ptr_manhattan_distance() {
let a: Vec<f32> = vec![1.0, 2.0, 3.0];
let b: Vec<f32> = vec![4.0, 6.0, 8.0];
let dist = unsafe { manhattan_distance_ptr(a.as_ptr(), b.as_ptr(), a.len()) };
assert!((dist - 12.0).abs() < 1e-5, "Expected 12.0, got {}", dist);
}
#[test]
#[cfg(all(target_arch = "x86_64", feature = "simd-avx512"))]
fn test_avx512_euclidean() {
if !is_avx512_available() {
println!("AVX-512 not available, skipping test");
return;
}
let a: Vec<f32> = (0..256).map(|i| i as f32).collect();
let b: Vec<f32> = (0..256).map(|i| (i + 1) as f32).collect();
let scalar = scalar::euclidean_distance(&a, &b);
let simd = unsafe { l2_distance_ptr_avx512(a.as_ptr(), b.as_ptr(), a.len()) };
assert!(
(scalar - simd).abs() < 1e-3,
"scalar={}, simd={}",
scalar,
simd
);
}
#[test]
#[cfg(all(target_arch = "x86_64", feature = "simd-avx512"))]
fn test_avx512_cosine() {
if !is_avx512_available() {
println!("AVX-512 not available, skipping test");
return;
}
let a: Vec<f32> = (0..256).map(|i| i as f32 * 0.01).collect();
let b: Vec<f32> = (0..256).map(|i| (256 - i) as f32 * 0.01).collect();
let scalar = scalar::cosine_distance(&a, &b);
let simd = unsafe { cosine_distance_ptr_avx512(a.as_ptr(), b.as_ptr(), a.len()) };
assert!(
(scalar - simd).abs() < 1e-4,
"scalar={}, simd={}",
scalar,
simd
);
}
#[test]
#[cfg(all(target_arch = "x86_64", feature = "simd-avx512"))]
fn test_avx512_inner_product() {
if !is_avx512_available() {
println!("AVX-512 not available, skipping test");
return;
}
let a: Vec<f32> = (0..256).map(|i| i as f32 * 0.01).collect();
let b: Vec<f32> = (0..256).map(|i| (256 - i) as f32 * 0.01).collect();
let scalar = scalar::inner_product_distance(&a, &b);
let simd = unsafe { inner_product_ptr_avx512(a.as_ptr(), b.as_ptr(), a.len()) };
assert!(
(scalar - simd).abs() < 1e-2,
"scalar={}, simd={}",
scalar,
simd
);
}
#[test]
#[cfg(all(target_arch = "x86_64", feature = "simd-avx512"))]
fn test_avx512_manhattan() {
if !is_avx512_available() {
println!("AVX-512 not available, skipping test");
return;
}
let a: Vec<f32> = (0..256).map(|i| i as f32).collect();
let b: Vec<f32> = (0..256).map(|i| (i + 1) as f32).collect();
let scalar = scalar::manhattan_distance(&a, &b);
let simd = unsafe { manhattan_distance_ptr_avx512(a.as_ptr(), b.as_ptr(), a.len()) };
assert!(
(scalar - simd).abs() < 1e-4,
"scalar={}, simd={}",
scalar,
simd
);
}
#[test]
#[cfg(all(target_arch = "x86_64", feature = "simd-avx512"))]
fn test_avx512_remainder_handling() {
if !is_avx512_available() {
println!("AVX-512 not available, skipping test");
return;
}
for size in [1, 7, 15, 17, 31, 33, 47, 63, 65, 127, 129, 255, 257] {
let a: Vec<f32> = (0..size).map(|i| i as f32).collect();
let b: Vec<f32> = (0..size).map(|i| (size - i) as f32).collect();
let scalar = scalar::euclidean_distance(&a, &b);
let simd = unsafe { l2_distance_ptr_avx512(a.as_ptr(), b.as_ptr(), a.len()) };
assert!(
(scalar - simd).abs() < 1e-2,
"size={}, scalar={}, simd={}",
size,
scalar,
simd
);
}
}
#[test]
fn test_simd_level_detection() {
let level = simd_level();
assert!(
level == "AVX-512" || level == "AVX2" || level == "NEON" || level == "Scalar",
"Unexpected SIMD level: {}",
level
);
println!("Detected SIMD level: {}", level);
}
#[test]
fn test_feature_detection_functions() {
let _avx512 = is_avx512_available();
let _avx2 = is_avx2_available();
let _neon = is_neon_available();
println!("AVX-512: {}, AVX2: {}, NEON: {}", _avx512, _avx2, _neon);
}
}