trueno/backends/avx512/ops/
arithmetic.rs1#[cfg(target_arch = "x86_64")]
9use std::arch::x86_64::*;
10
11const NT_THRESHOLD: usize = 8192;
14
15#[inline]
17#[target_feature(enable = "avx512f")]
18pub(crate) unsafe fn add(a: &[f32], b: &[f32], result: &mut [f32]) {
20 unsafe {
21 let len = a.len();
22 if len >= NT_THRESHOLD {
23 add_nt(a, b, result);
24 } else {
25 add_cached(a, b, result);
26 }
27 }
28}
29
30#[inline]
32#[target_feature(enable = "avx512f")]
33unsafe fn add_cached(a: &[f32], b: &[f32], result: &mut [f32]) {
34 unsafe {
35 let len = a.len();
36 let mut i = 0;
37 while i + 16 <= len {
38 let va = _mm512_loadu_ps(a.as_ptr().add(i));
39 let vb = _mm512_loadu_ps(b.as_ptr().add(i));
40 _mm512_storeu_ps(result.as_mut_ptr().add(i), _mm512_add_ps(va, vb));
41 i += 16;
42 }
43 for j in i..len {
44 result[j] = a[j] + b[j];
45 }
46 }
47}
48
49#[inline]
53#[target_feature(enable = "avx512f")]
54unsafe fn add_nt(a: &[f32], b: &[f32], result: &mut [f32]) {
55 unsafe {
56 let len = a.len();
57 let ap = a.as_ptr();
58 let bp = b.as_ptr();
59 let rp = result.as_mut_ptr();
60 let mut i = 0;
61
62 while i + 64 <= len {
64 _mm_prefetch(ap.add(i + 128).cast::<i8>(), _MM_HINT_T0);
66 _mm_prefetch(bp.add(i + 128).cast::<i8>(), _MM_HINT_T0);
67
68 let va0 = _mm512_loadu_ps(ap.add(i));
69 let vb0 = _mm512_loadu_ps(bp.add(i));
70 let va1 = _mm512_loadu_ps(ap.add(i + 16));
71 let vb1 = _mm512_loadu_ps(bp.add(i + 16));
72 let va2 = _mm512_loadu_ps(ap.add(i + 32));
73 let vb2 = _mm512_loadu_ps(bp.add(i + 32));
74 let va3 = _mm512_loadu_ps(ap.add(i + 48));
75 let vb3 = _mm512_loadu_ps(bp.add(i + 48));
76
77 _mm512_stream_ps(rp.add(i), _mm512_add_ps(va0, vb0));
78 _mm512_stream_ps(rp.add(i + 16), _mm512_add_ps(va1, vb1));
79 _mm512_stream_ps(rp.add(i + 32), _mm512_add_ps(va2, vb2));
80 _mm512_stream_ps(rp.add(i + 48), _mm512_add_ps(va3, vb3));
81
82 i += 64;
83 }
84
85 while i + 16 <= len {
87 let va = _mm512_loadu_ps(ap.add(i));
88 let vb = _mm512_loadu_ps(bp.add(i));
89 _mm512_stream_ps(rp.add(i), _mm512_add_ps(va, vb));
90 i += 16;
91 }
92
93 _mm_sfence();
95
96 for j in i..len {
98 result[j] = a[j] + b[j];
99 }
100 }
101}
102
103#[inline]
105#[target_feature(enable = "avx512f")]
106pub(crate) unsafe fn sub(a: &[f32], b: &[f32], result: &mut [f32]) {
108 unsafe {
109 let len = a.len();
110 if len >= NT_THRESHOLD {
111 sub_nt(a, b, result);
112 } else {
113 sub_cached(a, b, result);
114 }
115 }
116}
117
118#[inline]
119#[target_feature(enable = "avx512f")]
120unsafe fn sub_cached(a: &[f32], b: &[f32], result: &mut [f32]) {
121 unsafe {
122 let len = a.len();
123 let mut i = 0;
124 while i + 16 <= len {
125 let va = _mm512_loadu_ps(a.as_ptr().add(i));
126 let vb = _mm512_loadu_ps(b.as_ptr().add(i));
127 _mm512_storeu_ps(result.as_mut_ptr().add(i), _mm512_sub_ps(va, vb));
128 i += 16;
129 }
130 for j in i..len {
131 result[j] = a[j] - b[j];
132 }
133 }
134}
135
136#[inline]
137#[target_feature(enable = "avx512f")]
138unsafe fn sub_nt(a: &[f32], b: &[f32], result: &mut [f32]) {
139 unsafe {
140 let len = a.len();
141 let ap = a.as_ptr();
142 let bp = b.as_ptr();
143 let rp = result.as_mut_ptr();
144 let mut i = 0;
145
146 while i + 64 <= len {
147 _mm_prefetch(ap.add(i + 128).cast::<i8>(), _MM_HINT_T0);
148 _mm_prefetch(bp.add(i + 128).cast::<i8>(), _MM_HINT_T0);
149
150 let va0 = _mm512_loadu_ps(ap.add(i));
151 let vb0 = _mm512_loadu_ps(bp.add(i));
152 let va1 = _mm512_loadu_ps(ap.add(i + 16));
153 let vb1 = _mm512_loadu_ps(bp.add(i + 16));
154 let va2 = _mm512_loadu_ps(ap.add(i + 32));
155 let vb2 = _mm512_loadu_ps(bp.add(i + 32));
156 let va3 = _mm512_loadu_ps(ap.add(i + 48));
157 let vb3 = _mm512_loadu_ps(bp.add(i + 48));
158
159 _mm512_stream_ps(rp.add(i), _mm512_sub_ps(va0, vb0));
160 _mm512_stream_ps(rp.add(i + 16), _mm512_sub_ps(va1, vb1));
161 _mm512_stream_ps(rp.add(i + 32), _mm512_sub_ps(va2, vb2));
162 _mm512_stream_ps(rp.add(i + 48), _mm512_sub_ps(va3, vb3));
163
164 i += 64;
165 }
166
167 while i + 16 <= len {
168 let va = _mm512_loadu_ps(ap.add(i));
169 let vb = _mm512_loadu_ps(bp.add(i));
170 _mm512_stream_ps(rp.add(i), _mm512_sub_ps(va, vb));
171 i += 16;
172 }
173
174 _mm_sfence();
175
176 for j in i..len {
177 result[j] = a[j] - b[j];
178 }
179 }
180}
181
182#[inline]
184#[target_feature(enable = "avx512f")]
185pub(crate) unsafe fn mul(a: &[f32], b: &[f32], result: &mut [f32]) {
187 unsafe {
188 let len = a.len();
189 if len >= NT_THRESHOLD {
190 mul_nt(a, b, result);
191 } else {
192 mul_cached(a, b, result);
193 }
194 }
195}
196
197#[inline]
198#[target_feature(enable = "avx512f")]
199unsafe fn mul_cached(a: &[f32], b: &[f32], result: &mut [f32]) {
200 unsafe {
201 let len = a.len();
202 let mut i = 0;
203 while i + 16 <= len {
204 let va = _mm512_loadu_ps(a.as_ptr().add(i));
205 let vb = _mm512_loadu_ps(b.as_ptr().add(i));
206 _mm512_storeu_ps(result.as_mut_ptr().add(i), _mm512_mul_ps(va, vb));
207 i += 16;
208 }
209 for j in i..len {
210 result[j] = a[j] * b[j];
211 }
212 }
213}
214
215#[inline]
216#[target_feature(enable = "avx512f")]
217unsafe fn mul_nt(a: &[f32], b: &[f32], result: &mut [f32]) {
218 unsafe {
219 let len = a.len();
220 let ap = a.as_ptr();
221 let bp = b.as_ptr();
222 let rp = result.as_mut_ptr();
223 let mut i = 0;
224
225 while i + 64 <= len {
226 _mm_prefetch(ap.add(i + 128).cast::<i8>(), _MM_HINT_T0);
227 _mm_prefetch(bp.add(i + 128).cast::<i8>(), _MM_HINT_T0);
228
229 let va0 = _mm512_loadu_ps(ap.add(i));
230 let vb0 = _mm512_loadu_ps(bp.add(i));
231 let va1 = _mm512_loadu_ps(ap.add(i + 16));
232 let vb1 = _mm512_loadu_ps(bp.add(i + 16));
233 let va2 = _mm512_loadu_ps(ap.add(i + 32));
234 let vb2 = _mm512_loadu_ps(bp.add(i + 32));
235 let va3 = _mm512_loadu_ps(ap.add(i + 48));
236 let vb3 = _mm512_loadu_ps(bp.add(i + 48));
237
238 _mm512_stream_ps(rp.add(i), _mm512_mul_ps(va0, vb0));
239 _mm512_stream_ps(rp.add(i + 16), _mm512_mul_ps(va1, vb1));
240 _mm512_stream_ps(rp.add(i + 32), _mm512_mul_ps(va2, vb2));
241 _mm512_stream_ps(rp.add(i + 48), _mm512_mul_ps(va3, vb3));
242
243 i += 64;
244 }
245
246 while i + 16 <= len {
247 let va = _mm512_loadu_ps(ap.add(i));
248 let vb = _mm512_loadu_ps(bp.add(i));
249 _mm512_stream_ps(rp.add(i), _mm512_mul_ps(va, vb));
250 i += 16;
251 }
252
253 _mm_sfence();
254
255 for j in i..len {
256 result[j] = a[j] * b[j];
257 }
258 }
259}
260
261#[inline]
263#[target_feature(enable = "avx512f")]
264pub(crate) unsafe fn div(a: &[f32], b: &[f32], result: &mut [f32]) {
266 unsafe {
267 let len = a.len();
268 let mut i = 0;
269 while i + 16 <= len {
270 let va = _mm512_loadu_ps(a.as_ptr().add(i));
271 let vb = _mm512_loadu_ps(b.as_ptr().add(i));
272 _mm512_storeu_ps(result.as_mut_ptr().add(i), _mm512_div_ps(va, vb));
273 i += 16;
274 }
275 for j in i..len {
276 result[j] = a[j] / b[j];
277 }
278 }
279}