alopex_core/vector/
simd.rs

1//! SIMD ベースの距離カーネル群。
2//!
3//! - DistanceKernel: 共通インターフェイス(Send + Sync)
4//! - ScalarKernel: 参照実装(unsafe なし)
5//! - Avx2Kernel: x86_64 AVX2 実装(条件コンパイル)
6//! - NeonKernel: aarch64 NEON 実装(条件コンパイル)
7//! - select_kernel: 実行時に最適カーネルを選択
8
9use crate::vector::Metric;
10
11/// 距離カーネルの共通インターフェイス。
12pub trait DistanceKernel: Send + Sync {
13    /// コサイン類似度(0.0〜1.0、ゼロノルム時は0.0)。
14    fn cosine(&self, query: &[f32], vector: &[f32]) -> f32;
15    /// L2 距離(負値、大きいほど近い)。
16    fn l2(&self, query: &[f32], vector: &[f32]) -> f32;
17    /// 内積。
18    fn inner_product(&self, query: &[f32], vector: &[f32]) -> f32;
19    /// バッチスコアリング。`vectors` は `dimension * n` の連続配列。
20    fn batch_score(
21        &self,
22        metric: Metric,
23        query: &[f32],
24        vectors: &[f32],
25        dimension: usize,
26        scores: &mut [f32],
27    );
28}
29
30/// スカラーカーネル(リファレンス実装)。
31#[derive(Debug, Default)]
32pub struct ScalarKernel;
33
34impl ScalarKernel {
35    #[inline]
36    fn dot(query: &[f32], vector: &[f32]) -> f32 {
37        query
38            .iter()
39            .zip(vector.iter())
40            .map(|(a, b)| a * b)
41            .sum::<f32>()
42    }
43
44    #[inline]
45    fn norm(v: &[f32]) -> f32 {
46        v.iter().map(|x| x * x).sum::<f32>().sqrt()
47    }
48}
49
50impl DistanceKernel for ScalarKernel {
51    fn cosine(&self, query: &[f32], vector: &[f32]) -> f32 {
52        let dot = Self::dot(query, vector);
53        let q_norm = Self::norm(query);
54        let v_norm = Self::norm(vector);
55        if q_norm == 0.0 || v_norm == 0.0 {
56            0.0
57        } else {
58            dot / (q_norm * v_norm)
59        }
60    }
61
62    fn l2(&self, query: &[f32], vector: &[f32]) -> f32 {
63        let dist = query
64            .iter()
65            .zip(vector.iter())
66            .map(|(a, b)| {
67                let d = a - b;
68                d * d
69            })
70            .sum::<f32>()
71            .sqrt();
72        -dist
73    }
74
75    fn inner_product(&self, query: &[f32], vector: &[f32]) -> f32 {
76        Self::dot(query, vector)
77    }
78
79    fn batch_score(
80        &self,
81        metric: Metric,
82        query: &[f32],
83        vectors: &[f32],
84        dimension: usize,
85        scores: &mut [f32],
86    ) {
87        for (i, chunk) in vectors.chunks(dimension).enumerate() {
88            if i >= scores.len() {
89                break;
90            }
91            scores[i] = match metric {
92                Metric::Cosine => self.cosine(query, chunk),
93                Metric::L2 => self.l2(query, chunk),
94                Metric::InnerProduct => self.inner_product(query, chunk),
95            };
96        }
97    }
98}
99
100// ============================================================================
101// AVX2 実装 (x86_64)
102// ============================================================================
103
104#[cfg(target_arch = "x86_64")]
105mod avx2 {
106    use super::{DistanceKernel, Metric};
107    use std::arch::x86_64::*;
108
109    #[derive(Debug, Default)]
110    pub struct Avx2Kernel;
111
112    #[inline]
113    fn horizontal_sum_ps(v: __m256) -> f32 {
114        unsafe {
115            // Reduce 8 lanes -> 4 -> 2 -> 1 using horizontal add.
116            let lo = _mm256_castps256_ps128(v);
117            let hi = _mm256_extractf128_ps(v, 1);
118            let sum128 = _mm_add_ps(lo, hi); // [a0+a4, a1+a5, a2+a6, a3+a7]
119            let sum64 = _mm_hadd_ps(sum128, sum128); // [a0+a4+a1+a5, a2+a6+a3+a7, ...]
120            let sum32 = _mm_hadd_ps(sum64, sum64); // [total, total, ...]
121            _mm_cvtss_f32(sum32)
122        }
123    }
124
125    impl Avx2Kernel {
126        #[inline]
127        unsafe fn dot(query: &[f32], vector: &[f32]) -> f32 {
128            let mut acc = _mm256_setzero_ps();
129            let mut i = 0;
130            while i + 8 <= query.len() {
131                let q = _mm256_loadu_ps(query.as_ptr().add(i));
132                let v = _mm256_loadu_ps(vector.as_ptr().add(i));
133                // AVX2+FMA 前提: FMA 非搭載AVX2はサポート対象外(設計方針)。
134                acc = _mm256_fmadd_ps(q, v, acc);
135                i += 8;
136            }
137            let mut sum = horizontal_sum_ps(acc);
138            for j in i..query.len() {
139                sum += *query.get_unchecked(j) * *vector.get_unchecked(j);
140            }
141            sum
142        }
143
144        #[inline]
145        unsafe fn norm(v: &[f32]) -> f32 {
146            let mut acc = _mm256_setzero_ps();
147            let mut i = 0;
148            while i + 8 <= v.len() {
149                let x = _mm256_loadu_ps(v.as_ptr().add(i));
150                acc = _mm256_fmadd_ps(x, x, acc);
151                i += 8;
152            }
153            let mut sum = horizontal_sum_ps(acc);
154            for j in i..v.len() {
155                let x = *v.get_unchecked(j);
156                sum += x * x;
157            }
158            sum.sqrt()
159        }
160
161        #[inline]
162        unsafe fn cosine_impl(&self, query: &[f32], vector: &[f32]) -> f32 {
163            let dot = Self::dot(query, vector);
164            let q_norm = Self::norm(query);
165            let v_norm = Self::norm(vector);
166            if q_norm == 0.0 || v_norm == 0.0 {
167                0.0
168            } else {
169                dot / (q_norm * v_norm)
170            }
171        }
172
173        #[inline]
174        unsafe fn l2_impl(&self, query: &[f32], vector: &[f32]) -> f32 {
175            let mut acc = _mm256_setzero_ps();
176            let mut i = 0;
177            while i + 8 <= query.len() {
178                let q = _mm256_loadu_ps(query.as_ptr().add(i));
179                let v = _mm256_loadu_ps(vector.as_ptr().add(i));
180                let diff = _mm256_sub_ps(q, v);
181                acc = _mm256_fmadd_ps(diff, diff, acc);
182                i += 8;
183            }
184            let mut sum = horizontal_sum_ps(acc);
185            for j in i..query.len() {
186                let d = *query.get_unchecked(j) - *vector.get_unchecked(j);
187                sum += d * d;
188            }
189            -sum.sqrt()
190        }
191    }
192
193    impl DistanceKernel for Avx2Kernel {
194        fn cosine(&self, query: &[f32], vector: &[f32]) -> f32 {
195            unsafe { self.cosine_impl(query, vector) }
196        }
197
198        fn l2(&self, query: &[f32], vector: &[f32]) -> f32 {
199            unsafe { self.l2_impl(query, vector) }
200        }
201
202        fn inner_product(&self, query: &[f32], vector: &[f32]) -> f32 {
203            unsafe { Self::dot(query, vector) }
204        }
205
206        fn batch_score(
207            &self,
208            metric: Metric,
209            query: &[f32],
210            vectors: &[f32],
211            dimension: usize,
212            scores: &mut [f32],
213        ) {
214            for (i, chunk) in vectors.chunks(dimension).enumerate() {
215                if i >= scores.len() {
216                    break;
217                }
218                scores[i] = match metric {
219                    Metric::Cosine => unsafe { self.cosine_impl(query, chunk) },
220                    Metric::L2 => unsafe { self.l2_impl(query, chunk) },
221                    Metric::InnerProduct => unsafe { Self::dot(query, chunk) },
222                };
223            }
224        }
225    }
226
227    pub fn create() -> Box<dyn DistanceKernel> {
228        Box::new(Avx2Kernel)
229    }
230
231    #[cfg(all(test, not(target_arch = "wasm32")))]
232    mod tests {
233        use super::*;
234
235        #[test]
236        fn horizontal_sum_correct_for_ones() {
237            if !std::is_x86_feature_detected!("avx2") {
238                return;
239            }
240            unsafe {
241                let v = _mm256_set1_ps(1.0);
242                let total = horizontal_sum_ps(v);
243                assert!((total - 8.0).abs() < 1e-6);
244            }
245        }
246    }
247}
248
249// ============================================================================
250// NEON 実装 (aarch64)
251// ============================================================================
252
253#[cfg(target_arch = "aarch64")]
254mod neon {
255    use super::{DistanceKernel, Metric};
256    use core::arch::aarch64::*;
257
258    #[derive(Debug, Default)]
259    pub struct NeonKernel;
260
261    #[inline]
262    unsafe fn horizontal_sum(v: float32x4_t) -> f32 {
263        let pair_sum = vadd_f32(vget_low_f32(v), vget_high_f32(v));
264        let sum = vpadd_f32(pair_sum, pair_sum);
265        vget_lane_f32(sum, 0)
266    }
267
268    #[inline]
269    unsafe fn dot(query: &[f32], vector: &[f32]) -> f32 {
270        let mut acc = vdupq_n_f32(0.0);
271        let mut i = 0;
272        while i + 4 <= query.len() {
273            let q = vld1q_f32(query.as_ptr().add(i));
274            let v = vld1q_f32(vector.as_ptr().add(i));
275            acc = vfmaq_f32(acc, q, v);
276            i += 4;
277        }
278        let mut sum = horizontal_sum(acc);
279        for j in i..query.len() {
280            sum += *query.get_unchecked(j) * *vector.get_unchecked(j);
281        }
282        sum
283    }
284
285    #[inline]
286    unsafe fn norm(v: &[f32]) -> f32 {
287        let mut acc = vdupq_n_f32(0.0);
288        let mut i = 0;
289        while i + 4 <= v.len() {
290            let x = vld1q_f32(v.as_ptr().add(i));
291            acc = vfmaq_f32(acc, x, x);
292            i += 4;
293        }
294        let mut sum = horizontal_sum(acc);
295        for j in i..v.len() {
296            let x = *v.get_unchecked(j);
297            sum += x * x;
298        }
299        sum.sqrt()
300    }
301
302    impl DistanceKernel for NeonKernel {
303        fn cosine(&self, query: &[f32], vector: &[f32]) -> f32 {
304            unsafe {
305                let dot = dot(query, vector);
306                let q_norm = norm(query);
307                let v_norm = norm(vector);
308                if q_norm == 0.0 || v_norm == 0.0 {
309                    0.0
310                } else {
311                    dot / (q_norm * v_norm)
312                }
313            }
314        }
315
316        fn l2(&self, query: &[f32], vector: &[f32]) -> f32 {
317            unsafe {
318                let mut acc = vdupq_n_f32(0.0);
319                let mut i = 0;
320                while i + 4 <= query.len() {
321                    let q = vld1q_f32(query.as_ptr().add(i));
322                    let v = vld1q_f32(vector.as_ptr().add(i));
323                    let diff = vsubq_f32(q, v);
324                    acc = vfmaq_f32(acc, diff, diff);
325                    i += 4;
326                }
327                let mut sum = horizontal_sum(acc);
328                for j in i..query.len() {
329                    let d = *query.get_unchecked(j) - *vector.get_unchecked(j);
330                    sum += d * d;
331                }
332                -sum.sqrt()
333            }
334        }
335
336        fn inner_product(&self, query: &[f32], vector: &[f32]) -> f32 {
337            unsafe { dot(query, vector) }
338        }
339
340        fn batch_score(
341            &self,
342            metric: Metric,
343            query: &[f32],
344            vectors: &[f32],
345            dimension: usize,
346            scores: &mut [f32],
347        ) {
348            for (i, chunk) in vectors.chunks(dimension).enumerate() {
349                if i >= scores.len() {
350                    break;
351                }
352                scores[i] = match metric {
353                    Metric::Cosine => self.cosine(query, chunk),
354                    Metric::L2 => self.l2(query, chunk),
355                    Metric::InnerProduct => self.inner_product(query, chunk),
356                };
357            }
358        }
359    }
360
361    pub fn create() -> Box<dyn DistanceKernel> {
362        Box::new(NeonKernel)
363    }
364}
365
366/// 実行時に最適なカーネルを選択する。
367pub fn select_kernel() -> Box<dyn DistanceKernel> {
368    #[cfg(target_arch = "x86_64")]
369    {
370        if std::is_x86_feature_detected!("avx2") {
371            return avx2::create();
372        }
373    }
374
375    #[cfg(target_arch = "aarch64")]
376    {
377        if std::arch::is_aarch64_feature_detected!("neon") {
378            return neon::create();
379        }
380    }
381
382    Box::new(ScalarKernel)
383}
384
385#[cfg(all(test, not(target_arch = "wasm32")))]
386mod tests {
387    use super::*;
388    use crate::vector::score;
389
390    #[test]
391    fn scalar_matches_reference() {
392        let k = ScalarKernel;
393        let q = [1.0, 2.0, 3.0, 4.0];
394        let v = [4.0, 3.0, 2.0, 1.0];
395        let metrics = [Metric::Cosine, Metric::L2, Metric::InnerProduct];
396        for &m in &metrics {
397            let ref_score = score(m, &q, &v).unwrap();
398            let k_score = match m {
399                Metric::Cosine => k.cosine(&q, &v),
400                Metric::L2 => k.l2(&q, &v),
401                Metric::InnerProduct => k.inner_product(&q, &v),
402            };
403            assert!((ref_score - k_score).abs() < 1e-6);
404        }
405    }
406
407    #[test]
408    fn scalar_cosine_zero_norm_returns_zero() {
409        let k = ScalarKernel;
410        let q = [0.0, 0.0, 0.0];
411        let v = [1.0, 2.0, 3.0];
412        assert_eq!(k.cosine(&q, &v), 0.0);
413    }
414
415    #[test]
416    fn batch_score_populates_all() {
417        let k = ScalarKernel;
418        let q = [1.0, 0.0];
419        let vectors = [1.0, 0.0, 0.0, 1.0];
420        let mut scores = [0.0f32; 2];
421        k.batch_score(Metric::InnerProduct, &q, &vectors, 2, &mut scores);
422        assert_eq!(scores[0], 1.0);
423        assert_eq!(scores[1], 0.0);
424    }
425
426    #[test]
427    fn select_kernel_returns_any() {
428        let k = select_kernel();
429        let q = [1.0, 2.0];
430        let v = [2.0, 1.0];
431        let s = k.inner_product(&q, &v);
432        assert!((s - 4.0).abs() < 1e-6);
433    }
434
435    #[test]
436    fn select_kernel_matches_scalar_for_all_metrics() {
437        let kernel = select_kernel();
438        let scalar = ScalarKernel;
439        let q = vec![1.0f32, 2.0, 3.0, 4.0];
440        let v1 = vec![4.0f32, 3.0, 2.0, 1.0];
441        let v2 = vec![1.0f32, 1.0, 1.0, 1.0];
442
443        let metrics = [Metric::Cosine, Metric::L2, Metric::InnerProduct];
444        for &m in &metrics {
445            let s1 = match m {
446                Metric::Cosine => scalar.cosine(&q, &v1),
447                Metric::L2 => scalar.l2(&q, &v1),
448                Metric::InnerProduct => scalar.inner_product(&q, &v1),
449            };
450            let k1 = match m {
451                Metric::Cosine => kernel.cosine(&q, &v1),
452                Metric::L2 => kernel.l2(&q, &v1),
453                Metric::InnerProduct => kernel.inner_product(&q, &v1),
454            };
455            assert!((s1 - k1).abs() < 1e-6);
456
457            let s2 = match m {
458                Metric::Cosine => scalar.cosine(&q, &v2),
459                Metric::L2 => scalar.l2(&q, &v2),
460                Metric::InnerProduct => scalar.inner_product(&q, &v2),
461            };
462            let k2 = match m {
463                Metric::Cosine => kernel.cosine(&q, &v2),
464                Metric::L2 => kernel.l2(&q, &v2),
465                Metric::InnerProduct => kernel.inner_product(&q, &v2),
466            };
467            assert!((s2 - k2).abs() < 1e-6);
468        }
469    }
470
471    fn assert_same_f32(a: f32, b: f32) {
472        if a.is_nan() && b.is_nan() {
473            return;
474        }
475        if a.is_infinite() && b.is_infinite() {
476            assert_eq!(a.is_sign_positive(), b.is_sign_positive());
477            return;
478        }
479        assert!((a - b).abs() < 1e-5, "a={a}, b={b}");
480    }
481
482    #[test]
483    fn kernel_handles_nan_and_inf_like_scalar() {
484        let kernel = select_kernel();
485        let scalar = ScalarKernel;
486        let cases = vec![
487            (
488                Metric::Cosine,
489                vec![f32::NAN, 1.0, 2.0],
490                vec![1.0, 2.0, 3.0],
491            ),
492            (
493                Metric::InnerProduct,
494                vec![f32::INFINITY, 1.0],
495                vec![1.0, 2.0],
496            ),
497            (Metric::L2, vec![f32::INFINITY, 0.0], vec![1.0, 0.0]),
498        ];
499
500        for (metric, q, v) in cases {
501            let s = match metric {
502                Metric::Cosine => scalar.cosine(&q, &v),
503                Metric::L2 => scalar.l2(&q, &v),
504                Metric::InnerProduct => scalar.inner_product(&q, &v),
505            };
506            let k = match metric {
507                Metric::Cosine => kernel.cosine(&q, &v),
508                Metric::L2 => kernel.l2(&q, &v),
509                Metric::InnerProduct => kernel.inner_product(&q, &v),
510            };
511            assert_same_f32(s, k);
512        }
513    }
514
515    #[test]
516    fn cosine_with_nan_matches_scalar() {
517        let kernel = select_kernel();
518        let scalar = ScalarKernel;
519        let q = [f32::NAN, 1.0, 2.0, 3.0];
520        let v = [1.0, 2.0, 3.0, 4.0];
521        let s = scalar.cosine(&q, &v);
522        let k = kernel.cosine(&q, &v);
523        assert_same_f32(s, k);
524    }
525
526    #[test]
527    fn l2_with_inf_matches_scalar() {
528        let kernel = select_kernel();
529        let scalar = ScalarKernel;
530        let q = [f32::INFINITY, 0.0, 1.0];
531        let v = [1.0, 0.0, 1.0];
532        let s = scalar.l2(&q, &v);
533        let k = kernel.l2(&q, &v);
534        assert_same_f32(s, k);
535    }
536
537    #[test]
538    fn inner_product_with_nan_matches_scalar() {
539        let kernel = select_kernel();
540        let scalar = ScalarKernel;
541        let q = [1.0, f32::NAN];
542        let v = [2.0, 3.0];
543        let s = scalar.inner_product(&q, &v);
544        let k = kernel.inner_product(&q, &v);
545        assert_same_f32(s, k);
546    }
547
548    #[test]
549    fn batch_score_propagates_nan_inf_like_scalar() {
550        let kernel = select_kernel();
551        let scalar = ScalarKernel;
552        let q = [1.0, f32::NAN];
553        let vectors = [2.0, 3.0, f32::INFINITY, 0.0];
554        let mut scores_kernel = [0.0f32; 2];
555        let mut scores_scalar = [0.0f32; 2];
556        kernel.batch_score(Metric::InnerProduct, &q, &vectors, 2, &mut scores_kernel);
557        scalar.batch_score(Metric::InnerProduct, &q, &vectors, 2, &mut scores_scalar);
558        for (a, b) in scores_scalar.iter().zip(scores_kernel.iter()) {
559            assert_same_f32(*a, *b);
560        }
561    }
562}