use super::distance::DistanceMetric;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SimdLevel {
Scalar,
Sse,
Avx,
AvxFma,
}
impl SimdLevel {
#[inline]
pub fn detect() -> Self {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx") && is_x86_feature_detected!("fma") {
SimdLevel::AvxFma
} else if is_x86_feature_detected!("avx") {
SimdLevel::Avx
} else if is_x86_feature_detected!("sse") {
SimdLevel::Sse
} else {
SimdLevel::Scalar
}
}
#[cfg(not(target_arch = "x86_64"))]
{
SimdLevel::Scalar
}
}
}
static SIMD_LEVEL: std::sync::OnceLock<SimdLevel> = std::sync::OnceLock::new();
#[inline]
pub fn simd_level() -> SimdLevel {
*SIMD_LEVEL.get_or_init(SimdLevel::detect)
}
#[inline]
pub fn l2_squared_simd(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len(), "Vector dimensions must match");
match simd_level() {
#[cfg(target_arch = "x86_64")]
SimdLevel::AvxFma => unsafe { l2_squared_avx_fma(a, b) },
#[cfg(target_arch = "x86_64")]
SimdLevel::Avx => unsafe { l2_squared_avx(a, b) },
#[cfg(target_arch = "x86_64")]
SimdLevel::Sse => unsafe { l2_squared_sse(a, b) },
_ => l2_squared_scalar(a, b),
}
}
#[inline]
fn l2_squared_scalar(a: &[f32], b: &[f32]) -> f32 {
let mut sum = 0.0f32;
for i in 0..a.len() {
let d = a[i] - b[i];
sum += d * d;
}
sum
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse")]
unsafe fn l2_squared_sse(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let len = a.len();
let mut sum = _mm_setzero_ps();
let chunks = len / 4;
for i in 0..chunks {
let idx = i * 4;
let va = _mm_loadu_ps(a.as_ptr().add(idx));
let vb = _mm_loadu_ps(b.as_ptr().add(idx));
let diff = _mm_sub_ps(va, vb);
let sq = _mm_mul_ps(diff, diff);
sum = _mm_add_ps(sum, sq);
}
let mut result = horizontal_sum_sse(sum);
for i in (chunks * 4)..len {
let d = a[i] - b[i];
result += d * d;
}
result
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse")]
#[inline]
unsafe fn horizontal_sum_sse(v: std::arch::x86_64::__m128) -> f32 {
use std::arch::x86_64::*;
let shuf = _mm_movehdup_ps(v); let sums = _mm_add_ps(v, shuf); let shuf2 = _mm_movehl_ps(sums, sums); let sums2 = _mm_add_ss(sums, shuf2); _mm_cvtss_f32(sums2)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx")]
unsafe fn l2_squared_avx(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let len = a.len();
let mut sum = _mm256_setzero_ps();
let chunks = len / 8;
for i in 0..chunks {
let idx = i * 8;
let va = _mm256_loadu_ps(a.as_ptr().add(idx));
let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
let diff = _mm256_sub_ps(va, vb);
let sq = _mm256_mul_ps(diff, diff);
sum = _mm256_add_ps(sum, sq);
}
let mut result = horizontal_sum_avx(sum);
for i in (chunks * 8)..len {
let d = a[i] - b[i];
result += d * d;
}
result
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx")]
#[inline]
unsafe fn horizontal_sum_avx(v: std::arch::x86_64::__m256) -> f32 {
use std::arch::x86_64::*;
let high = _mm256_extractf128_ps(v, 1);
let low = _mm256_castps256_ps128(v);
let sum128 = _mm_add_ps(high, low);
horizontal_sum_sse(sum128)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx", enable = "fma")]
unsafe fn l2_squared_avx_fma(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let len = a.len();
let mut sum = _mm256_setzero_ps();
let chunks = len / 8;
for i in 0..chunks {
let idx = i * 8;
let va = _mm256_loadu_ps(a.as_ptr().add(idx));
let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
let diff = _mm256_sub_ps(va, vb);
sum = _mm256_fmadd_ps(diff, diff, sum);
}
let mut result = horizontal_sum_avx(sum);
for i in (chunks * 8)..len {
let d = a[i] - b[i];
result += d * d;
}
result
}
#[inline]
pub fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len(), "Vector dimensions must match");
match simd_level() {
#[cfg(target_arch = "x86_64")]
SimdLevel::AvxFma => unsafe { dot_product_avx_fma(a, b) },
#[cfg(target_arch = "x86_64")]
SimdLevel::Avx => unsafe { dot_product_avx(a, b) },
#[cfg(target_arch = "x86_64")]
SimdLevel::Sse => unsafe { dot_product_sse(a, b) },
_ => dot_product_scalar(a, b),
}
}
#[inline]
fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
let mut sum = 0.0f32;
for i in 0..a.len() {
sum += a[i] * b[i];
}
sum
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse")]
unsafe fn dot_product_sse(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let len = a.len();
let mut sum = _mm_setzero_ps();
let chunks = len / 4;
for i in 0..chunks {
let idx = i * 4;
let va = _mm_loadu_ps(a.as_ptr().add(idx));
let vb = _mm_loadu_ps(b.as_ptr().add(idx));
let prod = _mm_mul_ps(va, vb);
sum = _mm_add_ps(sum, prod);
}
let mut result = horizontal_sum_sse(sum);
for i in (chunks * 4)..len {
result += a[i] * b[i];
}
result
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx")]
unsafe fn dot_product_avx(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let len = a.len();
let mut sum = _mm256_setzero_ps();
let chunks = len / 8;
for i in 0..chunks {
let idx = i * 8;
let va = _mm256_loadu_ps(a.as_ptr().add(idx));
let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
let prod = _mm256_mul_ps(va, vb);
sum = _mm256_add_ps(sum, prod);
}
let mut result = horizontal_sum_avx(sum);
for i in (chunks * 8)..len {
result += a[i] * b[i];
}
result
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx", enable = "fma")]
unsafe fn dot_product_avx_fma(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let len = a.len();
let mut sum = _mm256_setzero_ps();
let chunks = len / 8;
for i in 0..chunks {
let idx = i * 8;
let va = _mm256_loadu_ps(a.as_ptr().add(idx));
let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
sum = _mm256_fmadd_ps(va, vb, sum);
}
let mut result = horizontal_sum_avx(sum);
for i in (chunks * 8)..len {
result += a[i] * b[i];
}
result
}
#[inline]
pub fn l2_norm_simd(v: &[f32]) -> f32 {
dot_product_simd(v, v).sqrt()
}
#[inline]
pub fn cosine_distance_simd(a: &[f32], b: &[f32]) -> f32 {
let dot = dot_product_simd(a, b);
let norm_a = l2_norm_simd(a);
let norm_b = l2_norm_simd(b);
if norm_a == 0.0 || norm_b == 0.0 {
return 1.0;
}
let similarity = (dot / (norm_a * norm_b)).clamp(-1.0, 1.0);
1.0 - similarity
}
#[inline]
pub fn inner_product_distance_simd(a: &[f32], b: &[f32]) -> f32 {
-dot_product_simd(a, b)
}
#[inline]
pub fn distance_simd(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
match metric {
DistanceMetric::L2 => l2_squared_simd(a, b),
DistanceMetric::Cosine => cosine_distance_simd(a, b),
DistanceMetric::InnerProduct => inner_product_distance_simd(a, b),
}
}
pub fn batch_distances(
query: &[f32],
targets: &[Vec<f32>],
metric: DistanceMetric,
top_k: usize,
) -> Vec<(usize, f32)> {
let mut results: Vec<(usize, f32)> = targets
.iter()
.enumerate()
.map(|(i, target)| (i, distance_simd(query, target, metric)))
.collect();
if top_k < results.len() {
results.select_nth_unstable_by(top_k, |a, b| {
a.1.partial_cmp(&b.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.0.cmp(&b.0))
});
results.truncate(top_k);
}
results.sort_by(|a, b| {
a.1.partial_cmp(&b.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.0.cmp(&b.0))
});
results
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simd_level_detection() {
let level = simd_level();
println!("Detected SIMD level: {:?}", level);
assert!(matches!(
level,
SimdLevel::Scalar | SimdLevel::Sse | SimdLevel::Avx | SimdLevel::AvxFma
));
}
#[test]
fn test_l2_squared_simd_identical() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let b = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
assert!((l2_squared_simd(&a, &b) - 0.0).abs() < 1e-6);
}
#[test]
fn test_l2_squared_simd_simple() {
let a = vec![0.0, 0.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0, 0.0];
assert!((l2_squared_simd(&a, &b) - 1.0).abs() < 1e-6);
}
#[test]
fn test_l2_squared_simd_vs_scalar() {
let a: Vec<f32> = (0..256).map(|i| i as f32 * 0.1).collect();
let b: Vec<f32> = (0..256).map(|i| (i + 1) as f32 * 0.1).collect();
let simd_result = l2_squared_simd(&a, &b);
let scalar_result = l2_squared_scalar(&a, &b);
assert!(
(simd_result - scalar_result).abs() < 1e-3,
"SIMD: {}, Scalar: {}",
simd_result,
scalar_result
);
}
#[test]
fn test_dot_product_simd() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let b = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
let result = dot_product_simd(&a, &b);
assert!((result - 36.0).abs() < 1e-6); }
#[test]
fn test_dot_product_simd_vs_scalar() {
let a: Vec<f32> = (0..256).map(|i| i as f32 * 0.1).collect();
let b: Vec<f32> = (0..256).map(|i| (i + 1) as f32 * 0.1).collect();
let simd_result = dot_product_simd(&a, &b);
let scalar_result = dot_product_scalar(&a, &b);
assert!(
(simd_result - scalar_result).abs() < 1.0, "SIMD: {}, Scalar: {}",
simd_result,
scalar_result
);
}
#[test]
fn test_l2_norm_simd() {
let v = vec![3.0, 4.0, 0.0, 0.0];
assert!((l2_norm_simd(&v) - 5.0).abs() < 1e-6);
}
#[test]
fn test_cosine_distance_simd_identical() {
let a = vec![1.0, 0.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0, 0.0];
assert!((cosine_distance_simd(&a, &b) - 0.0).abs() < 1e-6);
}
#[test]
fn test_cosine_distance_simd_orthogonal() {
let a = vec![1.0, 0.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0, 0.0];
assert!((cosine_distance_simd(&a, &b) - 1.0).abs() < 1e-6);
}
#[test]
fn test_batch_distances() {
let query = vec![0.0, 0.0, 0.0, 0.0];
let targets = vec![
vec![1.0, 0.0, 0.0, 0.0], vec![2.0, 0.0, 0.0, 0.0], vec![0.5, 0.0, 0.0, 0.0], ];
let results = batch_distances(&query, &targets, DistanceMetric::L2, 2);
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, 2); assert_eq!(results[1].0, 0); }
#[test]
fn test_odd_length_vectors() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0]; let b = vec![5.0, 4.0, 3.0, 2.0, 1.0];
let simd_result = l2_squared_simd(&a, &b);
let expected = 16.0 + 4.0 + 0.0 + 4.0 + 16.0; assert!((simd_result - expected).abs() < 1e-6);
}
#[test]
fn test_large_vectors() {
let a: Vec<f32> = (0..1536).map(|i| (i as f32).sin()).collect();
let b: Vec<f32> = (0..1536).map(|i| (i as f32).cos()).collect();
let simd_result = l2_squared_simd(&a, &b);
let scalar_result = l2_squared_scalar(&a, &b);
assert!(
(simd_result - scalar_result).abs() / scalar_result.abs() < 1e-5,
"Relative error too large: SIMD={}, Scalar={}",
simd_result,
scalar_result
);
}
}