Skip to main content

trueno/backends/avx2/ops/
arithmetic.rs

1//! AVX2 arithmetic operations (add, sub, mul, div).
2//!
3//! For large vectors (≥8192 elements), uses non-temporal stores (`_mm256_stream_ps`)
4//! to bypass cache and 4-way unrolling for ILP.
5
6#[cfg(target_arch = "x86_64")]
7use std::arch::x86_64::*;
8
9const NT_THRESHOLD: usize = 8192;
10
11/// AVX2 vector addition.
12#[inline]
13#[target_feature(enable = "avx2")]
14// SAFETY: caller ensures preconditions are met for this unsafe function
15pub(crate) unsafe fn add(a: &[f32], b: &[f32], result: &mut [f32]) {
16    unsafe {
17        let len = a.len();
18        // NT stores require 32-byte alignment (#242 SIGSEGV fix)
19        let rp_aligned = (result.as_ptr() as usize) % 32 == 0;
20        if len >= NT_THRESHOLD && rp_aligned {
21            add_nt(a, b, result);
22        } else {
23            add_cached(a, b, result);
24        }
25    }
26}
27
28#[inline]
29#[target_feature(enable = "avx2")]
30unsafe fn add_cached(a: &[f32], b: &[f32], result: &mut [f32]) {
31    unsafe {
32        let len = a.len();
33        let mut i = 0;
34        while i + 8 <= len {
35            let va = _mm256_loadu_ps(a.as_ptr().add(i));
36            let vb = _mm256_loadu_ps(b.as_ptr().add(i));
37            _mm256_storeu_ps(result.as_mut_ptr().add(i), _mm256_add_ps(va, vb));
38            i += 8;
39        }
40        for j in i..len {
41            result[j] = a[j] + b[j];
42        }
43    }
44}
45
46#[inline]
47#[target_feature(enable = "avx2")]
48unsafe fn add_nt(a: &[f32], b: &[f32], result: &mut [f32]) {
49    unsafe {
50        let len = a.len();
51        let ap = a.as_ptr();
52        let bp = b.as_ptr();
53        let rp = result.as_mut_ptr();
54        let mut i = 0;
55
56        // 4-way unrolled NT loop (32 f32 = 128 bytes = 2 cache lines per iter)
57        while i + 32 <= len {
58            _mm_prefetch(ap.add(i + 64).cast::<i8>(), _MM_HINT_T0);
59            _mm_prefetch(bp.add(i + 64).cast::<i8>(), _MM_HINT_T0);
60
61            let va0 = _mm256_loadu_ps(ap.add(i));
62            let vb0 = _mm256_loadu_ps(bp.add(i));
63            let va1 = _mm256_loadu_ps(ap.add(i + 8));
64            let vb1 = _mm256_loadu_ps(bp.add(i + 8));
65            let va2 = _mm256_loadu_ps(ap.add(i + 16));
66            let vb2 = _mm256_loadu_ps(bp.add(i + 16));
67            let va3 = _mm256_loadu_ps(ap.add(i + 24));
68            let vb3 = _mm256_loadu_ps(bp.add(i + 24));
69
70            _mm256_stream_ps(rp.add(i), _mm256_add_ps(va0, vb0));
71            _mm256_stream_ps(rp.add(i + 8), _mm256_add_ps(va1, vb1));
72            _mm256_stream_ps(rp.add(i + 16), _mm256_add_ps(va2, vb2));
73            _mm256_stream_ps(rp.add(i + 24), _mm256_add_ps(va3, vb3));
74
75            i += 32;
76        }
77
78        while i + 8 <= len {
79            let va = _mm256_loadu_ps(ap.add(i));
80            let vb = _mm256_loadu_ps(bp.add(i));
81            _mm256_stream_ps(rp.add(i), _mm256_add_ps(va, vb));
82            i += 8;
83        }
84
85        _mm_sfence();
86
87        for j in i..len {
88            result[j] = a[j] + b[j];
89        }
90    }
91}
92
93/// AVX2 vector subtraction.
94#[inline]
95#[target_feature(enable = "avx2")]
96// SAFETY: caller ensures preconditions are met for this unsafe function
97pub(crate) unsafe fn sub(a: &[f32], b: &[f32], result: &mut [f32]) {
98    unsafe {
99        let len = a.len();
100        let ap = a.as_ptr();
101        let bp = b.as_ptr();
102        let rp = result.as_mut_ptr();
103        let mut i = 0;
104
105        // NT stores require 32-byte alignment (#242 SIGSEGV fix)
106        let rp_aligned = (rp as usize) % 32 == 0;
107        if len >= NT_THRESHOLD && rp_aligned {
108            while i + 32 <= len {
109                _mm_prefetch(ap.add(i + 64).cast::<i8>(), _MM_HINT_T0);
110                _mm_prefetch(bp.add(i + 64).cast::<i8>(), _MM_HINT_T0);
111
112                _mm256_stream_ps(
113                    rp.add(i),
114                    _mm256_sub_ps(_mm256_loadu_ps(ap.add(i)), _mm256_loadu_ps(bp.add(i))),
115                );
116                _mm256_stream_ps(
117                    rp.add(i + 8),
118                    _mm256_sub_ps(_mm256_loadu_ps(ap.add(i + 8)), _mm256_loadu_ps(bp.add(i + 8))),
119                );
120                _mm256_stream_ps(
121                    rp.add(i + 16),
122                    _mm256_sub_ps(_mm256_loadu_ps(ap.add(i + 16)), _mm256_loadu_ps(bp.add(i + 16))),
123                );
124                _mm256_stream_ps(
125                    rp.add(i + 24),
126                    _mm256_sub_ps(_mm256_loadu_ps(ap.add(i + 24)), _mm256_loadu_ps(bp.add(i + 24))),
127                );
128                i += 32;
129            }
130            while i + 8 <= len {
131                _mm256_stream_ps(
132                    rp.add(i),
133                    _mm256_sub_ps(_mm256_loadu_ps(ap.add(i)), _mm256_loadu_ps(bp.add(i))),
134                );
135                i += 8;
136            }
137            _mm_sfence();
138        } else {
139            while i + 8 <= len {
140                _mm256_storeu_ps(
141                    rp.add(i),
142                    _mm256_sub_ps(_mm256_loadu_ps(ap.add(i)), _mm256_loadu_ps(bp.add(i))),
143                );
144                i += 8;
145            }
146        }
147
148        for j in i..len {
149            result[j] = a[j] - b[j];
150        }
151    }
152}
153
154/// AVX2 vector multiplication.
155#[inline]
156#[target_feature(enable = "avx2")]
157// SAFETY: caller ensures preconditions are met for this unsafe function
158pub(crate) unsafe fn mul(a: &[f32], b: &[f32], result: &mut [f32]) {
159    unsafe {
160        let len = a.len();
161        let ap = a.as_ptr();
162        let bp = b.as_ptr();
163        let rp = result.as_mut_ptr();
164        let mut i = 0;
165
166        // NT stores (_mm256_stream_ps) require 32-byte aligned output.
167        // Vec<f32> default alignment is 4 bytes — only use NT path if aligned.
168        // Fix for #242 SIGSEGV: General Protection Fault from unaligned stream_ps.
169        let rp_aligned = (rp as usize) % 32 == 0;
170        if len >= NT_THRESHOLD && rp_aligned {
171            while i + 32 <= len {
172                _mm_prefetch(ap.add(i + 64).cast::<i8>(), _MM_HINT_T0);
173                _mm_prefetch(bp.add(i + 64).cast::<i8>(), _MM_HINT_T0);
174
175                _mm256_stream_ps(
176                    rp.add(i),
177                    _mm256_mul_ps(_mm256_loadu_ps(ap.add(i)), _mm256_loadu_ps(bp.add(i))),
178                );
179                _mm256_stream_ps(
180                    rp.add(i + 8),
181                    _mm256_mul_ps(_mm256_loadu_ps(ap.add(i + 8)), _mm256_loadu_ps(bp.add(i + 8))),
182                );
183                _mm256_stream_ps(
184                    rp.add(i + 16),
185                    _mm256_mul_ps(_mm256_loadu_ps(ap.add(i + 16)), _mm256_loadu_ps(bp.add(i + 16))),
186                );
187                _mm256_stream_ps(
188                    rp.add(i + 24),
189                    _mm256_mul_ps(_mm256_loadu_ps(ap.add(i + 24)), _mm256_loadu_ps(bp.add(i + 24))),
190                );
191                i += 32;
192            }
193            while i + 8 <= len {
194                _mm256_stream_ps(
195                    rp.add(i),
196                    _mm256_mul_ps(_mm256_loadu_ps(ap.add(i)), _mm256_loadu_ps(bp.add(i))),
197                );
198                i += 8;
199            }
200            _mm_sfence();
201        } else {
202            while i + 8 <= len {
203                _mm256_storeu_ps(
204                    rp.add(i),
205                    _mm256_mul_ps(_mm256_loadu_ps(ap.add(i)), _mm256_loadu_ps(bp.add(i))),
206                );
207                i += 8;
208            }
209        }
210
211        for j in i..len {
212            result[j] = a[j] * b[j];
213        }
214    }
215}
216
217/// AVX2 vector division.
218#[inline]
219#[target_feature(enable = "avx2")]
220// SAFETY: caller ensures preconditions are met for this unsafe function
221pub(crate) unsafe fn div(a: &[f32], b: &[f32], result: &mut [f32]) {
222    unsafe {
223        let len = a.len();
224        let mut i = 0;
225        while i + 8 <= len {
226            let va = _mm256_loadu_ps(a.as_ptr().add(i));
227            let vb = _mm256_loadu_ps(b.as_ptr().add(i));
228            _mm256_storeu_ps(result.as_mut_ptr().add(i), _mm256_div_ps(va, vb));
229            i += 8;
230        }
231        for j in i..len {
232            result[j] = a[j] / b[j];
233        }
234    }
235}