mathfun/simd/
x86_64.rs

1use super::{Simd, UnaryFn1, UnaryFn2};
2use core::arch::asm;
3use core::arch::x86_64::*;
4
5pub(crate) struct Avx2Fma {}
6
7impl Simd for Avx2Fma {
8    const F32_WIDTH: usize = 8;
9    const F64_WIDTH: usize = 4;
10
11    type Vf32 = __m256;
12    type Vf64 = __m256d;
13    type Vi32 = __m256i;
14    #[inline(always)]
15    unsafe fn set1_f32(x: f32) -> __m256 {
16        _mm256_set1_ps(x)
17    }
18    #[inline(always)]
19    unsafe fn set1_i32(x: i32) -> __m256i {
20        _mm256_set1_epi32(x)
21    }
22
23    #[inline(always)]
24    unsafe fn loadu_f32(ptr: *const f32) -> __m256 {
25        _mm256_loadu_ps(ptr)
26    }
27
28    #[inline(always)]
29    unsafe fn loadu_f64(ptr: *const f64) -> __m256d {
30        _mm256_loadu_pd(ptr)
31    }
32
33    #[inline(always)]
34    unsafe fn sqrt_f32(a: __m256) -> __m256 {
35        _mm256_sqrt_ps(a)
36    }
37
38    #[inline(always)]
39    unsafe fn sqrt_f64(a: __m256d) -> __m256d {
40        _mm256_sqrt_pd(a)
41    }
42
43    #[inline(always)]
44    unsafe fn and_f32(a: __m256, b: __m256) -> __m256 {
45        _mm256_and_ps(a, b)
46    }
47
48    #[inline(always)]
49    unsafe fn mul_f32(a: __m256, b: __m256) -> __m256 {
50        _mm256_mul_ps(a, b)
51    }
52
53    #[inline(always)]
54    unsafe fn add_i32(a: __m256i, b: __m256i) -> __m256i {
55        _mm256_add_epi32(a, b)
56    }
57
58    #[inline(always)]
59    unsafe fn and_i32(a: __m256i, b: __m256i) -> __m256i {
60        _mm256_and_si256(a, b)
61    }
62    #[inline(always)]
63    unsafe fn cvt_i32_f32(a: __m256i) -> __m256 {
64        _mm256_cvtepi32_ps(a)
65    }
66    #[inline(always)]
67    unsafe fn cvt_f32_i32(a: __m256) -> __m256i {
68        _mm256_cvtps_epi32(a)
69    }
70
71    #[inline(always)]
72    unsafe fn cvtt_f32_i32(a: __m256) -> __m256i {
73        _mm256_cvttps_epi32(a)
74    }
75
76    #[inline(always)]
77    unsafe fn sub_i32(a: __m256i, b: __m256i) -> __m256i {
78        _mm256_sub_epi32(a, b)
79    }
80
81    #[inline(always)]
82    unsafe fn andnot_i32(a: __m256i, b: __m256i) -> __m256i {
83        _mm256_andnot_si256(a, b)
84    }
85
86    #[inline(always)]
87    unsafe fn slli_i32<const IMM8: i32>(a: __m256i) -> __m256i {
88        _mm256_slli_epi32(a, IMM8)
89    }
90
91    #[inline(always)]
92    unsafe fn cmpeq_i32(a: __m256i, b: __m256i) -> __m256i {
93        _mm256_cmpeq_epi32(a, b)
94    }
95
96    #[inline(always)]
97    unsafe fn cast_i32_f32(a: __m256i) -> __m256 {
98        _mm256_castsi256_ps(a)
99    }
100
101    #[inline(always)]
102    unsafe fn fmadd_f32(a: __m256, b: __m256, c: __m256) -> __m256 {
103        _mm256_fmadd_ps(a, b, c)
104    }
105
106    #[inline(always)]
107    unsafe fn andnot_f32(a: __m256, b: __m256) -> __m256 {
108        _mm256_andnot_ps(a, b)
109    }
110
111    #[inline(always)]
112    unsafe fn add_f32(a: __m256, b: __m256) -> __m256 {
113        _mm256_add_ps(a, b)
114    }
115
116    #[inline(always)]
117    unsafe fn xor_f32(a: __m256, b: __m256) -> __m256 {
118        _mm256_xor_ps(a, b)
119    }
120
121    #[inline(always)]
122    unsafe fn storeu_f32(ptr: *mut f32, a: __m256) {
123        _mm256_storeu_ps(ptr, a)
124    }
125
126    #[inline(always)]
127    unsafe fn storeu_f64(ptr: *mut f64, a: __m256d) {
128        _mm256_storeu_pd(ptr, a)
129    }
130
131    #[inline(always)]
132    unsafe fn sub_f32(a: __m256, b: __m256) -> __m256 {
133        _mm256_sub_ps(a, b)
134    }
135
136    #[inline(always)]
137    unsafe fn cmp_eq_f32(a: __m256, b: __m256) -> __m256 {
138        _mm256_cmp_ps(a, b, _CMP_EQ_OS)
139    }
140
141    #[inline(always)]
142    unsafe fn cmp_lt_f32(a: __m256, b: __m256) -> __m256 {
143        _mm256_cmp_ps(a, b, _CMP_LT_OS)
144    }
145
146    #[inline(always)]
147    unsafe fn mask_mul_f32(mask: __m256, a: __m256, b: __m256) -> __m256 {
148        let one = _mm256_set1_ps(1.0);
149        let one = _mm256_andnot_ps(mask, one);
150        let masked_one = _mm256_and_ps(b, mask);
151        let masked_b = _mm256_or_ps(masked_one, one);
152        let c = _mm256_mul_ps(a, masked_b);
153        c
154    }
155
156    #[inline(always)]
157    unsafe fn mask_sub_f32(mask: __m256, a: __m256, b: __m256) -> __m256 {
158        let masked_b = _mm256_and_ps(b, mask);
159        let c = _mm256_sub_ps(a, masked_b);
160        c
161    }
162
163    #[inline(always)]
164    unsafe fn or_f32(a: __m256, b: __m256) -> __m256 {
165        _mm256_or_ps(a, b)
166    }
167
168    #[inline(always)]
169    unsafe fn mask_add_f32(mask: __m256, a: __m256, b: __m256) -> __m256 {
170        let masked_b = _mm256_and_ps(b, mask);
171        let c = _mm256_add_ps(a, masked_b);
172        c
173    }
174
175    #[inline(always)]
176    unsafe fn cast_f32_i32(a: __m256) -> __m256i {
177        _mm256_castps_si256(a)
178    }
179
180    #[inline(always)]
181    unsafe fn srli_i32<const IMM8: i32>(a: __m256i) -> __m256i {
182        _mm256_srli_epi32(a, IMM8)
183    }
184
185    #[inline(always)]
186    unsafe fn min_f32(a: __m256, b: __m256) -> __m256 {
187        _mm256_min_ps(a, b)
188    }
189
190    #[inline(always)]
191    unsafe fn floor_f32(a: __m256) -> __m256 {
192        _mm256_round_ps(a, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)
193    }
194
195    #[inline(always)]
196    unsafe fn max_f32(a: __m256, b: __m256) -> __m256 {
197        _mm256_max_ps(a, b)
198    }
199
200    #[inline(always)]
201    unsafe fn div_f32(a: __m256, b: __m256) -> __m256 {
202        _mm256_div_ps(a, b)
203    }
204
205    #[inline(always)]
206    unsafe fn get_exp_mant_f32(a: __m256) -> (__m256, __m256) {
207        let a_0 = a;
208        let zero_mask = Self::cmp_eq_f32(a, Self::set1_f32(0.0));
209        let nan_mask = Self::cmp_lt_f32(a, Self::set1_f32(0.0));
210        let inv_mant_mask = Self::cast_i32_f32(Self::set1_i32(!0x7f800000));
211        let inf_mask = Self::cmp_eq_f32(a, Self::set1_f32(f32::INFINITY));
212        let denorm_mul = Self::set1_f32(134217730.);
213        let denorm_th = Self::set1_f32(1.1754945e-38);
214        let denorm_mask = Self::cmp_lt_f32(a, denorm_th);
215        let mut a = Self::mask_mul_f32(denorm_mask, a, denorm_mul);
216
217        let mut imm0 = Self::srli_i32::<23>(Self::cast_f32_i32(a));
218
219        /* keep only the fractional part */
220        a = Self::and_f32(a, inv_mant_mask);
221        a = Self::or_f32(a, Self::set1_f32(0.5));
222
223        // this is again another AVX2 instruction
224        imm0 = Self::sub_i32(imm0, Self::set1_i32(0x7f));
225
226        let e = Self::cvt_i32_f32(imm0);
227
228        let e = Self::mask_sub_f32(denorm_mask, e, Self::set1_f32(27.0));
229        let e = Self::mask_sub_f32(zero_mask, e, Self::set1_f32(f32::INFINITY));
230        let e = Self::mask_add_f32(inf_mask, e, Self::set1_f32(f32::INFINITY));
231        let e = Self::min_f32(e, a_0);
232        let e = Self::mask_add_f32(nan_mask, e, Self::set1_f32(f32::NAN));
233
234        (e, a)
235    }
236}
237
238impl UnaryFn1 for Avx2Fma {}
239impl UnaryFn2 for Avx2Fma {}
240
241impl Avx2Fma {
242    #[target_feature(enable = "avx2,avx,fma")]
243    pub(crate) unsafe fn vs_exp(n: usize, a: *const f32, b: *mut f32) {
244        Self::vs_exp_0(n, a, b)
245    }
246
247    #[target_feature(enable = "avx2,avx,fma")]
248    pub(crate) unsafe fn vs_ln(n: usize, a: *const f32, b: *mut f32) {
249        Self::vs_ln_0(n, a, b)
250    }
251
252    #[target_feature(enable = "avx2,avx,fma")]
253    pub(crate) unsafe fn vs_tanh(n: usize, a: *const f32, b: *mut f32) {
254        Self::vs_tanh_0(n, a, b)
255    }
256
257    #[target_feature(enable = "avx2,avx,fma")]
258    pub(crate) unsafe fn vs_sin(n: usize, a: *const f32, b: *mut f32) {
259        Self::vs_sin_0(n, a, b)
260    }
261
262    #[target_feature(enable = "avx2,avx,fma")]
263    pub(crate) unsafe fn vs_cos(n: usize, a: *const f32, b: *mut f32) {
264        Self::vs_cos_0(n, a, b)
265    }
266}
267
268pub(crate) struct AvxSse2 {}
269
270impl Simd for AvxSse2 {
271    const F32_WIDTH: usize = 8;
272    const F64_WIDTH: usize = 4;
273
274    type Vf32 = __m256;
275    type Vf64 = __m256d;
276    type Vi32 = __m256i;
277    #[inline(always)]
278    unsafe fn set1_f32(x: f32) -> __m256 {
279        _mm256_set1_ps(x)
280    }
281    #[inline(always)]
282    unsafe fn set1_i32(x: i32) -> __m256i {
283        _mm256_set1_epi32(x)
284    }
285    #[inline(always)]
286    unsafe fn loadu_f64(ptr: *const f64) -> __m256d {
287        _mm256_loadu_pd(ptr)
288    }
289
290    #[inline(always)]
291    unsafe fn storeu_f64(ptr: *mut f64, a: __m256d) {
292        _mm256_storeu_pd(ptr, a)
293    }
294
295    #[inline(always)]
296    unsafe fn sqrt_f32(a: __m256) -> __m256 {
297        _mm256_sqrt_ps(a)
298    }
299
300    #[inline(always)]
301    unsafe fn sqrt_f64(a: __m256d) -> __m256d {
302        _mm256_sqrt_pd(a)
303    }
304
305    #[inline(always)]
306    unsafe fn loadu_f32(ptr: *const f32) -> __m256 {
307        _mm256_loadu_ps(ptr)
308    }
309
310    #[inline(always)]
311    unsafe fn and_f32(a: __m256, b: __m256) -> __m256 {
312        _mm256_and_ps(a, b)
313    }
314
315    #[inline(always)]
316    unsafe fn mul_f32(a: __m256, b: __m256) -> __m256 {
317        _mm256_mul_ps(a, b)
318    }
319
320    #[inline(always)]
321    unsafe fn add_i32(a: __m256i, b: __m256i) -> __m256i {
322        // extract second half of a and b
323        let a1 = _mm256_extractf128_si256(a, 1);
324        let b1 = _mm256_extractf128_si256(b, 1);
325        let a0 = _mm256_castsi256_si128(a);
326        let b0 = _mm256_castsi256_si128(b);
327        let c0 = _mm_add_epi32(a0, b0);
328        let c1 = _mm_add_epi32(a1, b1);
329        _mm256_insertf128_si256(_mm256_castsi128_si256(c0), c1, 1)
330    }
331
332    #[inline(always)]
333    unsafe fn and_i32(a: __m256i, b: __m256i) -> __m256i {
334        // extract second half of a and b
335        let a1 = _mm256_extractf128_si256(a, 1);
336        let b1 = _mm256_extractf128_si256(b, 1);
337        let a0 = _mm256_castsi256_si128(a);
338        let b0 = _mm256_castsi256_si128(b);
339        let c0 = _mm_and_si128(a0, b0);
340        let c1 = _mm_and_si128(a1, b1);
341        _mm256_insertf128_si256(_mm256_castsi128_si256(c0), c1, 1)
342    }
343    #[inline(always)]
344    unsafe fn cvt_i32_f32(a: __m256i) -> __m256 {
345        _mm256_cvtepi32_ps(a)
346    }
347    #[inline(always)]
348    unsafe fn cvt_f32_i32(a: __m256) -> __m256i {
349        _mm256_cvtps_epi32(a)
350    }
351
352    #[inline(always)]
353    unsafe fn cvtt_f32_i32(a: __m256) -> __m256i {
354        _mm256_cvttps_epi32(a)
355    }
356
357    #[inline(always)]
358    unsafe fn sub_i32(a: __m256i, b: __m256i) -> __m256i {
359        // extract second half of a and b
360        let a1 = _mm256_extractf128_si256(a, 1);
361        let b1 = _mm256_extractf128_si256(b, 1);
362        let a0 = _mm256_castsi256_si128(a);
363        let b0 = _mm256_castsi256_si128(b);
364        let c0 = _mm_sub_epi32(a0, b0);
365        let c1 = _mm_sub_epi32(a1, b1);
366        _mm256_insertf128_si256(_mm256_castsi128_si256(c0), c1, 1)
367    }
368
369    #[inline(always)]
370    unsafe fn andnot_i32(a: __m256i, b: __m256i) -> __m256i {
371        // extract second half of a and b
372        let a1 = _mm256_extractf128_si256(a, 1);
373        let b1 = _mm256_extractf128_si256(b, 1);
374        let a0 = _mm256_castsi256_si128(a);
375        let b0 = _mm256_castsi256_si128(b);
376        let c0 = _mm_andnot_si128(a0, b0);
377        let c1 = _mm_andnot_si128(a1, b1);
378        _mm256_insertf128_si256(_mm256_castsi128_si256(c0), c1, 1)
379    }
380
381    #[inline(always)]
382    unsafe fn slli_i32<const IMM8: i32>(a: __m256i) -> __m256i {
383        // extract second half of a and b
384        let a1 = _mm256_extractf128_si256(a, 1);
385        let a0 = _mm256_castsi256_si128(a);
386        let c0 = _mm_slli_epi32(a0, IMM8);
387        let c1 = _mm_slli_epi32(a1, IMM8);
388        _mm256_insertf128_si256(_mm256_castsi128_si256(c0), c1, 1)
389    }
390
391    #[inline(always)]
392    unsafe fn cmpeq_i32(a: __m256i, b: __m256i) -> __m256i {
393        // extract second half of a and b
394        let a1 = _mm256_extractf128_si256(a, 1);
395        let b1 = _mm256_extractf128_si256(b, 1);
396        let a0 = _mm256_castsi256_si128(a);
397        let b0 = _mm256_castsi256_si128(b);
398        let c0 = _mm_cmpeq_epi32(a0, b0);
399        let c1 = _mm_cmpeq_epi32(a1, b1);
400        _mm256_insertf128_si256(_mm256_castsi128_si256(c0), c1, 1)
401    }
402
403    #[inline(always)]
404    unsafe fn cast_i32_f32(a: __m256i) -> __m256 {
405        _mm256_castsi256_ps(a)
406    }
407
408    #[inline(always)]
409    unsafe fn fmadd_f32(a: __m256, b: __m256, c: __m256) -> __m256 {
410        let mul = _mm256_mul_ps(a, b);
411        _mm256_add_ps(mul, c)
412    }
413
414    #[inline(always)]
415    unsafe fn andnot_f32(a: __m256, b: __m256) -> __m256 {
416        _mm256_andnot_ps(a, b)
417    }
418
419    #[inline(always)]
420    unsafe fn add_f32(a: __m256, b: __m256) -> __m256 {
421        _mm256_add_ps(a, b)
422    }
423
424    #[inline(always)]
425    unsafe fn xor_f32(a: __m256, b: __m256) -> __m256 {
426        _mm256_xor_ps(a, b)
427    }
428
429    #[inline(always)]
430    unsafe fn storeu_f32(ptr: *mut f32, a: __m256) {
431        _mm256_storeu_ps(ptr, a)
432    }
433
434    #[inline(always)]
435    unsafe fn sub_f32(a: __m256, b: __m256) -> __m256 {
436        _mm256_sub_ps(a, b)
437    }
438
439    #[inline(always)]
440    unsafe fn cmp_eq_f32(a: __m256, b: __m256) -> __m256 {
441        _mm256_cmp_ps(a, b, _CMP_EQ_OS)
442    }
443
444    #[inline(always)]
445    unsafe fn cmp_lt_f32(a: __m256, b: __m256) -> __m256 {
446        _mm256_cmp_ps(a, b, _CMP_LT_OS)
447    }
448
449    #[inline(always)]
450    unsafe fn mask_mul_f32(mask: __m256, a: __m256, b: __m256) -> __m256 {
451        let one = _mm256_set1_ps(1.0);
452        let one = _mm256_andnot_ps(mask, one);
453        let masked_one = _mm256_and_ps(b, mask);
454        let masked_b = _mm256_or_ps(masked_one, one);
455        let c = _mm256_mul_ps(a, masked_b);
456        c
457    }
458
459    #[inline(always)]
460    unsafe fn mask_sub_f32(mask: __m256, a: __m256, b: __m256) -> __m256 {
461        let masked_b = _mm256_and_ps(b, mask);
462        let c = _mm256_sub_ps(a, masked_b);
463        c
464    }
465
466    #[inline(always)]
467    unsafe fn or_f32(a: __m256, b: __m256) -> __m256 {
468        _mm256_or_ps(a, b)
469    }
470
471    #[inline(always)]
472    unsafe fn mask_add_f32(mask: __m256, a: __m256, b: __m256) -> __m256 {
473        let masked_b = _mm256_and_ps(b, mask);
474        let c = _mm256_add_ps(a, masked_b);
475        c
476    }
477
478    #[inline(always)]
479    unsafe fn cast_f32_i32(a: __m256) -> __m256i {
480        _mm256_castps_si256(a)
481    }
482
483    #[inline(always)]
484    unsafe fn srli_i32<const IMM8: i32>(a: __m256i) -> __m256i {
485        // extract second half of a
486        let a1 = _mm256_extractf128_si256(a, 1);
487        let a0 = _mm256_castsi256_si128(a);
488        let c0 = _mm_srli_epi32(a0, IMM8);
489        let c1 = _mm_srli_epi32(a1, IMM8);
490        _mm256_insertf128_si256(_mm256_castsi128_si256(c0), c1, 1)
491    }
492
493    #[inline(always)]
494    unsafe fn min_f32(a: __m256, b: __m256) -> __m256 {
495        _mm256_min_ps(a, b)
496    }
497
498    #[inline(always)]
499    unsafe fn floor_f32(a: __m256) -> __m256 {
500        _mm256_round_ps(a, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)
501    }
502
503    #[inline(always)]
504    unsafe fn max_f32(a: __m256, b: __m256) -> __m256 {
505        _mm256_max_ps(a, b)
506    }
507
508    #[inline(always)]
509    unsafe fn div_f32(a: __m256, b: __m256) -> __m256 {
510        _mm256_div_ps(a, b)
511    }
512
513    #[inline(always)]
514    unsafe fn get_exp_mant_f32(a: __m256) -> (__m256, __m256) {
515        let a_0 = a;
516        let zero_mask = Self::cmp_eq_f32(a, Self::set1_f32(0.0));
517        let nan_mask = Self::cmp_lt_f32(a, Self::set1_f32(0.0));
518        let inv_mant_mask = Self::cast_i32_f32(Self::set1_i32(!0x7f800000));
519        let inf_mask = Self::cmp_eq_f32(a, Self::set1_f32(f32::INFINITY));
520        let denorm_mul = Self::set1_f32(134217730.);
521        let denorm_th = Self::set1_f32(1.1754945e-38);
522        let denorm_mask = Self::cmp_lt_f32(a, denorm_th);
523        let mut a = Self::mask_mul_f32(denorm_mask, a, denorm_mul);
524
525        let mut imm0 = Self::srli_i32::<23>(Self::cast_f32_i32(a));
526
527        /* keep only the fractional part */
528        a = Self::and_f32(a, inv_mant_mask);
529        a = Self::or_f32(a, Self::set1_f32(0.5));
530
531        // this is again another AVX2 instruction
532        imm0 = Self::sub_i32(imm0, Self::set1_i32(0x7f));
533
534        let e = Self::cvt_i32_f32(imm0);
535
536        let e = Self::mask_sub_f32(denorm_mask, e, Self::set1_f32(27.0));
537        let e = Self::mask_sub_f32(zero_mask, e, Self::set1_f32(f32::INFINITY));
538        let e = Self::mask_add_f32(inf_mask, e, Self::set1_f32(f32::INFINITY));
539        let e = Self::min_f32(e, a_0);
540        let e = Self::mask_add_f32(nan_mask, e, Self::set1_f32(f32::NAN));
541
542        (e, a)
543    }
544}
545
546impl UnaryFn1 for AvxSse2 {
547    unsafe fn vs_exp_0(n: usize, a: *const f32, b: *mut f32) {
548        // define constants
549        const EXP_HI: f32 = 88.7228391117 * 1.0;
550        const EXP_LO: f32 = -88.7228391117 * 1.0;
551        const LOG2EF: f32 = 1.44269504088896341;
552        const EXP_P0: f32 = 0.00032712723;
553        const EXP_P1: f32 = 0.00228989065 + 1e-6;
554        // const EXP_P1: f32 = 0.00138888888;
555        const EXP_P2: f32 = 0.01373934392;
556        const EXP_P3: f32 = 0.06869671961;
557        const EXP_P4: f32 = 0.27478687845;
558        const EXP_P5: f32 = 0.82436063535;
559        const L2_U: f32 = -0.693_145_751_953_125;
560        const L2_L: f32 = -1.428_606_765_330_187_045_e-6;
561
562        let exp_hi = Self::set1_f32(EXP_HI);
563        let exp_lo = Self::set1_f32(EXP_LO);
564        let log2ef = Self::set1_f32(LOG2EF);
565
566        let half = Self::set1_f32(0.5);
567
568        let exp_p0 = Self::set1_f32(EXP_P0);
569        let exp_p1 = Self::set1_f32(EXP_P1);
570        let exp_p2 = Self::set1_f32(EXP_P2);
571        let exp_p3 = Self::set1_f32(EXP_P3);
572        let exp_p4 = Self::set1_f32(EXP_P4);
573        let exp_p5 = Self::set1_f32(EXP_P5);
574        let min_exponent = Self::set1_f32(-127.0);
575        let e_sqrt = Self::set1_f32(1.6487212707);
576
577        let l2_u = Self::set1_f32(L2_U);
578        let l2_l = Self::set1_f32(L2_L);
579
580        let mut i = 0;
581        while (i + Self::F32_WIDTH) <= n {
582            let x = Self::loadu_f32(a.offset(i as isize));
583
584            // Clamp x
585            let x = Self::min_f32(exp_hi, x);
586            let x = Self::max_f32(exp_lo, x);
587
588            // Compute fx = floor(x * log2ef + 0.5)
589            let mut fx = Self::mul_f32(x, log2ef);
590            // use to zero rounding since nearest int is problematic for near overflow and underflowing values
591            fx = Self::floor_f32(fx);
592            // to prevent denormalized values
593            fx = Self::max_f32(fx, min_exponent);
594
595            // Approximation for exp(x)
596            let x = Self::fmadd_f32(fx, l2_u, x);
597            let x = Self::fmadd_f32(fx, l2_l, x);
598
599            let x = Self::sub_f32(x, half);
600
601            let mut y = exp_p0;
602            y = Self::fmadd_f32(y, x, exp_p1);
603            y = Self::fmadd_f32(y, x, exp_p2);
604            y = Self::fmadd_f32(y, x, exp_p3);
605            y = Self::fmadd_f32(y, x, exp_p4);
606            y = Self::fmadd_f32(y, x, exp_p5);
607            y = Self::fmadd_f32(y, x, e_sqrt);
608            y = Self::fmadd_f32(y, x, e_sqrt);
609
610            // Compute 2^fx
611            let mut imm0 = Self::cvt_f32_i32(fx);
612            imm0 = Self::add_i32(imm0, Self::set1_i32(0x7f));
613            imm0 = Self::slli_i32::<23>(imm0);
614            let pow2n = Self::cast_i32_f32(imm0);
615
616            // Final result
617            let r = Self::mul_f32(y, pow2n);
618            Self::storeu_f32(b.offset(i as isize), r);
619            i += Self::F32_WIDTH;
620        }
621        while i < n {
622            *b.offset(i as isize) = (*a.offset(i as isize)).exp();
623            i += 1;
624        }
625    }
626}
627
628impl UnaryFn2 for AvxSse2 {}
629
630impl AvxSse2 {
631    #[target_feature(enable = "avx,sse2,sse")]
632    pub(crate) unsafe fn vs_exp(n: usize, a: *const f32, b: *mut f32) {
633        Self::vs_exp_0(n, a, b)
634    }
635
636    #[target_feature(enable = "avx,sse2,sse")]
637    pub(crate) unsafe fn vs_ln(n: usize, a: *const f32, b: *mut f32) {
638        Self::vs_ln_0(n, a, b)
639    }
640
641    #[target_feature(enable = "avx,sse2,sse")]
642    pub(crate) unsafe fn vs_tanh(n: usize, a: *const f32, b: *mut f32) {
643        Self::vs_tanh_0(n, a, b)
644    }
645
646    #[target_feature(enable = "avx,sse2,sse")]
647    pub(crate) unsafe fn vs_sin(n: usize, a: *const f32, b: *mut f32) {
648        Self::vs_sin_0(n, a, b)
649    }
650
651    #[target_feature(enable = "avx,sse2,sse")]
652    pub(crate) unsafe fn vs_cos(n: usize, a: *const f32, b: *mut f32) {
653        Self::vs_cos_0(n, a, b)
654    }
655
656    #[target_feature(enable = "avx")]
657    pub(crate) unsafe fn vs_sqrt(n: usize, a: *const f32, b: *mut f32) {
658        Self::vs_sqrt_0(n, a, b)
659    }
660
661    #[target_feature(enable = "avx")]
662    pub(crate) unsafe fn vd_sqrt(n: usize, a: *const f64, b: *mut f64) {
663        Self::vd_sqrt_0(n, a, b)
664    }
665}
666
667pub(crate) struct Avx512f {}
668pub(crate) unsafe fn vs_ln_avx512f_asm(n: usize, a: *const f32, b: *mut f32) {
669    const NR: usize = 16;
670    // define constants
671    const LN2F_HI: f32 = 0.693359375;
672    const LN2F_LO: f32 = -2.12194440E-4;
673    const P0LOGF: f32 = -0.5;
674    const P1LOGF: f32 = 3.3333331174E-1;
675    const P2LOGF: f32 = -2.4999993993E-1;
676    const P3LOGF: f32 = 2.0000714765E-1;
677    const P4LOGF: f32 = -1.6666657665E-1;
678    const P5LOGF: f32 = 1.4249322787E-1;
679    const P6LOGF: f32 = -1.250000140846E-1;
680    const P7LOGF: f32 = 1.1676998740E-1;
681    // const P8LOGF: f32 = -1.1514610310E-1;
682    // const P9LOGF: f32 = 7.0376836292E-2;
683
684    let constant_arr = [
685        0.73337,
686        1.5,
687        1.0,
688        -0.58496250072,
689        -1.0,
690        P7LOGF,
691        P6LOGF,
692        P5LOGF,
693        P4LOGF,
694        P3LOGF,
695        P2LOGF,
696        P1LOGF,
697        P0LOGF,
698        LN2F_HI + LN2F_LO,
699        f32::NAN,
700    ];
701
702    let mut i = 0;
703    asm!(
704        "vxorps %zmm0, %zmm0, %zmm0",
705        "vbroadcastss 0({constant_arrx}), %zmm1",
706        "vbroadcastss 4({constant_arrx}), %zmm2",
707        "vbroadcastss 8({constant_arrx}), %zmm3",
708        "vbroadcastss 12({constant_arrx}), %zmm4",
709        "vbroadcastss 16({constant_arrx}), %zmm5",
710
711        "vbroadcastss 20({constant_arrx}), %zmm6",
712        "vbroadcastss 24({constant_arrx}), %zmm7",
713        "vbroadcastss 28({constant_arrx}), %zmm8",
714        "vbroadcastss 32({constant_arrx}), %zmm9",
715        "vbroadcastss 36({constant_arrx}), %zmm10",
716        "vbroadcastss 40({constant_arrx}), %zmm11",
717        "vbroadcastss 44({constant_arrx}), %zmm12",
718        "vbroadcastss 48({constant_arrx}), %zmm13",
719
720        "vbroadcastss 52({constant_arrx}), %zmm14",
721
722        "test {nx:e}, {nx:e}",
723        "je 3f",
724
725        "2:",
726        "vmovups ({ax}), %zmm15",
727        "vcmpltps        %zmm0, %zmm15, %k1",
728        "vgetexpps       %zmm15, %zmm16",
729        "vgetmantps      $2, %zmm15, %zmm15",
730        "vcmplt_oqps     %zmm1, %zmm15, %k2",
731        "vmulps  %zmm2, %zmm15, %zmm15 {{%k2}}",
732        "vaddps  %zmm3, %zmm16, %zmm16",
733        "vaddps  %zmm4, %zmm16, %zmm16 {{%k2}}",
734        "vaddps  %zmm5, %zmm15, %zmm15",
735        "vmovaps %zmm6, %zmm17",
736        "vfmadd213ps     %zmm7, %zmm15, %zmm17",
737        "vfmadd213ps     %zmm8, %zmm15, %zmm17",
738        "vfmadd213ps     %zmm9, %zmm15, %zmm17",
739        "vfmadd213ps     %zmm10, %zmm15, %zmm17",
740        "vfmadd213ps     %zmm11, %zmm15, %zmm17",
741        "vfmadd213ps     %zmm12, %zmm15, %zmm17",
742        "vfmadd213ps     %zmm13, %zmm15, %zmm17",
743        "vfmadd213ps     %zmm3, %zmm15, %zmm17",
744        "vmulps  %zmm17, %zmm15, %zmm15",
745        "vfmadd231ps     %zmm14, %zmm16, %zmm15",
746        "vbroadcastss    56({constant_arrx}), %zmm15 {{%k1}}",
747        "vmovups %zmm15, ({bx})",
748        "add  $64, {ax}",
749        "add  $64, {bx}",
750        "add $16, {ix:e}",
751        "cmp {nx:e}, {ix:e}",
752        "jl 2b",
753
754        "3:",
755        "vzeroupper",
756
757        constant_arrx = in(reg) &constant_arr,
758        ax = inout(reg) a => _,
759        bx = inout(reg) b => _,
760        ix = inout(reg) i => i,
761        nx = inout(reg) n / NR * NR => _,
762        out("zmm0") _, out("zmm1") _, out("zmm2") _, out("zmm3") _, out("zmm4") _, out("zmm5") _, out("zmm6") _, out("zmm7") _, out("zmm8") _, out("zmm9") _,
763        out("zmm10") _, out("zmm11") _, out("zmm12") _, out("zmm13") _, out("zmm14") _, out("zmm15") _, out("zmm16") _, out("zmm17") _, out("zmm18") _, out("zmm19") _,
764        out("zmm20") _, out("zmm21") _, out("zmm22") _, out("zmm23") _, out("zmm24") _, out("zmm25") _, out("zmm26") _, out("zmm27") _, out("zmm28") _, out("zmm29") _,
765        out("zmm30") _, out("zmm31") _, out("k1") _, out("k2") _,
766        options(att_syntax)
767    );
768    while i < n {
769        *b.offset(i as isize) = (*a.offset(i as isize)).ln();
770        i += 1;
771    }
772}
773
774pub(crate) unsafe fn vs_exp_avx512f_asm(n: usize, a: *const f32, b: *mut f32) {
775    const NR: usize = 16;
776    // Constants
777    // use asm until avx512f is stabilized
778    const EXP_HI: f32 = 88.3762626647949 * 2.0;
779    const EXP_LO: f32 = -88.3762626647949 * 2.0;
780    const LOG2EF: f32 = 1.44269504088896341;
781    const INV_LOG2EF: f32 = 0.693359375;
782    const CEHPES_EXP_C2: f32 = -2.12194440e-4;
783    const EXP_P0: f32 = 1.9875691500E-4;
784    const EXP_P1: f32 = 1.3981999507E-3;
785    const EXP_P2: f32 = 8.3334519073E-3;
786    const EXP_P3: f32 = 4.1665795894E-2;
787    const EXP_P4: f32 = 1.6666665459E-1;
788    const EXP_P5: f32 = 5.0000001201E-1;
789
790    let constant_arr =
791        [LOG2EF, -CEHPES_EXP_C2 - INV_LOG2EF, EXP_P0, EXP_P1, EXP_P2, EXP_P3, EXP_P4, EXP_P5, 1.0, EXP_HI, EXP_LO];
792    let mut i = 0;
793    asm!(
794        "vbroadcastss ({constant_arrx}), %zmm0",
795        "vbroadcastss 4({constant_arrx}), %zmm1",
796        "vbroadcastss 8({constant_arrx}), %zmm2",
797        "vbroadcastss 12({constant_arrx}), %zmm3",
798        "vbroadcastss 16({constant_arrx}), %zmm4",
799        "vbroadcastss 20({constant_arrx}), %zmm5",
800        "vbroadcastss 24({constant_arrx}), %zmm6",
801        "vbroadcastss 28({constant_arrx}), %zmm7",
802        "vbroadcastss 32({constant_arrx}), %zmm8",
803
804        "vbroadcastss 36({constant_arrx}), %zmm13",
805        "vbroadcastss 40({constant_arrx}), %zmm14",
806
807        "test {nx:e}, {nx:e}",
808        "je 3f",
809
810        "2:",
811        "vmovups ({ax}), %zmm9",
812        // order of input for max and min is important
813        // since it leads to correct NaN propagation
814        "vminps %zmm9, %zmm13, %zmm9",
815        "vmaxps %zmm9, %zmm14, %zmm9",
816        "vmulps  %zmm0, %zmm9, %zmm10",
817        "vrndscaleps     $8, %zmm10, %zmm10",
818        "vfmadd231ps     %zmm1, %zmm10, %zmm9",
819        "vmovaps %zmm2, %zmm11",
820        "vfmadd213ps     %zmm3, %zmm9, %zmm11",
821        "vfmadd213ps     %zmm4, %zmm9, %zmm11",
822        "vfmadd213ps     %zmm5, %zmm9, %zmm11",
823        "vfmadd213ps     %zmm6, %zmm9, %zmm11",
824        "vmulps  %zmm9, %zmm9, %zmm12",
825        "vfmadd213ps     %zmm7, %zmm9, %zmm11",
826        "vfmadd213ps     %zmm9, %zmm12, %zmm11",
827        "vaddps  %zmm8, %zmm11, %zmm9",
828        "vscalefps       %zmm10, %zmm9, %zmm9",
829        "vmovups %zmm9, ({bx})",
830        "add  $64, {ax}",
831        "add  $64, {bx}",
832        "add $16, {ix:e}",
833        "cmp {nx:e}, {ix:e}",
834        "jl 2b",
835
836        "3:",
837        "vzeroupper",
838
839        constant_arrx = in(reg) &constant_arr,
840        ax = inout(reg) a => _,
841        bx = inout(reg) b => _,
842        ix = inout(reg) i => i,
843        nx = inout(reg) n / NR * NR => _,
844        out("zmm0") _, out("zmm1") _, out("zmm2") _, out("zmm3") _, out("zmm4") _, out("zmm5") _, out("zmm6") _, out("zmm7") _, out("zmm8") _, out("zmm9") _,
845        out("zmm10") _, out("zmm11") _, out("zmm12") _, out("zmm13") _, out("zmm14") _, out("zmm15") _, out("zmm16") _, out("zmm17") _, out("zmm18") _, out("zmm19") _,
846        out("zmm20") _, out("zmm21") _, out("zmm22") _, out("zmm23") _, out("zmm24") _, out("zmm25") _, out("zmm26") _, out("zmm27") _, out("zmm28") _, out("zmm29") _,
847        out("zmm30") _, out("zmm31") _,
848        options(att_syntax)
849    );
850    while i < n {
851        *b.offset(i as isize) = (*a.offset(i as isize)).exp();
852        i += 1;
853    }
854}
855
856pub(crate) unsafe fn vs_tanh_avx512f_asm(n: usize, a: *const f32, b: *mut f32) {
857    const NR: usize = 16;
858    // Constants
859    // use asm until avx512f is stabilized
860    const EXP_HI: f32 = 88.3762626647949 * 0.5;
861    const EXP_LO: f32 = -88.3762626647949 * 0.5;
862    const LOG2EF: f32 = 1.44269504088896341;
863    const INV_LOG2EF: f32 = 0.693359375;
864    const CEHPES_EXP_C2: f32 = -2.12194440e-4;
865    const EXP_P0: f32 = 1.9875691500E-4;
866    const EXP_P1: f32 = 1.3981999507E-3;
867    const EXP_P2: f32 = 8.3334519073E-3;
868    const EXP_P3: f32 = 4.1665795894E-2;
869    const EXP_P4: f32 = 1.6666665459E-1;
870    const EXP_P5: f32 = 5.0000001201E-1;
871
872    let constant_arr = [
873        LOG2EF,
874        -CEHPES_EXP_C2 - INV_LOG2EF,
875        EXP_P0,
876        EXP_P1,
877        EXP_P2,
878        EXP_P3,
879        EXP_P4,
880        EXP_P5,
881        1.0,
882        EXP_HI,
883        EXP_LO,
884        -1.0,
885    ];
886    let mut i = 0;
887    asm!(
888        "vbroadcastss ({constant_arrx}), %zmm0",
889        "vbroadcastss 4({constant_arrx}), %zmm1",
890        "vbroadcastss 8({constant_arrx}), %zmm2",
891        "vbroadcastss 12({constant_arrx}), %zmm3",
892        "vbroadcastss 16({constant_arrx}), %zmm4",
893        "vbroadcastss 20({constant_arrx}), %zmm5",
894        "vbroadcastss 24({constant_arrx}), %zmm6",
895        "vbroadcastss 28({constant_arrx}), %zmm7",
896        "vbroadcastss 32({constant_arrx}), %zmm8",
897
898        "vbroadcastss 36({constant_arrx}), %zmm13",
899        "vbroadcastss 40({constant_arrx}), %zmm14",
900        "vbroadcastss 44({constant_arrx}), %zmm15",
901
902        "test {nx:e}, {nx:e}",
903        "je 3f",
904
905        "2:",
906        "vmovups ({ax}), %zmm31",
907        // order of input for max and min is important
908        // since it leads to correct NaN propagation
909        "vminps %zmm31, %zmm13, %zmm9",
910        "vmaxps %zmm9, %zmm14, %zmm9",
911        "vmulps  %zmm0, %zmm9, %zmm10",
912        "vrndscaleps     $8, %zmm10, %zmm10",
913        "vfmadd231ps     %zmm1, %zmm10, %zmm9",
914        "vmovaps %zmm2, %zmm11",
915        "vfmadd213ps     %zmm3, %zmm9, %zmm11",
916        "vfmadd213ps     %zmm4, %zmm9, %zmm11",
917        "vfmadd213ps     %zmm5, %zmm9, %zmm11",
918        "vfmadd213ps     %zmm6, %zmm9, %zmm11",
919        "vmulps  %zmm9, %zmm9, %zmm12",
920        "vfmadd213ps     %zmm7, %zmm9, %zmm11",
921        "vfmadd213ps     %zmm9, %zmm12, %zmm11",
922        "vaddps  %zmm8, %zmm11, %zmm9",
923        "vscalefps       %zmm10, %zmm9, %zmm9",
924        "vmulps %zmm9, %zmm9, %zmm9",
925
926        "vaddps %zmm9, %zmm15, %zmm16",
927        "vaddps %zmm9, %zmm8, %zmm9",
928        "vdivps %zmm9, %zmm16, %zmm9",
929        "vminps %zmm9, %zmm8, %zmm9",
930        "vmaxps %zmm9, %zmm15, %zmm9",
931
932        "vmovups %zmm9, ({bx})",
933        "add  $64, {ax}",
934        "add  $64, {bx}",
935        "add $16, {ix:e}",
936        "cmp {nx:e}, {ix:e}",
937        "jl 2b",
938
939        "3:",
940        "vzeroupper",
941
942        constant_arrx = in(reg) &constant_arr,
943        ax = inout(reg) a => _,
944        bx = inout(reg) b => _,
945        ix = inout(reg) i => i,
946        nx = inout(reg) n / NR * NR => _,
947        out("zmm0") _, out("zmm1") _, out("zmm2") _, out("zmm3") _, out("zmm4") _, out("zmm5") _, out("zmm6") _, out("zmm7") _, out("zmm8") _, out("zmm9") _,
948        out("zmm10") _, out("zmm11") _, out("zmm12") _, out("zmm13") _, out("zmm14") _, out("zmm15") _, out("zmm16") _, out("zmm17") _, out("zmm18") _, out("zmm19") _,
949        out("zmm20") _, out("zmm21") _, out("zmm22") _, out("zmm23") _, out("zmm24") _, out("zmm25") _, out("zmm26") _, out("zmm27") _, out("zmm28") _, out("zmm29") _,
950        out("zmm30") _, out("zmm31") _,
951        options(att_syntax)
952    );
953    while i < n {
954        *b.offset(i as isize) = (*a.offset(i as isize)).tanh();
955        i += 1;
956    }
957}
958
959pub(crate) unsafe fn vs_sqrt_avx512f_asm(n: usize, a: *const f32, b: *mut f32) {
960    const NR: usize = 16;
961    // Constants
962    // use asm until avx512f is stabilized
963    let mut i = 0;
964    asm!(
965        "test {nx:e}, {nx:e}",
966        "je 3f",
967
968        "2:",
969        "vmovups ({ax}), %zmm9",
970        "vsqrtps %zmm9, %zmm9",
971        "vmovups %zmm9, ({bx})",
972        "add  $64, {ax}",
973        "add  $64, {bx}",
974        "add $16, {ix:e}",
975        "cmp {nx:e}, {ix:e}",
976        "jl 2b",
977
978        "3:",
979        "vzeroupper",
980
981        ax = inout(reg) a => _,
982        bx = inout(reg) b => _,
983        ix = inout(reg) i => i,
984        nx = inout(reg) n / NR * NR => _,
985        out("zmm9") _,
986        options(att_syntax)
987    );
988    while i < n {
989        *b.offset(i as isize) = (*a.offset(i as isize)).sqrt();
990        i += 1;
991    }
992}
993
994pub(crate) unsafe fn vd_sqrt_avx512f_asm(n: usize, a: *const f64, b: *mut f64) {
995    const NR: usize = 8;
996    // Constants
997    // use asm until avx512f is stabilized
998    let mut i = 0;
999    asm!(
1000        "test {nx:e}, {nx:e}",
1001        "je 3f",
1002
1003        "2:",
1004        "vmovupd ({ax}), %zmm9",
1005        "vsqrtpd %zmm9, %zmm9",
1006        "vmovupd %zmm9, ({bx})",
1007        "add  $64, {ax}",
1008        "add  $64, {bx}",
1009        "add $16, {ix:e}",
1010        "cmp {nx:e}, {ix:e}",
1011        "jl 2b",
1012
1013        "3:",
1014        "vzeroupper",
1015
1016        ax = inout(reg) a => _,
1017        bx = inout(reg) b => _,
1018        ix = inout(reg) i => i,
1019        nx = inout(reg) n / NR * NR => _,
1020        out("zmm9") _,
1021        options(att_syntax)
1022    );
1023    while i < n {
1024        *b.offset(i as isize) = (*a.offset(i as isize)).sqrt();
1025        i += 1;
1026    }
1027}
1028
1029impl Avx512f {
1030    pub(crate) unsafe fn vs_exp(n: usize, a: *const f32, b: *mut f32) {
1031        vs_exp_avx512f_asm(n, a, b);
1032    }
1033
1034    pub(crate) unsafe fn vs_ln(n: usize, a: *const f32, b: *mut f32) {
1035        vs_ln_avx512f_asm(n, a, b);
1036    }
1037
1038    pub(crate) unsafe fn vs_tanh(n: usize, a: *const f32, b: *mut f32) {
1039        vs_tanh_avx512f_asm(n, a, b);
1040    }
1041
1042    pub(crate) unsafe fn vs_sqrt(n: usize, a: *const f32, b: *mut f32) {
1043        vs_sqrt_avx512f_asm(n, a, b);
1044    }
1045
1046    pub(crate) unsafe fn vd_sqrt(n: usize, a: *const f64, b: *mut f64) {
1047        vd_sqrt_avx512f_asm(n, a, b);
1048    }
1049}