Skip to main content

nodedb_vector/distance/simd/
avx512.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! AVX-512 kernels for x86_64.
4
5#![cfg(target_arch = "x86_64")]
6
7pub fn l2_squared(a: &[f32], b: &[f32]) -> f32 {
8    assert_eq!(a.len(), b.len(), "avx512 l2: length mismatch");
9    unsafe { l2_impl(a, b) }
10}
11
12#[target_feature(enable = "avx512f")]
13unsafe fn l2_impl(a: &[f32], b: &[f32]) -> f32 {
14    assert_eq!(a.len(), b.len(), "avx512 l2_impl: length mismatch");
15    unsafe {
16        use std::arch::x86_64::*;
17        let n = a.len();
18        let mut sum = _mm512_setzero_ps();
19        let chunks = n / 16;
20        for i in 0..chunks {
21            let off = i * 16;
22            let va = _mm512_loadu_ps(a.as_ptr().add(off));
23            let vb = _mm512_loadu_ps(b.as_ptr().add(off));
24            let diff = _mm512_sub_ps(va, vb);
25            sum = _mm512_fmadd_ps(diff, diff, sum);
26        }
27        let mut result = _mm512_reduce_add_ps(sum);
28        for i in (chunks * 16)..n {
29            let d = a[i] - b[i];
30            result += d * d;
31        }
32        result
33    }
34}
35
36pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
37    assert_eq!(a.len(), b.len(), "avx512 cosine: length mismatch");
38    unsafe { cosine_impl(a, b) }
39}
40
41#[target_feature(enable = "avx512f")]
42unsafe fn cosine_impl(a: &[f32], b: &[f32]) -> f32 {
43    assert_eq!(a.len(), b.len(), "avx512 cosine_impl: length mismatch");
44    unsafe {
45        use std::arch::x86_64::*;
46        let n = a.len();
47        let mut vdot = _mm512_setzero_ps();
48        let mut vna = _mm512_setzero_ps();
49        let mut vnb = _mm512_setzero_ps();
50        let chunks = n / 16;
51        for i in 0..chunks {
52            let off = i * 16;
53            let va = _mm512_loadu_ps(a.as_ptr().add(off));
54            let vb = _mm512_loadu_ps(b.as_ptr().add(off));
55            vdot = _mm512_fmadd_ps(va, vb, vdot);
56            vna = _mm512_fmadd_ps(va, va, vna);
57            vnb = _mm512_fmadd_ps(vb, vb, vnb);
58        }
59        let mut dot = _mm512_reduce_add_ps(vdot);
60        let mut na = _mm512_reduce_add_ps(vna);
61        let mut nb = _mm512_reduce_add_ps(vnb);
62        for i in (chunks * 16)..n {
63            dot += a[i] * b[i];
64            na += a[i] * a[i];
65            nb += b[i] * b[i];
66        }
67        let denom = (na * nb).sqrt();
68        if denom < f32::EPSILON {
69            1.0
70        } else {
71            (1.0 - dot / denom).max(0.0)
72        }
73    }
74}
75
76pub fn neg_inner_product(a: &[f32], b: &[f32]) -> f32 {
77    assert_eq!(a.len(), b.len(), "avx512 ip: length mismatch");
78    unsafe { ip_impl(a, b) }
79}
80
81#[target_feature(enable = "avx512f")]
82unsafe fn ip_impl(a: &[f32], b: &[f32]) -> f32 {
83    assert_eq!(a.len(), b.len(), "avx512 ip_impl: length mismatch");
84    unsafe {
85        use std::arch::x86_64::*;
86        let n = a.len();
87        let mut vdot = _mm512_setzero_ps();
88        let chunks = n / 16;
89        for i in 0..chunks {
90            let off = i * 16;
91            let va = _mm512_loadu_ps(a.as_ptr().add(off));
92            let vb = _mm512_loadu_ps(b.as_ptr().add(off));
93            vdot = _mm512_fmadd_ps(va, vb, vdot);
94        }
95        let mut dot = _mm512_reduce_add_ps(vdot);
96        for i in (chunks * 16)..n {
97            dot += a[i] * b[i];
98        }
99        -dot
100    }
101}