Skip to main content

nodedb_vector/distance/simd/
avx2.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! AVX2+FMA kernels for x86_64.
4
5#![cfg(target_arch = "x86_64")]
6
7pub fn l2_squared(a: &[f32], b: &[f32]) -> f32 {
8    assert_eq!(a.len(), b.len(), "avx2 l2: length mismatch");
9    // SAFETY: caller verified avx2+fma via is_x86_feature_detected.
10    unsafe { l2_squared_impl(a, b) }
11}
12
13#[target_feature(enable = "avx2,fma")]
14unsafe fn l2_squared_impl(a: &[f32], b: &[f32]) -> f32 {
15    assert_eq!(a.len(), b.len(), "avx2 l2_impl: length mismatch");
16    unsafe {
17        use std::arch::x86_64::*;
18        let n = a.len();
19        let mut sum = _mm256_setzero_ps();
20        let chunks = n / 8;
21        for i in 0..chunks {
22            let off = i * 8;
23            let va = _mm256_loadu_ps(a.as_ptr().add(off));
24            let vb = _mm256_loadu_ps(b.as_ptr().add(off));
25            let diff = _mm256_sub_ps(va, vb);
26            sum = _mm256_fmadd_ps(diff, diff, sum);
27        }
28        let mut result = hsum256(sum);
29        for i in (chunks * 8)..n {
30            let d = a[i] - b[i];
31            result += d * d;
32        }
33        result
34    }
35}
36
37pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
38    assert_eq!(a.len(), b.len(), "avx2 cosine: length mismatch");
39    unsafe { cosine_impl(a, b) }
40}
41
42#[target_feature(enable = "avx2,fma")]
43unsafe fn cosine_impl(a: &[f32], b: &[f32]) -> f32 {
44    assert_eq!(a.len(), b.len(), "avx2 cosine_impl: length mismatch");
45    unsafe {
46        use std::arch::x86_64::*;
47        let n = a.len();
48        let mut vdot = _mm256_setzero_ps();
49        let mut vna = _mm256_setzero_ps();
50        let mut vnb = _mm256_setzero_ps();
51        let chunks = n / 8;
52        for i in 0..chunks {
53            let off = i * 8;
54            let va = _mm256_loadu_ps(a.as_ptr().add(off));
55            let vb = _mm256_loadu_ps(b.as_ptr().add(off));
56            vdot = _mm256_fmadd_ps(va, vb, vdot);
57            vna = _mm256_fmadd_ps(va, va, vna);
58            vnb = _mm256_fmadd_ps(vb, vb, vnb);
59        }
60        let mut dot = hsum256(vdot);
61        let mut na = hsum256(vna);
62        let mut nb = hsum256(vnb);
63        for i in (chunks * 8)..n {
64            dot += a[i] * b[i];
65            na += a[i] * a[i];
66            nb += b[i] * b[i];
67        }
68        let denom = (na * nb).sqrt();
69        if denom < f32::EPSILON {
70            1.0
71        } else {
72            (1.0 - dot / denom).max(0.0)
73        }
74    }
75}
76
77pub fn neg_inner_product(a: &[f32], b: &[f32]) -> f32 {
78    assert_eq!(a.len(), b.len(), "avx2 ip: length mismatch");
79    unsafe { ip_impl(a, b) }
80}
81
82#[target_feature(enable = "avx2,fma")]
83unsafe fn ip_impl(a: &[f32], b: &[f32]) -> f32 {
84    assert_eq!(a.len(), b.len(), "avx2 ip_impl: length mismatch");
85    unsafe {
86        use std::arch::x86_64::*;
87        let n = a.len();
88        let mut vdot = _mm256_setzero_ps();
89        let chunks = n / 8;
90        for i in 0..chunks {
91            let off = i * 8;
92            let va = _mm256_loadu_ps(a.as_ptr().add(off));
93            let vb = _mm256_loadu_ps(b.as_ptr().add(off));
94            vdot = _mm256_fmadd_ps(va, vb, vdot);
95        }
96        let mut dot = hsum256(vdot);
97        for i in (chunks * 8)..n {
98            dot += a[i] * b[i];
99        }
100        -dot
101    }
102}
103
104/// Horizontal sum of 8 × f32 in a __m256.
105#[target_feature(enable = "avx2")]
106unsafe fn hsum256(v: std::arch::x86_64::__m256) -> f32 {
107    use std::arch::x86_64::*;
108    let hi = _mm256_extractf128_ps(v, 1);
109    let lo = _mm256_castps256_ps128(v);
110    let sum128 = _mm_add_ps(lo, hi);
111    let shuf = _mm_movehdup_ps(sum128);
112    let sums = _mm_add_ps(sum128, shuf);
113    let shuf2 = _mm_movehl_ps(sums, sums);
114    let sums2 = _mm_add_ss(sums, shuf2);
115    _mm_cvtss_f32(sums2)
116}