1mod ops;
17
18#[cfg(target_arch = "x86_64")]
19use std::arch::x86_64::*;
20
21use super::VectorBackend;
22
23pub struct Avx512Backend;
25
26impl VectorBackend for Avx512Backend {
27 #[inline]
28 #[target_feature(enable = "avx512f")]
29 unsafe fn add(a: &[f32], b: &[f32], result: &mut [f32]) {
31 unsafe {
32 ops::arithmetic::add(a, b, result);
33 }
34 }
35
36 #[inline]
37 #[target_feature(enable = "avx512f")]
38 unsafe fn sub(a: &[f32], b: &[f32], result: &mut [f32]) {
40 unsafe {
41 ops::arithmetic::sub(a, b, result);
42 }
43 }
44
45 #[inline]
46 #[target_feature(enable = "avx512f")]
47 unsafe fn mul(a: &[f32], b: &[f32], result: &mut [f32]) {
49 unsafe {
50 ops::arithmetic::mul(a, b, result);
51 }
52 }
53
54 #[inline]
55 #[target_feature(enable = "avx512f")]
56 unsafe fn div(a: &[f32], b: &[f32], result: &mut [f32]) {
58 unsafe {
59 ops::arithmetic::div(a, b, result);
60 }
61 }
62
63 #[inline]
64 #[target_feature(enable = "avx512f")]
65 unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
67 unsafe { ops::reductions::dot(a, b) }
68 }
69
70 #[inline]
71 #[target_feature(enable = "avx512f")]
72 unsafe fn sum(a: &[f32]) -> f32 {
74 unsafe { ops::reductions::sum(a) }
75 }
76
77 #[inline]
78 #[target_feature(enable = "avx512f")]
79 unsafe fn max(a: &[f32]) -> f32 {
81 unsafe { ops::reductions::max(a) }
82 }
83
84 #[inline]
85 #[target_feature(enable = "avx512f")]
86 unsafe fn min(a: &[f32]) -> f32 {
88 unsafe { ops::reductions::min(a) }
89 }
90
91 #[inline]
92 #[target_feature(enable = "avx512f")]
93 unsafe fn argmax(a: &[f32]) -> usize {
95 unsafe { ops::reductions::argmax(a) }
96 }
97
98 #[inline]
99 #[target_feature(enable = "avx512f")]
100 unsafe fn argmin(a: &[f32]) -> usize {
102 unsafe { ops::reductions::argmin(a) }
103 }
104
105 unsafe fn sum_kahan(a: &[f32]) -> f32 {
107 unsafe { ops::reductions::sum_kahan(a) }
108 }
109
110 #[inline]
111 #[target_feature(enable = "avx512f")]
112 unsafe fn norm_l2(a: &[f32]) -> f32 {
114 unsafe {
115 if a.is_empty() {
116 return 0.0;
117 }
118 let len = a.len();
119 let mut i = 0;
120 let mut acc = _mm512_setzero_ps();
121 while i + 16 <= len {
122 let va = _mm512_loadu_ps(a.as_ptr().add(i));
123 acc = _mm512_add_ps(acc, _mm512_mul_ps(va, va));
124 i += 16;
125 }
126 let mut sum_sq = _mm512_reduce_add_ps(acc);
127 for &val in &a[i..] {
128 sum_sq += val * val;
129 }
130 sum_sq.sqrt()
131 }
132 }
133
134 #[inline]
135 #[target_feature(enable = "avx512f")]
136 unsafe fn norm_l1(a: &[f32]) -> f32 {
138 unsafe {
139 let len = a.len();
140 let mut i = 0;
141 let sign_mask = _mm512_set1_ps(f32::from_bits(0x7FFF_FFFF));
142 let mut acc = _mm512_setzero_ps();
143 while i + 16 <= len {
144 acc = _mm512_add_ps(
145 acc,
146 _mm512_and_ps(_mm512_loadu_ps(a.as_ptr().add(i)), sign_mask),
147 );
148 i += 16;
149 }
150 let mut result = _mm512_reduce_add_ps(acc);
151 for &val in &a[i..] {
152 result += val.abs();
153 }
154 result
155 }
156 }
157
158 #[inline]
159 #[target_feature(enable = "avx512f")]
160 unsafe fn norm_linf(a: &[f32]) -> f32 {
162 unsafe {
163 let len = a.len();
164 let mut i = 0;
165 let sign_mask = _mm512_set1_ps(f32::from_bits(0x7FFF_FFFF));
166 let mut max_vec = _mm512_setzero_ps();
167 while i + 16 <= len {
168 max_vec = _mm512_max_ps(
169 max_vec,
170 _mm512_and_ps(_mm512_loadu_ps(a.as_ptr().add(i)), sign_mask),
171 );
172 i += 16;
173 }
174 let mut result = _mm512_reduce_max_ps(max_vec);
175 for &val in &a[i..] {
176 let abs_val = val.abs();
177 if abs_val > result {
178 result = abs_val;
179 }
180 }
181 result
182 }
183 }
184
185 #[inline]
186 #[target_feature(enable = "avx512f")]
187 unsafe fn scale(a: &[f32], scalar: f32, result: &mut [f32]) {
189 unsafe {
190 let len = a.len();
191 let mut i = 0;
192 let scalar_vec = _mm512_set1_ps(scalar);
193 while i + 16 <= len {
194 _mm512_storeu_ps(
195 result.as_mut_ptr().add(i),
196 _mm512_mul_ps(_mm512_loadu_ps(a.as_ptr().add(i)), scalar_vec),
197 );
198 i += 16;
199 }
200 for j in i..len {
201 result[j] = a[j] * scalar;
202 }
203 }
204 }
205
206 #[inline]
207 #[target_feature(enable = "avx512f")]
208 unsafe fn abs(a: &[f32], result: &mut [f32]) {
210 unsafe {
211 let len = a.len();
212 let mut i = 0;
213 let sign_mask = _mm512_set1_ps(f32::from_bits(0x7FFF_FFFF));
214 while i + 16 <= len {
215 _mm512_storeu_ps(
216 result.as_mut_ptr().add(i),
217 _mm512_and_ps(_mm512_loadu_ps(a.as_ptr().add(i)), sign_mask),
218 );
219 i += 16;
220 }
221 for j in i..len {
222 result[j] = a[j].abs();
223 }
224 }
225 }
226
227 #[inline]
228 #[target_feature(enable = "avx512f")]
229 unsafe fn clamp(a: &[f32], min_val: f32, max_val: f32, result: &mut [f32]) {
231 unsafe {
232 let len = a.len();
233 let mut i = 0;
234 let min_vec = _mm512_set1_ps(min_val);
235 let max_vec = _mm512_set1_ps(max_val);
236 while i + 16 <= len {
237 let va = _mm512_loadu_ps(a.as_ptr().add(i));
238 _mm512_storeu_ps(
239 result.as_mut_ptr().add(i),
240 _mm512_min_ps(_mm512_max_ps(va, min_vec), max_vec),
241 );
242 i += 16;
243 }
244 for j in i..len {
245 result[j] = a[j].max(min_val).min(max_val);
246 }
247 }
248 }
249
250 #[inline]
251 #[target_feature(enable = "avx512f")]
252 unsafe fn lerp(a: &[f32], b: &[f32], t: f32, result: &mut [f32]) {
254 unsafe {
255 let len = a.len();
256 let mut i = 0;
257 let t_vec = _mm512_set1_ps(t);
258 while i + 16 <= len {
259 let va = _mm512_loadu_ps(a.as_ptr().add(i));
260 let vb = _mm512_loadu_ps(b.as_ptr().add(i));
261 _mm512_storeu_ps(
262 result.as_mut_ptr().add(i),
263 _mm512_fmadd_ps(t_vec, _mm512_sub_ps(vb, va), va),
264 );
265 i += 16;
266 }
267 for j in i..len {
268 result[j] = a[j] + t * (b[j] - a[j]);
269 }
270 }
271 }
272
273 #[inline]
274 #[target_feature(enable = "avx512f")]
275 unsafe fn fma(a: &[f32], b: &[f32], c: &[f32], result: &mut [f32]) {
277 unsafe {
278 let len = a.len();
279 let mut i = 0;
280 while i + 16 <= len {
281 let va = _mm512_loadu_ps(a.as_ptr().add(i));
282 let vb = _mm512_loadu_ps(b.as_ptr().add(i));
283 let vc = _mm512_loadu_ps(c.as_ptr().add(i));
284 _mm512_storeu_ps(result.as_mut_ptr().add(i), _mm512_fmadd_ps(va, vb, vc));
285 i += 16;
286 }
287 for j in i..len {
288 result[j] = a[j] * b[j] + c[j];
289 }
290 }
291 }
292
293 #[inline]
294 #[target_feature(enable = "avx512f")]
295 unsafe fn relu(a: &[f32], result: &mut [f32]) {
297 unsafe {
298 let len = a.len();
299 let ap = a.as_ptr();
300 let rp = result.as_mut_ptr();
301 let mut i = 0;
302 let zero = _mm512_setzero_ps();
303
304 if len >= 8192 {
305 while i + 64 <= len {
307 _mm_prefetch(ap.add(i + 128).cast::<i8>(), _MM_HINT_T0);
308
309 _mm512_stream_ps(rp.add(i), _mm512_max_ps(_mm512_loadu_ps(ap.add(i)), zero));
310 _mm512_stream_ps(
311 rp.add(i + 16),
312 _mm512_max_ps(_mm512_loadu_ps(ap.add(i + 16)), zero),
313 );
314 _mm512_stream_ps(
315 rp.add(i + 32),
316 _mm512_max_ps(_mm512_loadu_ps(ap.add(i + 32)), zero),
317 );
318 _mm512_stream_ps(
319 rp.add(i + 48),
320 _mm512_max_ps(_mm512_loadu_ps(ap.add(i + 48)), zero),
321 );
322
323 i += 64;
324 }
325 while i + 16 <= len {
326 _mm512_stream_ps(rp.add(i), _mm512_max_ps(_mm512_loadu_ps(ap.add(i)), zero));
327 i += 16;
328 }
329 _mm_sfence();
330 } else {
331 while i + 16 <= len {
332 _mm512_storeu_ps(rp.add(i), _mm512_max_ps(_mm512_loadu_ps(ap.add(i)), zero));
333 i += 16;
334 }
335 }
336 for j in i..len {
337 result[j] = a[j].max(0.0);
338 }
339 }
340 }
341
342 #[inline]
343 #[target_feature(enable = "avx512f")]
344 unsafe fn exp(a: &[f32], result: &mut [f32]) {
346 unsafe {
347 let len = a.len();
348 let mut i = 0;
349 let ln2 = _mm512_set1_ps(std::f32::consts::LN_2);
350 let inv_ln2 = _mm512_set1_ps(1.0 / std::f32::consts::LN_2);
351 let one = _mm512_set1_ps(1.0);
352 let c2 = _mm512_set1_ps(0.5);
353 let c3 = _mm512_set1_ps(0.166_666_67);
354 let c4 = _mm512_set1_ps(0.041_666_668);
355 let c5 = _mm512_set1_ps(0.008_333_334);
356 while i + 16 <= len {
357 let x = _mm512_loadu_ps(a.as_ptr().add(i));
358 let k = _mm512_cvtps_epi32(_mm512_mul_ps(x, inv_ln2));
359 let kf = _mm512_cvtepi32_ps(k);
360 let r = _mm512_sub_ps(x, _mm512_mul_ps(kf, ln2));
361 let mut poly = _mm512_fmadd_ps(r, c5, one);
362 poly = _mm512_fmadd_ps(r, _mm512_fmadd_ps(r, poly, c4), one);
363 poly = _mm512_fmadd_ps(r, _mm512_fmadd_ps(r, poly, c3), one);
364 poly = _mm512_fmadd_ps(r, _mm512_fmadd_ps(r, poly, c2), one);
365 poly = _mm512_fmadd_ps(r, poly, one);
366 let exp_k = _mm512_castsi512_ps(_mm512_slli_epi32(
367 _mm512_add_epi32(k, _mm512_set1_epi32(127)),
368 23,
369 ));
370 _mm512_storeu_ps(result.as_mut_ptr().add(i), _mm512_mul_ps(poly, exp_k));
371 i += 16;
372 }
373 for j in i..len {
374 result[j] = a[j].exp();
375 }
376 }
377 }
378
379 #[inline]
380 #[target_feature(enable = "avx512f")]
381 unsafe fn sigmoid(a: &[f32], result: &mut [f32]) {
383 let len = a.len();
385 for j in 0..len {
386 result[j] = 1.0 / (1.0 + (-a[j]).exp());
387 }
388 }
389
390 #[inline]
391 #[target_feature(enable = "avx512f")]
392 unsafe fn gelu(a: &[f32], result: &mut [f32]) {
394 for j in 0..a.len() {
395 let x = a[j];
396 let inner = 0.797_884_56 * (x + 0.044_715 * x * x * x);
397 result[j] = 0.5 * x * (1.0 + inner.tanh());
398 }
399 }
400
401 #[inline]
402 #[target_feature(enable = "avx512f")]
403 unsafe fn swish(a: &[f32], result: &mut [f32]) {
405 for j in 0..a.len() {
406 result[j] = a[j] / (1.0 + (-a[j]).exp());
407 }
408 }
409
410 #[inline]
411 #[target_feature(enable = "avx512f")]
412 unsafe fn tanh(a: &[f32], result: &mut [f32]) {
414 for j in 0..a.len() {
415 result[j] = a[j].tanh();
416 }
417 }
418
419 #[inline]
420 #[target_feature(enable = "avx512f")]
421 unsafe fn sqrt(a: &[f32], result: &mut [f32]) {
423 unsafe {
424 let len = a.len();
425 let mut i = 0;
426 while i + 16 <= len {
427 _mm512_storeu_ps(
428 result.as_mut_ptr().add(i),
429 _mm512_sqrt_ps(_mm512_loadu_ps(a.as_ptr().add(i))),
430 );
431 i += 16;
432 }
433 for j in i..len {
434 result[j] = a[j].sqrt();
435 }
436 }
437 }
438
439 #[inline]
440 #[target_feature(enable = "avx512f")]
441 unsafe fn recip(a: &[f32], result: &mut [f32]) {
443 unsafe {
444 let len = a.len();
445 let mut i = 0;
446 let one = _mm512_set1_ps(1.0);
447 while i + 16 <= len {
448 _mm512_storeu_ps(
449 result.as_mut_ptr().add(i),
450 _mm512_div_ps(one, _mm512_loadu_ps(a.as_ptr().add(i))),
451 );
452 i += 16;
453 }
454 for j in i..len {
455 result[j] = a[j].recip();
456 }
457 }
458 }
459
460 unsafe fn ln(a: &[f32], result: &mut [f32]) {
462 unsafe {
463 super::scalar::ScalarBackend::ln(a, result);
464 }
465 }
466 unsafe fn log2(a: &[f32], result: &mut [f32]) {
468 unsafe {
469 super::scalar::ScalarBackend::log2(a, result);
470 }
471 }
472 unsafe fn log10(a: &[f32], result: &mut [f32]) {
474 unsafe {
475 super::scalar::ScalarBackend::log10(a, result);
476 }
477 }
478 unsafe fn sin(a: &[f32], result: &mut [f32]) {
480 unsafe {
481 super::scalar::ScalarBackend::sin(a, result);
482 }
483 }
484 unsafe fn cos(a: &[f32], result: &mut [f32]) {
486 unsafe {
487 super::scalar::ScalarBackend::cos(a, result);
488 }
489 }
490 unsafe fn tan(a: &[f32], result: &mut [f32]) {
492 unsafe {
493 super::scalar::ScalarBackend::tan(a, result);
494 }
495 }
496
497 unsafe fn floor(a: &[f32], result: &mut [f32]) {
499 unsafe {
500 super::scalar::ScalarBackend::floor(a, result);
501 }
502 }
503 unsafe fn ceil(a: &[f32], result: &mut [f32]) {
505 unsafe {
506 super::scalar::ScalarBackend::ceil(a, result);
507 }
508 }
509 unsafe fn round(a: &[f32], result: &mut [f32]) {
511 unsafe {
512 super::scalar::ScalarBackend::round(a, result);
513 }
514 }
515}
516
517#[cfg(all(test, target_arch = "x86_64"))]
518mod tests;