Skip to main content

trueno/backends/avx512/
mod.rs

1//! AVX-512 backend implementation (x86_64 advanced SIMD)
2//!
3//! This backend uses AVX-512 intrinsics for 512-bit SIMD operations.
4//! AVX-512 is available on Intel Skylake-X/Sapphire Rapids (2017+) and AMD Zen 4 (2022+) CPUs.
5//!
6//! # Performance
7//!
8//! Expected speedup: 16x for operations on f32 vectors (16 elements per register)
9//! This provides 2x improvement over AVX2 (8 elements) and ~16x over scalar.
10//!
11//! # Safety
12//!
13//! All AVX-512 intrinsics are marked `unsafe` by Rust. This module carefully isolates
14//! all unsafe code and verifies correctness through comprehensive testing.
15
16mod ops;
17
18#[cfg(target_arch = "x86_64")]
19use std::arch::x86_64::*;
20
21use super::VectorBackend;
22
23/// AVX-512 backend (512-bit SIMD for x86_64)
24pub struct Avx512Backend;
25
26impl VectorBackend for Avx512Backend {
27    #[inline]
28    #[target_feature(enable = "avx512f")]
29    // SAFETY: caller ensures preconditions are met for this unsafe function
30    unsafe fn add(a: &[f32], b: &[f32], result: &mut [f32]) {
31        unsafe {
32            ops::arithmetic::add(a, b, result);
33        }
34    }
35
36    #[inline]
37    #[target_feature(enable = "avx512f")]
38    // SAFETY: caller ensures preconditions are met for this unsafe function
39    unsafe fn sub(a: &[f32], b: &[f32], result: &mut [f32]) {
40        unsafe {
41            ops::arithmetic::sub(a, b, result);
42        }
43    }
44
45    #[inline]
46    #[target_feature(enable = "avx512f")]
47    // SAFETY: caller ensures preconditions are met for this unsafe function
48    unsafe fn mul(a: &[f32], b: &[f32], result: &mut [f32]) {
49        unsafe {
50            ops::arithmetic::mul(a, b, result);
51        }
52    }
53
54    #[inline]
55    #[target_feature(enable = "avx512f")]
56    // SAFETY: caller ensures preconditions are met for this unsafe function
57    unsafe fn div(a: &[f32], b: &[f32], result: &mut [f32]) {
58        unsafe {
59            ops::arithmetic::div(a, b, result);
60        }
61    }
62
63    #[inline]
64    #[target_feature(enable = "avx512f")]
65    // SAFETY: caller ensures preconditions are met for this unsafe function
66    unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
67        unsafe { ops::reductions::dot(a, b) }
68    }
69
70    #[inline]
71    #[target_feature(enable = "avx512f")]
72    // SAFETY: caller ensures preconditions are met for this unsafe function
73    unsafe fn sum(a: &[f32]) -> f32 {
74        unsafe { ops::reductions::sum(a) }
75    }
76
77    #[inline]
78    #[target_feature(enable = "avx512f")]
79    // SAFETY: caller ensures preconditions are met for this unsafe function
80    unsafe fn max(a: &[f32]) -> f32 {
81        unsafe { ops::reductions::max(a) }
82    }
83
84    #[inline]
85    #[target_feature(enable = "avx512f")]
86    // SAFETY: caller ensures preconditions are met for this unsafe function
87    unsafe fn min(a: &[f32]) -> f32 {
88        unsafe { ops::reductions::min(a) }
89    }
90
91    #[inline]
92    #[target_feature(enable = "avx512f")]
93    // SAFETY: caller ensures preconditions are met for this unsafe function
94    unsafe fn argmax(a: &[f32]) -> usize {
95        unsafe { ops::reductions::argmax(a) }
96    }
97
98    #[inline]
99    #[target_feature(enable = "avx512f")]
100    // SAFETY: caller ensures preconditions are met for this unsafe function
101    unsafe fn argmin(a: &[f32]) -> usize {
102        unsafe { ops::reductions::argmin(a) }
103    }
104
105    // SAFETY: caller ensures preconditions are met for this unsafe function
106    unsafe fn sum_kahan(a: &[f32]) -> f32 {
107        unsafe { ops::reductions::sum_kahan(a) }
108    }
109
110    #[inline]
111    #[target_feature(enable = "avx512f")]
112    // SAFETY: caller ensures preconditions are met for this unsafe function
113    unsafe fn norm_l2(a: &[f32]) -> f32 {
114        unsafe {
115            if a.is_empty() {
116                return 0.0;
117            }
118            let len = a.len();
119            let mut i = 0;
120            let mut acc = _mm512_setzero_ps();
121            while i + 16 <= len {
122                let va = _mm512_loadu_ps(a.as_ptr().add(i));
123                acc = _mm512_add_ps(acc, _mm512_mul_ps(va, va));
124                i += 16;
125            }
126            let mut sum_sq = _mm512_reduce_add_ps(acc);
127            for &val in &a[i..] {
128                sum_sq += val * val;
129            }
130            sum_sq.sqrt()
131        }
132    }
133
134    #[inline]
135    #[target_feature(enable = "avx512f")]
136    // SAFETY: caller ensures preconditions are met for this unsafe function
137    unsafe fn norm_l1(a: &[f32]) -> f32 {
138        unsafe {
139            let len = a.len();
140            let mut i = 0;
141            let sign_mask = _mm512_set1_ps(f32::from_bits(0x7FFF_FFFF));
142            let mut acc = _mm512_setzero_ps();
143            while i + 16 <= len {
144                acc = _mm512_add_ps(
145                    acc,
146                    _mm512_and_ps(_mm512_loadu_ps(a.as_ptr().add(i)), sign_mask),
147                );
148                i += 16;
149            }
150            let mut result = _mm512_reduce_add_ps(acc);
151            for &val in &a[i..] {
152                result += val.abs();
153            }
154            result
155        }
156    }
157
158    #[inline]
159    #[target_feature(enable = "avx512f")]
160    // SAFETY: caller ensures preconditions are met for this unsafe function
161    unsafe fn norm_linf(a: &[f32]) -> f32 {
162        unsafe {
163            let len = a.len();
164            let mut i = 0;
165            let sign_mask = _mm512_set1_ps(f32::from_bits(0x7FFF_FFFF));
166            let mut max_vec = _mm512_setzero_ps();
167            while i + 16 <= len {
168                max_vec = _mm512_max_ps(
169                    max_vec,
170                    _mm512_and_ps(_mm512_loadu_ps(a.as_ptr().add(i)), sign_mask),
171                );
172                i += 16;
173            }
174            let mut result = _mm512_reduce_max_ps(max_vec);
175            for &val in &a[i..] {
176                let abs_val = val.abs();
177                if abs_val > result {
178                    result = abs_val;
179                }
180            }
181            result
182        }
183    }
184
185    #[inline]
186    #[target_feature(enable = "avx512f")]
187    // SAFETY: caller ensures preconditions are met for this unsafe function
188    unsafe fn scale(a: &[f32], scalar: f32, result: &mut [f32]) {
189        unsafe {
190            let len = a.len();
191            let mut i = 0;
192            let scalar_vec = _mm512_set1_ps(scalar);
193            while i + 16 <= len {
194                _mm512_storeu_ps(
195                    result.as_mut_ptr().add(i),
196                    _mm512_mul_ps(_mm512_loadu_ps(a.as_ptr().add(i)), scalar_vec),
197                );
198                i += 16;
199            }
200            for j in i..len {
201                result[j] = a[j] * scalar;
202            }
203        }
204    }
205
206    #[inline]
207    #[target_feature(enable = "avx512f")]
208    // SAFETY: caller ensures preconditions are met for this unsafe function
209    unsafe fn abs(a: &[f32], result: &mut [f32]) {
210        unsafe {
211            let len = a.len();
212            let mut i = 0;
213            let sign_mask = _mm512_set1_ps(f32::from_bits(0x7FFF_FFFF));
214            while i + 16 <= len {
215                _mm512_storeu_ps(
216                    result.as_mut_ptr().add(i),
217                    _mm512_and_ps(_mm512_loadu_ps(a.as_ptr().add(i)), sign_mask),
218                );
219                i += 16;
220            }
221            for j in i..len {
222                result[j] = a[j].abs();
223            }
224        }
225    }
226
227    #[inline]
228    #[target_feature(enable = "avx512f")]
229    // SAFETY: caller ensures preconditions are met for this unsafe function
230    unsafe fn clamp(a: &[f32], min_val: f32, max_val: f32, result: &mut [f32]) {
231        unsafe {
232            let len = a.len();
233            let mut i = 0;
234            let min_vec = _mm512_set1_ps(min_val);
235            let max_vec = _mm512_set1_ps(max_val);
236            while i + 16 <= len {
237                let va = _mm512_loadu_ps(a.as_ptr().add(i));
238                _mm512_storeu_ps(
239                    result.as_mut_ptr().add(i),
240                    _mm512_min_ps(_mm512_max_ps(va, min_vec), max_vec),
241                );
242                i += 16;
243            }
244            for j in i..len {
245                result[j] = a[j].max(min_val).min(max_val);
246            }
247        }
248    }
249
250    #[inline]
251    #[target_feature(enable = "avx512f")]
252    // SAFETY: caller ensures preconditions are met for this unsafe function
253    unsafe fn lerp(a: &[f32], b: &[f32], t: f32, result: &mut [f32]) {
254        unsafe {
255            let len = a.len();
256            let mut i = 0;
257            let t_vec = _mm512_set1_ps(t);
258            while i + 16 <= len {
259                let va = _mm512_loadu_ps(a.as_ptr().add(i));
260                let vb = _mm512_loadu_ps(b.as_ptr().add(i));
261                _mm512_storeu_ps(
262                    result.as_mut_ptr().add(i),
263                    _mm512_fmadd_ps(t_vec, _mm512_sub_ps(vb, va), va),
264                );
265                i += 16;
266            }
267            for j in i..len {
268                result[j] = a[j] + t * (b[j] - a[j]);
269            }
270        }
271    }
272
273    #[inline]
274    #[target_feature(enable = "avx512f")]
275    // SAFETY: caller ensures preconditions are met for this unsafe function
276    unsafe fn fma(a: &[f32], b: &[f32], c: &[f32], result: &mut [f32]) {
277        unsafe {
278            let len = a.len();
279            let mut i = 0;
280            while i + 16 <= len {
281                let va = _mm512_loadu_ps(a.as_ptr().add(i));
282                let vb = _mm512_loadu_ps(b.as_ptr().add(i));
283                let vc = _mm512_loadu_ps(c.as_ptr().add(i));
284                _mm512_storeu_ps(result.as_mut_ptr().add(i), _mm512_fmadd_ps(va, vb, vc));
285                i += 16;
286            }
287            for j in i..len {
288                result[j] = a[j] * b[j] + c[j];
289            }
290        }
291    }
292
293    #[inline]
294    #[target_feature(enable = "avx512f")]
295    // SAFETY: caller ensures preconditions are met for this unsafe function
296    unsafe fn relu(a: &[f32], result: &mut [f32]) {
297        unsafe {
298            let len = a.len();
299            let ap = a.as_ptr();
300            let rp = result.as_mut_ptr();
301            let mut i = 0;
302            let zero = _mm512_setzero_ps();
303
304            if len >= 8192 {
305                // Non-temporal path: 4-way unrolled with prefetch
306                while i + 64 <= len {
307                    _mm_prefetch(ap.add(i + 128).cast::<i8>(), _MM_HINT_T0);
308
309                    _mm512_stream_ps(rp.add(i), _mm512_max_ps(_mm512_loadu_ps(ap.add(i)), zero));
310                    _mm512_stream_ps(
311                        rp.add(i + 16),
312                        _mm512_max_ps(_mm512_loadu_ps(ap.add(i + 16)), zero),
313                    );
314                    _mm512_stream_ps(
315                        rp.add(i + 32),
316                        _mm512_max_ps(_mm512_loadu_ps(ap.add(i + 32)), zero),
317                    );
318                    _mm512_stream_ps(
319                        rp.add(i + 48),
320                        _mm512_max_ps(_mm512_loadu_ps(ap.add(i + 48)), zero),
321                    );
322
323                    i += 64;
324                }
325                while i + 16 <= len {
326                    _mm512_stream_ps(rp.add(i), _mm512_max_ps(_mm512_loadu_ps(ap.add(i)), zero));
327                    i += 16;
328                }
329                _mm_sfence();
330            } else {
331                while i + 16 <= len {
332                    _mm512_storeu_ps(rp.add(i), _mm512_max_ps(_mm512_loadu_ps(ap.add(i)), zero));
333                    i += 16;
334                }
335            }
336            for j in i..len {
337                result[j] = a[j].max(0.0);
338            }
339        }
340    }
341
342    #[inline]
343    #[target_feature(enable = "avx512f")]
344    // SAFETY: caller ensures preconditions are met for this unsafe function
345    unsafe fn exp(a: &[f32], result: &mut [f32]) {
346        unsafe {
347            let len = a.len();
348            let mut i = 0;
349            let ln2 = _mm512_set1_ps(std::f32::consts::LN_2);
350            let inv_ln2 = _mm512_set1_ps(1.0 / std::f32::consts::LN_2);
351            let one = _mm512_set1_ps(1.0);
352            let c2 = _mm512_set1_ps(0.5);
353            let c3 = _mm512_set1_ps(0.166_666_67);
354            let c4 = _mm512_set1_ps(0.041_666_668);
355            let c5 = _mm512_set1_ps(0.008_333_334);
356            while i + 16 <= len {
357                let x = _mm512_loadu_ps(a.as_ptr().add(i));
358                let k = _mm512_cvtps_epi32(_mm512_mul_ps(x, inv_ln2));
359                let kf = _mm512_cvtepi32_ps(k);
360                let r = _mm512_sub_ps(x, _mm512_mul_ps(kf, ln2));
361                let mut poly = _mm512_fmadd_ps(r, c5, one);
362                poly = _mm512_fmadd_ps(r, _mm512_fmadd_ps(r, poly, c4), one);
363                poly = _mm512_fmadd_ps(r, _mm512_fmadd_ps(r, poly, c3), one);
364                poly = _mm512_fmadd_ps(r, _mm512_fmadd_ps(r, poly, c2), one);
365                poly = _mm512_fmadd_ps(r, poly, one);
366                let exp_k = _mm512_castsi512_ps(_mm512_slli_epi32(
367                    _mm512_add_epi32(k, _mm512_set1_epi32(127)),
368                    23,
369                ));
370                _mm512_storeu_ps(result.as_mut_ptr().add(i), _mm512_mul_ps(poly, exp_k));
371                i += 16;
372            }
373            for j in i..len {
374                result[j] = a[j].exp();
375            }
376        }
377    }
378
379    #[inline]
380    #[target_feature(enable = "avx512f")]
381    // SAFETY: caller ensures preconditions are met for this unsafe function
382    unsafe fn sigmoid(a: &[f32], result: &mut [f32]) {
383        // sigmoid(x) = 1 / (1 + exp(-x))
384        let len = a.len();
385        for j in 0..len {
386            result[j] = 1.0 / (1.0 + (-a[j]).exp());
387        }
388    }
389
390    #[inline]
391    #[target_feature(enable = "avx512f")]
392    // SAFETY: caller ensures preconditions are met for this unsafe function
393    unsafe fn gelu(a: &[f32], result: &mut [f32]) {
394        for j in 0..a.len() {
395            let x = a[j];
396            let inner = 0.797_884_56 * (x + 0.044_715 * x * x * x);
397            result[j] = 0.5 * x * (1.0 + inner.tanh());
398        }
399    }
400
401    #[inline]
402    #[target_feature(enable = "avx512f")]
403    // SAFETY: caller ensures preconditions are met for this unsafe function
404    unsafe fn swish(a: &[f32], result: &mut [f32]) {
405        for j in 0..a.len() {
406            result[j] = a[j] / (1.0 + (-a[j]).exp());
407        }
408    }
409
410    #[inline]
411    #[target_feature(enable = "avx512f")]
412    // SAFETY: caller ensures preconditions are met for this unsafe function
413    unsafe fn tanh(a: &[f32], result: &mut [f32]) {
414        for j in 0..a.len() {
415            result[j] = a[j].tanh();
416        }
417    }
418
419    #[inline]
420    #[target_feature(enable = "avx512f")]
421    // SAFETY: caller ensures preconditions are met for this unsafe function
422    unsafe fn sqrt(a: &[f32], result: &mut [f32]) {
423        unsafe {
424            let len = a.len();
425            let mut i = 0;
426            while i + 16 <= len {
427                _mm512_storeu_ps(
428                    result.as_mut_ptr().add(i),
429                    _mm512_sqrt_ps(_mm512_loadu_ps(a.as_ptr().add(i))),
430                );
431                i += 16;
432            }
433            for j in i..len {
434                result[j] = a[j].sqrt();
435            }
436        }
437    }
438
439    #[inline]
440    #[target_feature(enable = "avx512f")]
441    // SAFETY: caller ensures preconditions are met for this unsafe function
442    unsafe fn recip(a: &[f32], result: &mut [f32]) {
443        unsafe {
444            let len = a.len();
445            let mut i = 0;
446            let one = _mm512_set1_ps(1.0);
447            while i + 16 <= len {
448                _mm512_storeu_ps(
449                    result.as_mut_ptr().add(i),
450                    _mm512_div_ps(one, _mm512_loadu_ps(a.as_ptr().add(i))),
451                );
452                i += 16;
453            }
454            for j in i..len {
455                result[j] = a[j].recip();
456            }
457        }
458    }
459
460    // SAFETY: caller ensures preconditions are met for this unsafe function
461    unsafe fn ln(a: &[f32], result: &mut [f32]) {
462        unsafe {
463            super::scalar::ScalarBackend::ln(a, result);
464        }
465    }
466    // SAFETY: caller ensures preconditions are met for this unsafe function
467    unsafe fn log2(a: &[f32], result: &mut [f32]) {
468        unsafe {
469            super::scalar::ScalarBackend::log2(a, result);
470        }
471    }
472    // SAFETY: caller ensures preconditions are met for this unsafe function
473    unsafe fn log10(a: &[f32], result: &mut [f32]) {
474        unsafe {
475            super::scalar::ScalarBackend::log10(a, result);
476        }
477    }
478    // SAFETY: caller ensures preconditions are met for this unsafe function
479    unsafe fn sin(a: &[f32], result: &mut [f32]) {
480        unsafe {
481            super::scalar::ScalarBackend::sin(a, result);
482        }
483    }
484    // SAFETY: caller ensures preconditions are met for this unsafe function
485    unsafe fn cos(a: &[f32], result: &mut [f32]) {
486        unsafe {
487            super::scalar::ScalarBackend::cos(a, result);
488        }
489    }
490    // SAFETY: caller ensures preconditions are met for this unsafe function
491    unsafe fn tan(a: &[f32], result: &mut [f32]) {
492        unsafe {
493            super::scalar::ScalarBackend::tan(a, result);
494        }
495    }
496
497    // SAFETY: caller ensures preconditions are met for this unsafe function
498    unsafe fn floor(a: &[f32], result: &mut [f32]) {
499        unsafe {
500            super::scalar::ScalarBackend::floor(a, result);
501        }
502    }
503    // SAFETY: caller ensures preconditions are met for this unsafe function
504    unsafe fn ceil(a: &[f32], result: &mut [f32]) {
505        unsafe {
506            super::scalar::ScalarBackend::ceil(a, result);
507        }
508    }
509    // SAFETY: caller ensures preconditions are met for this unsafe function
510    unsafe fn round(a: &[f32], result: &mut [f32]) {
511        unsafe {
512            super::scalar::ScalarBackend::round(a, result);
513        }
514    }
515}
516
517#[cfg(all(test, target_arch = "x86_64"))]
518mod tests;