nodedb_vector/distance/simd/
avx2.rs1#![cfg(target_arch = "x86_64")]
6
7pub fn l2_squared(a: &[f32], b: &[f32]) -> f32 {
8 assert_eq!(a.len(), b.len(), "avx2 l2: length mismatch");
9 unsafe { l2_squared_impl(a, b) }
11}
12
13#[target_feature(enable = "avx2,fma")]
14unsafe fn l2_squared_impl(a: &[f32], b: &[f32]) -> f32 {
15 assert_eq!(a.len(), b.len(), "avx2 l2_impl: length mismatch");
16 unsafe {
17 use std::arch::x86_64::*;
18 let n = a.len();
19 let mut sum = _mm256_setzero_ps();
20 let chunks = n / 8;
21 for i in 0..chunks {
22 let off = i * 8;
23 let va = _mm256_loadu_ps(a.as_ptr().add(off));
24 let vb = _mm256_loadu_ps(b.as_ptr().add(off));
25 let diff = _mm256_sub_ps(va, vb);
26 sum = _mm256_fmadd_ps(diff, diff, sum);
27 }
28 let mut result = hsum256(sum);
29 for i in (chunks * 8)..n {
30 let d = a[i] - b[i];
31 result += d * d;
32 }
33 result
34 }
35}
36
37pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
38 assert_eq!(a.len(), b.len(), "avx2 cosine: length mismatch");
39 unsafe { cosine_impl(a, b) }
40}
41
42#[target_feature(enable = "avx2,fma")]
43unsafe fn cosine_impl(a: &[f32], b: &[f32]) -> f32 {
44 assert_eq!(a.len(), b.len(), "avx2 cosine_impl: length mismatch");
45 unsafe {
46 use std::arch::x86_64::*;
47 let n = a.len();
48 let mut vdot = _mm256_setzero_ps();
49 let mut vna = _mm256_setzero_ps();
50 let mut vnb = _mm256_setzero_ps();
51 let chunks = n / 8;
52 for i in 0..chunks {
53 let off = i * 8;
54 let va = _mm256_loadu_ps(a.as_ptr().add(off));
55 let vb = _mm256_loadu_ps(b.as_ptr().add(off));
56 vdot = _mm256_fmadd_ps(va, vb, vdot);
57 vna = _mm256_fmadd_ps(va, va, vna);
58 vnb = _mm256_fmadd_ps(vb, vb, vnb);
59 }
60 let mut dot = hsum256(vdot);
61 let mut na = hsum256(vna);
62 let mut nb = hsum256(vnb);
63 for i in (chunks * 8)..n {
64 dot += a[i] * b[i];
65 na += a[i] * a[i];
66 nb += b[i] * b[i];
67 }
68 let denom = (na * nb).sqrt();
69 if denom < f32::EPSILON {
70 1.0
71 } else {
72 (1.0 - dot / denom).max(0.0)
73 }
74 }
75}
76
77pub fn neg_inner_product(a: &[f32], b: &[f32]) -> f32 {
78 assert_eq!(a.len(), b.len(), "avx2 ip: length mismatch");
79 unsafe { ip_impl(a, b) }
80}
81
82#[target_feature(enable = "avx2,fma")]
83unsafe fn ip_impl(a: &[f32], b: &[f32]) -> f32 {
84 assert_eq!(a.len(), b.len(), "avx2 ip_impl: length mismatch");
85 unsafe {
86 use std::arch::x86_64::*;
87 let n = a.len();
88 let mut vdot = _mm256_setzero_ps();
89 let chunks = n / 8;
90 for i in 0..chunks {
91 let off = i * 8;
92 let va = _mm256_loadu_ps(a.as_ptr().add(off));
93 let vb = _mm256_loadu_ps(b.as_ptr().add(off));
94 vdot = _mm256_fmadd_ps(va, vb, vdot);
95 }
96 let mut dot = hsum256(vdot);
97 for i in (chunks * 8)..n {
98 dot += a[i] * b[i];
99 }
100 -dot
101 }
102}
103
104#[target_feature(enable = "avx2")]
106unsafe fn hsum256(v: std::arch::x86_64::__m256) -> f32 {
107 use std::arch::x86_64::*;
108 let hi = _mm256_extractf128_ps(v, 1);
109 let lo = _mm256_castps256_ps128(v);
110 let sum128 = _mm_add_ps(lo, hi);
111 let shuf = _mm_movehdup_ps(sum128);
112 let sums = _mm_add_ps(sum128, shuf);
113 let shuf2 = _mm_movehl_ps(sums, sums);
114 let sums2 = _mm_add_ss(sums, shuf2);
115 _mm_cvtss_f32(sums2)
116}