#![allow(unreachable_code)]
use crate::types::DistanceMetric;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
#[cfg(target_arch = "x86_64")]
#[inline]
pub fn is_avx2_available() -> bool {
is_x86_feature_detected!("avx2")
}
#[cfg(not(target_arch = "x86_64"))]
#[inline]
pub fn is_avx2_available() -> bool {
false
}
#[cfg(target_arch = "x86_64")]
#[inline]
pub fn is_fma_available() -> bool {
is_x86_feature_detected!("fma")
}
#[cfg(not(target_arch = "x86_64"))]
#[inline]
pub fn is_fma_available() -> bool {
false
}
#[cfg(target_arch = "aarch64")]
#[inline]
pub fn is_neon_available() -> bool {
true
}
#[cfg(not(target_arch = "aarch64"))]
#[inline]
pub fn is_neon_available() -> bool {
false
}
#[cfg(target_arch = "x86_64")]
#[inline]
pub fn is_avx512_available() -> bool {
is_x86_feature_detected!("avx512f")
}
#[cfg(not(target_arch = "x86_64"))]
#[inline]
pub fn is_avx512_available() -> bool {
false
}
#[cfg(target_arch = "x86_64")]
#[inline]
unsafe fn horizontal_sum_avx512(v: __m512) -> f32 {
let low = _mm512_castps512_ps256(v); let high = _mm512_extractf32x8_ps(v, 1);
let sum256 = _mm256_add_ps(low, high);
horizontal_sum_avx2(sum256)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[inline]
unsafe fn dot_product_avx512(a: &[f32], b: &[f32]) -> f32 {
let len = a.len();
let mut sum = _mm512_setzero_ps();
let chunks = len / 16;
for i in 0..chunks {
let offset = i * 16;
let a_ptr = a.as_ptr().add(offset);
let b_ptr = b.as_ptr().add(offset);
let a_vec = _mm512_loadu_ps(a_ptr);
let b_vec = _mm512_loadu_ps(b_ptr);
sum = _mm512_fmadd_ps(a_vec, b_vec, sum);
}
let mut total = horizontal_sum_avx512(sum);
for i in (chunks * 16)..len {
total += a[i] * b[i];
}
total
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[inline]
unsafe fn cosine_similarity_avx512(a: &[f32], b: &[f32]) -> f32 {
let len = a.len();
let mut dot_sum = _mm512_setzero_ps();
let mut norm_a_sum = _mm512_setzero_ps();
let mut norm_b_sum = _mm512_setzero_ps();
let chunks = len / 16;
for i in 0..chunks {
let offset = i * 16;
let a_ptr = a.as_ptr().add(offset);
let b_ptr = b.as_ptr().add(offset);
let a_vec = _mm512_loadu_ps(a_ptr);
let b_vec = _mm512_loadu_ps(b_ptr);
dot_sum = _mm512_fmadd_ps(a_vec, b_vec, dot_sum);
norm_a_sum = _mm512_fmadd_ps(a_vec, a_vec, norm_a_sum);
norm_b_sum = _mm512_fmadd_ps(b_vec, b_vec, norm_b_sum);
}
let mut dot = horizontal_sum_avx512(dot_sum);
let mut norm_a = horizontal_sum_avx512(norm_a_sum);
let mut norm_b = horizontal_sum_avx512(norm_b_sum);
for i in (chunks * 16)..len {
dot += a[i] * b[i];
norm_a += a[i] * a[i];
norm_b += b[i] * b[i];
}
let denominator = (norm_a.sqrt() * norm_b.sqrt()).max(1e-10);
dot / denominator
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[inline]
unsafe fn euclidean_distance_avx512(a: &[f32], b: &[f32]) -> f32 {
let len = a.len();
let mut sum_sq = _mm512_setzero_ps();
let chunks = len / 16;
for i in 0..chunks {
let offset = i * 16;
let a_ptr = a.as_ptr().add(offset);
let b_ptr = b.as_ptr().add(offset);
let a_vec = _mm512_loadu_ps(a_ptr);
let b_vec = _mm512_loadu_ps(b_ptr);
let diff = _mm512_sub_ps(a_vec, b_vec);
sum_sq = _mm512_fmadd_ps(diff, diff, sum_sq);
}
let mut total = horizontal_sum_avx512(sum_sq);
for i in (chunks * 16)..len {
let diff = a[i] - b[i];
total += diff * diff;
}
total.sqrt()
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[inline]
unsafe fn manhattan_distance_avx512(a: &[f32], b: &[f32]) -> f32 {
let len = a.len();
let mut sum = _mm512_setzero_ps();
let chunks = len / 16;
for i in 0..chunks {
let offset = i * 16;
let a_ptr = a.as_ptr().add(offset);
let b_ptr = b.as_ptr().add(offset);
let a_vec = _mm512_loadu_ps(a_ptr);
let b_vec = _mm512_loadu_ps(b_ptr);
let diff = _mm512_sub_ps(a_vec, b_vec);
let abs_diff = _mm512_abs_ps(diff);
sum = _mm512_add_ps(sum, abs_diff);
}
let mut total = horizontal_sum_avx512(sum);
for i in (chunks * 16)..len {
total += (a[i] - b[i]).abs();
}
total
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn horizontal_sum_neon(v: float32x4_t) -> f32 {
let pair_sum = vpaddq_f32(v, v);
let final_sum = vpaddq_f32(pair_sum, pair_sum);
vgetq_lane_f32(final_sum, 0)
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn dot_product_neon(a: &[f32], b: &[f32]) -> f32 {
let len = a.len();
let mut sum = vdupq_n_f32(0.0);
let chunks = len / 4;
for i in 0..chunks {
let offset = i * 4;
let a_ptr = a.as_ptr().add(offset);
let b_ptr = b.as_ptr().add(offset);
let a_vec = vld1q_f32(a_ptr);
let b_vec = vld1q_f32(b_ptr);
sum = vmlaq_f32(sum, a_vec, b_vec);
}
let mut total = horizontal_sum_neon(sum);
for i in (chunks * 4)..len {
total += a[i] * b[i];
}
total
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn cosine_similarity_neon(a: &[f32], b: &[f32]) -> f32 {
let len = a.len();
let mut dot_sum = vdupq_n_f32(0.0);
let mut norm_a_sum = vdupq_n_f32(0.0);
let mut norm_b_sum = vdupq_n_f32(0.0);
let chunks = len / 4;
for i in 0..chunks {
let offset = i * 4;
let a_ptr = a.as_ptr().add(offset);
let b_ptr = b.as_ptr().add(offset);
let a_vec = vld1q_f32(a_ptr);
let b_vec = vld1q_f32(b_ptr);
dot_sum = vmlaq_f32(dot_sum, a_vec, b_vec);
norm_a_sum = vmlaq_f32(norm_a_sum, a_vec, a_vec);
norm_b_sum = vmlaq_f32(norm_b_sum, b_vec, b_vec);
}
let mut dot = horizontal_sum_neon(dot_sum);
let mut norm_a = horizontal_sum_neon(norm_a_sum);
let mut norm_b = horizontal_sum_neon(norm_b_sum);
for i in (chunks * 4)..len {
dot += a[i] * b[i];
norm_a += a[i] * a[i];
norm_b += b[i] * b[i];
}
let denominator = (norm_a.sqrt() * norm_b.sqrt()).max(1e-10);
dot / denominator
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn euclidean_distance_neon(a: &[f32], b: &[f32]) -> f32 {
let len = a.len();
let mut sum_sq = vdupq_n_f32(0.0);
let chunks = len / 4;
for i in 0..chunks {
let offset = i * 4;
let a_ptr = a.as_ptr().add(offset);
let b_ptr = b.as_ptr().add(offset);
let a_vec = vld1q_f32(a_ptr);
let b_vec = vld1q_f32(b_ptr);
let diff = vsubq_f32(a_vec, b_vec);
sum_sq = vmlaq_f32(sum_sq, diff, diff);
}
let mut total = horizontal_sum_neon(sum_sq);
for i in (chunks * 4)..len {
let diff = a[i] - b[i];
total += diff * diff;
}
total.sqrt()
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn manhattan_distance_neon(a: &[f32], b: &[f32]) -> f32 {
let len = a.len();
let mut sum = vdupq_n_f32(0.0);
let chunks = len / 4;
for i in 0..chunks {
let offset = i * 4;
let a_ptr = a.as_ptr().add(offset);
let b_ptr = b.as_ptr().add(offset);
let a_vec = vld1q_f32(a_ptr);
let b_vec = vld1q_f32(b_ptr);
let diff = vsubq_f32(a_vec, b_vec);
let abs_diff = vabsq_f32(diff);
sum = vaddq_f32(sum, abs_diff);
}
let mut total = horizontal_sum_neon(sum);
for i in (chunks * 4)..len {
total += (a[i] - b[i]).abs();
}
total
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2,fma")]
#[inline]
unsafe fn dot_product_fma(a: &[f32], b: &[f32]) -> f32 {
let len = a.len();
let mut sum = _mm256_setzero_ps();
let chunks = len / 8;
for i in 0..chunks {
let offset = i * 8;
let a_ptr = a.as_ptr().add(offset);
let b_ptr = b.as_ptr().add(offset);
let a_vec = _mm256_loadu_ps(a_ptr);
let b_vec = _mm256_loadu_ps(b_ptr);
sum = _mm256_fmadd_ps(a_vec, b_vec, sum);
}
let mut total = horizontal_sum_avx2(sum);
for i in (chunks * 8)..len {
total += a[i] * b[i];
}
total
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2,fma")]
#[inline]
unsafe fn cosine_similarity_fma(a: &[f32], b: &[f32]) -> f32 {
let len = a.len();
let mut dot_sum = _mm256_setzero_ps();
let mut norm_a_sum = _mm256_setzero_ps();
let mut norm_b_sum = _mm256_setzero_ps();
let chunks = len / 8;
for i in 0..chunks {
let offset = i * 8;
let a_ptr = a.as_ptr().add(offset);
let b_ptr = b.as_ptr().add(offset);
let a_vec = _mm256_loadu_ps(a_ptr);
let b_vec = _mm256_loadu_ps(b_ptr);
dot_sum = _mm256_fmadd_ps(a_vec, b_vec, dot_sum);
norm_a_sum = _mm256_fmadd_ps(a_vec, a_vec, norm_a_sum);
norm_b_sum = _mm256_fmadd_ps(b_vec, b_vec, norm_b_sum);
}
let mut dot = horizontal_sum_avx2(dot_sum);
let mut norm_a = horizontal_sum_avx2(norm_a_sum);
let mut norm_b = horizontal_sum_avx2(norm_b_sum);
for i in (chunks * 8)..len {
dot += a[i] * b[i];
norm_a += a[i] * a[i];
norm_b += b[i] * b[i];
}
let denominator = (norm_a.sqrt() * norm_b.sqrt()).max(1e-10);
dot / denominator
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2,fma")]
#[inline]
unsafe fn euclidean_distance_fma(a: &[f32], b: &[f32]) -> f32 {
let len = a.len();
let mut sum_sq = _mm256_setzero_ps();
let chunks = len / 8;
for i in 0..chunks {
let offset = i * 8;
let a_ptr = a.as_ptr().add(offset);
let b_ptr = b.as_ptr().add(offset);
let a_vec = _mm256_loadu_ps(a_ptr);
let b_vec = _mm256_loadu_ps(b_ptr);
let diff = _mm256_sub_ps(a_vec, b_vec);
sum_sq = _mm256_fmadd_ps(diff, diff, sum_sq);
}
let mut total = horizontal_sum_avx2(sum_sq);
for i in (chunks * 8)..len {
let diff = a[i] - b[i];
total += diff * diff;
}
total.sqrt()
}
#[cfg(target_arch = "x86_64")]
#[inline]
unsafe fn horizontal_sum_avx2(v: __m256) -> f32 {
let hi = _mm256_extractf128_ps(v, 1); let lo = _mm256_castps256_ps128(v);
let sum128 = _mm_add_ps(lo, hi);
let sum64 = _mm_hadd_ps(sum128, sum128); let sum32 = _mm_hadd_ps(sum64, sum64);
_mm_cvtss_f32(sum32)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
let len = a.len();
let mut sum = _mm256_setzero_ps();
let chunks = len / 8;
for i in 0..chunks {
let offset = i * 8;
let a_ptr = a.as_ptr().add(offset);
let b_ptr = b.as_ptr().add(offset);
let a_vec = _mm256_loadu_ps(a_ptr);
let b_vec = _mm256_loadu_ps(b_ptr);
let mul = _mm256_mul_ps(a_vec, b_vec);
sum = _mm256_add_ps(sum, mul);
}
let mut total = horizontal_sum_avx2(sum);
for i in (chunks * 8)..len {
total += a[i] * b[i];
}
total
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn cosine_similarity_avx2(a: &[f32], b: &[f32]) -> f32 {
let len = a.len();
let mut dot_sum = _mm256_setzero_ps();
let mut norm_a_sum = _mm256_setzero_ps();
let mut norm_b_sum = _mm256_setzero_ps();
let chunks = len / 8;
for i in 0..chunks {
let offset = i * 8;
let a_ptr = a.as_ptr().add(offset);
let b_ptr = b.as_ptr().add(offset);
let a_vec = _mm256_loadu_ps(a_ptr);
let b_vec = _mm256_loadu_ps(b_ptr);
dot_sum = _mm256_add_ps(dot_sum, _mm256_mul_ps(a_vec, b_vec));
norm_a_sum = _mm256_add_ps(norm_a_sum, _mm256_mul_ps(a_vec, a_vec));
norm_b_sum = _mm256_add_ps(norm_b_sum, _mm256_mul_ps(b_vec, b_vec));
}
let mut dot = horizontal_sum_avx2(dot_sum);
let mut norm_a = horizontal_sum_avx2(norm_a_sum);
let mut norm_b = horizontal_sum_avx2(norm_b_sum);
for i in (chunks * 8)..len {
dot += a[i] * b[i];
norm_a += a[i] * a[i];
norm_b += b[i] * b[i];
}
let denominator = (norm_a.sqrt() * norm_b.sqrt()).max(1e-10);
dot / denominator
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn euclidean_distance_avx2(a: &[f32], b: &[f32]) -> f32 {
let len = a.len();
let mut sum_sq = _mm256_setzero_ps();
let chunks = len / 8;
for i in 0..chunks {
let offset = i * 8;
let a_ptr = a.as_ptr().add(offset);
let b_ptr = b.as_ptr().add(offset);
let a_vec = _mm256_loadu_ps(a_ptr);
let b_vec = _mm256_loadu_ps(b_ptr);
let diff = _mm256_sub_ps(a_vec, b_vec);
sum_sq = _mm256_add_ps(sum_sq, _mm256_mul_ps(diff, diff));
}
let mut total = horizontal_sum_avx2(sum_sq);
for i in (chunks * 8)..len {
let diff = a[i] - b[i];
total += diff * diff;
}
total.sqrt()
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn manhattan_distance_avx2(a: &[f32], b: &[f32]) -> f32 {
let len = a.len();
let mut sum = _mm256_setzero_ps();
let sign_mask = _mm256_set1_ps(-0.0);
let chunks = len / 8;
for i in 0..chunks {
let offset = i * 8;
let a_ptr = a.as_ptr().add(offset);
let b_ptr = b.as_ptr().add(offset);
let a_vec = _mm256_loadu_ps(a_ptr);
let b_vec = _mm256_loadu_ps(b_ptr);
let diff = _mm256_sub_ps(a_vec, b_vec);
let abs_diff = _mm256_andnot_ps(sign_mask, diff);
sum = _mm256_add_ps(sum, abs_diff);
}
let mut total = horizontal_sum_avx2(sum);
for i in (chunks * 8)..len {
total += (a[i] - b[i]).abs();
}
total
}
#[inline]
#[allow(dead_code)]
fn cosine_similarity_autovec(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "Vectors must have same dimension");
let chunk_size = 8; let len = a.len();
let chunks = len / chunk_size;
let mut dot_product = 0.0f32;
let mut norm_a = 0.0f32;
let mut norm_b = 0.0f32;
for i in 0..chunks {
let offset = i * chunk_size;
for j in 0..chunk_size {
let idx = offset + j;
let a_val = unsafe { *a.get_unchecked(idx) };
let b_val = unsafe { *b.get_unchecked(idx) };
dot_product += a_val * b_val;
norm_a += a_val * a_val;
norm_b += b_val * b_val;
}
}
for i in (chunks * chunk_size)..len {
let a_val = unsafe { *a.get_unchecked(i) };
let b_val = unsafe { *b.get_unchecked(i) };
dot_product += a_val * b_val;
norm_a += a_val * a_val;
norm_b += b_val * b_val;
}
let denominator = (norm_a.sqrt() * norm_b.sqrt()).max(1e-10);
dot_product / denominator
}
#[inline]
pub fn cosine_similarity_simd(a: &[f32], b: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
{
if is_avx512_available() {
unsafe { cosine_similarity_avx512(a, b) }
} else if is_fma_available() {
unsafe { cosine_similarity_fma(a, b) }
} else if is_avx2_available() {
unsafe { cosine_similarity_avx2(a, b) }
} else {
cosine_similarity_autovec(a, b)
}
}
#[cfg(target_arch = "aarch64")]
{
unsafe { cosine_similarity_neon(a, b) }
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
cosine_similarity_autovec(a, b)
}
}
#[inline]
#[allow(dead_code)]
fn euclidean_distance_autovec(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "Vectors must have same dimension");
let chunk_size = 8;
let len = a.len();
let chunks = len / chunk_size;
let mut sum_sq = 0.0f32;
for i in 0..chunks {
let offset = i * chunk_size;
for j in 0..chunk_size {
let idx = offset + j;
let diff = unsafe { *a.get_unchecked(idx) - *b.get_unchecked(idx) };
sum_sq += diff * diff;
}
}
for i in (chunks * chunk_size)..len {
let diff = unsafe { *a.get_unchecked(i) - *b.get_unchecked(i) };
sum_sq += diff * diff;
}
sum_sq.sqrt()
}
#[inline]
pub fn euclidean_distance_simd(a: &[f32], b: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
{
if is_avx512_available() {
unsafe { euclidean_distance_avx512(a, b) }
} else if is_fma_available() {
unsafe { euclidean_distance_fma(a, b) }
} else if is_avx2_available() {
unsafe { euclidean_distance_avx2(a, b) }
} else {
euclidean_distance_autovec(a, b)
}
}
#[cfg(target_arch = "aarch64")]
{
unsafe { euclidean_distance_neon(a, b) }
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
euclidean_distance_autovec(a, b)
}
}
#[inline]
#[allow(dead_code)]
fn dot_product_autovec(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "Vectors must have same dimension");
let chunk_size = 8;
let len = a.len();
let chunks = len / chunk_size;
let mut dot = 0.0f32;
for i in 0..chunks {
let offset = i * chunk_size;
for j in 0..chunk_size {
let idx = offset + j;
dot += unsafe { *a.get_unchecked(idx) * *b.get_unchecked(idx) };
}
}
for i in (chunks * chunk_size)..len {
dot += unsafe { *a.get_unchecked(i) * *b.get_unchecked(i) };
}
dot
}
#[inline]
pub fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
{
if is_avx512_available() {
unsafe { dot_product_avx512(a, b) }
} else if is_fma_available() {
unsafe { dot_product_fma(a, b) }
} else if is_avx2_available() {
unsafe { dot_product_avx2(a, b) }
} else {
dot_product_autovec(a, b)
}
}
#[cfg(target_arch = "aarch64")]
{
unsafe { dot_product_neon(a, b) }
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
dot_product_autovec(a, b)
}
}
#[inline]
#[allow(dead_code)]
fn manhattan_distance_autovec(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "Vectors must have same dimension");
let chunk_size = 8;
let len = a.len();
let chunks = len / chunk_size;
let mut sum = 0.0f32;
for i in 0..chunks {
let offset = i * chunk_size;
for j in 0..chunk_size {
let idx = offset + j;
sum += unsafe { (*a.get_unchecked(idx) - *b.get_unchecked(idx)).abs() };
}
}
for i in (chunks * chunk_size)..len {
sum += unsafe { (*a.get_unchecked(i) - *b.get_unchecked(i)).abs() };
}
sum
}
#[inline]
pub fn manhattan_distance_simd(a: &[f32], b: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
{
if is_avx512_available() {
unsafe { manhattan_distance_avx512(a, b) }
} else if is_avx2_available() {
unsafe { manhattan_distance_avx2(a, b) }
} else {
manhattan_distance_autovec(a, b)
}
}
#[cfg(target_arch = "aarch64")]
{
unsafe { manhattan_distance_neon(a, b) }
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
manhattan_distance_autovec(a, b)
}
}
pub fn compute_distance_simd(metric: DistanceMetric, a: &[f32], b: &[f32]) -> f32 {
match metric {
DistanceMetric::Cosine => cosine_similarity_simd(a, b),
DistanceMetric::Euclidean => -euclidean_distance_simd(a, b),
DistanceMetric::DotProduct => dot_product_simd(a, b),
DistanceMetric::Manhattan => -manhattan_distance_simd(a, b),
}
}
#[inline]
pub fn compute_distance_lower_is_better_simd(metric: DistanceMetric, a: &[f32], b: &[f32]) -> f32 {
match metric {
DistanceMetric::Cosine => {
1.0 - cosine_similarity_simd(a, b)
}
DistanceMetric::Euclidean => euclidean_distance_simd(a, b),
DistanceMetric::DotProduct => {
-dot_product_simd(a, b)
}
DistanceMetric::Manhattan => manhattan_distance_simd(a, b),
}
}
#[inline]
pub fn quantized_manhattan_distance_simd(a: &[u8], b: &[u8]) -> u32 {
assert_eq!(a.len(), b.len(), "Vector dimension mismatch");
#[cfg(target_arch = "x86_64")]
{
if is_avx2_available() {
return unsafe { quantized_manhattan_distance_avx2(a, b) };
}
}
#[cfg(target_arch = "aarch64")]
{
return unsafe { quantized_manhattan_distance_neon(a, b) };
}
quantized_manhattan_distance_scalar(a, b)
}
#[inline]
pub fn quantized_dot_product_simd(a: &[u8], b: &[u8]) -> u32 {
assert_eq!(a.len(), b.len(), "Vector dimension mismatch");
#[cfg(target_arch = "x86_64")]
{
if is_avx2_available() {
return unsafe { quantized_dot_product_avx2(a, b) };
}
}
#[cfg(target_arch = "aarch64")]
{
return unsafe { quantized_dot_product_neon(a, b) };
}
quantized_dot_product_scalar(a, b)
}
#[inline]
pub fn quantized_euclidean_squared_simd(a: &[u8], b: &[u8]) -> u32 {
assert_eq!(a.len(), b.len(), "Vector dimension mismatch");
#[cfg(target_arch = "x86_64")]
{
if is_avx2_available() {
return unsafe { quantized_euclidean_squared_avx2(a, b) };
}
}
#[cfg(target_arch = "aarch64")]
{
return unsafe { quantized_euclidean_squared_neon(a, b) };
}
quantized_euclidean_squared_scalar(a, b)
}
#[inline]
fn quantized_manhattan_distance_scalar(a: &[u8], b: &[u8]) -> u32 {
a.iter()
.zip(b.iter())
.map(|(&x, &y)| (x as i32 - y as i32).unsigned_abs())
.sum()
}
#[inline]
fn quantized_dot_product_scalar(a: &[u8], b: &[u8]) -> u32 {
a.iter()
.zip(b.iter())
.map(|(&x, &y)| x as u32 * y as u32)
.sum()
}
#[inline]
fn quantized_euclidean_squared_scalar(a: &[u8], b: &[u8]) -> u32 {
a.iter()
.zip(b.iter())
.map(|(&x, &y)| {
let diff = x as i32 - y as i32;
(diff * diff) as u32
})
.sum()
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn quantized_manhattan_distance_avx2(a: &[u8], b: &[u8]) -> u32 {
let len = a.len();
let mut sum = _mm256_setzero_si256();
let mut i = 0;
while i + 32 <= len {
let va = _mm256_loadu_si256(a.as_ptr().add(i) as *const __m256i);
let vb = _mm256_loadu_si256(b.as_ptr().add(i) as *const __m256i);
let diff1 = _mm256_subs_epu8(va, vb);
let diff2 = _mm256_subs_epu8(vb, va);
let abs_diff = _mm256_or_si256(diff1, diff2);
let abs_diff_lo = _mm256_unpacklo_epi8(abs_diff, _mm256_setzero_si256());
let abs_diff_hi = _mm256_unpackhi_epi8(abs_diff, _mm256_setzero_si256());
sum = _mm256_add_epi16(sum, abs_diff_lo);
sum = _mm256_add_epi16(sum, abs_diff_hi);
i += 32;
}
let sum_lo = _mm256_unpacklo_epi16(sum, _mm256_setzero_si256());
let sum_hi = _mm256_unpackhi_epi16(sum, _mm256_setzero_si256());
let sum32 = _mm256_add_epi32(sum_lo, sum_hi);
let mut result_arr = [0u32; 8];
_mm256_storeu_si256(result_arr.as_mut_ptr() as *mut __m256i, sum32);
let mut result: u32 = result_arr.iter().sum();
while i < len {
result += (a[i] as i32 - b[i] as i32).unsigned_abs();
i += 1;
}
result
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn quantized_dot_product_avx2(a: &[u8], b: &[u8]) -> u32 {
let len = a.len();
let mut sum = _mm256_setzero_si256();
let mut i = 0;
while i + 16 <= len {
let va_128 = _mm_loadu_si128(a.as_ptr().add(i) as *const __m128i);
let vb_128 = _mm_loadu_si128(b.as_ptr().add(i) as *const __m128i);
let va = _mm256_cvtepu8_epi16(va_128);
let vb = _mm256_cvtepu8_epi16(vb_128);
let prod = _mm256_madd_epi16(va, vb);
sum = _mm256_add_epi32(sum, prod);
i += 16;
}
let mut result_arr = [0u32; 8];
_mm256_storeu_si256(result_arr.as_mut_ptr() as *mut __m256i, sum);
let mut result: u32 = result_arr.iter().sum();
while i < len {
result += a[i] as u32 * b[i] as u32;
i += 1;
}
result
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn quantized_euclidean_squared_avx2(a: &[u8], b: &[u8]) -> u32 {
let len = a.len();
let mut sum = _mm256_setzero_si256();
let mut i = 0;
while i + 16 <= len {
let va_128 = _mm_loadu_si128(a.as_ptr().add(i) as *const __m128i);
let vb_128 = _mm_loadu_si128(b.as_ptr().add(i) as *const __m128i);
let va = _mm256_cvtepu8_epi16(va_128);
let vb = _mm256_cvtepu8_epi16(vb_128);
let diff = _mm256_sub_epi16(va, vb);
let squared = _mm256_madd_epi16(diff, diff);
sum = _mm256_add_epi32(sum, squared);
i += 16;
}
let mut result_arr = [0u32; 8];
_mm256_storeu_si256(result_arr.as_mut_ptr() as *mut __m256i, sum);
let mut result: u32 = result_arr.iter().sum();
while i < len {
let diff = a[i] as i32 - b[i] as i32;
result += (diff * diff) as u32;
i += 1;
}
result
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn quantized_manhattan_distance_neon(a: &[u8], b: &[u8]) -> u32 {
let len = a.len();
let mut sum = vdupq_n_u32(0);
let mut i = 0;
while i + 16 <= len {
let va = vld1q_u8(a.as_ptr().add(i));
let vb = vld1q_u8(b.as_ptr().add(i));
let abs_diff = vabdq_u8(va, vb);
let abs_diff_lo = vmovl_u8(vget_low_u8(abs_diff));
let abs_diff_hi = vmovl_u8(vget_high_u8(abs_diff));
sum = vaddw_u16(sum, vget_low_u16(abs_diff_lo));
sum = vaddw_u16(sum, vget_high_u16(abs_diff_lo));
sum = vaddw_u16(sum, vget_low_u16(abs_diff_hi));
sum = vaddw_u16(sum, vget_high_u16(abs_diff_hi));
i += 16;
}
let mut result = vaddvq_u32(sum);
while i < len {
result += (a[i] as i32 - b[i] as i32).unsigned_abs();
i += 1;
}
result
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn quantized_dot_product_neon(a: &[u8], b: &[u8]) -> u32 {
let len = a.len();
let mut sum = vdupq_n_u32(0);
let mut i = 0;
while i + 8 <= len {
let va = vld1_u8(a.as_ptr().add(i));
let vb = vld1_u8(b.as_ptr().add(i));
let va_16 = vmovl_u8(va);
let vb_16 = vmovl_u8(vb);
let prod = vmull_u16(vget_low_u16(va_16), vget_low_u16(vb_16));
sum = vaddq_u32(sum, prod);
let prod_hi = vmull_u16(vget_high_u16(va_16), vget_high_u16(vb_16));
sum = vaddq_u32(sum, prod_hi);
i += 8;
}
let mut result = vaddvq_u32(sum);
while i < len {
result += a[i] as u32 * b[i] as u32;
i += 1;
}
result
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn quantized_euclidean_squared_neon(a: &[u8], b: &[u8]) -> u32 {
let len = a.len();
let mut sum = vdupq_n_u32(0);
let mut i = 0;
while i + 8 <= len {
let va = vld1_u8(a.as_ptr().add(i));
let vb = vld1_u8(b.as_ptr().add(i));
let abs_diff = vabd_u8(va, vb);
let diff_16 = vmovl_u8(abs_diff);
let squared = vmull_u16(vget_low_u16(diff_16), vget_low_u16(diff_16));
sum = vaddq_u32(sum, squared);
let squared_hi = vmull_u16(vget_high_u16(diff_16), vget_high_u16(diff_16));
sum = vaddq_u32(sum, squared_hi);
i += 8;
}
let mut result = vaddvq_u32(sum);
while i < len {
let diff = a[i] as i32 - b[i] as i32;
result += (diff * diff) as u32;
i += 1;
}
result
}
#[inline]
pub fn normalize_vector_simd(vec: &mut [f32]) {
let norm_squared = dot_product_simd(vec, vec);
let norm = norm_squared.sqrt();
if norm > 1e-10 {
let inv_norm = 1.0 / norm;
scale_vector_simd(vec, inv_norm);
}
}
#[inline]
pub fn scale_vector_simd(vec: &mut [f32], scalar: f32) {
#[cfg(target_arch = "x86_64")]
{
if is_avx512_available() {
unsafe {
scale_vector_avx512(vec, scalar);
}
return;
}
if is_avx2_available() {
unsafe {
scale_vector_avx2(vec, scalar);
}
return;
}
}
#[cfg(target_arch = "aarch64")]
{
unsafe {
scale_vector_neon(vec, scalar);
}
return;
}
for x in vec.iter_mut() {
*x *= scalar;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[inline]
unsafe fn scale_vector_avx512(vec: &mut [f32], scalar: f32) {
let len = vec.len();
let scalar_vec = _mm512_set1_ps(scalar);
let mut i = 0;
while i + 16 <= len {
let ptr = vec.as_mut_ptr().add(i);
let v = _mm512_loadu_ps(ptr);
let scaled = _mm512_mul_ps(v, scalar_vec);
_mm512_storeu_ps(ptr, scaled);
i += 16;
}
while i < len {
vec[i] *= scalar;
i += 1;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn scale_vector_avx2(vec: &mut [f32], scalar: f32) {
let len = vec.len();
let scalar_vec = _mm256_set1_ps(scalar);
let mut i = 0;
while i + 8 <= len {
let ptr = vec.as_mut_ptr().add(i);
let v = _mm256_loadu_ps(ptr);
let scaled = _mm256_mul_ps(v, scalar_vec);
_mm256_storeu_ps(ptr, scaled);
i += 8;
}
while i < len {
vec[i] *= scalar;
i += 1;
}
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn scale_vector_neon(vec: &mut [f32], scalar: f32) {
let len = vec.len();
let scalar_vec = vdupq_n_f32(scalar);
let mut i = 0;
while i + 4 <= len {
let ptr = vec.as_mut_ptr().add(i);
let v = vld1q_f32(ptr);
let scaled = vmulq_f32(v, scalar_vec);
vst1q_f32(ptr, scaled);
i += 4;
}
while i < len {
vec[i] *= scalar;
i += 1;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cosine_similarity_simd() {
let v1 = vec![1.0, 0.0, 0.0];
let v2 = vec![1.0, 0.0, 0.0];
let sim = cosine_similarity_simd(&v1, &v2);
assert!((sim - 1.0).abs() < 1e-6);
let v1 = vec![1.0, 0.0, 0.0];
let v2 = vec![0.0, 1.0, 0.0];
let sim = cosine_similarity_simd(&v1, &v2);
assert!(sim.abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_simd_large() {
let v1: Vec<f32> = (0..100).map(|i| i as f32).collect();
let v2: Vec<f32> = (0..100).map(|i| (i + 1) as f32).collect();
let sim = cosine_similarity_simd(&v1, &v2);
assert!(sim > 0.99); }
#[test]
fn test_euclidean_distance_simd() {
let v1 = vec![0.0, 0.0, 0.0];
let v2 = vec![3.0, 4.0, 0.0];
let dist = euclidean_distance_simd(&v1, &v2);
assert!((dist - 5.0).abs() < 1e-6);
}
#[test]
fn test_euclidean_distance_simd_large() {
let v1 = vec![0.0; 100];
let v2 = vec![1.0; 100];
let dist = euclidean_distance_simd(&v1, &v2);
assert!((dist - 10.0).abs() < 1e-6); }
#[test]
fn test_dot_product_simd() {
let v1 = vec![1.0, 2.0, 3.0];
let v2 = vec![4.0, 5.0, 6.0];
let dot = dot_product_simd(&v1, &v2);
assert!((dot - 32.0).abs() < 1e-6); }
#[test]
fn test_dot_product_simd_large() {
let v1: Vec<f32> = (1..=100).map(|i| i as f32).collect();
let v2: Vec<f32> = (1..=100).map(|i| i as f32).collect();
let dot = dot_product_simd(&v1, &v2);
let expected: f32 = (1..=100).map(|i| (i * i) as f32).sum();
assert!((dot - expected).abs() < 1e-3);
}
#[test]
fn test_manhattan_distance_simd() {
let v1 = vec![1.0, 2.0, 3.0];
let v2 = vec![4.0, 5.0, 6.0];
let dist = manhattan_distance_simd(&v1, &v2);
assert!((dist - 9.0).abs() < 1e-6); }
#[test]
fn test_manhattan_distance_simd_large() {
let v1 = vec![0.0; 100];
let v2 = vec![1.0; 100];
let dist = manhattan_distance_simd(&v1, &v2);
assert!((dist - 100.0).abs() < 1e-6);
}
#[test]
fn test_compute_distance_simd() {
let v1 = vec![1.0, 0.0, 0.0];
let v2 = vec![1.0, 0.0, 0.0];
let sim = compute_distance_simd(DistanceMetric::Cosine, &v1, &v2);
assert!((sim - 1.0).abs() < 1e-6);
let dist = compute_distance_simd(DistanceMetric::Euclidean, &v1, &v2);
assert!(dist.abs() < 1e-6);
let dot = compute_distance_simd(DistanceMetric::DotProduct, &v1, &v2);
assert!((dot - 1.0).abs() < 1e-6);
let manhattan = compute_distance_simd(DistanceMetric::Manhattan, &v1, &v2);
assert!(manhattan.abs() < 1e-6);
}
#[test]
fn test_is_avx2_available() {
let _available = is_avx2_available();
#[cfg(not(target_arch = "x86_64"))]
assert!(!is_avx2_available());
}
#[test]
fn test_is_neon_available() {
let available = is_neon_available();
#[cfg(target_arch = "aarch64")]
assert!(available, "NEON should always be available on aarch64");
#[cfg(not(target_arch = "aarch64"))]
assert!(!available, "NEON should not be available on non-aarch64");
}
#[test]
fn test_is_avx512_available() {
let _available = is_avx512_available();
#[cfg(not(target_arch = "x86_64"))]
assert!(!is_avx512_available());
}
#[test]
fn test_avx2_correctness() {
let v1: Vec<f32> = (0..768).map(|i| (i as f32) * 0.01).collect();
let v2: Vec<f32> = (0..768).map(|i| (i as f32) * 0.02).collect();
let cosine = cosine_similarity_simd(&v1, &v2);
let euclidean = euclidean_distance_simd(&v1, &v2);
let dot = dot_product_simd(&v1, &v2);
let manhattan = manhattan_distance_simd(&v1, &v2);
assert!(cosine > 0.0 && cosine <= 1.0);
assert!(euclidean > 0.0);
assert!(dot > 0.0);
assert!(manhattan > 0.0);
let cosine_autovec = cosine_similarity_autovec(&v1, &v2);
let euclidean_autovec = euclidean_distance_autovec(&v1, &v2);
let dot_autovec = dot_product_autovec(&v1, &v2);
let manhattan_autovec = manhattan_distance_autovec(&v1, &v2);
let relative_error = |a: f32, b: f32| (a - b).abs() / a.max(b).max(1.0);
assert!(relative_error(cosine, cosine_autovec) < 1e-5);
assert!(relative_error(euclidean, euclidean_autovec) < 1e-5);
assert!(relative_error(dot, dot_autovec) < 1e-5);
assert!(relative_error(manhattan, manhattan_autovec) < 1e-5);
}
#[test]
fn test_neon_correctness() {
let v1: Vec<f32> = (0..768).map(|i| (i as f32) * 0.01).collect();
let v2: Vec<f32> = (0..768).map(|i| (i as f32) * 0.02).collect();
let cosine = cosine_similarity_simd(&v1, &v2);
let euclidean = euclidean_distance_simd(&v1, &v2);
let dot = dot_product_simd(&v1, &v2);
let manhattan = manhattan_distance_simd(&v1, &v2);
assert!(cosine > 0.0 && cosine <= 1.0);
assert!(euclidean > 0.0);
assert!(dot > 0.0);
assert!(manhattan > 0.0);
let cosine_autovec = cosine_similarity_autovec(&v1, &v2);
let euclidean_autovec = euclidean_distance_autovec(&v1, &v2);
let dot_autovec = dot_product_autovec(&v1, &v2);
let manhattan_autovec = manhattan_distance_autovec(&v1, &v2);
let relative_error = |a: f32, b: f32| (a - b).abs() / a.max(b).max(1.0);
assert!(relative_error(cosine, cosine_autovec) < 1e-5);
assert!(relative_error(euclidean, euclidean_autovec) < 1e-5);
assert!(relative_error(dot, dot_autovec) < 1e-5);
assert!(relative_error(manhattan, manhattan_autovec) < 1e-5);
}
#[test]
fn test_avx512_correctness() {
let v1: Vec<f32> = (0..1024).map(|i| (i as f32) * 0.01).collect();
let v2: Vec<f32> = (0..1024).map(|i| (i as f32) * 0.02).collect();
let cosine = cosine_similarity_simd(&v1, &v2);
let euclidean = euclidean_distance_simd(&v1, &v2);
let dot = dot_product_simd(&v1, &v2);
let manhattan = manhattan_distance_simd(&v1, &v2);
assert!(cosine > 0.0 && cosine <= 1.0);
assert!(euclidean > 0.0);
assert!(dot > 0.0);
assert!(manhattan > 0.0);
let cosine_autovec = cosine_similarity_autovec(&v1, &v2);
let euclidean_autovec = euclidean_distance_autovec(&v1, &v2);
let dot_autovec = dot_product_autovec(&v1, &v2);
let manhattan_autovec = manhattan_distance_autovec(&v1, &v2);
let relative_error = |a: f32, b: f32| (a - b).abs() / a.max(b).max(1.0);
assert!(relative_error(cosine, cosine_autovec) < 1e-5);
assert!(relative_error(euclidean, euclidean_autovec) < 1e-5);
assert!(relative_error(dot, dot_autovec) < 1e-5);
assert!(relative_error(manhattan, manhattan_autovec) < 1e-5);
}
#[test]
fn test_quantized_manhattan_distance() {
let a = vec![10u8, 20, 30, 40, 50, 60, 70, 80];
let b = vec![15u8, 25, 35, 45, 55, 65, 75, 85];
let distance_simd = quantized_manhattan_distance_simd(&a, &b);
let distance_scalar = quantized_manhattan_distance_scalar(&a, &b);
assert_eq!(distance_simd, distance_scalar);
assert_eq!(distance_simd, 40); }
#[test]
fn test_quantized_manhattan_distance_large() {
let a: Vec<u8> = (0..768).map(|i| (i % 256) as u8).collect();
let b: Vec<u8> = (0..768).map(|i| ((i + 10) % 256) as u8).collect();
let distance_simd = quantized_manhattan_distance_simd(&a, &b);
let distance_scalar = quantized_manhattan_distance_scalar(&a, &b);
assert_eq!(distance_simd, distance_scalar);
}
#[test]
fn test_quantized_dot_product() {
let a = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
let b = vec![8u8, 7, 6, 5, 4, 3, 2, 1];
let dot_simd = quantized_dot_product_simd(&a, &b);
let dot_scalar = quantized_dot_product_scalar(&a, &b);
assert_eq!(dot_simd, dot_scalar);
assert_eq!(dot_simd, 120);
}
#[test]
fn test_quantized_dot_product_large() {
let a: Vec<u8> = (0..768).map(|i| (i % 256) as u8).collect();
let b: Vec<u8> = (0..768).map(|i| ((255 - i) % 256) as u8).collect();
let dot_simd = quantized_dot_product_simd(&a, &b);
let dot_scalar = quantized_dot_product_scalar(&a, &b);
assert_eq!(dot_simd, dot_scalar);
}
#[test]
fn test_quantized_euclidean_squared() {
let a = vec![10u8, 20, 30, 40];
let b = vec![13u8, 24, 27, 45];
let dist_simd = quantized_euclidean_squared_simd(&a, &b);
let dist_scalar = quantized_euclidean_squared_scalar(&a, &b);
assert_eq!(dist_simd, dist_scalar);
assert_eq!(dist_simd, 59);
}
#[test]
fn test_quantized_euclidean_squared_large() {
let a: Vec<u8> = (0..768).map(|i| (i % 256) as u8).collect();
let b: Vec<u8> = (0..768).map(|i| ((i + 5) % 256) as u8).collect();
let dist_simd = quantized_euclidean_squared_simd(&a, &b);
let dist_scalar = quantized_euclidean_squared_scalar(&a, &b);
assert_eq!(dist_simd, dist_scalar);
}
#[test]
fn test_quantized_edge_cases() {
let a = vec![100u8; 100];
let b = vec![100u8; 100];
assert_eq!(quantized_manhattan_distance_simd(&a, &b), 0);
assert_eq!(quantized_euclidean_squared_simd(&a, &b), 0);
let c = vec![0u8; 100];
let d = vec![255u8; 100];
assert_eq!(quantized_manhattan_distance_simd(&c, &d), 255 * 100);
assert_eq!(quantized_euclidean_squared_simd(&c, &d), 255 * 255 * 100);
}
#[test]
fn test_quantized_simd_correctness() {
let a: Vec<u8> = (0..1024).map(|i| ((i * 17 + 42) % 256) as u8).collect();
let b: Vec<u8> = (0..1024).map(|i| ((i * 23 + 99) % 256) as u8).collect();
let manhattan_simd = quantized_manhattan_distance_simd(&a, &b);
let manhattan_scalar = quantized_manhattan_distance_scalar(&a, &b);
assert_eq!(manhattan_simd, manhattan_scalar);
let dot_simd = quantized_dot_product_simd(&a, &b);
let dot_scalar = quantized_dot_product_scalar(&a, &b);
assert_eq!(dot_simd, dot_scalar);
let euclidean_simd = quantized_euclidean_squared_simd(&a, &b);
let euclidean_scalar = quantized_euclidean_squared_scalar(&a, &b);
assert_eq!(euclidean_simd, euclidean_scalar);
}
#[test]
fn test_normalize_vector_simd() {
let mut vec = vec![3.0, 4.0, 0.0];
normalize_vector_simd(&mut vec);
assert!((vec[0] - 0.6).abs() < 1e-6);
assert!((vec[1] - 0.8).abs() < 1e-6);
assert!((vec[2] - 0.0).abs() < 1e-6);
let norm_squared: f32 = vec.iter().map(|x| x * x).sum();
assert!((norm_squared - 1.0).abs() < 1e-6);
}
#[test]
fn test_normalize_vector_simd_large() {
let mut vec: Vec<f32> = (0..768).map(|i| (i % 100) as f32).collect();
normalize_vector_simd(&mut vec);
let norm_squared: f32 = vec.iter().map(|x| x * x).sum();
assert!((norm_squared - 1.0).abs() < 1e-5);
}
#[test]
fn test_normalize_vector_simd_zero() {
let mut vec = vec![0.0, 0.0, 0.0];
normalize_vector_simd(&mut vec);
assert_eq!(vec, vec![0.0, 0.0, 0.0]);
}
#[test]
fn test_scale_vector_simd() {
let mut vec = vec![1.0, 2.0, 3.0, 4.0];
scale_vector_simd(&mut vec, 2.0);
assert_eq!(vec, vec![2.0, 4.0, 6.0, 8.0]);
}
#[test]
fn test_scale_vector_simd_large() {
let mut vec: Vec<f32> = (0..1024).map(|i| i as f32).collect();
scale_vector_simd(&mut vec, 0.5);
for (i, &value) in vec.iter().enumerate() {
assert!((value - (i as f32 * 0.5)).abs() < 1e-5);
}
}
}