nodedb_vector/distance/simd/
avx512.rs1#![cfg(target_arch = "x86_64")]
6
7pub fn l2_squared(a: &[f32], b: &[f32]) -> f32 {
8 assert_eq!(a.len(), b.len(), "avx512 l2: length mismatch");
9 unsafe { l2_impl(a, b) }
10}
11
12#[target_feature(enable = "avx512f")]
13unsafe fn l2_impl(a: &[f32], b: &[f32]) -> f32 {
14 assert_eq!(a.len(), b.len(), "avx512 l2_impl: length mismatch");
15 unsafe {
16 use std::arch::x86_64::*;
17 let n = a.len();
18 let mut sum = _mm512_setzero_ps();
19 let chunks = n / 16;
20 for i in 0..chunks {
21 let off = i * 16;
22 let va = _mm512_loadu_ps(a.as_ptr().add(off));
23 let vb = _mm512_loadu_ps(b.as_ptr().add(off));
24 let diff = _mm512_sub_ps(va, vb);
25 sum = _mm512_fmadd_ps(diff, diff, sum);
26 }
27 let mut result = _mm512_reduce_add_ps(sum);
28 for i in (chunks * 16)..n {
29 let d = a[i] - b[i];
30 result += d * d;
31 }
32 result
33 }
34}
35
36pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
37 assert_eq!(a.len(), b.len(), "avx512 cosine: length mismatch");
38 unsafe { cosine_impl(a, b) }
39}
40
41#[target_feature(enable = "avx512f")]
42unsafe fn cosine_impl(a: &[f32], b: &[f32]) -> f32 {
43 assert_eq!(a.len(), b.len(), "avx512 cosine_impl: length mismatch");
44 unsafe {
45 use std::arch::x86_64::*;
46 let n = a.len();
47 let mut vdot = _mm512_setzero_ps();
48 let mut vna = _mm512_setzero_ps();
49 let mut vnb = _mm512_setzero_ps();
50 let chunks = n / 16;
51 for i in 0..chunks {
52 let off = i * 16;
53 let va = _mm512_loadu_ps(a.as_ptr().add(off));
54 let vb = _mm512_loadu_ps(b.as_ptr().add(off));
55 vdot = _mm512_fmadd_ps(va, vb, vdot);
56 vna = _mm512_fmadd_ps(va, va, vna);
57 vnb = _mm512_fmadd_ps(vb, vb, vnb);
58 }
59 let mut dot = _mm512_reduce_add_ps(vdot);
60 let mut na = _mm512_reduce_add_ps(vna);
61 let mut nb = _mm512_reduce_add_ps(vnb);
62 for i in (chunks * 16)..n {
63 dot += a[i] * b[i];
64 na += a[i] * a[i];
65 nb += b[i] * b[i];
66 }
67 let denom = (na * nb).sqrt();
68 if denom < f32::EPSILON {
69 1.0
70 } else {
71 (1.0 - dot / denom).max(0.0)
72 }
73 }
74}
75
76pub fn neg_inner_product(a: &[f32], b: &[f32]) -> f32 {
77 assert_eq!(a.len(), b.len(), "avx512 ip: length mismatch");
78 unsafe { ip_impl(a, b) }
79}
80
81#[target_feature(enable = "avx512f")]
82unsafe fn ip_impl(a: &[f32], b: &[f32]) -> f32 {
83 assert_eq!(a.len(), b.len(), "avx512 ip_impl: length mismatch");
84 unsafe {
85 use std::arch::x86_64::*;
86 let n = a.len();
87 let mut vdot = _mm512_setzero_ps();
88 let chunks = n / 16;
89 for i in 0..chunks {
90 let off = i * 16;
91 let va = _mm512_loadu_ps(a.as_ptr().add(off));
92 let vb = _mm512_loadu_ps(b.as_ptr().add(off));
93 vdot = _mm512_fmadd_ps(va, vb, vdot);
94 }
95 let mut dot = _mm512_reduce_add_ps(vdot);
96 for i in (chunks * 16)..n {
97 dot += a[i] * b[i];
98 }
99 -dot
100 }
101}