Skip to main content

carbon_simd/x86_64/
floating.rs

1use crate::*;
2use core::arch::x86_64::*;
3
4unsafe impl SimdFloatingElement for f32 {
5    #[inline(always)]
6    unsafe fn sqrt(x: Self::Vector) -> Self::Vector {
7        unsafe { _mm256_sqrt_ps(x) }
8    }
9
10    #[inline(always)]
11    unsafe fn exp(x: Self::Vector) -> Self::Vector {
12        unsafe {
13            const LN_2: f32 = core::f32::consts::LN_2;
14            const C0: f32 = 1.0 / 1.0;
15            const C1: f32 = 1.0 / 1.0;
16            const C2: f32 = 1.0 / 2.0;
17            const C3: f32 = 1.0 / 6.0;
18            const C4: f32 = 1.0 / 24.0;
19            const C5: f32 = 1.0 / 120.0;
20
21            let fx = _mm256_mul_ps(x, Self::set(1.0 / LN_2));
22            let n = _mm256_floor_ps(fx);
23            let f = _mm256_sub_ps(x, _mm256_mul_ps(n, _mm256_set1_ps(LN_2)));
24
25            let poly = _mm256_set1_ps(C5);
26            let poly = _mm256_fmadd_ps(poly, f, _mm256_set1_ps(C4));
27            let poly = _mm256_fmadd_ps(poly, f, _mm256_set1_ps(C3));
28            let poly = _mm256_fmadd_ps(poly, f, _mm256_set1_ps(C2));
29            let poly = _mm256_fmadd_ps(poly, f, _mm256_set1_ps(C1));
30            let poly = _mm256_fmadd_ps(poly, f, _mm256_set1_ps(C0));
31
32            const EXP_BIAS: i32 = 0x7f;
33            const EXP_OFFSET: i32 = 23;
34
35            let exp_bias = _mm256_set1_epi32(EXP_BIAS);
36            let n_i32 = _mm256_cvtps_epi32(n);
37            let pow_2_n = _mm256_castsi256_ps(_mm256_slli_epi32(
38                _mm256_add_epi32(n_i32, exp_bias),
39                EXP_OFFSET,
40            ));
41
42            _mm256_mul_ps(pow_2_n, poly)
43        }
44    }
45
46    #[inline(always)]
47    unsafe fn tanh(x: Self::Vector) -> Self::Vector {
48        unsafe {
49            let one = Self::set(1.0);
50            let two = Self::set(2.0);
51
52            let mul_2x = <Self as SimdNumElement>::mul(x, two);
53            let exp_2x = <Self as SimdFloatingElement>::exp(mul_2x);
54            let exp_2x_plus_1 = <Self as SimdNumElement>::add(exp_2x, one);
55
56            let first_term = one;
57            let second_term = <Self as SimdNumElement>::div(two, exp_2x_plus_1);
58
59            <Self as SimdNumElement>::sub(first_term, second_term)
60        }
61    }
62}
63
64unsafe impl SimdFloatingElement for f64 {
65    #[inline(always)]
66    unsafe fn sqrt(x: Self::Vector) -> Self::Vector {
67        unsafe { _mm256_sqrt_pd(x) }
68    }
69
70    #[inline(always)]
71    unsafe fn exp(x: Self::Vector) -> Self::Vector {
72        unsafe {
73            let mut buff = [0.0f64; Self::VECTOR_LEN];
74            Self::store(buff.as_mut_ptr(), x);
75            for i in &mut buff {
76                *i = i.exp();
77            }
78            Self::load(buff.as_ptr())
79        }
80    }
81
82    #[inline(always)]
83    unsafe fn tanh(x: Self::Vector) -> Self::Vector {
84        unsafe {
85            let one = Self::set(1.0);
86            let two = Self::set(2.0);
87
88            let mul_2x = <Self as SimdNumElement>::mul(x, two);
89            let exp_2x = <Self as SimdFloatingElement>::exp(mul_2x);
90            let exp_2x_plus_1 = <Self as SimdNumElement>::add(exp_2x, one);
91
92            let first_term = one;
93            let second_term = <Self as SimdNumElement>::div(two, exp_2x_plus_1);
94
95            <Self as SimdNumElement>::sub(first_term, second_term)
96        }
97    }
98}