#[cfg(feature = "simd")]
use wide::f32x8;
#[cfg(feature = "simd")]
#[must_use]
pub fn l2_distance_squared_simd(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "vectors must have same length");
let len = a.len();
let chunks = len / 8;
let remainder = len % 8;
let mut sum = f32x8::ZERO;
for i in 0..chunks {
let offset = i * 8;
let a_chunk = f32x8::new([
a[offset],
a[offset + 1],
a[offset + 2],
a[offset + 3],
a[offset + 4],
a[offset + 5],
a[offset + 6],
a[offset + 7],
]);
let b_chunk = f32x8::new([
b[offset],
b[offset + 1],
b[offset + 2],
b[offset + 3],
b[offset + 4],
b[offset + 5],
b[offset + 6],
b[offset + 7],
]);
let diff = a_chunk - b_chunk;
sum += diff * diff;
}
let sum_array: [f32; 8] = sum.into();
let mut total: f32 = sum_array.iter().sum();
let offset = chunks * 8;
for i in 0..remainder {
let diff = a[offset + i] - b[offset + i];
total += diff * diff;
}
total
}
#[cfg(feature = "simd")]
#[must_use]
pub fn l2_distance_simd(a: &[f32], b: &[f32]) -> f32 {
l2_distance_squared_simd(a, b).sqrt()
}
#[cfg(not(feature = "simd"))]
pub fn l2_distance_squared_simd(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| {
let diff = x - y;
diff * diff
})
.sum()
}
#[cfg(not(feature = "simd"))]
pub fn l2_distance_simd(a: &[f32], b: &[f32]) -> f32 {
l2_distance_squared_simd(a, b).sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_l2_distance_squared_basic() {
let a = [0.0, 0.0, 0.0];
let b = [3.0, 4.0, 0.0];
let dist_sq = l2_distance_squared_simd(&a, &b);
assert!(
(dist_sq - 25.0).abs() < 1e-6,
"expected 25.0, got {}",
dist_sq
);
}
#[test]
fn test_l2_distance_basic() {
let a = [0.0, 0.0];
let b = [3.0, 4.0];
let dist = l2_distance_simd(&a, &b);
assert!((dist - 5.0).abs() < 1e-6, "expected 5.0, got {}", dist);
}
#[test]
fn test_l2_distance_384_dims() {
let a: Vec<f32> = (0..384).map(|i| i as f32 * 0.01).collect();
let b: Vec<f32> = (0..384).map(|i| (i + 1) as f32 * 0.01).collect();
let dist_simd = l2_distance_simd(&a, &b);
let dist_scalar: f32 = a
.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt();
assert!(
(dist_simd - dist_scalar).abs() < 1e-4,
"SIMD {} vs Scalar {}",
dist_simd,
dist_scalar
);
}
}