Skip to main content

trueno/backends/avx2/ops/
reductions.rs

1//! AVX2 reduction operations (dot, sum, max, min, argmax, argmin).
2
3#[cfg(target_arch = "x86_64")]
4use std::arch::x86_64::*;
5
6use crate::backends::VectorBackend;
7
8/// AVX2 dot product with 4-accumulator unrolling for ILP.
9#[inline]
10#[target_feature(enable = "avx2,fma")]
11// SAFETY: caller ensures preconditions are met for this unsafe function
12pub(crate) unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
13    unsafe {
14        let len = a.len();
15        let mut i = 0;
16
17        let mut acc0 = _mm256_setzero_ps();
18        let mut acc1 = _mm256_setzero_ps();
19        let mut acc2 = _mm256_setzero_ps();
20        let mut acc3 = _mm256_setzero_ps();
21
22        while i + 32 <= len {
23            let va0 = _mm256_loadu_ps(a.as_ptr().add(i));
24            let vb0 = _mm256_loadu_ps(b.as_ptr().add(i));
25            let va1 = _mm256_loadu_ps(a.as_ptr().add(i + 8));
26            let vb1 = _mm256_loadu_ps(b.as_ptr().add(i + 8));
27            let va2 = _mm256_loadu_ps(a.as_ptr().add(i + 16));
28            let vb2 = _mm256_loadu_ps(b.as_ptr().add(i + 16));
29            let va3 = _mm256_loadu_ps(a.as_ptr().add(i + 24));
30            let vb3 = _mm256_loadu_ps(b.as_ptr().add(i + 24));
31
32            acc0 = _mm256_fmadd_ps(va0, vb0, acc0);
33            acc1 = _mm256_fmadd_ps(va1, vb1, acc1);
34            acc2 = _mm256_fmadd_ps(va2, vb2, acc2);
35            acc3 = _mm256_fmadd_ps(va3, vb3, acc3);
36
37            i += 32;
38        }
39
40        while i + 8 <= len {
41            let va = _mm256_loadu_ps(a.as_ptr().add(i));
42            let vb = _mm256_loadu_ps(b.as_ptr().add(i));
43            acc0 = _mm256_fmadd_ps(va, vb, acc0);
44            i += 8;
45        }
46
47        let acc01 = _mm256_add_ps(acc0, acc1);
48        let acc23 = _mm256_add_ps(acc2, acc3);
49        let acc = _mm256_add_ps(acc01, acc23);
50
51        let mut result = {
52            let sum_halves = _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1));
53            let temp = _mm_add_ps(sum_halves, _mm_movehl_ps(sum_halves, sum_halves));
54            let temp = _mm_add_ss(temp, _mm_shuffle_ps(temp, temp, 1));
55            _mm_cvtss_f32(temp)
56        };
57
58        result += a[i..].iter().zip(&b[i..]).map(|(x, y)| x * y).sum::<f32>();
59        result
60    }
61}
62
63/// AVX2 vector sum.
64#[inline]
65#[target_feature(enable = "avx2")]
66// SAFETY: caller ensures preconditions are met for this unsafe function
67pub(crate) unsafe fn sum(a: &[f32]) -> f32 {
68    unsafe {
69        let len = a.len();
70        let mut i = 0;
71        let mut acc = _mm256_setzero_ps();
72
73        while i + 8 <= len {
74            let va = _mm256_loadu_ps(a.as_ptr().add(i));
75            acc = _mm256_add_ps(acc, va);
76            i += 8;
77        }
78
79        let mut result = {
80            let sum_halves = _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1));
81            let temp = _mm_add_ps(sum_halves, _mm_movehl_ps(sum_halves, sum_halves));
82            let temp = _mm_add_ss(temp, _mm_shuffle_ps(temp, temp, 1));
83            _mm_cvtss_f32(temp)
84        };
85
86        result += a[i..].iter().sum::<f32>();
87        result
88    }
89}
90
91/// AVX2 vector max.
92#[inline]
93#[target_feature(enable = "avx2")]
94// SAFETY: caller ensures preconditions are met for this unsafe function
95pub(crate) unsafe fn max(a: &[f32]) -> f32 {
96    unsafe {
97        let len = a.len();
98        let mut i = 0;
99        let mut vmax = _mm256_set1_ps(a[0]);
100
101        while i + 8 <= len {
102            let va = _mm256_loadu_ps(a.as_ptr().add(i));
103            vmax = _mm256_max_ps(vmax, va);
104            i += 8;
105        }
106
107        let mut result = {
108            let max_halves =
109                _mm_max_ps(_mm256_castps256_ps128(vmax), _mm256_extractf128_ps(vmax, 1));
110            let temp = _mm_max_ps(max_halves, _mm_movehl_ps(max_halves, max_halves));
111            let temp = _mm_max_ss(temp, _mm_shuffle_ps(temp, temp, 1));
112            _mm_cvtss_f32(temp)
113        };
114
115        for &val in &a[i..] {
116            if val > result {
117                result = val;
118            }
119        }
120        result
121    }
122}
123
124/// AVX2 vector min.
125#[inline]
126#[target_feature(enable = "avx2")]
127// SAFETY: caller ensures preconditions are met for this unsafe function
128pub(crate) unsafe fn min(a: &[f32]) -> f32 {
129    unsafe {
130        let len = a.len();
131        let mut i = 0;
132        let mut vmin = _mm256_set1_ps(a[0]);
133
134        while i + 8 <= len {
135            let va = _mm256_loadu_ps(a.as_ptr().add(i));
136            vmin = _mm256_min_ps(vmin, va);
137            i += 8;
138        }
139
140        let mut result = {
141            let min_halves =
142                _mm_min_ps(_mm256_castps256_ps128(vmin), _mm256_extractf128_ps(vmin, 1));
143            let temp = _mm_min_ps(min_halves, _mm_movehl_ps(min_halves, min_halves));
144            let temp = _mm_min_ss(temp, _mm_shuffle_ps(temp, temp, 1));
145            _mm_cvtss_f32(temp)
146        };
147
148        for &val in &a[i..] {
149            if val < result {
150                result = val;
151            }
152        }
153        result
154    }
155}
156
157/// AVX2 argmax.
158#[inline]
159#[target_feature(enable = "avx2")]
160// SAFETY: caller ensures preconditions are met for this unsafe function
161pub(crate) unsafe fn argmax(a: &[f32]) -> usize {
162    unsafe {
163        let len = a.len();
164        let mut max_idx: usize = 0;
165        let mut max_val = a[0];
166        let mut i = 0;
167
168        let mut vmax = _mm256_set1_ps(a[0]);
169        let mut vidx_max = _mm256_setzero_ps();
170        let vidx_inc = _mm256_set1_ps(8.0);
171        let mut vcurrent_idx = _mm256_set_ps(7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0);
172
173        while i + 8 <= len {
174            let va = _mm256_loadu_ps(a.as_ptr().add(i));
175            let mask = _mm256_cmp_ps(va, vmax, _CMP_GT_OQ);
176            vmax = _mm256_blendv_ps(vmax, va, mask);
177            vidx_max = _mm256_blendv_ps(vidx_max, vcurrent_idx, mask);
178            vcurrent_idx = _mm256_add_ps(vcurrent_idx, vidx_inc);
179            i += 8;
180        }
181
182        // Extract max from vector
183        let mut vals = [0.0f32; 8];
184        let mut idxs = [0.0f32; 8];
185        _mm256_storeu_ps(vals.as_mut_ptr(), vmax);
186        _mm256_storeu_ps(idxs.as_mut_ptr(), vidx_max);
187
188        for j in 0..8 {
189            if vals[j] > max_val {
190                max_val = vals[j];
191                max_idx = idxs[j] as usize;
192            }
193        }
194
195        // Check remaining elements
196        for (j, &val) in a[i..].iter().enumerate() {
197            if val > max_val {
198                max_val = val;
199                max_idx = i + j;
200            }
201        }
202
203        max_idx
204    }
205}
206
207/// AVX2 argmin.
208#[inline]
209#[target_feature(enable = "avx2")]
210// SAFETY: caller ensures preconditions are met for this unsafe function
211pub(crate) unsafe fn argmin(a: &[f32]) -> usize {
212    unsafe {
213        let len = a.len();
214        let mut min_idx: usize = 0;
215        let mut min_val = a[0];
216        let mut i = 0;
217
218        let mut vmin = _mm256_set1_ps(a[0]);
219        let mut vidx_min = _mm256_setzero_ps();
220        let vidx_inc = _mm256_set1_ps(8.0);
221        let mut vcurrent_idx = _mm256_set_ps(7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0);
222
223        while i + 8 <= len {
224            let va = _mm256_loadu_ps(a.as_ptr().add(i));
225            let mask = _mm256_cmp_ps(va, vmin, _CMP_LT_OQ);
226            vmin = _mm256_blendv_ps(vmin, va, mask);
227            vidx_min = _mm256_blendv_ps(vidx_min, vcurrent_idx, mask);
228            vcurrent_idx = _mm256_add_ps(vcurrent_idx, vidx_inc);
229            i += 8;
230        }
231
232        let mut vals = [0.0f32; 8];
233        let mut idxs = [0.0f32; 8];
234        _mm256_storeu_ps(vals.as_mut_ptr(), vmin);
235        _mm256_storeu_ps(idxs.as_mut_ptr(), vidx_min);
236
237        for j in 0..8 {
238            if vals[j] < min_val {
239                min_val = vals[j];
240                min_idx = idxs[j] as usize;
241            }
242        }
243
244        for (j, &val) in a[i..].iter().enumerate() {
245            if val < min_val {
246                min_val = val;
247                min_idx = i + j;
248            }
249        }
250
251        min_idx
252    }
253}
254
255/// Kahan sum for numerical stability (delegates to scalar).
256#[inline]
257// SAFETY: caller ensures preconditions are met for this unsafe function
258pub(crate) unsafe fn sum_kahan(a: &[f32]) -> f32 {
259    unsafe { crate::backends::scalar::ScalarBackend::sum_kahan(a) }
260}