trueno/backends/avx2/ops/
reductions.rs1#[cfg(target_arch = "x86_64")]
4use std::arch::x86_64::*;
5
6use crate::backends::VectorBackend;
7
8#[inline]
10#[target_feature(enable = "avx2,fma")]
11pub(crate) unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
13 unsafe {
14 let len = a.len();
15 let mut i = 0;
16
17 let mut acc0 = _mm256_setzero_ps();
18 let mut acc1 = _mm256_setzero_ps();
19 let mut acc2 = _mm256_setzero_ps();
20 let mut acc3 = _mm256_setzero_ps();
21
22 while i + 32 <= len {
23 let va0 = _mm256_loadu_ps(a.as_ptr().add(i));
24 let vb0 = _mm256_loadu_ps(b.as_ptr().add(i));
25 let va1 = _mm256_loadu_ps(a.as_ptr().add(i + 8));
26 let vb1 = _mm256_loadu_ps(b.as_ptr().add(i + 8));
27 let va2 = _mm256_loadu_ps(a.as_ptr().add(i + 16));
28 let vb2 = _mm256_loadu_ps(b.as_ptr().add(i + 16));
29 let va3 = _mm256_loadu_ps(a.as_ptr().add(i + 24));
30 let vb3 = _mm256_loadu_ps(b.as_ptr().add(i + 24));
31
32 acc0 = _mm256_fmadd_ps(va0, vb0, acc0);
33 acc1 = _mm256_fmadd_ps(va1, vb1, acc1);
34 acc2 = _mm256_fmadd_ps(va2, vb2, acc2);
35 acc3 = _mm256_fmadd_ps(va3, vb3, acc3);
36
37 i += 32;
38 }
39
40 while i + 8 <= len {
41 let va = _mm256_loadu_ps(a.as_ptr().add(i));
42 let vb = _mm256_loadu_ps(b.as_ptr().add(i));
43 acc0 = _mm256_fmadd_ps(va, vb, acc0);
44 i += 8;
45 }
46
47 let acc01 = _mm256_add_ps(acc0, acc1);
48 let acc23 = _mm256_add_ps(acc2, acc3);
49 let acc = _mm256_add_ps(acc01, acc23);
50
51 let mut result = {
52 let sum_halves = _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1));
53 let temp = _mm_add_ps(sum_halves, _mm_movehl_ps(sum_halves, sum_halves));
54 let temp = _mm_add_ss(temp, _mm_shuffle_ps(temp, temp, 1));
55 _mm_cvtss_f32(temp)
56 };
57
58 result += a[i..].iter().zip(&b[i..]).map(|(x, y)| x * y).sum::<f32>();
59 result
60 }
61}
62
63#[inline]
65#[target_feature(enable = "avx2")]
66pub(crate) unsafe fn sum(a: &[f32]) -> f32 {
68 unsafe {
69 let len = a.len();
70 let mut i = 0;
71 let mut acc = _mm256_setzero_ps();
72
73 while i + 8 <= len {
74 let va = _mm256_loadu_ps(a.as_ptr().add(i));
75 acc = _mm256_add_ps(acc, va);
76 i += 8;
77 }
78
79 let mut result = {
80 let sum_halves = _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1));
81 let temp = _mm_add_ps(sum_halves, _mm_movehl_ps(sum_halves, sum_halves));
82 let temp = _mm_add_ss(temp, _mm_shuffle_ps(temp, temp, 1));
83 _mm_cvtss_f32(temp)
84 };
85
86 result += a[i..].iter().sum::<f32>();
87 result
88 }
89}
90
91#[inline]
93#[target_feature(enable = "avx2")]
94pub(crate) unsafe fn max(a: &[f32]) -> f32 {
96 unsafe {
97 let len = a.len();
98 let mut i = 0;
99 let mut vmax = _mm256_set1_ps(a[0]);
100
101 while i + 8 <= len {
102 let va = _mm256_loadu_ps(a.as_ptr().add(i));
103 vmax = _mm256_max_ps(vmax, va);
104 i += 8;
105 }
106
107 let mut result = {
108 let max_halves =
109 _mm_max_ps(_mm256_castps256_ps128(vmax), _mm256_extractf128_ps(vmax, 1));
110 let temp = _mm_max_ps(max_halves, _mm_movehl_ps(max_halves, max_halves));
111 let temp = _mm_max_ss(temp, _mm_shuffle_ps(temp, temp, 1));
112 _mm_cvtss_f32(temp)
113 };
114
115 for &val in &a[i..] {
116 if val > result {
117 result = val;
118 }
119 }
120 result
121 }
122}
123
124#[inline]
126#[target_feature(enable = "avx2")]
127pub(crate) unsafe fn min(a: &[f32]) -> f32 {
129 unsafe {
130 let len = a.len();
131 let mut i = 0;
132 let mut vmin = _mm256_set1_ps(a[0]);
133
134 while i + 8 <= len {
135 let va = _mm256_loadu_ps(a.as_ptr().add(i));
136 vmin = _mm256_min_ps(vmin, va);
137 i += 8;
138 }
139
140 let mut result = {
141 let min_halves =
142 _mm_min_ps(_mm256_castps256_ps128(vmin), _mm256_extractf128_ps(vmin, 1));
143 let temp = _mm_min_ps(min_halves, _mm_movehl_ps(min_halves, min_halves));
144 let temp = _mm_min_ss(temp, _mm_shuffle_ps(temp, temp, 1));
145 _mm_cvtss_f32(temp)
146 };
147
148 for &val in &a[i..] {
149 if val < result {
150 result = val;
151 }
152 }
153 result
154 }
155}
156
157#[inline]
159#[target_feature(enable = "avx2")]
160pub(crate) unsafe fn argmax(a: &[f32]) -> usize {
162 unsafe {
163 let len = a.len();
164 let mut max_idx: usize = 0;
165 let mut max_val = a[0];
166 let mut i = 0;
167
168 let mut vmax = _mm256_set1_ps(a[0]);
169 let mut vidx_max = _mm256_setzero_ps();
170 let vidx_inc = _mm256_set1_ps(8.0);
171 let mut vcurrent_idx = _mm256_set_ps(7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0);
172
173 while i + 8 <= len {
174 let va = _mm256_loadu_ps(a.as_ptr().add(i));
175 let mask = _mm256_cmp_ps(va, vmax, _CMP_GT_OQ);
176 vmax = _mm256_blendv_ps(vmax, va, mask);
177 vidx_max = _mm256_blendv_ps(vidx_max, vcurrent_idx, mask);
178 vcurrent_idx = _mm256_add_ps(vcurrent_idx, vidx_inc);
179 i += 8;
180 }
181
182 let mut vals = [0.0f32; 8];
184 let mut idxs = [0.0f32; 8];
185 _mm256_storeu_ps(vals.as_mut_ptr(), vmax);
186 _mm256_storeu_ps(idxs.as_mut_ptr(), vidx_max);
187
188 for j in 0..8 {
189 if vals[j] > max_val {
190 max_val = vals[j];
191 max_idx = idxs[j] as usize;
192 }
193 }
194
195 for (j, &val) in a[i..].iter().enumerate() {
197 if val > max_val {
198 max_val = val;
199 max_idx = i + j;
200 }
201 }
202
203 max_idx
204 }
205}
206
207#[inline]
209#[target_feature(enable = "avx2")]
210pub(crate) unsafe fn argmin(a: &[f32]) -> usize {
212 unsafe {
213 let len = a.len();
214 let mut min_idx: usize = 0;
215 let mut min_val = a[0];
216 let mut i = 0;
217
218 let mut vmin = _mm256_set1_ps(a[0]);
219 let mut vidx_min = _mm256_setzero_ps();
220 let vidx_inc = _mm256_set1_ps(8.0);
221 let mut vcurrent_idx = _mm256_set_ps(7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0);
222
223 while i + 8 <= len {
224 let va = _mm256_loadu_ps(a.as_ptr().add(i));
225 let mask = _mm256_cmp_ps(va, vmin, _CMP_LT_OQ);
226 vmin = _mm256_blendv_ps(vmin, va, mask);
227 vidx_min = _mm256_blendv_ps(vidx_min, vcurrent_idx, mask);
228 vcurrent_idx = _mm256_add_ps(vcurrent_idx, vidx_inc);
229 i += 8;
230 }
231
232 let mut vals = [0.0f32; 8];
233 let mut idxs = [0.0f32; 8];
234 _mm256_storeu_ps(vals.as_mut_ptr(), vmin);
235 _mm256_storeu_ps(idxs.as_mut_ptr(), vidx_min);
236
237 for j in 0..8 {
238 if vals[j] < min_val {
239 min_val = vals[j];
240 min_idx = idxs[j] as usize;
241 }
242 }
243
244 for (j, &val) in a[i..].iter().enumerate() {
245 if val < min_val {
246 min_val = val;
247 min_idx = i + j;
248 }
249 }
250
251 min_idx
252 }
253}
254
255#[inline]
257pub(crate) unsafe fn sum_kahan(a: &[f32]) -> f32 {
259 unsafe { crate::backends::scalar::ScalarBackend::sum_kahan(a) }
260}