ipfrs_semantic/
simd.rs

1//! SIMD-optimized distance computation
2//!
3//! This module provides platform-specific SIMD implementations for common
4//! distance metrics used in vector search. Supports:
5//! - ARM NEON (aarch64)
6//! - x86 SSE/AVX/AVX2/AVX-512
7//! - Scalar fallback for other platforms
8//!
9//! Distance metrics:
10//! - L2 (Euclidean) distance
11//! - Cosine similarity/distance
12//! - Dot product
13//!
14//! The implementations use runtime feature detection to select the best
15//! available instruction set for the current CPU.
16
17#[cfg(target_arch = "aarch64")]
18use std::arch::is_aarch64_feature_detected;
19#[cfg(target_arch = "x86_64")]
20use std::arch::is_x86_feature_detected;
21
22/// Compute L2 (Euclidean) distance between two vectors using SIMD
23#[inline]
24pub fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
25    debug_assert_eq!(a.len(), b.len(), "Vector dimensions must match");
26
27    #[cfg(target_arch = "aarch64")]
28    {
29        if is_aarch64_feature_detected!("neon") {
30            return unsafe { l2_distance_neon(a, b) };
31        }
32    }
33
34    #[cfg(target_arch = "x86_64")]
35    {
36        if is_x86_feature_detected!("avx2") {
37            return unsafe { l2_distance_avx2(a, b) };
38        }
39        if is_x86_feature_detected!("avx") {
40            return unsafe { l2_distance_avx(a, b) };
41        }
42        if is_x86_feature_detected!("sse") {
43            return unsafe { l2_distance_sse(a, b) };
44        }
45    }
46
47    // Fallback to scalar implementation
48    l2_distance_scalar(a, b)
49}
50
51/// Compute dot product between two vectors using SIMD
52#[inline]
53pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
54    debug_assert_eq!(a.len(), b.len(), "Vector dimensions must match");
55
56    #[cfg(target_arch = "aarch64")]
57    {
58        if is_aarch64_feature_detected!("neon") {
59            return unsafe { dot_product_neon(a, b) };
60        }
61    }
62
63    #[cfg(target_arch = "x86_64")]
64    {
65        if is_x86_feature_detected!("avx2") {
66            return unsafe { dot_product_avx2(a, b) };
67        }
68        if is_x86_feature_detected!("avx") {
69            return unsafe { dot_product_avx(a, b) };
70        }
71        if is_x86_feature_detected!("sse") {
72            return unsafe { dot_product_sse(a, b) };
73        }
74    }
75
76    // Fallback to scalar implementation
77    dot_product_scalar(a, b)
78}
79
80/// Compute cosine distance (1 - cosine similarity) using SIMD
81#[inline]
82pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
83    debug_assert_eq!(a.len(), b.len(), "Vector dimensions must match");
84
85    #[cfg(target_arch = "aarch64")]
86    {
87        if is_aarch64_feature_detected!("neon") {
88            return unsafe { cosine_distance_neon(a, b) };
89        }
90    }
91
92    #[cfg(target_arch = "x86_64")]
93    {
94        if is_x86_feature_detected!("avx2") {
95            return unsafe { cosine_distance_avx2(a, b) };
96        }
97        if is_x86_feature_detected!("avx") {
98            return unsafe { cosine_distance_avx(a, b) };
99        }
100        if is_x86_feature_detected!("sse") {
101            return unsafe { cosine_distance_sse(a, b) };
102        }
103    }
104
105    // Fallback to scalar implementation
106    cosine_distance_scalar(a, b)
107}
108
109// ============================================================================
110// ARM NEON implementations
111// ============================================================================
112
113#[cfg(target_arch = "aarch64")]
114#[target_feature(enable = "neon")]
115unsafe fn l2_distance_neon(a: &[f32], b: &[f32]) -> f32 {
116    use std::arch::aarch64::*;
117
118    let len = a.len();
119    let mut sum = vdupq_n_f32(0.0);
120    let mut i = 0;
121
122    // Process 4 floats at a time
123    while i + 4 <= len {
124        let va = vld1q_f32(a.as_ptr().add(i));
125        let vb = vld1q_f32(b.as_ptr().add(i));
126        let diff = vsubq_f32(va, vb);
127        sum = vfmaq_f32(sum, diff, diff); // sum += diff * diff
128        i += 4;
129    }
130
131    // Horizontal sum
132    let mut result = vaddvq_f32(sum);
133
134    // Handle remaining elements
135    while i < len {
136        let diff = a[i] - b[i];
137        result += diff * diff;
138        i += 1;
139    }
140
141    result.sqrt()
142}
143
144#[cfg(target_arch = "aarch64")]
145#[target_feature(enable = "neon")]
146unsafe fn dot_product_neon(a: &[f32], b: &[f32]) -> f32 {
147    use std::arch::aarch64::*;
148
149    let len = a.len();
150    let mut sum = vdupq_n_f32(0.0);
151    let mut i = 0;
152
153    // Process 4 floats at a time
154    while i + 4 <= len {
155        let va = vld1q_f32(a.as_ptr().add(i));
156        let vb = vld1q_f32(b.as_ptr().add(i));
157        sum = vfmaq_f32(sum, va, vb); // sum += va * vb
158        i += 4;
159    }
160
161    // Horizontal sum
162    let mut result = vaddvq_f32(sum);
163
164    // Handle remaining elements
165    while i < len {
166        result += a[i] * b[i];
167        i += 1;
168    }
169
170    result
171}
172
173#[cfg(target_arch = "aarch64")]
174#[target_feature(enable = "neon")]
175unsafe fn cosine_distance_neon(a: &[f32], b: &[f32]) -> f32 {
176    use std::arch::aarch64::*;
177
178    let len = a.len();
179    let mut dot = vdupq_n_f32(0.0);
180    let mut norm_a = vdupq_n_f32(0.0);
181    let mut norm_b = vdupq_n_f32(0.0);
182    let mut i = 0;
183
184    // Process 4 floats at a time
185    while i + 4 <= len {
186        let va = vld1q_f32(a.as_ptr().add(i));
187        let vb = vld1q_f32(b.as_ptr().add(i));
188        dot = vfmaq_f32(dot, va, vb);
189        norm_a = vfmaq_f32(norm_a, va, va);
190        norm_b = vfmaq_f32(norm_b, vb, vb);
191        i += 4;
192    }
193
194    // Horizontal sum
195    let mut dot_sum = vaddvq_f32(dot);
196    let mut norm_a_sum = vaddvq_f32(norm_a);
197    let mut norm_b_sum = vaddvq_f32(norm_b);
198
199    // Handle remaining elements
200    while i < len {
201        dot_sum += a[i] * b[i];
202        norm_a_sum += a[i] * a[i];
203        norm_b_sum += b[i] * b[i];
204        i += 1;
205    }
206
207    let similarity = dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt());
208    1.0 - similarity
209}
210
211// ============================================================================
212// x86 SSE implementations
213// ============================================================================
214
215#[cfg(target_arch = "x86_64")]
216#[target_feature(enable = "sse")]
217unsafe fn l2_distance_sse(a: &[f32], b: &[f32]) -> f32 {
218    use std::arch::x86_64::*;
219
220    let len = a.len();
221    let mut sum = _mm_setzero_ps();
222    let mut i = 0;
223
224    // Process 4 floats at a time
225    while i + 4 <= len {
226        let va = _mm_loadu_ps(a.as_ptr().add(i));
227        let vb = _mm_loadu_ps(b.as_ptr().add(i));
228        let diff = _mm_sub_ps(va, vb);
229        sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
230        i += 4;
231    }
232
233    // Horizontal sum
234    let mut result = horizontal_sum_sse(sum);
235
236    // Handle remaining elements
237    while i < len {
238        let diff = a[i] - b[i];
239        result += diff * diff;
240        i += 1;
241    }
242
243    result.sqrt()
244}
245
246#[cfg(target_arch = "x86_64")]
247#[target_feature(enable = "sse")]
248unsafe fn dot_product_sse(a: &[f32], b: &[f32]) -> f32 {
249    use std::arch::x86_64::*;
250
251    let len = a.len();
252    let mut sum = _mm_setzero_ps();
253    let mut i = 0;
254
255    // Process 4 floats at a time
256    while i + 4 <= len {
257        let va = _mm_loadu_ps(a.as_ptr().add(i));
258        let vb = _mm_loadu_ps(b.as_ptr().add(i));
259        sum = _mm_add_ps(sum, _mm_mul_ps(va, vb));
260        i += 4;
261    }
262
263    // Horizontal sum
264    let mut result = horizontal_sum_sse(sum);
265
266    // Handle remaining elements
267    while i < len {
268        result += a[i] * b[i];
269        i += 1;
270    }
271
272    result
273}
274
275#[cfg(target_arch = "x86_64")]
276#[target_feature(enable = "sse")]
277unsafe fn cosine_distance_sse(a: &[f32], b: &[f32]) -> f32 {
278    use std::arch::x86_64::*;
279
280    let len = a.len();
281    let mut dot = _mm_setzero_ps();
282    let mut norm_a = _mm_setzero_ps();
283    let mut norm_b = _mm_setzero_ps();
284    let mut i = 0;
285
286    // Process 4 floats at a time
287    while i + 4 <= len {
288        let va = _mm_loadu_ps(a.as_ptr().add(i));
289        let vb = _mm_loadu_ps(b.as_ptr().add(i));
290        dot = _mm_add_ps(dot, _mm_mul_ps(va, vb));
291        norm_a = _mm_add_ps(norm_a, _mm_mul_ps(va, va));
292        norm_b = _mm_add_ps(norm_b, _mm_mul_ps(vb, vb));
293        i += 4;
294    }
295
296    // Horizontal sum
297    let mut dot_sum = horizontal_sum_sse(dot);
298    let mut norm_a_sum = horizontal_sum_sse(norm_a);
299    let mut norm_b_sum = horizontal_sum_sse(norm_b);
300
301    // Handle remaining elements
302    while i < len {
303        dot_sum += a[i] * b[i];
304        norm_a_sum += a[i] * a[i];
305        norm_b_sum += b[i] * b[i];
306        i += 1;
307    }
308
309    let similarity = dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt());
310    1.0 - similarity
311}
312
313#[cfg(target_arch = "x86_64")]
314#[inline]
315unsafe fn horizontal_sum_sse(v: std::arch::x86_64::__m128) -> f32 {
316    use std::arch::x86_64::*;
317
318    let shuf = _mm_movehdup_ps(v);
319    let sums = _mm_add_ps(v, shuf);
320    let shuf = _mm_movehl_ps(shuf, sums);
321    let result = _mm_add_ss(sums, shuf);
322    _mm_cvtss_f32(result)
323}
324
325// ============================================================================
326// x86 AVX implementations
327// ============================================================================
328
329#[cfg(target_arch = "x86_64")]
330#[target_feature(enable = "avx")]
331unsafe fn l2_distance_avx(a: &[f32], b: &[f32]) -> f32 {
332    use std::arch::x86_64::*;
333
334    let len = a.len();
335    let mut sum = _mm256_setzero_ps();
336    let mut i = 0;
337
338    // Process 8 floats at a time
339    while i + 8 <= len {
340        let va = _mm256_loadu_ps(a.as_ptr().add(i));
341        let vb = _mm256_loadu_ps(b.as_ptr().add(i));
342        let diff = _mm256_sub_ps(va, vb);
343        sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
344        i += 8;
345    }
346
347    // Horizontal sum
348    let mut result = horizontal_sum_avx(sum);
349
350    // Handle remaining elements
351    while i < len {
352        let diff = a[i] - b[i];
353        result += diff * diff;
354        i += 1;
355    }
356
357    result.sqrt()
358}
359
360#[cfg(target_arch = "x86_64")]
361#[target_feature(enable = "avx")]
362unsafe fn dot_product_avx(a: &[f32], b: &[f32]) -> f32 {
363    use std::arch::x86_64::*;
364
365    let len = a.len();
366    let mut sum = _mm256_setzero_ps();
367    let mut i = 0;
368
369    // Process 8 floats at a time
370    while i + 8 <= len {
371        let va = _mm256_loadu_ps(a.as_ptr().add(i));
372        let vb = _mm256_loadu_ps(b.as_ptr().add(i));
373        sum = _mm256_add_ps(sum, _mm256_mul_ps(va, vb));
374        i += 8;
375    }
376
377    // Horizontal sum
378    let mut result = horizontal_sum_avx(sum);
379
380    // Handle remaining elements
381    while i < len {
382        result += a[i] * b[i];
383        i += 1;
384    }
385
386    result
387}
388
389#[cfg(target_arch = "x86_64")]
390#[target_feature(enable = "avx")]
391unsafe fn cosine_distance_avx(a: &[f32], b: &[f32]) -> f32 {
392    use std::arch::x86_64::*;
393
394    let len = a.len();
395    let mut dot = _mm256_setzero_ps();
396    let mut norm_a = _mm256_setzero_ps();
397    let mut norm_b = _mm256_setzero_ps();
398    let mut i = 0;
399
400    // Process 8 floats at a time
401    while i + 8 <= len {
402        let va = _mm256_loadu_ps(a.as_ptr().add(i));
403        let vb = _mm256_loadu_ps(b.as_ptr().add(i));
404        dot = _mm256_add_ps(dot, _mm256_mul_ps(va, vb));
405        norm_a = _mm256_add_ps(norm_a, _mm256_mul_ps(va, va));
406        norm_b = _mm256_add_ps(norm_b, _mm256_mul_ps(vb, vb));
407        i += 8;
408    }
409
410    // Horizontal sum
411    let mut dot_sum = horizontal_sum_avx(dot);
412    let mut norm_a_sum = horizontal_sum_avx(norm_a);
413    let mut norm_b_sum = horizontal_sum_avx(norm_b);
414
415    // Handle remaining elements
416    while i < len {
417        dot_sum += a[i] * b[i];
418        norm_a_sum += a[i] * a[i];
419        norm_b_sum += b[i] * b[i];
420        i += 1;
421    }
422
423    let similarity = dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt());
424    1.0 - similarity
425}
426
427#[cfg(target_arch = "x86_64")]
428#[inline]
429unsafe fn horizontal_sum_avx(v: std::arch::x86_64::__m256) -> f32 {
430    use std::arch::x86_64::*;
431
432    let hi = _mm256_extractf128_ps(v, 1);
433    let lo = _mm256_castps256_ps128(v);
434    let sum128 = _mm_add_ps(hi, lo);
435    horizontal_sum_sse(sum128)
436}
437
438// ============================================================================
439// x86 AVX2 implementations
440// ============================================================================
441
442#[cfg(target_arch = "x86_64")]
443#[target_feature(enable = "avx2")]
444unsafe fn l2_distance_avx2(a: &[f32], b: &[f32]) -> f32 {
445    // AVX2 doesn't add much for f32 operations, so just use AVX
446    l2_distance_avx(a, b)
447}
448
449#[cfg(target_arch = "x86_64")]
450#[target_feature(enable = "avx2")]
451unsafe fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
452    // AVX2 doesn't add much for f32 operations, so just use AVX
453    dot_product_avx(a, b)
454}
455
456#[cfg(target_arch = "x86_64")]
457#[target_feature(enable = "avx2")]
458unsafe fn cosine_distance_avx2(a: &[f32], b: &[f32]) -> f32 {
459    // AVX2 doesn't add much for f32 operations, so just use AVX
460    cosine_distance_avx(a, b)
461}
462
463// ============================================================================
464// Scalar fallback implementations
465// ============================================================================
466
467#[inline]
468fn l2_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
469    a.iter()
470        .zip(b.iter())
471        .map(|(x, y)| {
472            let diff = x - y;
473            diff * diff
474        })
475        .sum::<f32>()
476        .sqrt()
477}
478
479#[inline]
480fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
481    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
482}
483
484#[inline]
485fn cosine_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
486    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
487    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
488    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
489    1.0 - (dot / (norm_a * norm_b))
490}
491
492// ============================================================================
493// Tests
494// ============================================================================
495
496#[cfg(test)]
497mod tests {
498    use super::*;
499
500    #[test]
501    fn test_l2_distance() {
502        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
503        let b = vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
504
505        let dist = l2_distance(&a, &b);
506        let expected = (8.0_f32).sqrt(); // sqrt(8 * 1^2)
507
508        assert!((dist - expected).abs() < 1e-5, "L2 distance mismatch");
509    }
510
511    #[test]
512    fn test_dot_product() {
513        let a = vec![1.0, 2.0, 3.0, 4.0];
514        let b = vec![5.0, 6.0, 7.0, 8.0];
515
516        let dot = dot_product(&a, &b);
517        let expected = 1.0 * 5.0 + 2.0 * 6.0 + 3.0 * 7.0 + 4.0 * 8.0;
518
519        assert!((dot - expected).abs() < 1e-5, "Dot product mismatch");
520    }
521
522    #[test]
523    fn test_cosine_distance() {
524        let a = vec![1.0, 0.0, 0.0, 0.0];
525        let b = vec![1.0, 0.0, 0.0, 0.0];
526
527        let dist = cosine_distance(&a, &b);
528
529        // Same vectors should have cosine distance close to 0
530        assert!(
531            dist.abs() < 1e-5,
532            "Cosine distance should be 0 for identical vectors"
533        );
534    }
535
536    #[test]
537    fn test_cosine_distance_orthogonal() {
538        let a = vec![1.0, 0.0, 0.0, 0.0];
539        let b = vec![0.0, 1.0, 0.0, 0.0];
540
541        let dist = cosine_distance(&a, &b);
542
543        // Orthogonal vectors should have cosine distance close to 1
544        assert!(
545            (dist - 1.0).abs() < 1e-5,
546            "Cosine distance should be 1 for orthogonal vectors"
547        );
548    }
549
550    #[test]
551    fn test_simd_vs_scalar_l2() {
552        let a: Vec<f32> = (0..128).map(|i| i as f32 * 0.1).collect();
553        let b: Vec<f32> = (0..128).map(|i| (i as f32 + 1.0) * 0.1).collect();
554
555        let simd_result = l2_distance(&a, &b);
556        let scalar_result = l2_distance_scalar(&a, &b);
557
558        assert!(
559            (simd_result - scalar_result).abs() < 1e-4,
560            "SIMD and scalar L2 results should match"
561        );
562    }
563
564    #[test]
565    fn test_simd_vs_scalar_dot() {
566        let a: Vec<f32> = (0..128).map(|i| i as f32 * 0.1).collect();
567        let b: Vec<f32> = (0..128).map(|i| (i as f32 + 1.0) * 0.1).collect();
568
569        let simd_result = dot_product(&a, &b);
570        let scalar_result = dot_product_scalar(&a, &b);
571
572        assert!(
573            (simd_result - scalar_result).abs() < 1e-3,
574            "SIMD and scalar dot product results should match"
575        );
576    }
577
578    #[test]
579    fn test_simd_vs_scalar_cosine() {
580        let a: Vec<f32> = (0..128).map(|i| (i as f32 * 0.1) + 1.0).collect();
581        let b: Vec<f32> = (0..128).map(|i| ((i as f32 + 1.0) * 0.1) + 1.0).collect();
582
583        let simd_result = cosine_distance(&a, &b);
584        let scalar_result = cosine_distance_scalar(&a, &b);
585
586        assert!(
587            (simd_result - scalar_result).abs() < 1e-4,
588            "SIMD and scalar cosine results should match"
589        );
590    }
591}