Skip to main content

memvid_core/
simd.rs

1//! SIMD-accelerated distance calculations for vector search.
2//!
3//! This module provides optimized L2 (Euclidean) distance functions using
4//! the `wide` crate for portable SIMD across `x86_64` and aarch64.
5
6#[cfg(feature = "simd")]
7use wide::f32x8;
8
9/// Compute squared L2 distance between two f32 slices using SIMD.
10///
11/// Uses 8-wide SIMD lanes (AVX2 on `x86_64`, NEON on aarch64).
12/// Falls back to scalar for remainder elements.
13#[cfg(feature = "simd")]
14#[must_use]
15pub fn l2_distance_squared_simd(a: &[f32], b: &[f32]) -> f32 {
16    debug_assert_eq!(a.len(), b.len(), "vectors must have same length");
17
18    let len = a.len();
19    let chunks = len / 8;
20    let remainder = len % 8;
21
22    let mut sum = f32x8::ZERO;
23
24    // Process 8 elements at a time
25    for i in 0..chunks {
26        let offset = i * 8;
27        let a_chunk = f32x8::new([
28            a[offset],
29            a[offset + 1],
30            a[offset + 2],
31            a[offset + 3],
32            a[offset + 4],
33            a[offset + 5],
34            a[offset + 6],
35            a[offset + 7],
36        ]);
37        let b_chunk = f32x8::new([
38            b[offset],
39            b[offset + 1],
40            b[offset + 2],
41            b[offset + 3],
42            b[offset + 4],
43            b[offset + 5],
44            b[offset + 6],
45            b[offset + 7],
46        ]);
47        let diff = a_chunk - b_chunk;
48        sum += diff * diff;
49    }
50
51    // Horizontal sum of the SIMD vector
52    let sum_array: [f32; 8] = sum.into();
53    let mut total: f32 = sum_array.iter().sum();
54
55    // Handle remainder elements with scalar code
56    let offset = chunks * 8;
57    for i in 0..remainder {
58        let diff = a[offset + i] - b[offset + i];
59        total += diff * diff;
60    }
61
62    total
63}
64
65/// Compute L2 distance (with sqrt) using SIMD.
66#[cfg(feature = "simd")]
67#[must_use]
68pub fn l2_distance_simd(a: &[f32], b: &[f32]) -> f32 {
69    l2_distance_squared_simd(a, b).sqrt()
70}
71
72// Scalar fallbacks when SIMD feature is disabled
73
74/// Compute squared L2 distance using scalar math.
75#[cfg(not(feature = "simd"))]
76pub fn l2_distance_squared_simd(a: &[f32], b: &[f32]) -> f32 {
77    a.iter()
78        .zip(b.iter())
79        .map(|(x, y)| {
80            let diff = x - y;
81            diff * diff
82        })
83        .sum()
84}
85
86/// Compute L2 distance using scalar math.
87#[cfg(not(feature = "simd"))]
88pub fn l2_distance_simd(a: &[f32], b: &[f32]) -> f32 {
89    l2_distance_squared_simd(a, b).sqrt()
90}
91
92#[cfg(test)]
93mod tests {
94    use super::*;
95
96    #[test]
97    fn test_l2_distance_squared_basic() {
98        let a = [0.0, 0.0, 0.0];
99        let b = [3.0, 4.0, 0.0];
100        let dist_sq = l2_distance_squared_simd(&a, &b);
101        assert!(
102            (dist_sq - 25.0).abs() < 1e-6,
103            "expected 25.0, got {}",
104            dist_sq
105        );
106    }
107
108    #[test]
109    fn test_l2_distance_basic() {
110        let a = [0.0, 0.0];
111        let b = [3.0, 4.0];
112        let dist = l2_distance_simd(&a, &b);
113        assert!((dist - 5.0).abs() < 1e-6, "expected 5.0, got {}", dist);
114    }
115
116    #[test]
117    fn test_l2_distance_384_dims() {
118        // Test with realistic 384-dim vectors
119        let a: Vec<f32> = (0..384).map(|i| i as f32 * 0.01).collect();
120        let b: Vec<f32> = (0..384).map(|i| (i + 1) as f32 * 0.01).collect();
121
122        let dist_simd = l2_distance_simd(&a, &b);
123
124        // Compare with scalar implementation
125        let dist_scalar: f32 = a
126            .iter()
127            .zip(b.iter())
128            .map(|(x, y)| (x - y).powi(2))
129            .sum::<f32>()
130            .sqrt();
131
132        assert!(
133            (dist_simd - dist_scalar).abs() < 1e-4,
134            "SIMD {} vs Scalar {}",
135            dist_simd,
136            dist_scalar
137        );
138    }
139}