Skip to main content

trueno/backends/avx2/
mod.rs

1//! AVX2 backend implementation (x86_64 advanced SIMD)
2//!
3//! This backend uses AVX2 intrinsics for 256-bit SIMD operations with FMA.
4//! AVX2 is available on Intel Haswell (2013+) and AMD Excavator (2015+) CPUs.
5
6#[cfg(target_arch = "x86_64")]
7use std::arch::x86_64::*;
8
9use super::VectorBackend;
10
11mod ops;
12
13/// AVX2 backend (256-bit SIMD for x86_64)
14pub struct Avx2Backend;
15
16impl VectorBackend for Avx2Backend {
17    #[inline]
18    #[target_feature(enable = "avx2")]
19    // SAFETY: caller ensures preconditions are met for this unsafe function
20    unsafe fn add(a: &[f32], b: &[f32], result: &mut [f32]) {
21        unsafe { ops::arithmetic::add(a, b, result) }
22    }
23
24    #[inline]
25    #[target_feature(enable = "avx2")]
26    // SAFETY: caller ensures preconditions are met for this unsafe function
27    unsafe fn sub(a: &[f32], b: &[f32], result: &mut [f32]) {
28        unsafe { ops::arithmetic::sub(a, b, result) }
29    }
30
31    #[inline]
32    #[target_feature(enable = "avx2")]
33    // SAFETY: caller ensures preconditions are met for this unsafe function
34    unsafe fn mul(a: &[f32], b: &[f32], result: &mut [f32]) {
35        unsafe { ops::arithmetic::mul(a, b, result) }
36    }
37
38    #[inline]
39    #[target_feature(enable = "avx2")]
40    // SAFETY: caller ensures preconditions are met for this unsafe function
41    unsafe fn div(a: &[f32], b: &[f32], result: &mut [f32]) {
42        unsafe { ops::arithmetic::div(a, b, result) }
43    }
44
45    #[inline]
46    #[target_feature(enable = "avx2,fma")]
47    // SAFETY: caller ensures preconditions are met for this unsafe function
48    unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
49        unsafe { ops::reductions::dot(a, b) }
50    }
51
52    #[inline]
53    #[target_feature(enable = "avx2")]
54    // SAFETY: caller ensures preconditions are met for this unsafe function
55    unsafe fn sum(a: &[f32]) -> f32 {
56        unsafe { ops::reductions::sum(a) }
57    }
58
59    #[inline]
60    #[target_feature(enable = "avx2")]
61    // SAFETY: caller ensures preconditions are met for this unsafe function
62    unsafe fn max(a: &[f32]) -> f32 {
63        unsafe { ops::reductions::max(a) }
64    }
65
66    #[inline]
67    #[target_feature(enable = "avx2")]
68    // SAFETY: caller ensures preconditions are met for this unsafe function
69    unsafe fn min(a: &[f32]) -> f32 {
70        unsafe { ops::reductions::min(a) }
71    }
72
73    #[inline]
74    #[target_feature(enable = "avx2")]
75    // SAFETY: caller ensures preconditions are met for this unsafe function
76    unsafe fn argmax(a: &[f32]) -> usize {
77        unsafe { ops::reductions::argmax(a) }
78    }
79
80    #[inline]
81    #[target_feature(enable = "avx2")]
82    // SAFETY: caller ensures preconditions are met for this unsafe function
83    unsafe fn argmin(a: &[f32]) -> usize {
84        unsafe { ops::reductions::argmin(a) }
85    }
86
87    #[inline]
88    // SAFETY: caller ensures preconditions are met for this unsafe function
89    unsafe fn sum_kahan(a: &[f32]) -> f32 {
90        unsafe { ops::reductions::sum_kahan(a) }
91    }
92
93    #[inline]
94    #[target_feature(enable = "avx2,fma")]
95    // SAFETY: caller ensures preconditions are met for this unsafe function
96    unsafe fn norm_l2(a: &[f32]) -> f32 {
97        unsafe {
98            if a.is_empty() {
99                return 0.0;
100            }
101            Self::dot(a, a).sqrt()
102        }
103    }
104
105    #[inline]
106    #[target_feature(enable = "avx2")]
107    // SAFETY: caller ensures preconditions are met for this unsafe function
108    unsafe fn norm_l1(a: &[f32]) -> f32 {
109        unsafe {
110            if a.is_empty() {
111                return 0.0;
112            }
113            let len = a.len();
114            let mut i = 0;
115            let mut acc = _mm256_setzero_ps();
116            let sign_mask = _mm256_set1_ps(f32::from_bits(0x7FFF_FFFF));
117
118            while i + 8 <= len {
119                let va = _mm256_loadu_ps(a.as_ptr().add(i));
120                let abs_va = _mm256_and_ps(va, sign_mask);
121                acc = _mm256_add_ps(acc, abs_va);
122                i += 8;
123            }
124
125            let mut result = {
126                let sum_halves =
127                    _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1));
128                let temp = _mm_add_ps(sum_halves, _mm_movehl_ps(sum_halves, sum_halves));
129                let temp = _mm_add_ss(temp, _mm_shuffle_ps(temp, temp, 1));
130                _mm_cvtss_f32(temp)
131            };
132
133            for &val in &a[i..] {
134                result += val.abs();
135            }
136            result
137        }
138    }
139
140    #[inline]
141    #[target_feature(enable = "avx2")]
142    // SAFETY: caller ensures preconditions are met for this unsafe function
143    unsafe fn norm_linf(a: &[f32]) -> f32 {
144        unsafe {
145            if a.is_empty() {
146                return 0.0;
147            }
148            let len = a.len();
149            let mut i = 0;
150            let mut max_vec = _mm256_setzero_ps();
151            let sign_mask = _mm256_set1_ps(f32::from_bits(0x7FFF_FFFF));
152
153            while i + 8 <= len {
154                let va = _mm256_loadu_ps(a.as_ptr().add(i));
155                let abs_va = _mm256_and_ps(va, sign_mask);
156                max_vec = _mm256_max_ps(max_vec, abs_va);
157                i += 8;
158            }
159
160            let mut result = {
161                let max_halves =
162                    _mm_max_ps(_mm256_castps256_ps128(max_vec), _mm256_extractf128_ps(max_vec, 1));
163                let temp = _mm_max_ps(max_halves, _mm_movehl_ps(max_halves, max_halves));
164                let temp = _mm_max_ss(temp, _mm_shuffle_ps(temp, temp, 1));
165                _mm_cvtss_f32(temp)
166            };
167
168            for &val in &a[i..] {
169                let abs_val = val.abs();
170                if abs_val > result {
171                    result = abs_val;
172                }
173            }
174            result
175        }
176    }
177
178    #[inline]
179    #[target_feature(enable = "avx2")]
180    // SAFETY: caller ensures preconditions are met for this unsafe function
181    unsafe fn scale(a: &[f32], scalar: f32, result: &mut [f32]) {
182        unsafe {
183            let len = a.len();
184            let mut i = 0;
185            let scalar_vec = _mm256_set1_ps(scalar);
186
187            while i + 8 <= len {
188                let va = _mm256_loadu_ps(a.as_ptr().add(i));
189                let vresult = _mm256_mul_ps(va, scalar_vec);
190                _mm256_storeu_ps(result.as_mut_ptr().add(i), vresult);
191                i += 8;
192            }
193
194            while i < len {
195                result[i] = a[i] * scalar;
196                i += 1;
197            }
198        }
199    }
200
201    #[inline]
202    #[target_feature(enable = "avx2")]
203    // SAFETY: caller ensures preconditions are met for this unsafe function
204    unsafe fn abs(a: &[f32], result: &mut [f32]) {
205        unsafe {
206            let len = a.len();
207            let mut i = 0;
208            let sign_mask = _mm256_set1_ps(f32::from_bits(0x7FFF_FFFF));
209
210            while i + 8 <= len {
211                let va = _mm256_loadu_ps(a.as_ptr().add(i));
212                let vresult = _mm256_and_ps(va, sign_mask);
213                _mm256_storeu_ps(result.as_mut_ptr().add(i), vresult);
214                i += 8;
215            }
216
217            for j in i..len {
218                result[j] = a[j].abs();
219            }
220        }
221    }
222
223    #[inline]
224    #[target_feature(enable = "avx2")]
225    // SAFETY: caller ensures preconditions are met for this unsafe function
226    unsafe fn clamp(a: &[f32], min_val: f32, max_val: f32, result: &mut [f32]) {
227        unsafe {
228            let len = a.len();
229            let mut i = 0;
230            let vmin = _mm256_set1_ps(min_val);
231            let vmax = _mm256_set1_ps(max_val);
232
233            while i + 8 <= len {
234                let va = _mm256_loadu_ps(a.as_ptr().add(i));
235                let vresult = _mm256_min_ps(_mm256_max_ps(va, vmin), vmax);
236                _mm256_storeu_ps(result.as_mut_ptr().add(i), vresult);
237                i += 8;
238            }
239
240            for j in i..len {
241                result[j] = a[j].clamp(min_val, max_val);
242            }
243        }
244    }
245
246    #[inline]
247    #[target_feature(enable = "avx2,fma")]
248    // SAFETY: caller ensures preconditions are met for this unsafe function
249    unsafe fn lerp(a: &[f32], b: &[f32], t: f32, result: &mut [f32]) {
250        unsafe {
251            let len = a.len();
252            let mut i = 0;
253            let vt = _mm256_set1_ps(t);
254            let v1_minus_t = _mm256_set1_ps(1.0 - t);
255
256            while i + 8 <= len {
257                let va = _mm256_loadu_ps(a.as_ptr().add(i));
258                let vb = _mm256_loadu_ps(b.as_ptr().add(i));
259                let vresult = _mm256_fmadd_ps(vb, vt, _mm256_mul_ps(va, v1_minus_t));
260                _mm256_storeu_ps(result.as_mut_ptr().add(i), vresult);
261                i += 8;
262            }
263
264            for j in i..len {
265                result[j] = a[j] * (1.0 - t) + b[j] * t;
266            }
267        }
268    }
269
270    #[inline]
271    #[target_feature(enable = "avx2,fma")]
272    // SAFETY: caller ensures preconditions are met for this unsafe function
273    unsafe fn fma(a: &[f32], b: &[f32], c: &[f32], result: &mut [f32]) {
274        unsafe {
275            let len = a.len();
276            let mut i = 0;
277
278            while i + 8 <= len {
279                let va = _mm256_loadu_ps(a.as_ptr().add(i));
280                let vb = _mm256_loadu_ps(b.as_ptr().add(i));
281                let vc = _mm256_loadu_ps(c.as_ptr().add(i));
282                let vresult = _mm256_fmadd_ps(va, vb, vc);
283                _mm256_storeu_ps(result.as_mut_ptr().add(i), vresult);
284                i += 8;
285            }
286
287            for j in i..len {
288                result[j] = a[j] * b[j] + c[j];
289            }
290        }
291    }
292
293    #[inline]
294    #[target_feature(enable = "avx2")]
295    // SAFETY: caller ensures preconditions are met for this unsafe function
296    unsafe fn relu(a: &[f32], result: &mut [f32]) {
297        unsafe {
298            let len = a.len();
299            let mut i = 0;
300            let vzero = _mm256_setzero_ps();
301
302            while i + 8 <= len {
303                let va = _mm256_loadu_ps(a.as_ptr().add(i));
304                let vresult = _mm256_max_ps(va, vzero);
305                _mm256_storeu_ps(result.as_mut_ptr().add(i), vresult);
306                i += 8;
307            }
308
309            for j in i..len {
310                result[j] = a[j].max(0.0);
311            }
312        }
313    }
314
315    // Delegate transcendental functions to scalar backend
316    #[inline]
317    // SAFETY: caller ensures preconditions are met for this unsafe function
318    unsafe fn exp(a: &[f32], result: &mut [f32]) {
319        unsafe { super::scalar::ScalarBackend::exp(a, result) }
320    }
321
322    #[inline]
323    // SAFETY: caller ensures preconditions are met for this unsafe function
324    unsafe fn sigmoid(a: &[f32], result: &mut [f32]) {
325        unsafe { super::scalar::ScalarBackend::sigmoid(a, result) }
326    }
327
328    #[inline]
329    // SAFETY: caller ensures preconditions are met for this unsafe function
330    unsafe fn gelu(a: &[f32], result: &mut [f32]) {
331        unsafe { super::scalar::ScalarBackend::gelu(a, result) }
332    }
333
334    #[inline]
335    // SAFETY: caller ensures preconditions are met for this unsafe function
336    unsafe fn swish(a: &[f32], result: &mut [f32]) {
337        unsafe { super::scalar::ScalarBackend::swish(a, result) }
338    }
339
340    #[inline]
341    // SAFETY: caller ensures preconditions are met for this unsafe function
342    unsafe fn tanh(a: &[f32], result: &mut [f32]) {
343        unsafe { super::scalar::ScalarBackend::tanh(a, result) }
344    }
345
346    #[inline]
347    #[target_feature(enable = "avx2")]
348    // SAFETY: caller ensures preconditions are met for this unsafe function
349    unsafe fn sqrt(a: &[f32], result: &mut [f32]) {
350        unsafe {
351            let len = a.len();
352            let mut i = 0;
353
354            while i + 8 <= len {
355                let va = _mm256_loadu_ps(a.as_ptr().add(i));
356                let vresult = _mm256_sqrt_ps(va);
357                _mm256_storeu_ps(result.as_mut_ptr().add(i), vresult);
358                i += 8;
359            }
360
361            for j in i..len {
362                result[j] = a[j].sqrt();
363            }
364        }
365    }
366
367    #[inline]
368    #[target_feature(enable = "avx2")]
369    // SAFETY: caller ensures preconditions are met for this unsafe function
370    unsafe fn recip(a: &[f32], result: &mut [f32]) {
371        unsafe {
372            let len = a.len();
373            let mut i = 0;
374            let vone = _mm256_set1_ps(1.0);
375
376            while i + 8 <= len {
377                let va = _mm256_loadu_ps(a.as_ptr().add(i));
378                let vresult = _mm256_div_ps(vone, va);
379                _mm256_storeu_ps(result.as_mut_ptr().add(i), vresult);
380                i += 8;
381            }
382
383            for j in i..len {
384                result[j] = 1.0 / a[j];
385            }
386        }
387    }
388
389    // Delegate log functions to scalar backend
390    #[inline]
391    // SAFETY: caller ensures preconditions are met for this unsafe function
392    unsafe fn ln(a: &[f32], result: &mut [f32]) {
393        unsafe { super::scalar::ScalarBackend::ln(a, result) }
394    }
395
396    #[inline]
397    // SAFETY: caller ensures preconditions are met for this unsafe function
398    unsafe fn log2(a: &[f32], result: &mut [f32]) {
399        unsafe { super::scalar::ScalarBackend::log2(a, result) }
400    }
401
402    #[inline]
403    // SAFETY: caller ensures preconditions are met for this unsafe function
404    unsafe fn log10(a: &[f32], result: &mut [f32]) {
405        unsafe { super::scalar::ScalarBackend::log10(a, result) }
406    }
407
408    // Delegate trig functions to scalar backend
409    #[inline]
410    // SAFETY: caller ensures preconditions are met for this unsafe function
411    unsafe fn sin(a: &[f32], result: &mut [f32]) {
412        unsafe { super::scalar::ScalarBackend::sin(a, result) }
413    }
414
415    #[inline]
416    // SAFETY: caller ensures preconditions are met for this unsafe function
417    unsafe fn cos(a: &[f32], result: &mut [f32]) {
418        unsafe { super::scalar::ScalarBackend::cos(a, result) }
419    }
420
421    #[inline]
422    // SAFETY: caller ensures preconditions are met for this unsafe function
423    unsafe fn tan(a: &[f32], result: &mut [f32]) {
424        unsafe { super::scalar::ScalarBackend::tan(a, result) }
425    }
426
427    #[inline]
428    #[target_feature(enable = "avx2")]
429    // SAFETY: caller ensures preconditions are met for this unsafe function
430    unsafe fn floor(a: &[f32], result: &mut [f32]) {
431        unsafe {
432            let len = a.len();
433            let mut i = 0;
434
435            while i + 8 <= len {
436                let va = _mm256_loadu_ps(a.as_ptr().add(i));
437                let vresult = _mm256_floor_ps(va);
438                _mm256_storeu_ps(result.as_mut_ptr().add(i), vresult);
439                i += 8;
440            }
441
442            for j in i..len {
443                result[j] = a[j].floor();
444            }
445        }
446    }
447
448    #[inline]
449    #[target_feature(enable = "avx2")]
450    // SAFETY: caller ensures preconditions are met for this unsafe function
451    unsafe fn ceil(a: &[f32], result: &mut [f32]) {
452        unsafe {
453            let len = a.len();
454            let mut i = 0;
455
456            while i + 8 <= len {
457                let va = _mm256_loadu_ps(a.as_ptr().add(i));
458                let vresult = _mm256_ceil_ps(va);
459                _mm256_storeu_ps(result.as_mut_ptr().add(i), vresult);
460                i += 8;
461            }
462
463            for j in i..len {
464                result[j] = a[j].ceil();
465            }
466        }
467    }
468
469    #[inline]
470    #[target_feature(enable = "avx2")]
471    // SAFETY: caller ensures preconditions are met for this unsafe function
472    unsafe fn round(a: &[f32], result: &mut [f32]) {
473        unsafe {
474            let len = a.len();
475            let mut i = 0;
476
477            // Round ties away from zero to match Rust's f32::round()
478            let half = _mm256_set1_ps(0.5);
479            let sign_mask = _mm256_set1_ps(f32::from_bits(0x8000_0000));
480            let abs_mask = _mm256_set1_ps(f32::from_bits(0x7FFF_FFFF));
481
482            while i + 8 <= len {
483                let va = _mm256_loadu_ps(a.as_ptr().add(i));
484                let sign = _mm256_and_ps(va, sign_mask);
485                let abs_val = _mm256_and_ps(va, abs_mask);
486                let shifted = _mm256_add_ps(abs_val, half);
487                let rounded_abs = _mm256_floor_ps(shifted);
488                let vresult = _mm256_or_ps(rounded_abs, sign);
489                _mm256_storeu_ps(result.as_mut_ptr().add(i), vresult);
490                i += 8;
491            }
492
493            for j in i..len {
494                result[j] = a[j].round();
495            }
496        }
497    }
498}