1#[inline]
2pub fn dot(a: &[f32], b: &[f32]) -> f32 {
3 debug_assert_eq!(a.len(), b.len());
4 #[cfg(target_arch = "x86_64")]
5 {
6 if x86_avx2_fma::get() {
7 unsafe {
8 return dot_avx2_fma(a, b);
9 };
10 } else if x86_avx2::get() {
11 unsafe {
12 return dot_avx2(a, b);
13 };
14 } else if x86_sse2::get() {
15 unsafe {
16 return dot_sse2(a, b);
17 };
18 }
19 return dot_scalar(a, b);
20 }
21 #[cfg(target_arch = "aarch64")]
22 {
23 unsafe {
24 return dot_neon(a, b);
25 };
26 }
27 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
28 {
29 return dot_scalar(a, b);
30 }
31}
32
33#[inline]
34#[allow(dead_code)]
35fn dot_scalar(a: &[f32], b: &[f32]) -> f32 {
36 a.iter().zip(b).map(|(x, y)| x * y).sum()
37}
38
39#[cfg(target_arch = "x86_64")]
41cpufeatures::new!(x86_avx2_fma, "avx2", "fma");
42#[cfg(target_arch = "x86_64")]
43cpufeatures::new!(x86_avx2, "avx2");
44#[cfg(target_arch = "x86_64")]
45cpufeatures::new!(x86_sse2, "sse2");
46
47#[cfg(target_arch = "x86_64")]
48#[target_feature(enable = "avx2")]
49unsafe fn dot_avx2(a: &[f32], b: &[f32]) -> f32 {
50 use core::arch::x86_64::*;
51 let mut i = 0usize;
52 let mut acc = _mm256_setzero_ps();
53 while i + 8 <= a.len() {
54 let va = _mm256_loadu_ps(a.as_ptr().add(i));
55 let vb = _mm256_loadu_ps(b.as_ptr().add(i));
56 let prod = _mm256_mul_ps(va, vb);
57 acc = _mm256_add_ps(acc, prod);
58 i += 8;
59 }
60 let mut tmp = [0f32; 8];
61 _mm256_storeu_ps(tmp.as_mut_ptr(), acc);
62 let mut sum: f32 = tmp.iter().sum();
63 while i < a.len() {
64 sum += *a.get_unchecked(i) * *b.get_unchecked(i);
65 i += 1;
66 }
67 sum
68}
69
70#[cfg(target_arch = "x86_64")]
71#[target_feature(enable = "avx2,fma")]
72unsafe fn dot_avx2_fma(a: &[f32], b: &[f32]) -> f32 {
73 use core::arch::x86_64::*;
74 let mut i = 0usize;
75 let mut acc = _mm256_setzero_ps();
76 while i + 8 <= a.len() {
77 let va = _mm256_loadu_ps(a.as_ptr().add(i));
78 let vb = _mm256_loadu_ps(b.as_ptr().add(i));
79 acc = _mm256_fmadd_ps(va, vb, acc);
80 i += 8;
81 }
82 let mut tmp = [0f32; 8];
83 _mm256_storeu_ps(tmp.as_mut_ptr(), acc);
84 let mut sum: f32 = tmp.iter().sum();
85 while i < a.len() {
86 sum += *a.get_unchecked(i) * *b.get_unchecked(i);
87 i += 1;
88 }
89 sum
90}
91
92#[cfg(target_arch = "x86_64")]
93#[target_feature(enable = "sse2")]
94unsafe fn dot_sse2(a: &[f32], b: &[f32]) -> f32 {
95 use core::arch::x86_64::*;
96 let mut i = 0usize;
97 let mut acc = _mm_setzero_ps();
98 while i + 4 <= a.len() {
99 let va = _mm_loadu_ps(a.as_ptr().add(i));
100 let vb = _mm_loadu_ps(b.as_ptr().add(i));
101 let prod = _mm_mul_ps(va, vb);
102 acc = _mm_add_ps(acc, prod);
103 i += 4;
104 }
105 let mut tmp = [0f32; 4];
106 _mm_storeu_ps(tmp.as_mut_ptr(), acc);
107 let mut sum: f32 = tmp.iter().sum();
108 while i < a.len() {
109 sum += *a.get_unchecked(i) * *b.get_unchecked(i);
110 i += 1;
111 }
112 sum
113}
114
115#[cfg(target_arch = "aarch64")]
116#[target_feature(enable = "neon")]
117unsafe fn dot_neon(a: &[f32], b: &[f32]) -> f32 {
118 use core::arch::aarch64::*;
119 let mut i = 0usize;
120 let mut acc = vdupq_n_f32(0.0);
121 while i + 4 <= a.len() {
122 let va = vld1q_f32(a.as_ptr().add(i));
123 let vb = vld1q_f32(b.as_ptr().add(i));
124 acc = vfmaq_f32(acc, va, vb); i += 4;
126 }
127 let mut sum: f32 = vaddvq_f32(acc);
128 while i < a.len() {
129 sum += *a.get_unchecked(i) * *b.get_unchecked(i);
130 i += 1;
131 }
132 sum
133}