Skip to main content

trueno/backends/avx512/ops/
arithmetic.rs

1//! AVX-512 arithmetic operations (add, sub, mul, div).
2//!
3//! For large vectors (≥8192 elements), uses non-temporal stores to bypass
4//! cache pollution and 4-way unrolling for instruction-level parallelism.
5//! Based on: Drepper (2007) "What Every Programmer Should Know About Memory"
6//! Section 6.1: non-temporal stores eliminate read-for-ownership traffic.
7
8#[cfg(target_arch = "x86_64")]
9use std::arch::x86_64::*;
10
11/// Threshold above which non-temporal stores are beneficial.
12/// Below this, data fits in L2 and cache-through stores are faster.
13const NT_THRESHOLD: usize = 8192;
14
15/// AVX-512 vector addition.
16#[inline]
17#[target_feature(enable = "avx512f")]
18// SAFETY: caller ensures preconditions are met for this unsafe function
19pub(crate) unsafe fn add(a: &[f32], b: &[f32], result: &mut [f32]) {
20    unsafe {
21        let len = a.len();
22        if len >= NT_THRESHOLD {
23            add_nt(a, b, result);
24        } else {
25            add_cached(a, b, result);
26        }
27    }
28}
29
30/// Cached-store path for small vectors (fits in L2).
31#[inline]
32#[target_feature(enable = "avx512f")]
33unsafe fn add_cached(a: &[f32], b: &[f32], result: &mut [f32]) {
34    unsafe {
35        let len = a.len();
36        let mut i = 0;
37        while i + 16 <= len {
38            let va = _mm512_loadu_ps(a.as_ptr().add(i));
39            let vb = _mm512_loadu_ps(b.as_ptr().add(i));
40            _mm512_storeu_ps(result.as_mut_ptr().add(i), _mm512_add_ps(va, vb));
41            i += 16;
42        }
43        for j in i..len {
44            result[j] = a[j] + b[j];
45        }
46    }
47}
48
49/// Non-temporal store path for large vectors.
50/// 4-way unrolled (64 f32 = 256 bytes per iteration = 4 cache lines).
51/// Prefetches 8 cache lines ahead (~512 bytes).
52#[inline]
53#[target_feature(enable = "avx512f")]
54unsafe fn add_nt(a: &[f32], b: &[f32], result: &mut [f32]) {
55    unsafe {
56        let len = a.len();
57        let ap = a.as_ptr();
58        let bp = b.as_ptr();
59        let rp = result.as_mut_ptr();
60        let mut i = 0;
61
62        // 4-way unrolled non-temporal loop
63        while i + 64 <= len {
64            // Prefetch 8 cache lines ahead (512 bytes = 128 f32)
65            _mm_prefetch(ap.add(i + 128).cast::<i8>(), _MM_HINT_T0);
66            _mm_prefetch(bp.add(i + 128).cast::<i8>(), _MM_HINT_T0);
67
68            let va0 = _mm512_loadu_ps(ap.add(i));
69            let vb0 = _mm512_loadu_ps(bp.add(i));
70            let va1 = _mm512_loadu_ps(ap.add(i + 16));
71            let vb1 = _mm512_loadu_ps(bp.add(i + 16));
72            let va2 = _mm512_loadu_ps(ap.add(i + 32));
73            let vb2 = _mm512_loadu_ps(bp.add(i + 32));
74            let va3 = _mm512_loadu_ps(ap.add(i + 48));
75            let vb3 = _mm512_loadu_ps(bp.add(i + 48));
76
77            _mm512_stream_ps(rp.add(i), _mm512_add_ps(va0, vb0));
78            _mm512_stream_ps(rp.add(i + 16), _mm512_add_ps(va1, vb1));
79            _mm512_stream_ps(rp.add(i + 32), _mm512_add_ps(va2, vb2));
80            _mm512_stream_ps(rp.add(i + 48), _mm512_add_ps(va3, vb3));
81
82            i += 64;
83        }
84
85        // Cleanup: remaining full SIMD widths
86        while i + 16 <= len {
87            let va = _mm512_loadu_ps(ap.add(i));
88            let vb = _mm512_loadu_ps(bp.add(i));
89            _mm512_stream_ps(rp.add(i), _mm512_add_ps(va, vb));
90            i += 16;
91        }
92
93        // Memory fence after non-temporal stores
94        _mm_sfence();
95
96        // Scalar remainder
97        for j in i..len {
98            result[j] = a[j] + b[j];
99        }
100    }
101}
102
103/// AVX-512 vector subtraction.
104#[inline]
105#[target_feature(enable = "avx512f")]
106// SAFETY: caller ensures preconditions are met for this unsafe function
107pub(crate) unsafe fn sub(a: &[f32], b: &[f32], result: &mut [f32]) {
108    unsafe {
109        let len = a.len();
110        if len >= NT_THRESHOLD {
111            sub_nt(a, b, result);
112        } else {
113            sub_cached(a, b, result);
114        }
115    }
116}
117
118#[inline]
119#[target_feature(enable = "avx512f")]
120unsafe fn sub_cached(a: &[f32], b: &[f32], result: &mut [f32]) {
121    unsafe {
122        let len = a.len();
123        let mut i = 0;
124        while i + 16 <= len {
125            let va = _mm512_loadu_ps(a.as_ptr().add(i));
126            let vb = _mm512_loadu_ps(b.as_ptr().add(i));
127            _mm512_storeu_ps(result.as_mut_ptr().add(i), _mm512_sub_ps(va, vb));
128            i += 16;
129        }
130        for j in i..len {
131            result[j] = a[j] - b[j];
132        }
133    }
134}
135
136#[inline]
137#[target_feature(enable = "avx512f")]
138unsafe fn sub_nt(a: &[f32], b: &[f32], result: &mut [f32]) {
139    unsafe {
140        let len = a.len();
141        let ap = a.as_ptr();
142        let bp = b.as_ptr();
143        let rp = result.as_mut_ptr();
144        let mut i = 0;
145
146        while i + 64 <= len {
147            _mm_prefetch(ap.add(i + 128).cast::<i8>(), _MM_HINT_T0);
148            _mm_prefetch(bp.add(i + 128).cast::<i8>(), _MM_HINT_T0);
149
150            let va0 = _mm512_loadu_ps(ap.add(i));
151            let vb0 = _mm512_loadu_ps(bp.add(i));
152            let va1 = _mm512_loadu_ps(ap.add(i + 16));
153            let vb1 = _mm512_loadu_ps(bp.add(i + 16));
154            let va2 = _mm512_loadu_ps(ap.add(i + 32));
155            let vb2 = _mm512_loadu_ps(bp.add(i + 32));
156            let va3 = _mm512_loadu_ps(ap.add(i + 48));
157            let vb3 = _mm512_loadu_ps(bp.add(i + 48));
158
159            _mm512_stream_ps(rp.add(i), _mm512_sub_ps(va0, vb0));
160            _mm512_stream_ps(rp.add(i + 16), _mm512_sub_ps(va1, vb1));
161            _mm512_stream_ps(rp.add(i + 32), _mm512_sub_ps(va2, vb2));
162            _mm512_stream_ps(rp.add(i + 48), _mm512_sub_ps(va3, vb3));
163
164            i += 64;
165        }
166
167        while i + 16 <= len {
168            let va = _mm512_loadu_ps(ap.add(i));
169            let vb = _mm512_loadu_ps(bp.add(i));
170            _mm512_stream_ps(rp.add(i), _mm512_sub_ps(va, vb));
171            i += 16;
172        }
173
174        _mm_sfence();
175
176        for j in i..len {
177            result[j] = a[j] - b[j];
178        }
179    }
180}
181
182/// AVX-512 vector multiplication.
183#[inline]
184#[target_feature(enable = "avx512f")]
185// SAFETY: caller ensures preconditions are met for this unsafe function
186pub(crate) unsafe fn mul(a: &[f32], b: &[f32], result: &mut [f32]) {
187    unsafe {
188        let len = a.len();
189        if len >= NT_THRESHOLD {
190            mul_nt(a, b, result);
191        } else {
192            mul_cached(a, b, result);
193        }
194    }
195}
196
197#[inline]
198#[target_feature(enable = "avx512f")]
199unsafe fn mul_cached(a: &[f32], b: &[f32], result: &mut [f32]) {
200    unsafe {
201        let len = a.len();
202        let mut i = 0;
203        while i + 16 <= len {
204            let va = _mm512_loadu_ps(a.as_ptr().add(i));
205            let vb = _mm512_loadu_ps(b.as_ptr().add(i));
206            _mm512_storeu_ps(result.as_mut_ptr().add(i), _mm512_mul_ps(va, vb));
207            i += 16;
208        }
209        for j in i..len {
210            result[j] = a[j] * b[j];
211        }
212    }
213}
214
215#[inline]
216#[target_feature(enable = "avx512f")]
217unsafe fn mul_nt(a: &[f32], b: &[f32], result: &mut [f32]) {
218    unsafe {
219        let len = a.len();
220        let ap = a.as_ptr();
221        let bp = b.as_ptr();
222        let rp = result.as_mut_ptr();
223        let mut i = 0;
224
225        while i + 64 <= len {
226            _mm_prefetch(ap.add(i + 128).cast::<i8>(), _MM_HINT_T0);
227            _mm_prefetch(bp.add(i + 128).cast::<i8>(), _MM_HINT_T0);
228
229            let va0 = _mm512_loadu_ps(ap.add(i));
230            let vb0 = _mm512_loadu_ps(bp.add(i));
231            let va1 = _mm512_loadu_ps(ap.add(i + 16));
232            let vb1 = _mm512_loadu_ps(bp.add(i + 16));
233            let va2 = _mm512_loadu_ps(ap.add(i + 32));
234            let vb2 = _mm512_loadu_ps(bp.add(i + 32));
235            let va3 = _mm512_loadu_ps(ap.add(i + 48));
236            let vb3 = _mm512_loadu_ps(bp.add(i + 48));
237
238            _mm512_stream_ps(rp.add(i), _mm512_mul_ps(va0, vb0));
239            _mm512_stream_ps(rp.add(i + 16), _mm512_mul_ps(va1, vb1));
240            _mm512_stream_ps(rp.add(i + 32), _mm512_mul_ps(va2, vb2));
241            _mm512_stream_ps(rp.add(i + 48), _mm512_mul_ps(va3, vb3));
242
243            i += 64;
244        }
245
246        while i + 16 <= len {
247            let va = _mm512_loadu_ps(ap.add(i));
248            let vb = _mm512_loadu_ps(bp.add(i));
249            _mm512_stream_ps(rp.add(i), _mm512_mul_ps(va, vb));
250            i += 16;
251        }
252
253        _mm_sfence();
254
255        for j in i..len {
256            result[j] = a[j] * b[j];
257        }
258    }
259}
260
261/// AVX-512 vector division.
262#[inline]
263#[target_feature(enable = "avx512f")]
264// SAFETY: caller ensures preconditions are met for this unsafe function
265pub(crate) unsafe fn div(a: &[f32], b: &[f32], result: &mut [f32]) {
266    unsafe {
267        let len = a.len();
268        let mut i = 0;
269        while i + 16 <= len {
270            let va = _mm512_loadu_ps(a.as_ptr().add(i));
271            let vb = _mm512_loadu_ps(b.as_ptr().add(i));
272            _mm512_storeu_ps(result.as_mut_ptr().add(i), _mm512_div_ps(va, vb));
273            i += 16;
274        }
275        for j in i..len {
276            result[j] = a[j] / b[j];
277        }
278    }
279}