Skip to main content

engine/
simd_distance.rs

1//! SIMD-accelerated distance functions
2//!
3//! Provides optimized distance calculations using:
4//! - AVX2 on x86_64 (256-bit vectors, 8 floats at a time)
5//! - NEON on aarch64 (128-bit vectors, 4 floats at a time)
6//! - Auto-vectorized fallback for other architectures
7
8#[cfg(target_arch = "x86_64")]
9use std::sync::OnceLock;
10
11use common::DistanceMetric;
12
13/// Cached result of AVX2+FMA runtime detection.
14/// Computed once on first call, then read from an atomic flag.
15#[cfg(target_arch = "x86_64")]
16static AVX2_AVAILABLE: OnceLock<bool> = OnceLock::new();
17
18#[cfg(target_arch = "x86_64")]
19#[inline(always)]
20fn avx2_available() -> bool {
21    *AVX2_AVAILABLE
22        .get_or_init(|| is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma"))
23}
24
25/// Calculate distance/similarity using SIMD when available
26/// Returns similarity score (higher = more similar)
27pub fn simd_distance(a: &[f32], b: &[f32], metric: DistanceMetric) -> f32 {
28    match metric {
29        DistanceMetric::Cosine => simd_cosine_similarity(a, b),
30        DistanceMetric::Euclidean => simd_negative_euclidean(a, b),
31        DistanceMetric::DotProduct => simd_dot_product(a, b),
32    }
33}
34
35/// SIMD-accelerated dot product
36#[inline]
37pub fn simd_dot_product(a: &[f32], b: &[f32]) -> f32 {
38    #[cfg(target_arch = "x86_64")]
39    {
40        if avx2_available() {
41            return unsafe { avx2_dot_product(a, b) };
42        }
43    }
44
45    #[cfg(target_arch = "aarch64")]
46    {
47        unsafe { neon_dot_product(a, b) }
48    }
49
50    // Scalar fallback for non-SIMD or x86_64 without AVX2 runtime support
51    #[cfg(not(target_arch = "aarch64"))]
52    {
53        a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
54    }
55}
56
57/// SIMD-accelerated cosine similarity
58#[inline]
59pub fn simd_cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
60    #[cfg(target_arch = "x86_64")]
61    {
62        if avx2_available() {
63            return unsafe { avx2_cosine_similarity(a, b) };
64        }
65    }
66
67    #[cfg(target_arch = "aarch64")]
68    {
69        unsafe { neon_cosine_similarity(a, b) }
70    }
71
72    // Scalar fallback for non-SIMD or x86_64 without AVX2 runtime support
73    #[cfg(not(target_arch = "aarch64"))]
74    {
75        fallback_cosine_similarity(a, b)
76    }
77}
78
79/// SIMD-accelerated negative euclidean distance
80#[inline]
81pub fn simd_negative_euclidean(a: &[f32], b: &[f32]) -> f32 {
82    #[cfg(target_arch = "x86_64")]
83    {
84        if avx2_available() {
85            return unsafe { avx2_negative_euclidean(a, b) };
86        }
87    }
88
89    #[cfg(target_arch = "aarch64")]
90    {
91        unsafe { neon_negative_euclidean(a, b) }
92    }
93
94    // Scalar fallback for non-SIMD or x86_64 without AVX2 runtime support
95    #[cfg(not(target_arch = "aarch64"))]
96    {
97        let sum: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
98        -sum.sqrt()
99    }
100}
101
102// ============================================================================
103// Fallback implementations (used when SIMD is not available at runtime)
104// ============================================================================
105
106/// Fallback cosine similarity for when SIMD instructions are not available
107#[inline]
108#[allow(dead_code)]
109fn fallback_cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
110    let mut dot = 0.0f32;
111    let mut norm_a = 0.0f32;
112    let mut norm_b = 0.0f32;
113
114    for (x, y) in a.iter().zip(b.iter()) {
115        dot += x * y;
116        norm_a += x * x;
117        norm_b += y * y;
118    }
119
120    let norm_a = norm_a.sqrt();
121    let norm_b = norm_b.sqrt();
122
123    if norm_a == 0.0 || norm_b == 0.0 {
124        return 0.0;
125    }
126
127    dot / (norm_a * norm_b)
128}
129
130// ============================================================================
131// Scalar implementations (auto-vectorization friendly)
132// ============================================================================
133
134/// Scalar dot product (compiler will auto-vectorize)
135/// Used in tests to validate SIMD implementations
136#[inline]
137#[cfg(test)]
138fn scalar_dot_product(a: &[f32], b: &[f32]) -> f32 {
139    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
140}
141
142/// Scalar cosine similarity
143/// Used in tests to validate SIMD implementations
144#[inline]
145#[cfg(test)]
146fn scalar_cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
147    let mut dot = 0.0f32;
148    let mut norm_a = 0.0f32;
149    let mut norm_b = 0.0f32;
150
151    for (x, y) in a.iter().zip(b.iter()) {
152        dot += x * y;
153        norm_a += x * x;
154        norm_b += y * y;
155    }
156
157    let norm_a = norm_a.sqrt();
158    let norm_b = norm_b.sqrt();
159
160    if norm_a == 0.0 || norm_b == 0.0 {
161        return 0.0;
162    }
163
164    dot / (norm_a * norm_b)
165}
166
167/// Scalar negative euclidean distance
168/// Used in tests to validate SIMD implementations
169#[inline]
170#[cfg(test)]
171fn scalar_negative_euclidean(a: &[f32], b: &[f32]) -> f32 {
172    let sum: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
173    -sum.sqrt()
174}
175
176// ============================================================================
177// x86_64 AVX2 implementations
178// ============================================================================
179
180#[cfg(target_arch = "x86_64")]
181#[target_feature(enable = "avx2", enable = "fma")]
182unsafe fn avx2_dot_product(a: &[f32], b: &[f32]) -> f32 {
183    use std::arch::x86_64::*;
184
185    let n = a.len();
186    let chunks = n / 8;
187    let remainder = n % 8;
188
189    let mut sum = _mm256_setzero_ps();
190
191    let a_ptr = a.as_ptr();
192    let b_ptr = b.as_ptr();
193
194    for i in 0..chunks {
195        let offset = i * 8;
196        let va = _mm256_loadu_ps(a_ptr.add(offset));
197        let vb = _mm256_loadu_ps(b_ptr.add(offset));
198        sum = _mm256_fmadd_ps(va, vb, sum);
199    }
200
201    // Horizontal sum
202    let mut result = hsum_avx(sum);
203
204    // Handle remainder
205    let start = chunks * 8;
206    for i in 0..remainder {
207        result += a[start + i] * b[start + i];
208    }
209
210    result
211}
212
213#[cfg(target_arch = "x86_64")]
214#[target_feature(enable = "avx2", enable = "fma")]
215unsafe fn avx2_cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
216    use std::arch::x86_64::*;
217
218    let n = a.len();
219    let chunks = n / 8;
220    let remainder = n % 8;
221
222    let mut dot_sum = _mm256_setzero_ps();
223    let mut norm_a_sum = _mm256_setzero_ps();
224    let mut norm_b_sum = _mm256_setzero_ps();
225
226    let a_ptr = a.as_ptr();
227    let b_ptr = b.as_ptr();
228
229    for i in 0..chunks {
230        let offset = i * 8;
231        let va = _mm256_loadu_ps(a_ptr.add(offset));
232        let vb = _mm256_loadu_ps(b_ptr.add(offset));
233
234        dot_sum = _mm256_fmadd_ps(va, vb, dot_sum);
235        norm_a_sum = _mm256_fmadd_ps(va, va, norm_a_sum);
236        norm_b_sum = _mm256_fmadd_ps(vb, vb, norm_b_sum);
237    }
238
239    let mut dot = hsum_avx(dot_sum);
240    let mut norm_a = hsum_avx(norm_a_sum);
241    let mut norm_b = hsum_avx(norm_b_sum);
242
243    // Handle remainder
244    let start = chunks * 8;
245    for i in 0..remainder {
246        let x = a[start + i];
247        let y = b[start + i];
248        dot += x * y;
249        norm_a += x * x;
250        norm_b += y * y;
251    }
252
253    let norm_a = norm_a.sqrt();
254    let norm_b = norm_b.sqrt();
255
256    if norm_a == 0.0 || norm_b == 0.0 {
257        return 0.0;
258    }
259
260    dot / (norm_a * norm_b)
261}
262
263#[cfg(target_arch = "x86_64")]
264#[target_feature(enable = "avx2", enable = "fma")]
265unsafe fn avx2_negative_euclidean(a: &[f32], b: &[f32]) -> f32 {
266    use std::arch::x86_64::*;
267
268    let n = a.len();
269    let chunks = n / 8;
270    let remainder = n % 8;
271
272    let mut sum = _mm256_setzero_ps();
273
274    let a_ptr = a.as_ptr();
275    let b_ptr = b.as_ptr();
276
277    for i in 0..chunks {
278        let offset = i * 8;
279        let va = _mm256_loadu_ps(a_ptr.add(offset));
280        let vb = _mm256_loadu_ps(b_ptr.add(offset));
281        let diff = _mm256_sub_ps(va, vb);
282        sum = _mm256_fmadd_ps(diff, diff, sum);
283    }
284
285    let mut result = hsum_avx(sum);
286
287    // Handle remainder
288    let start = chunks * 8;
289    for i in 0..remainder {
290        let diff = a[start + i] - b[start + i];
291        result += diff * diff;
292    }
293
294    -result.sqrt()
295}
296
297/// Horizontal sum of AVX 256-bit register
298#[cfg(target_arch = "x86_64")]
299#[target_feature(enable = "avx2")]
300#[inline]
301unsafe fn hsum_avx(v: std::arch::x86_64::__m256) -> f32 {
302    use std::arch::x86_64::*;
303
304    // Add high 128 bits to low 128 bits
305    let high = _mm256_extractf128_ps(v, 1);
306    let low = _mm256_castps256_ps128(v);
307    let sum128 = _mm_add_ps(high, low);
308
309    // Horizontal add in 128-bit
310    let shuf = _mm_movehdup_ps(sum128);
311    let sums = _mm_add_ps(sum128, shuf);
312    let shuf = _mm_movehl_ps(sums, sums);
313    let sums = _mm_add_ss(sums, shuf);
314
315    _mm_cvtss_f32(sums)
316}
317
318// ============================================================================
319// aarch64 NEON implementations
320// ============================================================================
321
322#[cfg(target_arch = "aarch64")]
323unsafe fn neon_dot_product(a: &[f32], b: &[f32]) -> f32 {
324    use std::arch::aarch64::*;
325
326    let n = a.len();
327    let chunks = n / 4;
328    let remainder = n % 4;
329
330    let mut sum = vdupq_n_f32(0.0);
331
332    let a_ptr = a.as_ptr();
333    let b_ptr = b.as_ptr();
334
335    for i in 0..chunks {
336        let offset = i * 4;
337        let va = vld1q_f32(a_ptr.add(offset));
338        let vb = vld1q_f32(b_ptr.add(offset));
339        sum = vfmaq_f32(sum, va, vb);
340    }
341
342    let mut result = vaddvq_f32(sum);
343
344    // Handle remainder
345    let start = chunks * 4;
346    for i in 0..remainder {
347        result += a[start + i] * b[start + i];
348    }
349
350    result
351}
352
353#[cfg(target_arch = "aarch64")]
354unsafe fn neon_cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
355    use std::arch::aarch64::*;
356
357    let n = a.len();
358    let chunks = n / 4;
359    let remainder = n % 4;
360
361    let mut dot_sum = vdupq_n_f32(0.0);
362    let mut norm_a_sum = vdupq_n_f32(0.0);
363    let mut norm_b_sum = vdupq_n_f32(0.0);
364
365    let a_ptr = a.as_ptr();
366    let b_ptr = b.as_ptr();
367
368    for i in 0..chunks {
369        let offset = i * 4;
370        let va = vld1q_f32(a_ptr.add(offset));
371        let vb = vld1q_f32(b_ptr.add(offset));
372
373        dot_sum = vfmaq_f32(dot_sum, va, vb);
374        norm_a_sum = vfmaq_f32(norm_a_sum, va, va);
375        norm_b_sum = vfmaq_f32(norm_b_sum, vb, vb);
376    }
377
378    let mut dot = vaddvq_f32(dot_sum);
379    let mut norm_a = vaddvq_f32(norm_a_sum);
380    let mut norm_b = vaddvq_f32(norm_b_sum);
381
382    // Handle remainder
383    let start = chunks * 4;
384    for i in 0..remainder {
385        let x = a[start + i];
386        let y = b[start + i];
387        dot += x * y;
388        norm_a += x * x;
389        norm_b += y * y;
390    }
391
392    let norm_a = norm_a.sqrt();
393    let norm_b = norm_b.sqrt();
394
395    if norm_a == 0.0 || norm_b == 0.0 {
396        return 0.0;
397    }
398
399    dot / (norm_a * norm_b)
400}
401
402#[cfg(target_arch = "aarch64")]
403unsafe fn neon_negative_euclidean(a: &[f32], b: &[f32]) -> f32 {
404    use std::arch::aarch64::*;
405
406    let n = a.len();
407    let chunks = n / 4;
408    let remainder = n % 4;
409
410    let mut sum = vdupq_n_f32(0.0);
411
412    let a_ptr = a.as_ptr();
413    let b_ptr = b.as_ptr();
414
415    for i in 0..chunks {
416        let offset = i * 4;
417        let va = vld1q_f32(a_ptr.add(offset));
418        let vb = vld1q_f32(b_ptr.add(offset));
419        let diff = vsubq_f32(va, vb);
420        sum = vfmaq_f32(sum, diff, diff);
421    }
422
423    let mut result = vaddvq_f32(sum);
424
425    // Handle remainder
426    let start = chunks * 4;
427    for i in 0..remainder {
428        let diff = a[start + i] - b[start + i];
429        result += diff * diff;
430    }
431
432    -result.sqrt()
433}
434
435// ============================================================================
436// Tests
437// ============================================================================
438
439#[cfg(test)]
440mod tests {
441    use super::*;
442
443    const EPSILON: f32 = 1e-5;
444
445    fn approx_eq(a: f32, b: f32) -> bool {
446        (a - b).abs() < EPSILON
447    }
448
449    #[test]
450    fn test_simd_dot_product() {
451        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
452        let b = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
453        // 1+2+3+4+5+6+7+8 = 36
454        let result = simd_dot_product(&a, &b);
455        assert!(approx_eq(result, 36.0), "Expected 36.0, got {}", result);
456    }
457
458    #[test]
459    fn test_simd_dot_product_large() {
460        // Test with 1024 elements (typical embedding size)
461        let a: Vec<f32> = (0..1024).map(|i| i as f32 * 0.001).collect();
462        let b: Vec<f32> = (0..1024).map(|i| (1024 - i) as f32 * 0.001).collect();
463
464        let simd_result = simd_dot_product(&a, &b);
465        let scalar_result = scalar_dot_product(&a, &b);
466
467        // Allow slightly larger tolerance for large vectors due to FP accumulation order
468        assert!(
469            (simd_result - scalar_result).abs() < 0.01,
470            "SIMD: {}, Scalar: {}",
471            simd_result,
472            scalar_result
473        );
474    }
475
476    #[test]
477    fn test_simd_cosine_identical() {
478        let a = vec![1.0, 0.0, 0.0, 0.0];
479        let result = simd_cosine_similarity(&a, &a);
480        assert!(approx_eq(result, 1.0), "Expected 1.0, got {}", result);
481    }
482
483    #[test]
484    fn test_simd_cosine_orthogonal() {
485        let a = vec![1.0, 0.0, 0.0, 0.0];
486        let b = vec![0.0, 1.0, 0.0, 0.0];
487        let result = simd_cosine_similarity(&a, &b);
488        assert!(approx_eq(result, 0.0), "Expected 0.0, got {}", result);
489    }
490
491    #[test]
492    fn test_simd_cosine_large() {
493        let a: Vec<f32> = (0..1024).map(|i| (i as f32).sin()).collect();
494        let b: Vec<f32> = (0..1024).map(|i| (i as f32).cos()).collect();
495
496        let simd_result = simd_cosine_similarity(&a, &b);
497        let scalar_result = scalar_cosine_similarity(&a, &b);
498
499        assert!(
500            (simd_result - scalar_result).abs() < 1e-4,
501            "SIMD: {}, Scalar: {}",
502            simd_result,
503            scalar_result
504        );
505    }
506
507    #[test]
508    fn test_simd_euclidean_identical() {
509        let a = vec![1.0, 2.0, 3.0, 4.0];
510        let result = simd_negative_euclidean(&a, &a);
511        assert!(approx_eq(result, 0.0), "Expected 0.0, got {}", result);
512    }
513
514    #[test]
515    fn test_simd_euclidean_known() {
516        let a = vec![0.0, 0.0, 0.0, 0.0];
517        let b = vec![3.0, 4.0, 0.0, 0.0];
518        // Distance = 5, negative = -5
519        let result = simd_negative_euclidean(&a, &b);
520        assert!(approx_eq(result, -5.0), "Expected -5.0, got {}", result);
521    }
522
523    #[test]
524    fn test_simd_euclidean_large() {
525        let a: Vec<f32> = (0..1024).map(|i| i as f32 * 0.01).collect();
526        let b: Vec<f32> = (0..1024).map(|i| (i + 1) as f32 * 0.01).collect();
527
528        let simd_result = simd_negative_euclidean(&a, &b);
529        let scalar_result = scalar_negative_euclidean(&a, &b);
530
531        assert!(
532            (simd_result - scalar_result).abs() < 1e-3,
533            "SIMD: {}, Scalar: {}",
534            simd_result,
535            scalar_result
536        );
537    }
538
539    #[test]
540    fn test_simd_distance_dispatch() {
541        let a = vec![1.0, 0.0, 0.0, 0.0];
542        let b = vec![1.0, 0.0, 0.0, 0.0];
543
544        assert!(approx_eq(
545            simd_distance(&a, &b, DistanceMetric::Cosine),
546            1.0
547        ));
548        assert!(approx_eq(
549            simd_distance(&a, &b, DistanceMetric::Euclidean),
550            0.0
551        ));
552        assert!(approx_eq(
553            simd_distance(&a, &b, DistanceMetric::DotProduct),
554            1.0
555        ));
556    }
557
558    #[test]
559    fn test_simd_remainder_handling() {
560        // Test with sizes that don't divide evenly by SIMD width (4 for NEON, 8 for AVX2)
561        for size in [3, 5, 7, 9, 11, 13, 15, 17] {
562            let a: Vec<f32> = (0..size).map(|i| i as f32).collect();
563            let b: Vec<f32> = (0..size).map(|i| (i + 1) as f32).collect();
564
565            let simd_dot = simd_dot_product(&a, &b);
566            let scalar_dot = scalar_dot_product(&a, &b);
567            assert!(
568                approx_eq(simd_dot, scalar_dot),
569                "Size {}: SIMD {} != Scalar {}",
570                size,
571                simd_dot,
572                scalar_dot
573            );
574        }
575    }
576}