1use super::{Simd, UnaryFn1, UnaryFn2};
2use core::arch::asm;
3use core::arch::x86_64::*;
4
5pub(crate) struct Avx2Fma {}
6
7impl Simd for Avx2Fma {
8 const F32_WIDTH: usize = 8;
9 const F64_WIDTH: usize = 4;
10
11 type Vf32 = __m256;
12 type Vf64 = __m256d;
13 type Vi32 = __m256i;
14 #[inline(always)]
15 unsafe fn set1_f32(x: f32) -> __m256 {
16 _mm256_set1_ps(x)
17 }
18 #[inline(always)]
19 unsafe fn set1_i32(x: i32) -> __m256i {
20 _mm256_set1_epi32(x)
21 }
22
23 #[inline(always)]
24 unsafe fn loadu_f32(ptr: *const f32) -> __m256 {
25 _mm256_loadu_ps(ptr)
26 }
27
28 #[inline(always)]
29 unsafe fn loadu_f64(ptr: *const f64) -> __m256d {
30 _mm256_loadu_pd(ptr)
31 }
32
33 #[inline(always)]
34 unsafe fn sqrt_f32(a: __m256) -> __m256 {
35 _mm256_sqrt_ps(a)
36 }
37
38 #[inline(always)]
39 unsafe fn sqrt_f64(a: __m256d) -> __m256d {
40 _mm256_sqrt_pd(a)
41 }
42
43 #[inline(always)]
44 unsafe fn and_f32(a: __m256, b: __m256) -> __m256 {
45 _mm256_and_ps(a, b)
46 }
47
48 #[inline(always)]
49 unsafe fn mul_f32(a: __m256, b: __m256) -> __m256 {
50 _mm256_mul_ps(a, b)
51 }
52
53 #[inline(always)]
54 unsafe fn add_i32(a: __m256i, b: __m256i) -> __m256i {
55 _mm256_add_epi32(a, b)
56 }
57
58 #[inline(always)]
59 unsafe fn and_i32(a: __m256i, b: __m256i) -> __m256i {
60 _mm256_and_si256(a, b)
61 }
62 #[inline(always)]
63 unsafe fn cvt_i32_f32(a: __m256i) -> __m256 {
64 _mm256_cvtepi32_ps(a)
65 }
66 #[inline(always)]
67 unsafe fn cvt_f32_i32(a: __m256) -> __m256i {
68 _mm256_cvtps_epi32(a)
69 }
70
71 #[inline(always)]
72 unsafe fn cvtt_f32_i32(a: __m256) -> __m256i {
73 _mm256_cvttps_epi32(a)
74 }
75
76 #[inline(always)]
77 unsafe fn sub_i32(a: __m256i, b: __m256i) -> __m256i {
78 _mm256_sub_epi32(a, b)
79 }
80
81 #[inline(always)]
82 unsafe fn andnot_i32(a: __m256i, b: __m256i) -> __m256i {
83 _mm256_andnot_si256(a, b)
84 }
85
86 #[inline(always)]
87 unsafe fn slli_i32<const IMM8: i32>(a: __m256i) -> __m256i {
88 _mm256_slli_epi32(a, IMM8)
89 }
90
91 #[inline(always)]
92 unsafe fn cmpeq_i32(a: __m256i, b: __m256i) -> __m256i {
93 _mm256_cmpeq_epi32(a, b)
94 }
95
96 #[inline(always)]
97 unsafe fn cast_i32_f32(a: __m256i) -> __m256 {
98 _mm256_castsi256_ps(a)
99 }
100
101 #[inline(always)]
102 unsafe fn fmadd_f32(a: __m256, b: __m256, c: __m256) -> __m256 {
103 _mm256_fmadd_ps(a, b, c)
104 }
105
106 #[inline(always)]
107 unsafe fn andnot_f32(a: __m256, b: __m256) -> __m256 {
108 _mm256_andnot_ps(a, b)
109 }
110
111 #[inline(always)]
112 unsafe fn add_f32(a: __m256, b: __m256) -> __m256 {
113 _mm256_add_ps(a, b)
114 }
115
116 #[inline(always)]
117 unsafe fn xor_f32(a: __m256, b: __m256) -> __m256 {
118 _mm256_xor_ps(a, b)
119 }
120
121 #[inline(always)]
122 unsafe fn storeu_f32(ptr: *mut f32, a: __m256) {
123 _mm256_storeu_ps(ptr, a)
124 }
125
126 #[inline(always)]
127 unsafe fn storeu_f64(ptr: *mut f64, a: __m256d) {
128 _mm256_storeu_pd(ptr, a)
129 }
130
131 #[inline(always)]
132 unsafe fn sub_f32(a: __m256, b: __m256) -> __m256 {
133 _mm256_sub_ps(a, b)
134 }
135
136 #[inline(always)]
137 unsafe fn cmp_eq_f32(a: __m256, b: __m256) -> __m256 {
138 _mm256_cmp_ps(a, b, _CMP_EQ_OS)
139 }
140
141 #[inline(always)]
142 unsafe fn cmp_lt_f32(a: __m256, b: __m256) -> __m256 {
143 _mm256_cmp_ps(a, b, _CMP_LT_OS)
144 }
145
146 #[inline(always)]
147 unsafe fn mask_mul_f32(mask: __m256, a: __m256, b: __m256) -> __m256 {
148 let one = _mm256_set1_ps(1.0);
149 let one = _mm256_andnot_ps(mask, one);
150 let masked_one = _mm256_and_ps(b, mask);
151 let masked_b = _mm256_or_ps(masked_one, one);
152 let c = _mm256_mul_ps(a, masked_b);
153 c
154 }
155
156 #[inline(always)]
157 unsafe fn mask_sub_f32(mask: __m256, a: __m256, b: __m256) -> __m256 {
158 let masked_b = _mm256_and_ps(b, mask);
159 let c = _mm256_sub_ps(a, masked_b);
160 c
161 }
162
163 #[inline(always)]
164 unsafe fn or_f32(a: __m256, b: __m256) -> __m256 {
165 _mm256_or_ps(a, b)
166 }
167
168 #[inline(always)]
169 unsafe fn mask_add_f32(mask: __m256, a: __m256, b: __m256) -> __m256 {
170 let masked_b = _mm256_and_ps(b, mask);
171 let c = _mm256_add_ps(a, masked_b);
172 c
173 }
174
175 #[inline(always)]
176 unsafe fn cast_f32_i32(a: __m256) -> __m256i {
177 _mm256_castps_si256(a)
178 }
179
180 #[inline(always)]
181 unsafe fn srli_i32<const IMM8: i32>(a: __m256i) -> __m256i {
182 _mm256_srli_epi32(a, IMM8)
183 }
184
185 #[inline(always)]
186 unsafe fn min_f32(a: __m256, b: __m256) -> __m256 {
187 _mm256_min_ps(a, b)
188 }
189
190 #[inline(always)]
191 unsafe fn floor_f32(a: __m256) -> __m256 {
192 _mm256_round_ps(a, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)
193 }
194
195 #[inline(always)]
196 unsafe fn max_f32(a: __m256, b: __m256) -> __m256 {
197 _mm256_max_ps(a, b)
198 }
199
200 #[inline(always)]
201 unsafe fn div_f32(a: __m256, b: __m256) -> __m256 {
202 _mm256_div_ps(a, b)
203 }
204
205 #[inline(always)]
206 unsafe fn get_exp_mant_f32(a: __m256) -> (__m256, __m256) {
207 let a_0 = a;
208 let zero_mask = Self::cmp_eq_f32(a, Self::set1_f32(0.0));
209 let nan_mask = Self::cmp_lt_f32(a, Self::set1_f32(0.0));
210 let inv_mant_mask = Self::cast_i32_f32(Self::set1_i32(!0x7f800000));
211 let inf_mask = Self::cmp_eq_f32(a, Self::set1_f32(f32::INFINITY));
212 let denorm_mul = Self::set1_f32(134217730.);
213 let denorm_th = Self::set1_f32(1.1754945e-38);
214 let denorm_mask = Self::cmp_lt_f32(a, denorm_th);
215 let mut a = Self::mask_mul_f32(denorm_mask, a, denorm_mul);
216
217 let mut imm0 = Self::srli_i32::<23>(Self::cast_f32_i32(a));
218
219 a = Self::and_f32(a, inv_mant_mask);
221 a = Self::or_f32(a, Self::set1_f32(0.5));
222
223 imm0 = Self::sub_i32(imm0, Self::set1_i32(0x7f));
225
226 let e = Self::cvt_i32_f32(imm0);
227
228 let e = Self::mask_sub_f32(denorm_mask, e, Self::set1_f32(27.0));
229 let e = Self::mask_sub_f32(zero_mask, e, Self::set1_f32(f32::INFINITY));
230 let e = Self::mask_add_f32(inf_mask, e, Self::set1_f32(f32::INFINITY));
231 let e = Self::min_f32(e, a_0);
232 let e = Self::mask_add_f32(nan_mask, e, Self::set1_f32(f32::NAN));
233
234 (e, a)
235 }
236}
237
238impl UnaryFn1 for Avx2Fma {}
239impl UnaryFn2 for Avx2Fma {}
240
241impl Avx2Fma {
242 #[target_feature(enable = "avx2,avx,fma")]
243 pub(crate) unsafe fn vs_exp(n: usize, a: *const f32, b: *mut f32) {
244 Self::vs_exp_0(n, a, b)
245 }
246
247 #[target_feature(enable = "avx2,avx,fma")]
248 pub(crate) unsafe fn vs_ln(n: usize, a: *const f32, b: *mut f32) {
249 Self::vs_ln_0(n, a, b)
250 }
251
252 #[target_feature(enable = "avx2,avx,fma")]
253 pub(crate) unsafe fn vs_tanh(n: usize, a: *const f32, b: *mut f32) {
254 Self::vs_tanh_0(n, a, b)
255 }
256
257 #[target_feature(enable = "avx2,avx,fma")]
258 pub(crate) unsafe fn vs_sin(n: usize, a: *const f32, b: *mut f32) {
259 Self::vs_sin_0(n, a, b)
260 }
261
262 #[target_feature(enable = "avx2,avx,fma")]
263 pub(crate) unsafe fn vs_cos(n: usize, a: *const f32, b: *mut f32) {
264 Self::vs_cos_0(n, a, b)
265 }
266}
267
268pub(crate) struct AvxSse2 {}
269
270impl Simd for AvxSse2 {
271 const F32_WIDTH: usize = 8;
272 const F64_WIDTH: usize = 4;
273
274 type Vf32 = __m256;
275 type Vf64 = __m256d;
276 type Vi32 = __m256i;
277 #[inline(always)]
278 unsafe fn set1_f32(x: f32) -> __m256 {
279 _mm256_set1_ps(x)
280 }
281 #[inline(always)]
282 unsafe fn set1_i32(x: i32) -> __m256i {
283 _mm256_set1_epi32(x)
284 }
285 #[inline(always)]
286 unsafe fn loadu_f64(ptr: *const f64) -> __m256d {
287 _mm256_loadu_pd(ptr)
288 }
289
290 #[inline(always)]
291 unsafe fn storeu_f64(ptr: *mut f64, a: __m256d) {
292 _mm256_storeu_pd(ptr, a)
293 }
294
295 #[inline(always)]
296 unsafe fn sqrt_f32(a: __m256) -> __m256 {
297 _mm256_sqrt_ps(a)
298 }
299
300 #[inline(always)]
301 unsafe fn sqrt_f64(a: __m256d) -> __m256d {
302 _mm256_sqrt_pd(a)
303 }
304
305 #[inline(always)]
306 unsafe fn loadu_f32(ptr: *const f32) -> __m256 {
307 _mm256_loadu_ps(ptr)
308 }
309
310 #[inline(always)]
311 unsafe fn and_f32(a: __m256, b: __m256) -> __m256 {
312 _mm256_and_ps(a, b)
313 }
314
315 #[inline(always)]
316 unsafe fn mul_f32(a: __m256, b: __m256) -> __m256 {
317 _mm256_mul_ps(a, b)
318 }
319
320 #[inline(always)]
321 unsafe fn add_i32(a: __m256i, b: __m256i) -> __m256i {
322 let a1 = _mm256_extractf128_si256(a, 1);
324 let b1 = _mm256_extractf128_si256(b, 1);
325 let a0 = _mm256_castsi256_si128(a);
326 let b0 = _mm256_castsi256_si128(b);
327 let c0 = _mm_add_epi32(a0, b0);
328 let c1 = _mm_add_epi32(a1, b1);
329 _mm256_insertf128_si256(_mm256_castsi128_si256(c0), c1, 1)
330 }
331
332 #[inline(always)]
333 unsafe fn and_i32(a: __m256i, b: __m256i) -> __m256i {
334 let a1 = _mm256_extractf128_si256(a, 1);
336 let b1 = _mm256_extractf128_si256(b, 1);
337 let a0 = _mm256_castsi256_si128(a);
338 let b0 = _mm256_castsi256_si128(b);
339 let c0 = _mm_and_si128(a0, b0);
340 let c1 = _mm_and_si128(a1, b1);
341 _mm256_insertf128_si256(_mm256_castsi128_si256(c0), c1, 1)
342 }
343 #[inline(always)]
344 unsafe fn cvt_i32_f32(a: __m256i) -> __m256 {
345 _mm256_cvtepi32_ps(a)
346 }
347 #[inline(always)]
348 unsafe fn cvt_f32_i32(a: __m256) -> __m256i {
349 _mm256_cvtps_epi32(a)
350 }
351
352 #[inline(always)]
353 unsafe fn cvtt_f32_i32(a: __m256) -> __m256i {
354 _mm256_cvttps_epi32(a)
355 }
356
357 #[inline(always)]
358 unsafe fn sub_i32(a: __m256i, b: __m256i) -> __m256i {
359 let a1 = _mm256_extractf128_si256(a, 1);
361 let b1 = _mm256_extractf128_si256(b, 1);
362 let a0 = _mm256_castsi256_si128(a);
363 let b0 = _mm256_castsi256_si128(b);
364 let c0 = _mm_sub_epi32(a0, b0);
365 let c1 = _mm_sub_epi32(a1, b1);
366 _mm256_insertf128_si256(_mm256_castsi128_si256(c0), c1, 1)
367 }
368
369 #[inline(always)]
370 unsafe fn andnot_i32(a: __m256i, b: __m256i) -> __m256i {
371 let a1 = _mm256_extractf128_si256(a, 1);
373 let b1 = _mm256_extractf128_si256(b, 1);
374 let a0 = _mm256_castsi256_si128(a);
375 let b0 = _mm256_castsi256_si128(b);
376 let c0 = _mm_andnot_si128(a0, b0);
377 let c1 = _mm_andnot_si128(a1, b1);
378 _mm256_insertf128_si256(_mm256_castsi128_si256(c0), c1, 1)
379 }
380
381 #[inline(always)]
382 unsafe fn slli_i32<const IMM8: i32>(a: __m256i) -> __m256i {
383 let a1 = _mm256_extractf128_si256(a, 1);
385 let a0 = _mm256_castsi256_si128(a);
386 let c0 = _mm_slli_epi32(a0, IMM8);
387 let c1 = _mm_slli_epi32(a1, IMM8);
388 _mm256_insertf128_si256(_mm256_castsi128_si256(c0), c1, 1)
389 }
390
391 #[inline(always)]
392 unsafe fn cmpeq_i32(a: __m256i, b: __m256i) -> __m256i {
393 let a1 = _mm256_extractf128_si256(a, 1);
395 let b1 = _mm256_extractf128_si256(b, 1);
396 let a0 = _mm256_castsi256_si128(a);
397 let b0 = _mm256_castsi256_si128(b);
398 let c0 = _mm_cmpeq_epi32(a0, b0);
399 let c1 = _mm_cmpeq_epi32(a1, b1);
400 _mm256_insertf128_si256(_mm256_castsi128_si256(c0), c1, 1)
401 }
402
403 #[inline(always)]
404 unsafe fn cast_i32_f32(a: __m256i) -> __m256 {
405 _mm256_castsi256_ps(a)
406 }
407
408 #[inline(always)]
409 unsafe fn fmadd_f32(a: __m256, b: __m256, c: __m256) -> __m256 {
410 let mul = _mm256_mul_ps(a, b);
411 _mm256_add_ps(mul, c)
412 }
413
414 #[inline(always)]
415 unsafe fn andnot_f32(a: __m256, b: __m256) -> __m256 {
416 _mm256_andnot_ps(a, b)
417 }
418
419 #[inline(always)]
420 unsafe fn add_f32(a: __m256, b: __m256) -> __m256 {
421 _mm256_add_ps(a, b)
422 }
423
424 #[inline(always)]
425 unsafe fn xor_f32(a: __m256, b: __m256) -> __m256 {
426 _mm256_xor_ps(a, b)
427 }
428
429 #[inline(always)]
430 unsafe fn storeu_f32(ptr: *mut f32, a: __m256) {
431 _mm256_storeu_ps(ptr, a)
432 }
433
434 #[inline(always)]
435 unsafe fn sub_f32(a: __m256, b: __m256) -> __m256 {
436 _mm256_sub_ps(a, b)
437 }
438
439 #[inline(always)]
440 unsafe fn cmp_eq_f32(a: __m256, b: __m256) -> __m256 {
441 _mm256_cmp_ps(a, b, _CMP_EQ_OS)
442 }
443
444 #[inline(always)]
445 unsafe fn cmp_lt_f32(a: __m256, b: __m256) -> __m256 {
446 _mm256_cmp_ps(a, b, _CMP_LT_OS)
447 }
448
449 #[inline(always)]
450 unsafe fn mask_mul_f32(mask: __m256, a: __m256, b: __m256) -> __m256 {
451 let one = _mm256_set1_ps(1.0);
452 let one = _mm256_andnot_ps(mask, one);
453 let masked_one = _mm256_and_ps(b, mask);
454 let masked_b = _mm256_or_ps(masked_one, one);
455 let c = _mm256_mul_ps(a, masked_b);
456 c
457 }
458
459 #[inline(always)]
460 unsafe fn mask_sub_f32(mask: __m256, a: __m256, b: __m256) -> __m256 {
461 let masked_b = _mm256_and_ps(b, mask);
462 let c = _mm256_sub_ps(a, masked_b);
463 c
464 }
465
466 #[inline(always)]
467 unsafe fn or_f32(a: __m256, b: __m256) -> __m256 {
468 _mm256_or_ps(a, b)
469 }
470
471 #[inline(always)]
472 unsafe fn mask_add_f32(mask: __m256, a: __m256, b: __m256) -> __m256 {
473 let masked_b = _mm256_and_ps(b, mask);
474 let c = _mm256_add_ps(a, masked_b);
475 c
476 }
477
478 #[inline(always)]
479 unsafe fn cast_f32_i32(a: __m256) -> __m256i {
480 _mm256_castps_si256(a)
481 }
482
483 #[inline(always)]
484 unsafe fn srli_i32<const IMM8: i32>(a: __m256i) -> __m256i {
485 let a1 = _mm256_extractf128_si256(a, 1);
487 let a0 = _mm256_castsi256_si128(a);
488 let c0 = _mm_srli_epi32(a0, IMM8);
489 let c1 = _mm_srli_epi32(a1, IMM8);
490 _mm256_insertf128_si256(_mm256_castsi128_si256(c0), c1, 1)
491 }
492
493 #[inline(always)]
494 unsafe fn min_f32(a: __m256, b: __m256) -> __m256 {
495 _mm256_min_ps(a, b)
496 }
497
498 #[inline(always)]
499 unsafe fn floor_f32(a: __m256) -> __m256 {
500 _mm256_round_ps(a, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)
501 }
502
503 #[inline(always)]
504 unsafe fn max_f32(a: __m256, b: __m256) -> __m256 {
505 _mm256_max_ps(a, b)
506 }
507
508 #[inline(always)]
509 unsafe fn div_f32(a: __m256, b: __m256) -> __m256 {
510 _mm256_div_ps(a, b)
511 }
512
513 #[inline(always)]
514 unsafe fn get_exp_mant_f32(a: __m256) -> (__m256, __m256) {
515 let a_0 = a;
516 let zero_mask = Self::cmp_eq_f32(a, Self::set1_f32(0.0));
517 let nan_mask = Self::cmp_lt_f32(a, Self::set1_f32(0.0));
518 let inv_mant_mask = Self::cast_i32_f32(Self::set1_i32(!0x7f800000));
519 let inf_mask = Self::cmp_eq_f32(a, Self::set1_f32(f32::INFINITY));
520 let denorm_mul = Self::set1_f32(134217730.);
521 let denorm_th = Self::set1_f32(1.1754945e-38);
522 let denorm_mask = Self::cmp_lt_f32(a, denorm_th);
523 let mut a = Self::mask_mul_f32(denorm_mask, a, denorm_mul);
524
525 let mut imm0 = Self::srli_i32::<23>(Self::cast_f32_i32(a));
526
527 a = Self::and_f32(a, inv_mant_mask);
529 a = Self::or_f32(a, Self::set1_f32(0.5));
530
531 imm0 = Self::sub_i32(imm0, Self::set1_i32(0x7f));
533
534 let e = Self::cvt_i32_f32(imm0);
535
536 let e = Self::mask_sub_f32(denorm_mask, e, Self::set1_f32(27.0));
537 let e = Self::mask_sub_f32(zero_mask, e, Self::set1_f32(f32::INFINITY));
538 let e = Self::mask_add_f32(inf_mask, e, Self::set1_f32(f32::INFINITY));
539 let e = Self::min_f32(e, a_0);
540 let e = Self::mask_add_f32(nan_mask, e, Self::set1_f32(f32::NAN));
541
542 (e, a)
543 }
544}
545
546impl UnaryFn1 for AvxSse2 {
547 unsafe fn vs_exp_0(n: usize, a: *const f32, b: *mut f32) {
548 const EXP_HI: f32 = 88.7228391117 * 1.0;
550 const EXP_LO: f32 = -88.7228391117 * 1.0;
551 const LOG2EF: f32 = 1.44269504088896341;
552 const EXP_P0: f32 = 0.00032712723;
553 const EXP_P1: f32 = 0.00228989065 + 1e-6;
554 const EXP_P2: f32 = 0.01373934392;
556 const EXP_P3: f32 = 0.06869671961;
557 const EXP_P4: f32 = 0.27478687845;
558 const EXP_P5: f32 = 0.82436063535;
559 const L2_U: f32 = -0.693_145_751_953_125;
560 const L2_L: f32 = -1.428_606_765_330_187_045_e-6;
561
562 let exp_hi = Self::set1_f32(EXP_HI);
563 let exp_lo = Self::set1_f32(EXP_LO);
564 let log2ef = Self::set1_f32(LOG2EF);
565
566 let half = Self::set1_f32(0.5);
567
568 let exp_p0 = Self::set1_f32(EXP_P0);
569 let exp_p1 = Self::set1_f32(EXP_P1);
570 let exp_p2 = Self::set1_f32(EXP_P2);
571 let exp_p3 = Self::set1_f32(EXP_P3);
572 let exp_p4 = Self::set1_f32(EXP_P4);
573 let exp_p5 = Self::set1_f32(EXP_P5);
574 let min_exponent = Self::set1_f32(-127.0);
575 let e_sqrt = Self::set1_f32(1.6487212707);
576
577 let l2_u = Self::set1_f32(L2_U);
578 let l2_l = Self::set1_f32(L2_L);
579
580 let mut i = 0;
581 while (i + Self::F32_WIDTH) <= n {
582 let x = Self::loadu_f32(a.offset(i as isize));
583
584 let x = Self::min_f32(exp_hi, x);
586 let x = Self::max_f32(exp_lo, x);
587
588 let mut fx = Self::mul_f32(x, log2ef);
590 fx = Self::floor_f32(fx);
592 fx = Self::max_f32(fx, min_exponent);
594
595 let x = Self::fmadd_f32(fx, l2_u, x);
597 let x = Self::fmadd_f32(fx, l2_l, x);
598
599 let x = Self::sub_f32(x, half);
600
601 let mut y = exp_p0;
602 y = Self::fmadd_f32(y, x, exp_p1);
603 y = Self::fmadd_f32(y, x, exp_p2);
604 y = Self::fmadd_f32(y, x, exp_p3);
605 y = Self::fmadd_f32(y, x, exp_p4);
606 y = Self::fmadd_f32(y, x, exp_p5);
607 y = Self::fmadd_f32(y, x, e_sqrt);
608 y = Self::fmadd_f32(y, x, e_sqrt);
609
610 let mut imm0 = Self::cvt_f32_i32(fx);
612 imm0 = Self::add_i32(imm0, Self::set1_i32(0x7f));
613 imm0 = Self::slli_i32::<23>(imm0);
614 let pow2n = Self::cast_i32_f32(imm0);
615
616 let r = Self::mul_f32(y, pow2n);
618 Self::storeu_f32(b.offset(i as isize), r);
619 i += Self::F32_WIDTH;
620 }
621 while i < n {
622 *b.offset(i as isize) = (*a.offset(i as isize)).exp();
623 i += 1;
624 }
625 }
626}
627
628impl UnaryFn2 for AvxSse2 {}
629
630impl AvxSse2 {
631 #[target_feature(enable = "avx,sse2,sse")]
632 pub(crate) unsafe fn vs_exp(n: usize, a: *const f32, b: *mut f32) {
633 Self::vs_exp_0(n, a, b)
634 }
635
636 #[target_feature(enable = "avx,sse2,sse")]
637 pub(crate) unsafe fn vs_ln(n: usize, a: *const f32, b: *mut f32) {
638 Self::vs_ln_0(n, a, b)
639 }
640
641 #[target_feature(enable = "avx,sse2,sse")]
642 pub(crate) unsafe fn vs_tanh(n: usize, a: *const f32, b: *mut f32) {
643 Self::vs_tanh_0(n, a, b)
644 }
645
646 #[target_feature(enable = "avx,sse2,sse")]
647 pub(crate) unsafe fn vs_sin(n: usize, a: *const f32, b: *mut f32) {
648 Self::vs_sin_0(n, a, b)
649 }
650
651 #[target_feature(enable = "avx,sse2,sse")]
652 pub(crate) unsafe fn vs_cos(n: usize, a: *const f32, b: *mut f32) {
653 Self::vs_cos_0(n, a, b)
654 }
655
656 #[target_feature(enable = "avx")]
657 pub(crate) unsafe fn vs_sqrt(n: usize, a: *const f32, b: *mut f32) {
658 Self::vs_sqrt_0(n, a, b)
659 }
660
661 #[target_feature(enable = "avx")]
662 pub(crate) unsafe fn vd_sqrt(n: usize, a: *const f64, b: *mut f64) {
663 Self::vd_sqrt_0(n, a, b)
664 }
665}
666
667pub(crate) struct Avx512f {}
668pub(crate) unsafe fn vs_ln_avx512f_asm(n: usize, a: *const f32, b: *mut f32) {
669 const NR: usize = 16;
670 const LN2F_HI: f32 = 0.693359375;
672 const LN2F_LO: f32 = -2.12194440E-4;
673 const P0LOGF: f32 = -0.5;
674 const P1LOGF: f32 = 3.3333331174E-1;
675 const P2LOGF: f32 = -2.4999993993E-1;
676 const P3LOGF: f32 = 2.0000714765E-1;
677 const P4LOGF: f32 = -1.6666657665E-1;
678 const P5LOGF: f32 = 1.4249322787E-1;
679 const P6LOGF: f32 = -1.250000140846E-1;
680 const P7LOGF: f32 = 1.1676998740E-1;
681 let constant_arr = [
685 0.73337,
686 1.5,
687 1.0,
688 -0.58496250072,
689 -1.0,
690 P7LOGF,
691 P6LOGF,
692 P5LOGF,
693 P4LOGF,
694 P3LOGF,
695 P2LOGF,
696 P1LOGF,
697 P0LOGF,
698 LN2F_HI + LN2F_LO,
699 f32::NAN,
700 ];
701
702 let mut i = 0;
703 asm!(
704 "vxorps %zmm0, %zmm0, %zmm0",
705 "vbroadcastss 0({constant_arrx}), %zmm1",
706 "vbroadcastss 4({constant_arrx}), %zmm2",
707 "vbroadcastss 8({constant_arrx}), %zmm3",
708 "vbroadcastss 12({constant_arrx}), %zmm4",
709 "vbroadcastss 16({constant_arrx}), %zmm5",
710
711 "vbroadcastss 20({constant_arrx}), %zmm6",
712 "vbroadcastss 24({constant_arrx}), %zmm7",
713 "vbroadcastss 28({constant_arrx}), %zmm8",
714 "vbroadcastss 32({constant_arrx}), %zmm9",
715 "vbroadcastss 36({constant_arrx}), %zmm10",
716 "vbroadcastss 40({constant_arrx}), %zmm11",
717 "vbroadcastss 44({constant_arrx}), %zmm12",
718 "vbroadcastss 48({constant_arrx}), %zmm13",
719
720 "vbroadcastss 52({constant_arrx}), %zmm14",
721
722 "test {nx:e}, {nx:e}",
723 "je 3f",
724
725 "2:",
726 "vmovups ({ax}), %zmm15",
727 "vcmpltps %zmm0, %zmm15, %k1",
728 "vgetexpps %zmm15, %zmm16",
729 "vgetmantps $2, %zmm15, %zmm15",
730 "vcmplt_oqps %zmm1, %zmm15, %k2",
731 "vmulps %zmm2, %zmm15, %zmm15 {{%k2}}",
732 "vaddps %zmm3, %zmm16, %zmm16",
733 "vaddps %zmm4, %zmm16, %zmm16 {{%k2}}",
734 "vaddps %zmm5, %zmm15, %zmm15",
735 "vmovaps %zmm6, %zmm17",
736 "vfmadd213ps %zmm7, %zmm15, %zmm17",
737 "vfmadd213ps %zmm8, %zmm15, %zmm17",
738 "vfmadd213ps %zmm9, %zmm15, %zmm17",
739 "vfmadd213ps %zmm10, %zmm15, %zmm17",
740 "vfmadd213ps %zmm11, %zmm15, %zmm17",
741 "vfmadd213ps %zmm12, %zmm15, %zmm17",
742 "vfmadd213ps %zmm13, %zmm15, %zmm17",
743 "vfmadd213ps %zmm3, %zmm15, %zmm17",
744 "vmulps %zmm17, %zmm15, %zmm15",
745 "vfmadd231ps %zmm14, %zmm16, %zmm15",
746 "vbroadcastss 56({constant_arrx}), %zmm15 {{%k1}}",
747 "vmovups %zmm15, ({bx})",
748 "add $64, {ax}",
749 "add $64, {bx}",
750 "add $16, {ix:e}",
751 "cmp {nx:e}, {ix:e}",
752 "jl 2b",
753
754 "3:",
755 "vzeroupper",
756
757 constant_arrx = in(reg) &constant_arr,
758 ax = inout(reg) a => _,
759 bx = inout(reg) b => _,
760 ix = inout(reg) i => i,
761 nx = inout(reg) n / NR * NR => _,
762 out("zmm0") _, out("zmm1") _, out("zmm2") _, out("zmm3") _, out("zmm4") _, out("zmm5") _, out("zmm6") _, out("zmm7") _, out("zmm8") _, out("zmm9") _,
763 out("zmm10") _, out("zmm11") _, out("zmm12") _, out("zmm13") _, out("zmm14") _, out("zmm15") _, out("zmm16") _, out("zmm17") _, out("zmm18") _, out("zmm19") _,
764 out("zmm20") _, out("zmm21") _, out("zmm22") _, out("zmm23") _, out("zmm24") _, out("zmm25") _, out("zmm26") _, out("zmm27") _, out("zmm28") _, out("zmm29") _,
765 out("zmm30") _, out("zmm31") _, out("k1") _, out("k2") _,
766 options(att_syntax)
767 );
768 while i < n {
769 *b.offset(i as isize) = (*a.offset(i as isize)).ln();
770 i += 1;
771 }
772}
773
774pub(crate) unsafe fn vs_exp_avx512f_asm(n: usize, a: *const f32, b: *mut f32) {
775 const NR: usize = 16;
776 const EXP_HI: f32 = 88.3762626647949 * 2.0;
779 const EXP_LO: f32 = -88.3762626647949 * 2.0;
780 const LOG2EF: f32 = 1.44269504088896341;
781 const INV_LOG2EF: f32 = 0.693359375;
782 const CEHPES_EXP_C2: f32 = -2.12194440e-4;
783 const EXP_P0: f32 = 1.9875691500E-4;
784 const EXP_P1: f32 = 1.3981999507E-3;
785 const EXP_P2: f32 = 8.3334519073E-3;
786 const EXP_P3: f32 = 4.1665795894E-2;
787 const EXP_P4: f32 = 1.6666665459E-1;
788 const EXP_P5: f32 = 5.0000001201E-1;
789
790 let constant_arr =
791 [LOG2EF, -CEHPES_EXP_C2 - INV_LOG2EF, EXP_P0, EXP_P1, EXP_P2, EXP_P3, EXP_P4, EXP_P5, 1.0, EXP_HI, EXP_LO];
792 let mut i = 0;
793 asm!(
794 "vbroadcastss ({constant_arrx}), %zmm0",
795 "vbroadcastss 4({constant_arrx}), %zmm1",
796 "vbroadcastss 8({constant_arrx}), %zmm2",
797 "vbroadcastss 12({constant_arrx}), %zmm3",
798 "vbroadcastss 16({constant_arrx}), %zmm4",
799 "vbroadcastss 20({constant_arrx}), %zmm5",
800 "vbroadcastss 24({constant_arrx}), %zmm6",
801 "vbroadcastss 28({constant_arrx}), %zmm7",
802 "vbroadcastss 32({constant_arrx}), %zmm8",
803
804 "vbroadcastss 36({constant_arrx}), %zmm13",
805 "vbroadcastss 40({constant_arrx}), %zmm14",
806
807 "test {nx:e}, {nx:e}",
808 "je 3f",
809
810 "2:",
811 "vmovups ({ax}), %zmm9",
812 "vminps %zmm9, %zmm13, %zmm9",
815 "vmaxps %zmm9, %zmm14, %zmm9",
816 "vmulps %zmm0, %zmm9, %zmm10",
817 "vrndscaleps $8, %zmm10, %zmm10",
818 "vfmadd231ps %zmm1, %zmm10, %zmm9",
819 "vmovaps %zmm2, %zmm11",
820 "vfmadd213ps %zmm3, %zmm9, %zmm11",
821 "vfmadd213ps %zmm4, %zmm9, %zmm11",
822 "vfmadd213ps %zmm5, %zmm9, %zmm11",
823 "vfmadd213ps %zmm6, %zmm9, %zmm11",
824 "vmulps %zmm9, %zmm9, %zmm12",
825 "vfmadd213ps %zmm7, %zmm9, %zmm11",
826 "vfmadd213ps %zmm9, %zmm12, %zmm11",
827 "vaddps %zmm8, %zmm11, %zmm9",
828 "vscalefps %zmm10, %zmm9, %zmm9",
829 "vmovups %zmm9, ({bx})",
830 "add $64, {ax}",
831 "add $64, {bx}",
832 "add $16, {ix:e}",
833 "cmp {nx:e}, {ix:e}",
834 "jl 2b",
835
836 "3:",
837 "vzeroupper",
838
839 constant_arrx = in(reg) &constant_arr,
840 ax = inout(reg) a => _,
841 bx = inout(reg) b => _,
842 ix = inout(reg) i => i,
843 nx = inout(reg) n / NR * NR => _,
844 out("zmm0") _, out("zmm1") _, out("zmm2") _, out("zmm3") _, out("zmm4") _, out("zmm5") _, out("zmm6") _, out("zmm7") _, out("zmm8") _, out("zmm9") _,
845 out("zmm10") _, out("zmm11") _, out("zmm12") _, out("zmm13") _, out("zmm14") _, out("zmm15") _, out("zmm16") _, out("zmm17") _, out("zmm18") _, out("zmm19") _,
846 out("zmm20") _, out("zmm21") _, out("zmm22") _, out("zmm23") _, out("zmm24") _, out("zmm25") _, out("zmm26") _, out("zmm27") _, out("zmm28") _, out("zmm29") _,
847 out("zmm30") _, out("zmm31") _,
848 options(att_syntax)
849 );
850 while i < n {
851 *b.offset(i as isize) = (*a.offset(i as isize)).exp();
852 i += 1;
853 }
854}
855
856pub(crate) unsafe fn vs_tanh_avx512f_asm(n: usize, a: *const f32, b: *mut f32) {
857 const NR: usize = 16;
858 const EXP_HI: f32 = 88.3762626647949 * 0.5;
861 const EXP_LO: f32 = -88.3762626647949 * 0.5;
862 const LOG2EF: f32 = 1.44269504088896341;
863 const INV_LOG2EF: f32 = 0.693359375;
864 const CEHPES_EXP_C2: f32 = -2.12194440e-4;
865 const EXP_P0: f32 = 1.9875691500E-4;
866 const EXP_P1: f32 = 1.3981999507E-3;
867 const EXP_P2: f32 = 8.3334519073E-3;
868 const EXP_P3: f32 = 4.1665795894E-2;
869 const EXP_P4: f32 = 1.6666665459E-1;
870 const EXP_P5: f32 = 5.0000001201E-1;
871
872 let constant_arr = [
873 LOG2EF,
874 -CEHPES_EXP_C2 - INV_LOG2EF,
875 EXP_P0,
876 EXP_P1,
877 EXP_P2,
878 EXP_P3,
879 EXP_P4,
880 EXP_P5,
881 1.0,
882 EXP_HI,
883 EXP_LO,
884 -1.0,
885 ];
886 let mut i = 0;
887 asm!(
888 "vbroadcastss ({constant_arrx}), %zmm0",
889 "vbroadcastss 4({constant_arrx}), %zmm1",
890 "vbroadcastss 8({constant_arrx}), %zmm2",
891 "vbroadcastss 12({constant_arrx}), %zmm3",
892 "vbroadcastss 16({constant_arrx}), %zmm4",
893 "vbroadcastss 20({constant_arrx}), %zmm5",
894 "vbroadcastss 24({constant_arrx}), %zmm6",
895 "vbroadcastss 28({constant_arrx}), %zmm7",
896 "vbroadcastss 32({constant_arrx}), %zmm8",
897
898 "vbroadcastss 36({constant_arrx}), %zmm13",
899 "vbroadcastss 40({constant_arrx}), %zmm14",
900 "vbroadcastss 44({constant_arrx}), %zmm15",
901
902 "test {nx:e}, {nx:e}",
903 "je 3f",
904
905 "2:",
906 "vmovups ({ax}), %zmm31",
907 "vminps %zmm31, %zmm13, %zmm9",
910 "vmaxps %zmm9, %zmm14, %zmm9",
911 "vmulps %zmm0, %zmm9, %zmm10",
912 "vrndscaleps $8, %zmm10, %zmm10",
913 "vfmadd231ps %zmm1, %zmm10, %zmm9",
914 "vmovaps %zmm2, %zmm11",
915 "vfmadd213ps %zmm3, %zmm9, %zmm11",
916 "vfmadd213ps %zmm4, %zmm9, %zmm11",
917 "vfmadd213ps %zmm5, %zmm9, %zmm11",
918 "vfmadd213ps %zmm6, %zmm9, %zmm11",
919 "vmulps %zmm9, %zmm9, %zmm12",
920 "vfmadd213ps %zmm7, %zmm9, %zmm11",
921 "vfmadd213ps %zmm9, %zmm12, %zmm11",
922 "vaddps %zmm8, %zmm11, %zmm9",
923 "vscalefps %zmm10, %zmm9, %zmm9",
924 "vmulps %zmm9, %zmm9, %zmm9",
925
926 "vaddps %zmm9, %zmm15, %zmm16",
927 "vaddps %zmm9, %zmm8, %zmm9",
928 "vdivps %zmm9, %zmm16, %zmm9",
929 "vminps %zmm9, %zmm8, %zmm9",
930 "vmaxps %zmm9, %zmm15, %zmm9",
931
932 "vmovups %zmm9, ({bx})",
933 "add $64, {ax}",
934 "add $64, {bx}",
935 "add $16, {ix:e}",
936 "cmp {nx:e}, {ix:e}",
937 "jl 2b",
938
939 "3:",
940 "vzeroupper",
941
942 constant_arrx = in(reg) &constant_arr,
943 ax = inout(reg) a => _,
944 bx = inout(reg) b => _,
945 ix = inout(reg) i => i,
946 nx = inout(reg) n / NR * NR => _,
947 out("zmm0") _, out("zmm1") _, out("zmm2") _, out("zmm3") _, out("zmm4") _, out("zmm5") _, out("zmm6") _, out("zmm7") _, out("zmm8") _, out("zmm9") _,
948 out("zmm10") _, out("zmm11") _, out("zmm12") _, out("zmm13") _, out("zmm14") _, out("zmm15") _, out("zmm16") _, out("zmm17") _, out("zmm18") _, out("zmm19") _,
949 out("zmm20") _, out("zmm21") _, out("zmm22") _, out("zmm23") _, out("zmm24") _, out("zmm25") _, out("zmm26") _, out("zmm27") _, out("zmm28") _, out("zmm29") _,
950 out("zmm30") _, out("zmm31") _,
951 options(att_syntax)
952 );
953 while i < n {
954 *b.offset(i as isize) = (*a.offset(i as isize)).tanh();
955 i += 1;
956 }
957}
958
959pub(crate) unsafe fn vs_sqrt_avx512f_asm(n: usize, a: *const f32, b: *mut f32) {
960 const NR: usize = 16;
961 let mut i = 0;
964 asm!(
965 "test {nx:e}, {nx:e}",
966 "je 3f",
967
968 "2:",
969 "vmovups ({ax}), %zmm9",
970 "vsqrtps %zmm9, %zmm9",
971 "vmovups %zmm9, ({bx})",
972 "add $64, {ax}",
973 "add $64, {bx}",
974 "add $16, {ix:e}",
975 "cmp {nx:e}, {ix:e}",
976 "jl 2b",
977
978 "3:",
979 "vzeroupper",
980
981 ax = inout(reg) a => _,
982 bx = inout(reg) b => _,
983 ix = inout(reg) i => i,
984 nx = inout(reg) n / NR * NR => _,
985 out("zmm9") _,
986 options(att_syntax)
987 );
988 while i < n {
989 *b.offset(i as isize) = (*a.offset(i as isize)).sqrt();
990 i += 1;
991 }
992}
993
994pub(crate) unsafe fn vd_sqrt_avx512f_asm(n: usize, a: *const f64, b: *mut f64) {
995 const NR: usize = 8;
996 let mut i = 0;
999 asm!(
1000 "test {nx:e}, {nx:e}",
1001 "je 3f",
1002
1003 "2:",
1004 "vmovupd ({ax}), %zmm9",
1005 "vsqrtpd %zmm9, %zmm9",
1006 "vmovupd %zmm9, ({bx})",
1007 "add $64, {ax}",
1008 "add $64, {bx}",
1009 "add $16, {ix:e}",
1010 "cmp {nx:e}, {ix:e}",
1011 "jl 2b",
1012
1013 "3:",
1014 "vzeroupper",
1015
1016 ax = inout(reg) a => _,
1017 bx = inout(reg) b => _,
1018 ix = inout(reg) i => i,
1019 nx = inout(reg) n / NR * NR => _,
1020 out("zmm9") _,
1021 options(att_syntax)
1022 );
1023 while i < n {
1024 *b.offset(i as isize) = (*a.offset(i as isize)).sqrt();
1025 i += 1;
1026 }
1027}
1028
1029impl Avx512f {
1030 pub(crate) unsafe fn vs_exp(n: usize, a: *const f32, b: *mut f32) {
1031 vs_exp_avx512f_asm(n, a, b);
1032 }
1033
1034 pub(crate) unsafe fn vs_ln(n: usize, a: *const f32, b: *mut f32) {
1035 vs_ln_avx512f_asm(n, a, b);
1036 }
1037
1038 pub(crate) unsafe fn vs_tanh(n: usize, a: *const f32, b: *mut f32) {
1039 vs_tanh_avx512f_asm(n, a, b);
1040 }
1041
1042 pub(crate) unsafe fn vs_sqrt(n: usize, a: *const f32, b: *mut f32) {
1043 vs_sqrt_avx512f_asm(n, a, b);
1044 }
1045
1046 pub(crate) unsafe fn vd_sqrt(n: usize, a: *const f64, b: *mut f64) {
1047 vd_sqrt_avx512f_asm(n, a, b);
1048 }
1049}