Skip to main content

reddb_server/storage/engine/
simd_distance.rs

1//! SIMD-Optimized Distance Functions
2//!
3//! Provides hardware-accelerated distance computations using x86_64 SIMD intrinsics.
4//! Falls back to scalar implementation when SIMD is not available.
5//!
6//! # Supported Instructions
7//!
8//! - **SSE**: 128-bit vectors, 4 f32 operations in parallel
9//! - **AVX**: 256-bit vectors, 8 f32 operations in parallel
10//! - **FMA**: Fused multiply-add for better precision and performance
11//!
12//! # Runtime Detection
13//!
14//! Uses `is_x86_feature_detected!` to select the best available implementation at runtime.
15
16use super::distance::DistanceMetric;
17
18/// SIMD capability level detected at runtime
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum SimdLevel {
21    /// No SIMD available, use scalar
22    Scalar,
23    /// SSE 128-bit SIMD (4 x f32)
24    Sse,
25    /// AVX 256-bit SIMD (8 x f32)
26    Avx,
27    /// AVX with FMA (fused multiply-add)
28    AvxFma,
29}
30
31impl SimdLevel {
32    /// Detect the best available SIMD level at runtime
33    #[inline]
34    pub fn detect() -> Self {
35        #[cfg(target_arch = "x86_64")]
36        {
37            if is_x86_feature_detected!("avx") && is_x86_feature_detected!("fma") {
38                SimdLevel::AvxFma
39            } else if is_x86_feature_detected!("avx") {
40                SimdLevel::Avx
41            } else if is_x86_feature_detected!("sse") {
42                SimdLevel::Sse
43            } else {
44                SimdLevel::Scalar
45            }
46        }
47        #[cfg(not(target_arch = "x86_64"))]
48        {
49            SimdLevel::Scalar
50        }
51    }
52}
53
54/// Global SIMD level (detected once at first use)
55static SIMD_LEVEL: std::sync::OnceLock<SimdLevel> = std::sync::OnceLock::new();
56
57/// Get the detected SIMD level
58#[inline]
59pub fn simd_level() -> SimdLevel {
60    *SIMD_LEVEL.get_or_init(SimdLevel::detect)
61}
62
63// ============================================================================
64// L2 Squared Distance
65// ============================================================================
66
67/// Compute L2 squared distance using the best available SIMD
68#[inline]
69pub fn l2_squared_simd(a: &[f32], b: &[f32]) -> f32 {
70    assert_eq!(a.len(), b.len(), "Vector dimensions must match");
71
72    match simd_level() {
73        #[cfg(target_arch = "x86_64")]
74        SimdLevel::AvxFma => unsafe { l2_squared_avx_fma(a, b) },
75        #[cfg(target_arch = "x86_64")]
76        SimdLevel::Avx => unsafe { l2_squared_avx(a, b) },
77        #[cfg(target_arch = "x86_64")]
78        SimdLevel::Sse => unsafe { l2_squared_sse(a, b) },
79        _ => l2_squared_scalar(a, b),
80    }
81}
82
83/// Scalar fallback for L2 squared distance
84#[inline]
85fn l2_squared_scalar(a: &[f32], b: &[f32]) -> f32 {
86    let mut sum = 0.0f32;
87    for i in 0..a.len() {
88        let d = a[i] - b[i];
89        sum += d * d;
90    }
91    sum
92}
93
94#[cfg(target_arch = "x86_64")]
95#[target_feature(enable = "sse")]
96unsafe fn l2_squared_sse(a: &[f32], b: &[f32]) -> f32 {
97    use std::arch::x86_64::*;
98
99    let len = a.len();
100    let mut sum = _mm_setzero_ps(); // 4 x f32 = 0.0
101
102    // Process 4 elements at a time
103    let chunks = len / 4;
104    for i in 0..chunks {
105        let idx = i * 4;
106        let va = _mm_loadu_ps(a.as_ptr().add(idx));
107        let vb = _mm_loadu_ps(b.as_ptr().add(idx));
108        let diff = _mm_sub_ps(va, vb);
109        let sq = _mm_mul_ps(diff, diff);
110        sum = _mm_add_ps(sum, sq);
111    }
112
113    // Horizontal sum of the 4 lanes
114    let mut result = horizontal_sum_sse(sum);
115
116    // Handle remaining elements
117    for i in (chunks * 4)..len {
118        let d = a[i] - b[i];
119        result += d * d;
120    }
121
122    result
123}
124
125#[cfg(target_arch = "x86_64")]
126#[target_feature(enable = "sse")]
127#[inline]
128unsafe fn horizontal_sum_sse(v: std::arch::x86_64::__m128) -> f32 {
129    use std::arch::x86_64::*;
130
131    // v = [a, b, c, d]
132    // Add pairs: [a+c, b+d, a+c, b+d]
133    let shuf = _mm_movehdup_ps(v); // [b, b, d, d]
134    let sums = _mm_add_ps(v, shuf); // [a+b, b+b, c+d, d+d]
135    let shuf2 = _mm_movehl_ps(sums, sums); // [c+d, d+d, c+d, d+d]
136    let sums2 = _mm_add_ss(sums, shuf2); // [a+b+c+d, ...]
137    _mm_cvtss_f32(sums2)
138}
139
140#[cfg(target_arch = "x86_64")]
141#[target_feature(enable = "avx")]
142unsafe fn l2_squared_avx(a: &[f32], b: &[f32]) -> f32 {
143    use std::arch::x86_64::*;
144
145    let len = a.len();
146    let mut sum = _mm256_setzero_ps(); // 8 x f32 = 0.0
147
148    // Process 8 elements at a time
149    let chunks = len / 8;
150    for i in 0..chunks {
151        let idx = i * 8;
152        let va = _mm256_loadu_ps(a.as_ptr().add(idx));
153        let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
154        let diff = _mm256_sub_ps(va, vb);
155        let sq = _mm256_mul_ps(diff, diff);
156        sum = _mm256_add_ps(sum, sq);
157    }
158
159    // Horizontal sum of the 8 lanes
160    let mut result = horizontal_sum_avx(sum);
161
162    // Handle remaining elements (up to 7)
163    for i in (chunks * 8)..len {
164        let d = a[i] - b[i];
165        result += d * d;
166    }
167
168    result
169}
170
171#[cfg(target_arch = "x86_64")]
172#[target_feature(enable = "avx")]
173#[inline]
174unsafe fn horizontal_sum_avx(v: std::arch::x86_64::__m256) -> f32 {
175    use std::arch::x86_64::*;
176
177    // Extract high and low 128-bit halves and add them
178    let high = _mm256_extractf128_ps(v, 1);
179    let low = _mm256_castps256_ps128(v);
180    let sum128 = _mm_add_ps(high, low);
181
182    // Now do SSE horizontal sum
183    horizontal_sum_sse(sum128)
184}
185
186#[cfg(target_arch = "x86_64")]
187#[target_feature(enable = "avx", enable = "fma")]
188unsafe fn l2_squared_avx_fma(a: &[f32], b: &[f32]) -> f32 {
189    use std::arch::x86_64::*;
190
191    let len = a.len();
192    let mut sum = _mm256_setzero_ps();
193
194    // Process 8 elements at a time using FMA
195    let chunks = len / 8;
196    for i in 0..chunks {
197        let idx = i * 8;
198        let va = _mm256_loadu_ps(a.as_ptr().add(idx));
199        let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
200        let diff = _mm256_sub_ps(va, vb);
201        // FMA: sum = diff * diff + sum (fused, more accurate)
202        sum = _mm256_fmadd_ps(diff, diff, sum);
203    }
204
205    let mut result = horizontal_sum_avx(sum);
206
207    // Handle remaining elements
208    for i in (chunks * 8)..len {
209        let d = a[i] - b[i];
210        result += d * d;
211    }
212
213    result
214}
215
216// ============================================================================
217// Dot Product
218// ============================================================================
219
220/// Compute dot product using the best available SIMD
221#[inline]
222pub fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
223    assert_eq!(a.len(), b.len(), "Vector dimensions must match");
224
225    match simd_level() {
226        #[cfg(target_arch = "x86_64")]
227        SimdLevel::AvxFma => unsafe { dot_product_avx_fma(a, b) },
228        #[cfg(target_arch = "x86_64")]
229        SimdLevel::Avx => unsafe { dot_product_avx(a, b) },
230        #[cfg(target_arch = "x86_64")]
231        SimdLevel::Sse => unsafe { dot_product_sse(a, b) },
232        _ => dot_product_scalar(a, b),
233    }
234}
235
236#[inline]
237fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
238    let mut sum = 0.0f32;
239    for i in 0..a.len() {
240        sum += a[i] * b[i];
241    }
242    sum
243}
244
245#[cfg(target_arch = "x86_64")]
246#[target_feature(enable = "sse")]
247unsafe fn dot_product_sse(a: &[f32], b: &[f32]) -> f32 {
248    use std::arch::x86_64::*;
249
250    let len = a.len();
251    let mut sum = _mm_setzero_ps();
252
253    let chunks = len / 4;
254    for i in 0..chunks {
255        let idx = i * 4;
256        let va = _mm_loadu_ps(a.as_ptr().add(idx));
257        let vb = _mm_loadu_ps(b.as_ptr().add(idx));
258        let prod = _mm_mul_ps(va, vb);
259        sum = _mm_add_ps(sum, prod);
260    }
261
262    let mut result = horizontal_sum_sse(sum);
263
264    for i in (chunks * 4)..len {
265        result += a[i] * b[i];
266    }
267
268    result
269}
270
271#[cfg(target_arch = "x86_64")]
272#[target_feature(enable = "avx")]
273unsafe fn dot_product_avx(a: &[f32], b: &[f32]) -> f32 {
274    use std::arch::x86_64::*;
275
276    let len = a.len();
277    let mut sum = _mm256_setzero_ps();
278
279    let chunks = len / 8;
280    for i in 0..chunks {
281        let idx = i * 8;
282        let va = _mm256_loadu_ps(a.as_ptr().add(idx));
283        let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
284        let prod = _mm256_mul_ps(va, vb);
285        sum = _mm256_add_ps(sum, prod);
286    }
287
288    let mut result = horizontal_sum_avx(sum);
289
290    for i in (chunks * 8)..len {
291        result += a[i] * b[i];
292    }
293
294    result
295}
296
297#[cfg(target_arch = "x86_64")]
298#[target_feature(enable = "avx", enable = "fma")]
299unsafe fn dot_product_avx_fma(a: &[f32], b: &[f32]) -> f32 {
300    use std::arch::x86_64::*;
301
302    let len = a.len();
303    let mut sum = _mm256_setzero_ps();
304
305    let chunks = len / 8;
306    for i in 0..chunks {
307        let idx = i * 8;
308        let va = _mm256_loadu_ps(a.as_ptr().add(idx));
309        let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
310        // FMA: sum = va * vb + sum
311        sum = _mm256_fmadd_ps(va, vb, sum);
312    }
313
314    let mut result = horizontal_sum_avx(sum);
315
316    for i in (chunks * 8)..len {
317        result += a[i] * b[i];
318    }
319
320    result
321}
322
323// ============================================================================
324// L2 Norm (magnitude)
325// ============================================================================
326
327/// Compute L2 norm using the best available SIMD
328#[inline]
329pub fn l2_norm_simd(v: &[f32]) -> f32 {
330    dot_product_simd(v, v).sqrt()
331}
332
333// ============================================================================
334// Cosine Distance
335// ============================================================================
336
337/// Compute cosine distance using SIMD
338#[inline]
339pub fn cosine_distance_simd(a: &[f32], b: &[f32]) -> f32 {
340    let dot = dot_product_simd(a, b);
341    let norm_a = l2_norm_simd(a);
342    let norm_b = l2_norm_simd(b);
343
344    if norm_a == 0.0 || norm_b == 0.0 {
345        return 1.0;
346    }
347
348    let similarity = (dot / (norm_a * norm_b)).clamp(-1.0, 1.0);
349    1.0 - similarity
350}
351
352/// Compute inner product distance using SIMD
353#[inline]
354pub fn inner_product_distance_simd(a: &[f32], b: &[f32]) -> f32 {
355    -dot_product_simd(a, b)
356}
357
358// ============================================================================
359// Unified Distance Function
360// ============================================================================
361
362/// Compute distance using SIMD with the specified metric
363#[inline]
364pub fn distance_simd(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
365    match metric {
366        DistanceMetric::L2 => l2_squared_simd(a, b),
367        DistanceMetric::Cosine => cosine_distance_simd(a, b),
368        DistanceMetric::InnerProduct => inner_product_distance_simd(a, b),
369    }
370}
371
372// ============================================================================
373// Batch Operations (for processing multiple vectors efficiently)
374// ============================================================================
375
376/// Compute distances from a query vector to multiple target vectors
377/// Returns (index, distance) pairs sorted by distance
378pub fn batch_distances(
379    query: &[f32],
380    targets: &[Vec<f32>],
381    metric: DistanceMetric,
382    top_k: usize,
383) -> Vec<(usize, f32)> {
384    let mut results: Vec<(usize, f32)> = targets
385        .iter()
386        .enumerate()
387        .map(|(i, target)| (i, distance_simd(query, target, metric)))
388        .collect();
389
390    // Partial sort for top-k (more efficient than full sort)
391    if top_k < results.len() {
392        results.select_nth_unstable_by(top_k, |a, b| {
393            a.1.partial_cmp(&b.1)
394                .unwrap_or(std::cmp::Ordering::Equal)
395                .then_with(|| a.0.cmp(&b.0))
396        });
397        results.truncate(top_k);
398    }
399
400    results.sort_by(|a, b| {
401        a.1.partial_cmp(&b.1)
402            .unwrap_or(std::cmp::Ordering::Equal)
403            .then_with(|| a.0.cmp(&b.0))
404    });
405    results
406}
407
408// ============================================================================
409// Tests
410// ============================================================================
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415
416    #[test]
417    fn test_simd_level_detection() {
418        let level = simd_level();
419        println!("Detected SIMD level: {:?}", level);
420        // Should detect at least scalar on any platform
421        assert!(matches!(
422            level,
423            SimdLevel::Scalar | SimdLevel::Sse | SimdLevel::Avx | SimdLevel::AvxFma
424        ));
425    }
426
427    #[test]
428    fn test_l2_squared_simd_identical() {
429        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
430        let b = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
431        assert!((l2_squared_simd(&a, &b) - 0.0).abs() < 1e-6);
432    }
433
434    #[test]
435    fn test_l2_squared_simd_simple() {
436        let a = vec![0.0, 0.0, 0.0, 0.0];
437        let b = vec![1.0, 0.0, 0.0, 0.0];
438        assert!((l2_squared_simd(&a, &b) - 1.0).abs() < 1e-6);
439    }
440
441    #[test]
442    fn test_l2_squared_simd_vs_scalar() {
443        let a: Vec<f32> = (0..256).map(|i| i as f32 * 0.1).collect();
444        let b: Vec<f32> = (0..256).map(|i| (i + 1) as f32 * 0.1).collect();
445
446        let simd_result = l2_squared_simd(&a, &b);
447        let scalar_result = l2_squared_scalar(&a, &b);
448
449        assert!(
450            (simd_result - scalar_result).abs() < 1e-3,
451            "SIMD: {}, Scalar: {}",
452            simd_result,
453            scalar_result
454        );
455    }
456
457    #[test]
458    fn test_dot_product_simd() {
459        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
460        let b = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
461        let result = dot_product_simd(&a, &b);
462        assert!((result - 36.0).abs() < 1e-6); // 1+2+3+4+5+6+7+8 = 36
463    }
464
465    #[test]
466    fn test_dot_product_simd_vs_scalar() {
467        let a: Vec<f32> = (0..256).map(|i| i as f32 * 0.1).collect();
468        let b: Vec<f32> = (0..256).map(|i| (i + 1) as f32 * 0.1).collect();
469
470        let simd_result = dot_product_simd(&a, &b);
471        let scalar_result = dot_product_scalar(&a, &b);
472
473        assert!(
474            (simd_result - scalar_result).abs() < 1.0, // Larger tolerance for accumulated FP
475            "SIMD: {}, Scalar: {}",
476            simd_result,
477            scalar_result
478        );
479    }
480
481    #[test]
482    fn test_l2_norm_simd() {
483        let v = vec![3.0, 4.0, 0.0, 0.0];
484        assert!((l2_norm_simd(&v) - 5.0).abs() < 1e-6);
485    }
486
487    #[test]
488    fn test_cosine_distance_simd_identical() {
489        let a = vec![1.0, 0.0, 0.0, 0.0];
490        let b = vec![1.0, 0.0, 0.0, 0.0];
491        assert!((cosine_distance_simd(&a, &b) - 0.0).abs() < 1e-6);
492    }
493
494    #[test]
495    fn test_cosine_distance_simd_orthogonal() {
496        let a = vec![1.0, 0.0, 0.0, 0.0];
497        let b = vec![0.0, 1.0, 0.0, 0.0];
498        assert!((cosine_distance_simd(&a, &b) - 1.0).abs() < 1e-6);
499    }
500
501    #[test]
502    fn test_batch_distances() {
503        let query = vec![0.0, 0.0, 0.0, 0.0];
504        let targets = vec![
505            vec![1.0, 0.0, 0.0, 0.0], // distance = 1.0
506            vec![2.0, 0.0, 0.0, 0.0], // distance = 4.0
507            vec![0.5, 0.0, 0.0, 0.0], // distance = 0.25
508        ];
509
510        let results = batch_distances(&query, &targets, DistanceMetric::L2, 2);
511        assert_eq!(results.len(), 2);
512        assert_eq!(results[0].0, 2); // Closest is index 2 (distance 0.25)
513        assert_eq!(results[1].0, 0); // Second is index 0 (distance 1.0)
514    }
515
516    #[test]
517    fn test_odd_length_vectors() {
518        // Test vectors that don't align to SIMD width
519        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0]; // 5 elements
520        let b = vec![5.0, 4.0, 3.0, 2.0, 1.0];
521
522        let simd_result = l2_squared_simd(&a, &b);
523        let expected = 16.0 + 4.0 + 0.0 + 4.0 + 16.0; // = 40.0
524        assert!((simd_result - expected).abs() < 1e-6);
525    }
526
527    #[test]
528    fn test_large_vectors() {
529        // Test with large vectors (1536 dimensions like text-embedding-3-large)
530        let a: Vec<f32> = (0..1536).map(|i| (i as f32).sin()).collect();
531        let b: Vec<f32> = (0..1536).map(|i| (i as f32).cos()).collect();
532
533        let simd_result = l2_squared_simd(&a, &b);
534        let scalar_result = l2_squared_scalar(&a, &b);
535
536        assert!(
537            (simd_result - scalar_result).abs() / scalar_result.abs() < 1e-5,
538            "Relative error too large: SIMD={}, Scalar={}",
539            simd_result,
540            scalar_result
541        );
542    }
543}