diskann_rs/
simd.rs

1//! # SIMD-Accelerated Distance Functions
2//!
3//! Optimized distance calculations using SIMD instructions.
4//! Falls back to scalar implementations when SIMD is not available.
5//!
6//! ## Supported Architectures
7//!
8//! - **x86_64**: AVX2, SSE4.1 (auto-detected at runtime)
9//! - **aarch64**: NEON (always available on Apple Silicon)
10//! - **Fallback**: Portable scalar implementation
11//!
12//! ## Performance
13//!
14//! SIMD acceleration provides 2-8x speedup for distance calculations:
15//! - L2 (Euclidean): Process 8 floats per iteration (AVX) or 4 (SSE/NEON)
16//! - Dot product: Same vectorization approach
17//! - Cosine: Computed as 1 - dot(a,b) / (||a|| * ||b||)
18
19use anndists::prelude::Distance;
20
21/// SIMD-accelerated L2 (Euclidean squared) distance
22#[derive(Clone, Copy, Debug, Default)]
23pub struct SimdL2;
24
25/// SIMD-accelerated dot product distance (for normalized vectors)
26#[derive(Clone, Copy, Debug, Default)]
27pub struct SimdDot;
28
29/// SIMD-accelerated cosine distance
30#[derive(Clone, Copy, Debug, Default)]
31pub struct SimdCosine;
32
33// =============================================================================
34// Portable scalar implementations (fallback)
35// =============================================================================
36
37#[inline]
38fn l2_squared_scalar(a: &[f32], b: &[f32]) -> f32 {
39    a.iter()
40        .zip(b.iter())
41        .map(|(x, y)| {
42            let d = x - y;
43            d * d
44        })
45        .sum()
46}
47
48#[inline]
49fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
50    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
51}
52
53#[inline]
54fn norm_squared_scalar(a: &[f32]) -> f32 {
55    a.iter().map(|x| x * x).sum()
56}
57
58// =============================================================================
59// x86_64 AVX2 implementations
60// =============================================================================
61
62#[cfg(target_arch = "x86_64")]
63mod x86_simd {
64    #[cfg(target_arch = "x86_64")]
65    use std::arch::x86_64::*;
66
67    /// Check if AVX2 is available at runtime
68    #[inline]
69    pub fn has_avx2() -> bool {
70        is_x86_feature_detected!("avx2")
71    }
72
73    /// Check if SSE4.1 is available at runtime
74    #[inline]
75    pub fn has_sse41() -> bool {
76        is_x86_feature_detected!("sse4.1")
77    }
78
79    /// L2 squared distance using AVX2 (8 floats at a time)
80    #[target_feature(enable = "avx2")]
81    #[inline]
82    pub unsafe fn l2_squared_avx2(a: &[f32], b: &[f32]) -> f32 {
83        debug_assert_eq!(a.len(), b.len());
84        let n = a.len();
85
86        let mut sum = _mm256_setzero_ps();
87        let mut i = 0;
88
89        // Process 8 elements at a time
90        while i + 8 <= n {
91            let va = _mm256_loadu_ps(a.as_ptr().add(i));
92            let vb = _mm256_loadu_ps(b.as_ptr().add(i));
93            let diff = _mm256_sub_ps(va, vb);
94            sum = _mm256_fmadd_ps(diff, diff, sum);
95            i += 8;
96        }
97
98        // Horizontal sum of 256-bit vector
99        let high = _mm256_extractf128_ps(sum, 1);
100        let low = _mm256_castps256_ps128(sum);
101        let sum128 = _mm_add_ps(high, low);
102        let shuf = _mm_movehdup_ps(sum128);
103        let sums = _mm_add_ps(sum128, shuf);
104        let shuf2 = _mm_movehl_ps(sums, sums);
105        let final_sum = _mm_add_ss(sums, shuf2);
106        let mut result = _mm_cvtss_f32(final_sum);
107
108        // Handle remaining elements
109        while i < n {
110            let d = a[i] - b[i];
111            result += d * d;
112            i += 1;
113        }
114
115        result
116    }
117
118    /// Dot product using AVX2
119    #[target_feature(enable = "avx2")]
120    #[inline]
121    pub unsafe fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
122        debug_assert_eq!(a.len(), b.len());
123        let n = a.len();
124
125        let mut sum = _mm256_setzero_ps();
126        let mut i = 0;
127
128        while i + 8 <= n {
129            let va = _mm256_loadu_ps(a.as_ptr().add(i));
130            let vb = _mm256_loadu_ps(b.as_ptr().add(i));
131            sum = _mm256_fmadd_ps(va, vb, sum);
132            i += 8;
133        }
134
135        // Horizontal sum
136        let high = _mm256_extractf128_ps(sum, 1);
137        let low = _mm256_castps256_ps128(sum);
138        let sum128 = _mm_add_ps(high, low);
139        let shuf = _mm_movehdup_ps(sum128);
140        let sums = _mm_add_ps(sum128, shuf);
141        let shuf2 = _mm_movehl_ps(sums, sums);
142        let final_sum = _mm_add_ss(sums, shuf2);
143        let mut result = _mm_cvtss_f32(final_sum);
144
145        while i < n {
146            result += a[i] * b[i];
147            i += 1;
148        }
149
150        result
151    }
152
153    /// Norm squared using AVX2
154    #[target_feature(enable = "avx2")]
155    #[inline]
156    pub unsafe fn norm_squared_avx2(a: &[f32]) -> f32 {
157        let n = a.len();
158        let mut sum = _mm256_setzero_ps();
159        let mut i = 0;
160
161        while i + 8 <= n {
162            let va = _mm256_loadu_ps(a.as_ptr().add(i));
163            sum = _mm256_fmadd_ps(va, va, sum);
164            i += 8;
165        }
166
167        let high = _mm256_extractf128_ps(sum, 1);
168        let low = _mm256_castps256_ps128(sum);
169        let sum128 = _mm_add_ps(high, low);
170        let shuf = _mm_movehdup_ps(sum128);
171        let sums = _mm_add_ps(sum128, shuf);
172        let shuf2 = _mm_movehl_ps(sums, sums);
173        let final_sum = _mm_add_ss(sums, shuf2);
174        let mut result = _mm_cvtss_f32(final_sum);
175
176        while i < n {
177            result += a[i] * a[i];
178            i += 1;
179        }
180
181        result
182    }
183
184    /// L2 squared using SSE4.1 (4 floats at a time)
185    #[target_feature(enable = "sse4.1")]
186    #[inline]
187    pub unsafe fn l2_squared_sse41(a: &[f32], b: &[f32]) -> f32 {
188        debug_assert_eq!(a.len(), b.len());
189        let n = a.len();
190
191        let mut sum = _mm_setzero_ps();
192        let mut i = 0;
193
194        while i + 4 <= n {
195            let va = _mm_loadu_ps(a.as_ptr().add(i));
196            let vb = _mm_loadu_ps(b.as_ptr().add(i));
197            let diff = _mm_sub_ps(va, vb);
198            let sq = _mm_mul_ps(diff, diff);
199            sum = _mm_add_ps(sum, sq);
200            i += 4;
201        }
202
203        // Horizontal sum
204        let shuf = _mm_movehdup_ps(sum);
205        let sums = _mm_add_ps(sum, shuf);
206        let shuf2 = _mm_movehl_ps(sums, sums);
207        let final_sum = _mm_add_ss(sums, shuf2);
208        let mut result = _mm_cvtss_f32(final_sum);
209
210        while i < n {
211            let d = a[i] - b[i];
212            result += d * d;
213            i += 1;
214        }
215
216        result
217    }
218
219    /// Dot product using SSE4.1
220    #[target_feature(enable = "sse4.1")]
221    #[inline]
222    pub unsafe fn dot_product_sse41(a: &[f32], b: &[f32]) -> f32 {
223        debug_assert_eq!(a.len(), b.len());
224        let n = a.len();
225
226        let mut sum = _mm_setzero_ps();
227        let mut i = 0;
228
229        while i + 4 <= n {
230            let va = _mm_loadu_ps(a.as_ptr().add(i));
231            let vb = _mm_loadu_ps(b.as_ptr().add(i));
232            let prod = _mm_mul_ps(va, vb);
233            sum = _mm_add_ps(sum, prod);
234            i += 4;
235        }
236
237        let shuf = _mm_movehdup_ps(sum);
238        let sums = _mm_add_ps(sum, shuf);
239        let shuf2 = _mm_movehl_ps(sums, sums);
240        let final_sum = _mm_add_ss(sums, shuf2);
241        let mut result = _mm_cvtss_f32(final_sum);
242
243        while i < n {
244            result += a[i] * b[i];
245            i += 1;
246        }
247
248        result
249    }
250}
251
252// =============================================================================
253// aarch64 NEON implementations
254// =============================================================================
255
256#[cfg(target_arch = "aarch64")]
257mod neon_simd {
258    use std::arch::aarch64::*;
259
260    /// L2 squared distance using NEON (4 floats at a time)
261    #[inline]
262    pub fn l2_squared_neon(a: &[f32], b: &[f32]) -> f32 {
263        debug_assert_eq!(a.len(), b.len());
264        let n = a.len();
265
266        // SAFETY: NEON is always available on aarch64
267        unsafe {
268            let mut sum = vdupq_n_f32(0.0);
269            let mut i = 0;
270
271            while i + 4 <= n {
272                let va = vld1q_f32(a.as_ptr().add(i));
273                let vb = vld1q_f32(b.as_ptr().add(i));
274                let diff = vsubq_f32(va, vb);
275                sum = vfmaq_f32(sum, diff, diff);
276                i += 4;
277            }
278
279            // Horizontal sum
280            let mut result = vaddvq_f32(sum);
281
282            while i < n {
283                let d = a[i] - b[i];
284                result += d * d;
285                i += 1;
286            }
287
288            result
289        }
290    }
291
292    /// Dot product using NEON
293    #[inline]
294    pub fn dot_product_neon(a: &[f32], b: &[f32]) -> f32 {
295        debug_assert_eq!(a.len(), b.len());
296        let n = a.len();
297
298        // SAFETY: NEON is always available on aarch64
299        unsafe {
300            let mut sum = vdupq_n_f32(0.0);
301            let mut i = 0;
302
303            while i + 4 <= n {
304                let va = vld1q_f32(a.as_ptr().add(i));
305                let vb = vld1q_f32(b.as_ptr().add(i));
306                sum = vfmaq_f32(sum, va, vb);
307                i += 4;
308            }
309
310            let mut result = vaddvq_f32(sum);
311
312            while i < n {
313                result += a[i] * b[i];
314                i += 1;
315            }
316
317            result
318        }
319    }
320
321    /// Norm squared using NEON
322    #[inline]
323    pub fn norm_squared_neon(a: &[f32]) -> f32 {
324        let n = a.len();
325
326        // SAFETY: NEON is always available on aarch64
327        unsafe {
328            let mut sum = vdupq_n_f32(0.0);
329            let mut i = 0;
330
331            while i + 4 <= n {
332                let va = vld1q_f32(a.as_ptr().add(i));
333                sum = vfmaq_f32(sum, va, va);
334                i += 4;
335            }
336
337            let mut result = vaddvq_f32(sum);
338
339            while i < n {
340                result += a[i] * a[i];
341                i += 1;
342            }
343
344            result
345        }
346    }
347}
348
349// =============================================================================
350// Unified dispatch functions
351// =============================================================================
352
353/// Compute L2 squared distance with best available SIMD
354#[inline]
355pub fn l2_squared(a: &[f32], b: &[f32]) -> f32 {
356    #[cfg(target_arch = "x86_64")]
357    {
358        if x86_simd::has_avx2() {
359            return unsafe { x86_simd::l2_squared_avx2(a, b) };
360        }
361        if x86_simd::has_sse41() {
362            return unsafe { x86_simd::l2_squared_sse41(a, b) };
363        }
364    }
365
366    #[cfg(target_arch = "aarch64")]
367    {
368        return neon_simd::l2_squared_neon(a, b);
369    }
370
371    #[allow(unreachable_code)]
372    l2_squared_scalar(a, b)
373}
374
375/// Compute dot product with best available SIMD
376#[inline]
377pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
378    #[cfg(target_arch = "x86_64")]
379    {
380        if x86_simd::has_avx2() {
381            return unsafe { x86_simd::dot_product_avx2(a, b) };
382        }
383        if x86_simd::has_sse41() {
384            return unsafe { x86_simd::dot_product_sse41(a, b) };
385        }
386    }
387
388    #[cfg(target_arch = "aarch64")]
389    {
390        return neon_simd::dot_product_neon(a, b);
391    }
392
393    #[allow(unreachable_code)]
394    dot_product_scalar(a, b)
395}
396
397/// Compute squared norm with best available SIMD
398#[inline]
399pub fn norm_squared(a: &[f32]) -> f32 {
400    #[cfg(target_arch = "x86_64")]
401    {
402        if x86_simd::has_avx2() {
403            return unsafe { x86_simd::norm_squared_avx2(a) };
404        }
405    }
406
407    #[cfg(target_arch = "aarch64")]
408    {
409        return neon_simd::norm_squared_neon(a);
410    }
411
412    #[allow(unreachable_code)]
413    norm_squared_scalar(a)
414}
415
416/// Compute cosine distance with best available SIMD
417/// Returns 1 - cosine_similarity, so 0 = identical, 2 = opposite
418#[inline]
419pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
420    let dot = dot_product(a, b);
421    let norm_a = norm_squared(a).sqrt();
422    let norm_b = norm_squared(b).sqrt();
423
424    if norm_a == 0.0 || norm_b == 0.0 {
425        return 1.0;
426    }
427
428    let cosine_sim = dot / (norm_a * norm_b);
429    1.0 - cosine_sim.clamp(-1.0, 1.0)
430}
431
432// =============================================================================
433// Distance trait implementations
434// =============================================================================
435
436impl Distance<f32> for SimdL2 {
437    fn eval(&self, a: &[f32], b: &[f32]) -> f32 {
438        l2_squared(a, b)
439    }
440}
441
442impl Distance<f32> for SimdDot {
443    fn eval(&self, a: &[f32], b: &[f32]) -> f32 {
444        // For ANN, we want distance (lower = closer)
445        // Assuming normalized vectors: distance = 1 - dot_product
446        1.0 - dot_product(a, b)
447    }
448}
449
450impl Distance<f32> for SimdCosine {
451    fn eval(&self, a: &[f32], b: &[f32]) -> f32 {
452        cosine_distance(a, b)
453    }
454}
455
456// =============================================================================
457// Runtime info
458// =============================================================================
459
460/// Returns information about SIMD capabilities
461pub fn simd_info() -> SimdInfo {
462    SimdInfo {
463        #[cfg(target_arch = "x86_64")]
464        avx2: x86_simd::has_avx2(),
465        #[cfg(not(target_arch = "x86_64"))]
466        avx2: false,
467
468        #[cfg(target_arch = "x86_64")]
469        sse41: x86_simd::has_sse41(),
470        #[cfg(not(target_arch = "x86_64"))]
471        sse41: false,
472
473        #[cfg(target_arch = "aarch64")]
474        neon: true,
475        #[cfg(not(target_arch = "aarch64"))]
476        neon: false,
477    }
478}
479
480/// Information about available SIMD features
481#[derive(Debug, Clone)]
482pub struct SimdInfo {
483    pub avx2: bool,
484    pub sse41: bool,
485    pub neon: bool,
486}
487
488impl std::fmt::Display for SimdInfo {
489    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
490        let mut features = Vec::new();
491        if self.avx2 {
492            features.push("AVX2");
493        }
494        if self.sse41 {
495            features.push("SSE4.1");
496        }
497        if self.neon {
498            features.push("NEON");
499        }
500        if features.is_empty() {
501            write!(f, "SIMD: none (scalar fallback)")
502        } else {
503            write!(f, "SIMD: {}", features.join(", "))
504        }
505    }
506}
507
508#[cfg(test)]
509mod tests {
510    use super::*;
511
512    #[test]
513    fn test_l2_squared_basic() {
514        let a = vec![1.0, 2.0, 3.0, 4.0];
515        let b = vec![5.0, 6.0, 7.0, 8.0];
516
517        let expected: f32 = a
518            .iter()
519            .zip(&b)
520            .map(|(x, y)| (x - y) * (x - y))
521            .sum();
522
523        let result = l2_squared(&a, &b);
524        assert!((result - expected).abs() < 1e-5, "expected {expected}, got {result}");
525    }
526
527    #[test]
528    fn test_l2_squared_large() {
529        // Test with dimension that requires multiple SIMD iterations + remainder
530        let dim = 133; // Not divisible by 4 or 8
531        let a: Vec<f32> = (0..dim).map(|i| i as f32).collect();
532        let b: Vec<f32> = (0..dim).map(|i| (i * 2) as f32).collect();
533
534        let expected = l2_squared_scalar(&a, &b);
535        let result = l2_squared(&a, &b);
536
537        assert!(
538            (result - expected).abs() < 1e-3,
539            "expected {expected}, got {result}"
540        );
541    }
542
543    #[test]
544    fn test_dot_product_basic() {
545        let a = vec![1.0, 2.0, 3.0, 4.0];
546        let b = vec![5.0, 6.0, 7.0, 8.0];
547
548        let expected: f32 = a.iter().zip(&b).map(|(x, y)| x * y).sum();
549        let result = dot_product(&a, &b);
550
551        assert!((result - expected).abs() < 1e-5, "expected {expected}, got {result}");
552    }
553
554    #[test]
555    fn test_dot_product_large() {
556        let dim = 128;
557        let a: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.01).collect();
558        let b: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.02).collect();
559
560        let expected = dot_product_scalar(&a, &b);
561        let result = dot_product(&a, &b);
562
563        assert!(
564            (result - expected).abs() < 1e-3,
565            "expected {expected}, got {result}"
566        );
567    }
568
569    #[test]
570    fn test_cosine_identical() {
571        let a = vec![1.0, 2.0, 3.0, 4.0];
572        let result = cosine_distance(&a, &a);
573        assert!(result.abs() < 1e-5, "identical vectors should have distance ~0, got {result}");
574    }
575
576    #[test]
577    fn test_cosine_orthogonal() {
578        let a = vec![1.0, 0.0];
579        let b = vec![0.0, 1.0];
580        let result = cosine_distance(&a, &b);
581        assert!((result - 1.0).abs() < 1e-5, "orthogonal vectors should have distance ~1, got {result}");
582    }
583
584    #[test]
585    fn test_cosine_opposite() {
586        let a = vec![1.0, 2.0, 3.0];
587        let b: Vec<f32> = a.iter().map(|x| -x).collect();
588        let result = cosine_distance(&a, &b);
589        assert!((result - 2.0).abs() < 1e-5, "opposite vectors should have distance ~2, got {result}");
590    }
591
592    #[test]
593    fn test_simd_info() {
594        let info = simd_info();
595        println!("{}", info);
596        // Just verify it doesn't panic
597    }
598
599    #[test]
600    fn test_distance_trait_impl() {
601        let a = vec![1.0, 2.0, 3.0, 4.0];
602        let b = vec![5.0, 6.0, 7.0, 8.0];
603
604        let l2 = SimdL2;
605        let result = l2.eval(&a, &b);
606        assert!(result > 0.0);
607
608        let cosine = SimdCosine;
609        let result = cosine.eval(&a, &b);
610        assert!(result >= 0.0 && result <= 2.0);
611    }
612}