#![allow(unsafe_code)]
#![allow(clippy::incompatible_msrv)]
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
pub unsafe fn dot_avx512(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::{
__m512, __mmask16, _mm512_add_ps, _mm512_fmadd_ps, _mm512_loadu_ps, _mm512_maskz_loadu_ps,
_mm512_reduce_add_ps, _mm512_setzero_ps,
};
let n = a.len().min(b.len());
if n == 0 {
return 0.0;
}
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let chunks_64 = n / 64;
let mut sum0: __m512 = _mm512_setzero_ps();
let mut sum1: __m512 = _mm512_setzero_ps();
let mut sum2: __m512 = _mm512_setzero_ps();
let mut sum3: __m512 = _mm512_setzero_ps();
for i in 0..chunks_64 {
let base = i * 64;
let va0 = _mm512_loadu_ps(a_ptr.add(base));
let vb0 = _mm512_loadu_ps(b_ptr.add(base));
let va1 = _mm512_loadu_ps(a_ptr.add(base + 16));
let vb1 = _mm512_loadu_ps(b_ptr.add(base + 16));
let va2 = _mm512_loadu_ps(a_ptr.add(base + 32));
let vb2 = _mm512_loadu_ps(b_ptr.add(base + 32));
let va3 = _mm512_loadu_ps(a_ptr.add(base + 48));
let vb3 = _mm512_loadu_ps(b_ptr.add(base + 48));
sum0 = _mm512_fmadd_ps(va0, vb0, sum0);
sum1 = _mm512_fmadd_ps(va1, vb1, sum1);
sum2 = _mm512_fmadd_ps(va2, vb2, sum2);
sum3 = _mm512_fmadd_ps(va3, vb3, sum3);
}
let sum01 = _mm512_add_ps(sum0, sum1);
let sum23 = _mm512_add_ps(sum2, sum3);
let sum_all = _mm512_add_ps(sum01, sum23);
let mut result = _mm512_reduce_add_ps(sum_all);
let remaining_start = chunks_64 * 64;
let remaining = n - remaining_start;
if remaining > 0 {
let chunks_16 = remaining / 16;
let mut sum_rem: __m512 = _mm512_setzero_ps();
for i in 0..chunks_16 {
let offset = remaining_start + i * 16;
let va = _mm512_loadu_ps(a_ptr.add(offset));
let vb = _mm512_loadu_ps(b_ptr.add(offset));
sum_rem = _mm512_fmadd_ps(va, vb, sum_rem);
}
let tail_count = remaining % 16;
if tail_count > 0 {
let tail_offset = remaining_start + chunks_16 * 16;
let mask: __mmask16 = ((1u32 << tail_count) - 1) as __mmask16;
let va = _mm512_maskz_loadu_ps(mask, a_ptr.add(tail_offset));
let vb = _mm512_maskz_loadu_ps(mask, b_ptr.add(tail_offset));
sum_rem = _mm512_fmadd_ps(va, vb, sum_rem);
}
result += _mm512_reduce_add_ps(sum_rem);
}
result
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
pub unsafe fn maxsim_avx512(query_tokens: &[&[f32]], doc_tokens: &[&[f32]]) -> f32 {
let mut total_score = 0.0;
for q in query_tokens {
let mut max_score = f32::NEG_INFINITY;
for d in doc_tokens {
let score = dot_avx512(q, d);
if score > max_score {
max_score = score;
}
}
total_score += max_score;
}
total_score
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
pub unsafe fn maxsim_avx2(query_tokens: &[&[f32]], doc_tokens: &[&[f32]]) -> f32 {
let mut total_score = 0.0;
for q in query_tokens {
let mut max_score = f32::NEG_INFINITY;
for d in doc_tokens {
let score = dot_avx2(q, d);
if score > max_score {
max_score = score;
}
}
total_score += max_score;
}
total_score
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
pub unsafe fn dot_avx2(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::{
__m256, _mm256_add_ps, _mm256_castps256_ps128, _mm256_extractf128_ps, _mm256_fmadd_ps,
_mm256_loadu_ps, _mm256_setzero_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32, _mm_movehl_ps,
_mm_shuffle_ps,
};
let n = a.len().min(b.len());
if n == 0 {
return 0.0;
}
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let chunks_32 = n / 32;
let mut sum0: __m256 = _mm256_setzero_ps();
let mut sum1: __m256 = _mm256_setzero_ps();
let mut sum2: __m256 = _mm256_setzero_ps();
let mut sum3: __m256 = _mm256_setzero_ps();
for i in 0..chunks_32 {
let base = i * 32;
let va0 = _mm256_loadu_ps(a_ptr.add(base));
let vb0 = _mm256_loadu_ps(b_ptr.add(base));
let va1 = _mm256_loadu_ps(a_ptr.add(base + 8));
let vb1 = _mm256_loadu_ps(b_ptr.add(base + 8));
let va2 = _mm256_loadu_ps(a_ptr.add(base + 16));
let vb2 = _mm256_loadu_ps(b_ptr.add(base + 16));
let va3 = _mm256_loadu_ps(a_ptr.add(base + 24));
let vb3 = _mm256_loadu_ps(b_ptr.add(base + 24));
sum0 = _mm256_fmadd_ps(va0, vb0, sum0);
sum1 = _mm256_fmadd_ps(va1, vb1, sum1);
sum2 = _mm256_fmadd_ps(va2, vb2, sum2);
sum3 = _mm256_fmadd_ps(va3, vb3, sum3);
}
let sum01 = _mm256_add_ps(sum0, sum1);
let sum23 = _mm256_add_ps(sum2, sum3);
let sum_all = _mm256_add_ps(sum01, sum23);
let hi = _mm256_extractf128_ps(sum_all, 1);
let lo = _mm256_castps256_ps128(sum_all);
let sum128 = _mm_add_ps(lo, hi);
let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
let mut result = _mm_cvtss_f32(sum32);
let remaining_start = chunks_32 * 32;
let remaining = n - remaining_start;
let chunks_8 = remaining / 8;
let mut sum: __m256 = _mm256_setzero_ps();
for i in 0..chunks_8 {
let offset = remaining_start + i * 8;
let va = _mm256_loadu_ps(a_ptr.add(offset));
let vb = _mm256_loadu_ps(b_ptr.add(offset));
sum = _mm256_fmadd_ps(va, vb, sum);
}
let hi = _mm256_extractf128_ps(sum, 1);
let lo = _mm256_castps256_ps128(sum);
let sum128 = _mm_add_ps(lo, hi);
let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
result += _mm_cvtss_f32(sum32);
let tail_start = remaining_start + chunks_8 * 8;
for i in tail_start..n {
result += *a.get_unchecked(i) * *b.get_unchecked(i);
}
result
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
pub unsafe fn l2_squared_avx512(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::{
__m512, __mmask16, _mm512_add_ps, _mm512_fmadd_ps, _mm512_loadu_ps, _mm512_maskz_loadu_ps,
_mm512_reduce_add_ps, _mm512_setzero_ps, _mm512_sub_ps,
};
let n = a.len().min(b.len());
if n == 0 {
return 0.0;
}
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let chunks_64 = n / 64;
let mut sum0: __m512 = _mm512_setzero_ps();
let mut sum1: __m512 = _mm512_setzero_ps();
let mut sum2: __m512 = _mm512_setzero_ps();
let mut sum3: __m512 = _mm512_setzero_ps();
for i in 0..chunks_64 {
let base = i * 64;
let va0 = _mm512_loadu_ps(a_ptr.add(base));
let vb0 = _mm512_loadu_ps(b_ptr.add(base));
let d0 = _mm512_sub_ps(va0, vb0);
let va1 = _mm512_loadu_ps(a_ptr.add(base + 16));
let vb1 = _mm512_loadu_ps(b_ptr.add(base + 16));
let d1 = _mm512_sub_ps(va1, vb1);
let va2 = _mm512_loadu_ps(a_ptr.add(base + 32));
let vb2 = _mm512_loadu_ps(b_ptr.add(base + 32));
let d2 = _mm512_sub_ps(va2, vb2);
let va3 = _mm512_loadu_ps(a_ptr.add(base + 48));
let vb3 = _mm512_loadu_ps(b_ptr.add(base + 48));
let d3 = _mm512_sub_ps(va3, vb3);
sum0 = _mm512_fmadd_ps(d0, d0, sum0);
sum1 = _mm512_fmadd_ps(d1, d1, sum1);
sum2 = _mm512_fmadd_ps(d2, d2, sum2);
sum3 = _mm512_fmadd_ps(d3, d3, sum3);
}
let sum01 = _mm512_add_ps(sum0, sum1);
let sum23 = _mm512_add_ps(sum2, sum3);
let sum_all = _mm512_add_ps(sum01, sum23);
let mut result = _mm512_reduce_add_ps(sum_all);
let remaining_start = chunks_64 * 64;
let remaining = n - remaining_start;
if remaining > 0 {
let chunks_16 = remaining / 16;
let mut sum_rem: __m512 = _mm512_setzero_ps();
for i in 0..chunks_16 {
let offset = remaining_start + i * 16;
let va = _mm512_loadu_ps(a_ptr.add(offset));
let vb = _mm512_loadu_ps(b_ptr.add(offset));
let d = _mm512_sub_ps(va, vb);
sum_rem = _mm512_fmadd_ps(d, d, sum_rem);
}
let tail_count = remaining % 16;
if tail_count > 0 {
let tail_offset = remaining_start + chunks_16 * 16;
let mask: __mmask16 = ((1u32 << tail_count) - 1) as __mmask16;
let va = _mm512_maskz_loadu_ps(mask, a_ptr.add(tail_offset));
let vb = _mm512_maskz_loadu_ps(mask, b_ptr.add(tail_offset));
let d = _mm512_sub_ps(va, vb);
sum_rem = _mm512_fmadd_ps(d, d, sum_rem);
}
result += _mm512_reduce_add_ps(sum_rem);
}
result
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
pub unsafe fn l2_squared_avx2(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::{
__m256, _mm256_add_ps, _mm256_castps256_ps128, _mm256_extractf128_ps, _mm256_fmadd_ps,
_mm256_loadu_ps, _mm256_setzero_ps, _mm256_sub_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32,
_mm_movehl_ps, _mm_shuffle_ps,
};
let n = a.len().min(b.len());
if n == 0 {
return 0.0;
}
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let chunks_32 = n / 32;
let mut sum0: __m256 = _mm256_setzero_ps();
let mut sum1: __m256 = _mm256_setzero_ps();
let mut sum2: __m256 = _mm256_setzero_ps();
let mut sum3: __m256 = _mm256_setzero_ps();
for i in 0..chunks_32 {
let base = i * 32;
let va0 = _mm256_loadu_ps(a_ptr.add(base));
let vb0 = _mm256_loadu_ps(b_ptr.add(base));
let d0 = _mm256_sub_ps(va0, vb0);
let va1 = _mm256_loadu_ps(a_ptr.add(base + 8));
let vb1 = _mm256_loadu_ps(b_ptr.add(base + 8));
let d1 = _mm256_sub_ps(va1, vb1);
let va2 = _mm256_loadu_ps(a_ptr.add(base + 16));
let vb2 = _mm256_loadu_ps(b_ptr.add(base + 16));
let d2 = _mm256_sub_ps(va2, vb2);
let va3 = _mm256_loadu_ps(a_ptr.add(base + 24));
let vb3 = _mm256_loadu_ps(b_ptr.add(base + 24));
let d3 = _mm256_sub_ps(va3, vb3);
sum0 = _mm256_fmadd_ps(d0, d0, sum0);
sum1 = _mm256_fmadd_ps(d1, d1, sum1);
sum2 = _mm256_fmadd_ps(d2, d2, sum2);
sum3 = _mm256_fmadd_ps(d3, d3, sum3);
}
let sum01 = _mm256_add_ps(sum0, sum1);
let sum23 = _mm256_add_ps(sum2, sum3);
let sum_all = _mm256_add_ps(sum01, sum23);
let hi = _mm256_extractf128_ps(sum_all, 1);
let lo = _mm256_castps256_ps128(sum_all);
let sum128 = _mm_add_ps(lo, hi);
let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
let mut result = _mm_cvtss_f32(sum32);
let remaining_start = chunks_32 * 32;
let remaining = n - remaining_start;
let chunks_8 = remaining / 8;
let mut sum: __m256 = _mm256_setzero_ps();
for i in 0..chunks_8 {
let offset = remaining_start + i * 8;
let va = _mm256_loadu_ps(a_ptr.add(offset));
let vb = _mm256_loadu_ps(b_ptr.add(offset));
let d = _mm256_sub_ps(va, vb);
sum = _mm256_fmadd_ps(d, d, sum);
}
let hi = _mm256_extractf128_ps(sum, 1);
let lo = _mm256_castps256_ps128(sum);
let sum128 = _mm_add_ps(lo, hi);
let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
result += _mm_cvtss_f32(sum32);
let tail_start = remaining_start + chunks_8 * 8;
for i in tail_start..n {
let d = *a.get_unchecked(i) - *b.get_unchecked(i);
result += d * d;
}
result
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
pub unsafe fn l1_avx512(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::{
__m512, __mmask16, _mm512_abs_ps, _mm512_add_ps, _mm512_loadu_ps, _mm512_maskz_loadu_ps,
_mm512_reduce_add_ps, _mm512_setzero_ps, _mm512_sub_ps,
};
let n = a.len().min(b.len());
if n == 0 {
return 0.0;
}
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let chunks_64 = n / 64;
let mut sum0: __m512 = _mm512_setzero_ps();
let mut sum1: __m512 = _mm512_setzero_ps();
let mut sum2: __m512 = _mm512_setzero_ps();
let mut sum3: __m512 = _mm512_setzero_ps();
for i in 0..chunks_64 {
let base = i * 64;
let d0 = _mm512_abs_ps(_mm512_sub_ps(
_mm512_loadu_ps(a_ptr.add(base)),
_mm512_loadu_ps(b_ptr.add(base)),
));
let d1 = _mm512_abs_ps(_mm512_sub_ps(
_mm512_loadu_ps(a_ptr.add(base + 16)),
_mm512_loadu_ps(b_ptr.add(base + 16)),
));
let d2 = _mm512_abs_ps(_mm512_sub_ps(
_mm512_loadu_ps(a_ptr.add(base + 32)),
_mm512_loadu_ps(b_ptr.add(base + 32)),
));
let d3 = _mm512_abs_ps(_mm512_sub_ps(
_mm512_loadu_ps(a_ptr.add(base + 48)),
_mm512_loadu_ps(b_ptr.add(base + 48)),
));
sum0 = _mm512_add_ps(sum0, d0);
sum1 = _mm512_add_ps(sum1, d1);
sum2 = _mm512_add_ps(sum2, d2);
sum3 = _mm512_add_ps(sum3, d3);
}
let sum_all = _mm512_add_ps(_mm512_add_ps(sum0, sum1), _mm512_add_ps(sum2, sum3));
let mut result = _mm512_reduce_add_ps(sum_all);
let remaining_start = chunks_64 * 64;
let remaining = n - remaining_start;
if remaining > 0 {
let chunks_16 = remaining / 16;
let mut sum_rem: __m512 = _mm512_setzero_ps();
for i in 0..chunks_16 {
let offset = remaining_start + i * 16;
let d = _mm512_abs_ps(_mm512_sub_ps(
_mm512_loadu_ps(a_ptr.add(offset)),
_mm512_loadu_ps(b_ptr.add(offset)),
));
sum_rem = _mm512_add_ps(sum_rem, d);
}
let tail_count = remaining % 16;
if tail_count > 0 {
let tail_offset = remaining_start + chunks_16 * 16;
let mask: __mmask16 = ((1u32 << tail_count) - 1) as __mmask16;
let va = _mm512_maskz_loadu_ps(mask, a_ptr.add(tail_offset));
let vb = _mm512_maskz_loadu_ps(mask, b_ptr.add(tail_offset));
let d = _mm512_abs_ps(_mm512_sub_ps(va, vb));
sum_rem = _mm512_add_ps(sum_rem, d);
}
result += _mm512_reduce_add_ps(sum_rem);
}
result
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn l1_avx2(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::{
__m256, _mm256_add_ps, _mm256_andnot_ps, _mm256_castps256_ps128, _mm256_extractf128_ps,
_mm256_loadu_ps, _mm256_set1_ps, _mm256_setzero_ps, _mm256_sub_ps, _mm_add_ps, _mm_add_ss,
_mm_cvtss_f32, _mm_movehl_ps, _mm_shuffle_ps,
};
let n = a.len().min(b.len());
if n == 0 {
return 0.0;
}
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let sign_mask = _mm256_set1_ps(f32::from_bits(0x8000_0000));
let chunks_32 = n / 32;
let mut sum0: __m256 = _mm256_setzero_ps();
let mut sum1: __m256 = _mm256_setzero_ps();
let mut sum2: __m256 = _mm256_setzero_ps();
let mut sum3: __m256 = _mm256_setzero_ps();
for i in 0..chunks_32 {
let base = i * 32;
let d0 = _mm256_andnot_ps(
sign_mask,
_mm256_sub_ps(
_mm256_loadu_ps(a_ptr.add(base)),
_mm256_loadu_ps(b_ptr.add(base)),
),
);
let d1 = _mm256_andnot_ps(
sign_mask,
_mm256_sub_ps(
_mm256_loadu_ps(a_ptr.add(base + 8)),
_mm256_loadu_ps(b_ptr.add(base + 8)),
),
);
let d2 = _mm256_andnot_ps(
sign_mask,
_mm256_sub_ps(
_mm256_loadu_ps(a_ptr.add(base + 16)),
_mm256_loadu_ps(b_ptr.add(base + 16)),
),
);
let d3 = _mm256_andnot_ps(
sign_mask,
_mm256_sub_ps(
_mm256_loadu_ps(a_ptr.add(base + 24)),
_mm256_loadu_ps(b_ptr.add(base + 24)),
),
);
sum0 = _mm256_add_ps(sum0, d0);
sum1 = _mm256_add_ps(sum1, d1);
sum2 = _mm256_add_ps(sum2, d2);
sum3 = _mm256_add_ps(sum3, d3);
}
let sum_all = _mm256_add_ps(_mm256_add_ps(sum0, sum1), _mm256_add_ps(sum2, sum3));
let hi = _mm256_extractf128_ps(sum_all, 1);
let lo = _mm256_castps256_ps128(sum_all);
let sum128 = _mm_add_ps(lo, hi);
let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
let mut result = _mm_cvtss_f32(sum32);
let remaining_start = chunks_32 * 32;
let remaining = n - remaining_start;
let chunks_8 = remaining / 8;
let mut sum: __m256 = _mm256_setzero_ps();
for i in 0..chunks_8 {
let offset = remaining_start + i * 8;
let d = _mm256_andnot_ps(
sign_mask,
_mm256_sub_ps(
_mm256_loadu_ps(a_ptr.add(offset)),
_mm256_loadu_ps(b_ptr.add(offset)),
),
);
sum = _mm256_add_ps(sum, d);
}
let hi = _mm256_extractf128_ps(sum, 1);
let lo = _mm256_castps256_ps128(sum);
let sum128 = _mm_add_ps(lo, hi);
let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
result += _mm_cvtss_f32(sum32);
let tail_start = remaining_start + chunks_8 * 8;
for i in tail_start..n {
result += (*a.get_unchecked(i) - *b.get_unchecked(i)).abs();
}
result
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
pub unsafe fn cosine_avx512(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::{
__m512, __mmask16, _mm512_add_ps, _mm512_fmadd_ps, _mm512_loadu_ps, _mm512_maskz_loadu_ps,
_mm512_reduce_add_ps, _mm512_setzero_ps,
};
let n = a.len().min(b.len());
if n == 0 {
return 0.0;
}
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let chunks_64 = n / 64;
let mut ab0: __m512 = _mm512_setzero_ps();
let mut ab1: __m512 = _mm512_setzero_ps();
let mut ab2: __m512 = _mm512_setzero_ps();
let mut ab3: __m512 = _mm512_setzero_ps();
let mut aa0: __m512 = _mm512_setzero_ps();
let mut aa1: __m512 = _mm512_setzero_ps();
let mut aa2: __m512 = _mm512_setzero_ps();
let mut aa3: __m512 = _mm512_setzero_ps();
let mut bb0: __m512 = _mm512_setzero_ps();
let mut bb1: __m512 = _mm512_setzero_ps();
let mut bb2: __m512 = _mm512_setzero_ps();
let mut bb3: __m512 = _mm512_setzero_ps();
for i in 0..chunks_64 {
let base = i * 64;
let va0 = _mm512_loadu_ps(a_ptr.add(base));
let vb0 = _mm512_loadu_ps(b_ptr.add(base));
let va1 = _mm512_loadu_ps(a_ptr.add(base + 16));
let vb1 = _mm512_loadu_ps(b_ptr.add(base + 16));
let va2 = _mm512_loadu_ps(a_ptr.add(base + 32));
let vb2 = _mm512_loadu_ps(b_ptr.add(base + 32));
let va3 = _mm512_loadu_ps(a_ptr.add(base + 48));
let vb3 = _mm512_loadu_ps(b_ptr.add(base + 48));
ab0 = _mm512_fmadd_ps(va0, vb0, ab0);
ab1 = _mm512_fmadd_ps(va1, vb1, ab1);
ab2 = _mm512_fmadd_ps(va2, vb2, ab2);
ab3 = _mm512_fmadd_ps(va3, vb3, ab3);
aa0 = _mm512_fmadd_ps(va0, va0, aa0);
aa1 = _mm512_fmadd_ps(va1, va1, aa1);
aa2 = _mm512_fmadd_ps(va2, va2, aa2);
aa3 = _mm512_fmadd_ps(va3, va3, aa3);
bb0 = _mm512_fmadd_ps(vb0, vb0, bb0);
bb1 = _mm512_fmadd_ps(vb1, vb1, bb1);
bb2 = _mm512_fmadd_ps(vb2, vb2, bb2);
bb3 = _mm512_fmadd_ps(vb3, vb3, bb3);
}
let ab_all = _mm512_add_ps(_mm512_add_ps(ab0, ab1), _mm512_add_ps(ab2, ab3));
let aa_all = _mm512_add_ps(_mm512_add_ps(aa0, aa1), _mm512_add_ps(aa2, aa3));
let bb_all = _mm512_add_ps(_mm512_add_ps(bb0, bb1), _mm512_add_ps(bb2, bb3));
let mut ab = _mm512_reduce_add_ps(ab_all);
let mut aa = _mm512_reduce_add_ps(aa_all);
let mut bb = _mm512_reduce_add_ps(bb_all);
let remaining_start = chunks_64 * 64;
let remaining = n - remaining_start;
if remaining > 0 {
let chunks_16 = remaining / 16;
let mut ab_rem: __m512 = _mm512_setzero_ps();
let mut aa_rem: __m512 = _mm512_setzero_ps();
let mut bb_rem: __m512 = _mm512_setzero_ps();
for i in 0..chunks_16 {
let offset = remaining_start + i * 16;
let va = _mm512_loadu_ps(a_ptr.add(offset));
let vb = _mm512_loadu_ps(b_ptr.add(offset));
ab_rem = _mm512_fmadd_ps(va, vb, ab_rem);
aa_rem = _mm512_fmadd_ps(va, va, aa_rem);
bb_rem = _mm512_fmadd_ps(vb, vb, bb_rem);
}
let tail_count = remaining % 16;
if tail_count > 0 {
let tail_offset = remaining_start + chunks_16 * 16;
let mask: __mmask16 = ((1u32 << tail_count) - 1) as __mmask16;
let va = _mm512_maskz_loadu_ps(mask, a_ptr.add(tail_offset));
let vb = _mm512_maskz_loadu_ps(mask, b_ptr.add(tail_offset));
ab_rem = _mm512_fmadd_ps(va, vb, ab_rem);
aa_rem = _mm512_fmadd_ps(va, va, aa_rem);
bb_rem = _mm512_fmadd_ps(vb, vb, bb_rem);
}
ab += _mm512_reduce_add_ps(ab_rem);
aa += _mm512_reduce_add_ps(aa_rem);
bb += _mm512_reduce_add_ps(bb_rem);
}
if aa > crate::NORM_EPSILON_SQ && bb > crate::NORM_EPSILON_SQ {
ab / (aa.sqrt() * bb.sqrt())
} else {
0.0
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
pub unsafe fn cosine_avx2(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::{
__m256, _mm256_add_ps, _mm256_castps256_ps128, _mm256_extractf128_ps, _mm256_fmadd_ps,
_mm256_loadu_ps, _mm256_setzero_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32, _mm_movehl_ps,
_mm_shuffle_ps,
};
let n = a.len().min(b.len());
if n == 0 {
return 0.0;
}
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let chunks_32 = n / 32;
let mut ab0: __m256 = _mm256_setzero_ps();
let mut ab1: __m256 = _mm256_setzero_ps();
let mut ab2: __m256 = _mm256_setzero_ps();
let mut ab3: __m256 = _mm256_setzero_ps();
let mut aa0: __m256 = _mm256_setzero_ps();
let mut aa1: __m256 = _mm256_setzero_ps();
let mut aa2: __m256 = _mm256_setzero_ps();
let mut aa3: __m256 = _mm256_setzero_ps();
let mut bb0: __m256 = _mm256_setzero_ps();
let mut bb1: __m256 = _mm256_setzero_ps();
let mut bb2: __m256 = _mm256_setzero_ps();
let mut bb3: __m256 = _mm256_setzero_ps();
for i in 0..chunks_32 {
let base = i * 32;
let va0 = _mm256_loadu_ps(a_ptr.add(base));
let vb0 = _mm256_loadu_ps(b_ptr.add(base));
let va1 = _mm256_loadu_ps(a_ptr.add(base + 8));
let vb1 = _mm256_loadu_ps(b_ptr.add(base + 8));
let va2 = _mm256_loadu_ps(a_ptr.add(base + 16));
let vb2 = _mm256_loadu_ps(b_ptr.add(base + 16));
let va3 = _mm256_loadu_ps(a_ptr.add(base + 24));
let vb3 = _mm256_loadu_ps(b_ptr.add(base + 24));
ab0 = _mm256_fmadd_ps(va0, vb0, ab0);
ab1 = _mm256_fmadd_ps(va1, vb1, ab1);
ab2 = _mm256_fmadd_ps(va2, vb2, ab2);
ab3 = _mm256_fmadd_ps(va3, vb3, ab3);
aa0 = _mm256_fmadd_ps(va0, va0, aa0);
aa1 = _mm256_fmadd_ps(va1, va1, aa1);
aa2 = _mm256_fmadd_ps(va2, va2, aa2);
aa3 = _mm256_fmadd_ps(va3, va3, aa3);
bb0 = _mm256_fmadd_ps(vb0, vb0, bb0);
bb1 = _mm256_fmadd_ps(vb1, vb1, bb1);
bb2 = _mm256_fmadd_ps(vb2, vb2, bb2);
bb3 = _mm256_fmadd_ps(vb3, vb3, bb3);
}
#[inline(always)]
unsafe fn hsum256(v: __m256) -> f32 {
let hi = _mm256_extractf128_ps(v, 1);
let lo = _mm256_castps256_ps128(v);
let sum128 = _mm_add_ps(lo, hi);
let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
_mm_cvtss_f32(sum32)
}
let ab_all = _mm256_add_ps(_mm256_add_ps(ab0, ab1), _mm256_add_ps(ab2, ab3));
let aa_all = _mm256_add_ps(_mm256_add_ps(aa0, aa1), _mm256_add_ps(aa2, aa3));
let bb_all = _mm256_add_ps(_mm256_add_ps(bb0, bb1), _mm256_add_ps(bb2, bb3));
let mut ab = hsum256(ab_all);
let mut aa = hsum256(aa_all);
let mut bb = hsum256(bb_all);
let remaining_start = chunks_32 * 32;
let remaining = n - remaining_start;
let chunks_8 = remaining / 8;
let mut ab_rem: __m256 = _mm256_setzero_ps();
let mut aa_rem: __m256 = _mm256_setzero_ps();
let mut bb_rem: __m256 = _mm256_setzero_ps();
for i in 0..chunks_8 {
let offset = remaining_start + i * 8;
let va = _mm256_loadu_ps(a_ptr.add(offset));
let vb = _mm256_loadu_ps(b_ptr.add(offset));
ab_rem = _mm256_fmadd_ps(va, vb, ab_rem);
aa_rem = _mm256_fmadd_ps(va, va, aa_rem);
bb_rem = _mm256_fmadd_ps(vb, vb, bb_rem);
}
ab += hsum256(ab_rem);
aa += hsum256(aa_rem);
bb += hsum256(bb_rem);
let tail_start = remaining_start + chunks_8 * 8;
for i in tail_start..n {
let ai = *a.get_unchecked(i);
let bi = *b.get_unchecked(i);
ab += ai * bi;
aa += ai * ai;
bb += bi * bi;
}
if aa > crate::NORM_EPSILON_SQ && bb > crate::NORM_EPSILON_SQ {
ab / (aa.sqrt() * bb.sqrt())
} else {
0.0
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
pub unsafe fn dot_u8_f32_avx2(a: &[f32], b: &[u8]) -> f32 {
use std::arch::x86_64::{
__m256, _mm256_add_ps, _mm256_castps256_ps128, _mm256_cvtepi32_ps, _mm256_cvtepu8_epi32,
_mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_loadu_ps, _mm256_setzero_ps, _mm_add_ps,
_mm_add_ss, _mm_cvtss_f32, _mm_loadl_epi64, _mm_movehl_ps, _mm_shuffle_ps,
};
let n = a.len().min(b.len());
if n == 0 {
return 0.0;
}
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let chunks_32 = n / 32;
let mut sum0: __m256 = _mm256_setzero_ps();
let mut sum1: __m256 = _mm256_setzero_ps();
let mut sum2: __m256 = _mm256_setzero_ps();
let mut sum3: __m256 = _mm256_setzero_ps();
for i in 0..chunks_32 {
let base = i * 32;
let b0 = _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
b_ptr.add(base) as *const _
)));
let b1 = _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
b_ptr.add(base + 8) as *const _
)));
let b2 = _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
b_ptr.add(base + 16) as *const _
)));
let b3 = _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
b_ptr.add(base + 24) as *const _
)));
let a0 = _mm256_loadu_ps(a_ptr.add(base));
let a1 = _mm256_loadu_ps(a_ptr.add(base + 8));
let a2 = _mm256_loadu_ps(a_ptr.add(base + 16));
let a3 = _mm256_loadu_ps(a_ptr.add(base + 24));
sum0 = _mm256_fmadd_ps(a0, b0, sum0);
sum1 = _mm256_fmadd_ps(a1, b1, sum1);
sum2 = _mm256_fmadd_ps(a2, b2, sum2);
sum3 = _mm256_fmadd_ps(a3, b3, sum3);
}
let sum_all = _mm256_add_ps(_mm256_add_ps(sum0, sum1), _mm256_add_ps(sum2, sum3));
let hi = _mm256_extractf128_ps(sum_all, 1);
let lo = _mm256_castps256_ps128(sum_all);
let sum128 = _mm_add_ps(lo, hi);
let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
let mut result = _mm_cvtss_f32(sum32);
let remaining_start = chunks_32 * 32;
let remaining = n - remaining_start;
let chunks_8 = remaining / 8;
let mut sum: __m256 = _mm256_setzero_ps();
for i in 0..chunks_8 {
let offset = remaining_start + i * 8;
let b_f32 = _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
b_ptr.add(offset) as *const _
)));
let a_f32 = _mm256_loadu_ps(a_ptr.add(offset));
sum = _mm256_fmadd_ps(a_f32, b_f32, sum);
}
let hi = _mm256_extractf128_ps(sum, 1);
let lo = _mm256_castps256_ps128(sum);
let sum128 = _mm_add_ps(lo, hi);
let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
result += _mm_cvtss_f32(sum32);
let tail_start = remaining_start + chunks_8 * 8;
for i in tail_start..n {
result += *a.get_unchecked(i) * (*b.get_unchecked(i) as f32);
}
result
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn dot_u8_avx2(a: &[u8], b: &[u8]) -> u32 {
use std::arch::x86_64::{
__m256i, _mm256_add_epi32, _mm256_castsi256_si128, _mm256_cvtepu8_epi16,
_mm256_extract_epi32, _mm256_extracti128_si256, _mm256_lddqu_si256, _mm256_madd_epi16,
_mm256_permute2x128_si256, _mm256_setzero_si256,
};
let n = a.len().min(b.len());
if n == 0 {
return 0;
}
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let mut acc32: __m256i = _mm256_setzero_si256();
let chunks_32 = n / 32;
for i in 0..chunks_32 {
let base = i * 32;
let va = _mm256_lddqu_si256(a_ptr.add(base) as *const __m256i);
let vb = _mm256_lddqu_si256(b_ptr.add(base) as *const __m256i);
let va_lo = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(va));
let vb_lo = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(vb));
let va_hi = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(va, 1));
let vb_hi = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(vb, 1));
let prod_lo = _mm256_madd_epi16(va_lo, vb_lo);
let prod_hi = _mm256_madd_epi16(va_hi, vb_hi);
acc32 = _mm256_add_epi32(acc32, prod_lo);
acc32 = _mm256_add_epi32(acc32, prod_hi);
}
let hi128 = _mm256_permute2x128_si256(acc32, acc32, 0x01);
let sum128 = _mm256_add_epi32(acc32, hi128);
let result: u32 = (_mm256_extract_epi32(sum128, 0) as u32)
.wrapping_add(_mm256_extract_epi32(sum128, 1) as u32)
.wrapping_add(_mm256_extract_epi32(sum128, 2) as u32)
.wrapping_add(_mm256_extract_epi32(sum128, 3) as u32);
let tail_start = chunks_32 * 32;
let tail: u32 = (tail_start..n)
.map(|i| *a.get_unchecked(i) as u32 * *b.get_unchecked(i) as u32)
.sum();
result.wrapping_add(tail)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f", enable = "avx512bw")]
pub unsafe fn dot_u8_avx512(a: &[u8], b: &[u8]) -> u32 {
use std::arch::x86_64::{
__m512i, _mm512_add_epi32, _mm512_cvtepu8_epi16, _mm512_extracti64x4_epi64,
_mm512_loadu_si512, _mm512_madd_epi16, _mm512_reduce_add_epi32, _mm512_setzero_si512,
};
let n = a.len().min(b.len());
if n == 0 {
return 0;
}
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let mut acc32: __m512i = _mm512_setzero_si512();
let chunks_64 = n / 64;
for i in 0..chunks_64 {
let base = i * 64;
let va = _mm512_loadu_si512(a_ptr.add(base) as *const _);
let vb = _mm512_loadu_si512(b_ptr.add(base) as *const _);
let va_lo16 = _mm512_cvtepu8_epi16(_mm512_extracti64x4_epi64(va, 0));
let vb_lo16 = _mm512_cvtepu8_epi16(_mm512_extracti64x4_epi64(vb, 0));
let va_hi16 = _mm512_cvtepu8_epi16(_mm512_extracti64x4_epi64(va, 1));
let vb_hi16 = _mm512_cvtepu8_epi16(_mm512_extracti64x4_epi64(vb, 1));
let prod_lo = _mm512_madd_epi16(va_lo16, vb_lo16);
let prod_hi = _mm512_madd_epi16(va_hi16, vb_hi16);
acc32 = _mm512_add_epi32(acc32, prod_lo);
acc32 = _mm512_add_epi32(acc32, prod_hi);
}
let mut result = _mm512_reduce_add_epi32(acc32) as u32;
let tail_start = chunks_64 * 64;
for i in tail_start..n {
result = result.wrapping_add(*a.get_unchecked(i) as u32 * *b.get_unchecked(i) as u32);
}
result
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn hamming_avx2(a: &[u8], b: &[u8]) -> u32 {
use std::arch::x86_64::{
__m256i, _mm256_add_epi64, _mm256_and_si256, _mm256_extract_epi64, _mm256_lddqu_si256,
_mm256_sad_epu8, _mm256_set1_epi8, _mm256_setr_epi8, _mm256_setzero_si256,
_mm256_shuffle_epi8, _mm256_srli_epi16, _mm256_xor_si256,
};
let n = a.len().min(b.len());
if n == 0 {
return 0;
}
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let lut = _mm256_setr_epi8(
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, );
let lo_mask = _mm256_set1_epi8(0x0F_u8 as i8);
let chunks_32 = n / 32;
let mut acc: __m256i = _mm256_setzero_si256();
for i in 0..chunks_32 {
let base = i * 32;
let va = _mm256_lddqu_si256(a_ptr.add(base) as *const __m256i);
let vb = _mm256_lddqu_si256(b_ptr.add(base) as *const __m256i);
let xored = _mm256_xor_si256(va, vb);
let lo_nibbles = _mm256_and_si256(xored, lo_mask);
let hi_nibbles = _mm256_and_si256(_mm256_srli_epi16(xored, 4), lo_mask);
let lo_cnt = _mm256_shuffle_epi8(lut, lo_nibbles);
let hi_cnt = _mm256_shuffle_epi8(lut, hi_nibbles);
let byte_cnt = _mm256_add_epi64(
_mm256_sad_epu8(lo_cnt, _mm256_setzero_si256()),
_mm256_sad_epu8(hi_cnt, _mm256_setzero_si256()),
);
acc = _mm256_add_epi64(acc, byte_cnt);
}
let mut result: u32 = (_mm256_extract_epi64(acc, 0) as u32)
.wrapping_add(_mm256_extract_epi64(acc, 1) as u32)
.wrapping_add(_mm256_extract_epi64(acc, 2) as u32)
.wrapping_add(_mm256_extract_epi64(acc, 3) as u32);
for i in (chunks_32 * 32)..n {
result += (*a.get_unchecked(i) ^ *b.get_unchecked(i)).count_ones();
}
result
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f", enable = "avx512vpopcntdq")]
pub unsafe fn hamming_avx512(a: &[u8], b: &[u8]) -> u32 {
use std::arch::x86_64::{
__m512i, _mm512_add_epi64, _mm512_loadu_si512, _mm512_popcnt_epi64,
_mm512_reduce_add_epi64, _mm512_setzero_si512, _mm512_xor_si512,
};
let n = a.len().min(b.len());
if n == 0 {
return 0;
}
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let chunks_64 = n / 64;
let mut acc: __m512i = _mm512_setzero_si512();
for i in 0..chunks_64 {
let base = i * 64;
let va = _mm512_loadu_si512(a_ptr.add(base) as *const _);
let vb = _mm512_loadu_si512(b_ptr.add(base) as *const _);
let xored = _mm512_xor_si512(va, vb);
let cnt = _mm512_popcnt_epi64(xored);
acc = _mm512_add_epi64(acc, cnt);
}
let mut result = _mm512_reduce_add_epi64(acc) as u32;
for i in (chunks_64 * 64)..n {
result += (*a.get_unchecked(i) ^ *b.get_unchecked(i)).count_ones();
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[cfg(target_arch = "x86_64")]
fn test_dot_avx512_correctness() {
if !is_x86_feature_detected!("avx512f") {
eprintln!("AVX-512F not available, skipping test");
return;
}
for size in [1, 15, 16, 17, 31, 32, 63, 64, 65, 127, 128, 256, 512, 1024] {
let a: Vec<f32> = (0..size).map(|i| (i as f32) * 0.1).collect();
let b: Vec<f32> = (0..size).map(|i| (i as f32) * 0.2).collect();
let expected: f32 = a.iter().zip(&b).map(|(x, y)| x * y).sum();
let actual = unsafe { dot_avx512(&a, &b) };
let rel_error = if expected.abs() > 1e-6 {
(actual - expected).abs() / expected.abs()
} else {
(actual - expected).abs()
};
assert!(
rel_error < 1e-4, "AVX-512 size={}: expected={}, actual={}, rel_error={}",
size,
expected,
actual,
rel_error
);
}
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_dot_avx2_correctness() {
if !is_x86_feature_detected!("avx2") || !is_x86_feature_detected!("fma") {
eprintln!("AVX2+FMA not available, skipping test");
return;
}
for size in [1, 7, 8, 9, 15, 16, 17, 31, 32, 33, 64, 128, 256, 512] {
let a: Vec<f32> = (0..size).map(|i| (i as f32) * 0.1).collect();
let b: Vec<f32> = (0..size).map(|i| (i as f32) * 0.2).collect();
let expected: f32 = a.iter().zip(&b).map(|(x, y)| x * y).sum();
let actual = unsafe { dot_avx2(&a, &b) };
let rel_error = if expected.abs() > 1e-6 {
(actual - expected).abs() / expected.abs()
} else {
(actual - expected).abs()
};
assert!(
rel_error < 1e-5,
"AVX2 size={}: expected={}, actual={}, rel_error={}",
size,
expected,
actual,
rel_error
);
}
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_l2_squared_avx512_correctness() {
if !is_x86_feature_detected!("avx512f") {
eprintln!("AVX-512F not available, skipping test");
return;
}
for size in [1, 15, 16, 17, 31, 32, 63, 64, 65, 127, 128, 256, 512, 1024] {
let a: Vec<f32> = (0..size).map(|i| (i as f32) * 0.1).collect();
let b: Vec<f32> = (0..size).map(|i| (i as f32) * 0.2).collect();
let expected: f32 = a.iter().zip(&b).map(|(x, y)| (x - y) * (x - y)).sum();
let actual = unsafe { l2_squared_avx512(&a, &b) };
let rel_error = if expected.abs() > 1e-6 {
(actual - expected).abs() / expected.abs()
} else {
(actual - expected).abs()
};
assert!(
rel_error < 1e-4,
"AVX-512 L2 size={}: expected={}, actual={}, rel_error={}",
size,
expected,
actual,
rel_error
);
}
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_l2_squared_avx2_correctness() {
if !is_x86_feature_detected!("avx2") || !is_x86_feature_detected!("fma") {
eprintln!("AVX2+FMA not available, skipping test");
return;
}
for size in [1, 7, 8, 9, 15, 16, 17, 31, 32, 33, 64, 128, 256, 512] {
let a: Vec<f32> = (0..size).map(|i| (i as f32) * 0.1).collect();
let b: Vec<f32> = (0..size).map(|i| (i as f32) * 0.2).collect();
let expected: f32 = a.iter().zip(&b).map(|(x, y)| (x - y) * (x - y)).sum();
let actual = unsafe { l2_squared_avx2(&a, &b) };
let rel_error = if expected.abs() > 1e-6 {
(actual - expected).abs() / expected.abs()
} else {
(actual - expected).abs()
};
assert!(
rel_error < 1e-5,
"AVX2 L2 size={}: expected={}, actual={}, rel_error={}",
size,
expected,
actual,
rel_error
);
}
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_cosine_avx512_correctness() {
if !is_x86_feature_detected!("avx512f") {
eprintln!("AVX-512F not available, skipping test");
return;
}
for size in [1, 15, 16, 17, 31, 32, 63, 64, 65, 127, 128, 256, 512, 1024] {
let a: Vec<f32> = (0..size).map(|i| ((i * 7) as f32).sin()).collect();
let b: Vec<f32> = (0..size).map(|i| ((i * 11) as f32).cos()).collect();
let ab: f32 = a.iter().zip(&b).map(|(x, y)| x * y).sum();
let aa: f32 = a.iter().map(|x| x * x).sum();
let bb: f32 = b.iter().map(|x| x * x).sum();
let expected = if aa > crate::NORM_EPSILON_SQ && bb > crate::NORM_EPSILON_SQ {
ab / (aa.sqrt() * bb.sqrt())
} else {
0.0
};
let actual = unsafe { cosine_avx512(&a, &b) };
let diff = (actual - expected).abs();
assert!(
diff < 1e-4,
"AVX-512 cosine size={}: expected={}, actual={}, diff={}",
size,
expected,
actual,
diff
);
}
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_cosine_avx2_correctness() {
if !is_x86_feature_detected!("avx2") || !is_x86_feature_detected!("fma") {
eprintln!("AVX2+FMA not available, skipping test");
return;
}
for size in [1, 7, 8, 9, 15, 16, 17, 31, 32, 33, 64, 128, 256, 512] {
let a: Vec<f32> = (0..size).map(|i| ((i * 7) as f32).sin()).collect();
let b: Vec<f32> = (0..size).map(|i| ((i * 11) as f32).cos()).collect();
let ab: f32 = a.iter().zip(&b).map(|(x, y)| x * y).sum();
let aa: f32 = a.iter().map(|x| x * x).sum();
let bb: f32 = b.iter().map(|x| x * x).sum();
let expected = if aa > crate::NORM_EPSILON_SQ && bb > crate::NORM_EPSILON_SQ {
ab / (aa.sqrt() * bb.sqrt())
} else {
0.0
};
let actual = unsafe { cosine_avx2(&a, &b) };
let diff = (actual - expected).abs();
assert!(
diff < 1e-5,
"AVX2 cosine size={}: expected={}, actual={}, diff={}",
size,
expected,
actual,
diff
);
}
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_avx2_vs_avx512_consistency() {
if !is_x86_feature_detected!("avx2")
|| !is_x86_feature_detected!("fma")
|| !is_x86_feature_detected!("avx512f")
{
eprintln!("Need both AVX2+FMA and AVX-512F, skipping");
return;
}
for size in [64, 128, 256, 512, 1024] {
let a: Vec<f32> = (0..size).map(|i| ((i * 7) as f32).sin()).collect();
let b: Vec<f32> = (0..size).map(|i| ((i * 11) as f32).cos()).collect();
let avx2_result = unsafe { dot_avx2(&a, &b) };
let avx512_result = unsafe { dot_avx512(&a, &b) };
let diff = (avx2_result - avx512_result).abs();
let max_val = avx2_result.abs().max(avx512_result.abs()).max(1e-6);
let rel_diff = diff / max_val;
assert!(
rel_diff < 1e-5,
"AVX2 vs AVX-512 mismatch at size={}: avx2={}, avx512={}, rel_diff={}",
size,
avx2_result,
avx512_result,
rel_diff
);
}
}
}