trueno/backends/sse2/ops/
activations.rs1#[cfg(target_arch = "x86_64")]
4use std::arch::x86_64::*;
5
6#[inline]
12#[target_feature(enable = "sse2")]
13pub(crate) unsafe fn exp_approx_sse2(x: __m128) -> __m128 {
15 let ln2 = _mm_set1_ps(std::f32::consts::LN_2);
16 let inv_ln2 = _mm_set1_ps(1.0 / std::f32::consts::LN_2);
17 let one = _mm_set1_ps(1.0);
18 let c2 = _mm_set1_ps(0.5);
19 let c3 = _mm_set1_ps(0.166_666_67);
20 let c4 = _mm_set1_ps(0.041_666_668);
21 let c5 = _mm_set1_ps(0.008_333_334);
22 let k = _mm_cvtps_epi32(_mm_mul_ps(x, inv_ln2));
23 let kf = _mm_cvtepi32_ps(k);
24 let r = _mm_sub_ps(x, _mm_mul_ps(kf, ln2));
25 let mut poly = _mm_add_ps(one, _mm_mul_ps(r, c5));
26 poly = _mm_add_ps(one, _mm_mul_ps(r, _mm_add_ps(c4, _mm_mul_ps(r, poly))));
27 poly = _mm_add_ps(one, _mm_mul_ps(r, _mm_add_ps(c3, _mm_mul_ps(r, poly))));
28 poly = _mm_add_ps(one, _mm_mul_ps(r, _mm_add_ps(c2, _mm_mul_ps(r, poly))));
29 poly = _mm_add_ps(one, _mm_mul_ps(r, poly));
30 let exp_k = _mm_castsi128_ps(_mm_slli_epi32(_mm_add_epi32(k, _mm_set1_epi32(127)), 23));
31 _mm_mul_ps(poly, exp_k)
32}
33
34#[inline]
36#[target_feature(enable = "sse2")]
37pub(crate) unsafe fn exp(a: &[f32], result: &mut [f32]) {
39 unsafe {
40 let len = a.len();
42 let mut i = 0;
43 let ln2 = _mm_set1_ps(std::f32::consts::LN_2);
44 let inv_ln2 = _mm_set1_ps(1.0 / std::f32::consts::LN_2);
45 let c1 = _mm_set1_ps(1.0);
46 let c2 = _mm_set1_ps(0.5);
47 let c3 = _mm_set1_ps(0.166_666_67);
48 let c4 = _mm_set1_ps(0.041_666_668);
49 let c5 = _mm_set1_ps(0.008_333_334);
50 while i + 4 <= len {
51 let x = _mm_loadu_ps(a.as_ptr().add(i));
52 let k = _mm_cvtps_epi32(_mm_mul_ps(x, inv_ln2));
53 let kf = _mm_cvtepi32_ps(k);
54 let r = _mm_sub_ps(x, _mm_mul_ps(kf, ln2));
55 let mut poly = _mm_add_ps(c1, _mm_mul_ps(r, c5));
56 poly = _mm_add_ps(c1, _mm_mul_ps(r, _mm_add_ps(c4, _mm_mul_ps(r, poly))));
57 poly = _mm_add_ps(c1, _mm_mul_ps(r, _mm_add_ps(c3, _mm_mul_ps(r, poly))));
58 poly = _mm_add_ps(c1, _mm_mul_ps(r, _mm_add_ps(c2, _mm_mul_ps(r, poly))));
59 poly = _mm_add_ps(c1, _mm_mul_ps(r, poly));
60 let exp_k = _mm_castsi128_ps(_mm_slli_epi32(_mm_add_epi32(k, _mm_set1_epi32(127)), 23));
61 _mm_storeu_ps(result.as_mut_ptr().add(i), _mm_mul_ps(poly, exp_k));
62 i += 4;
63 }
64 for j in i..len {
65 result[j] = a[j].exp();
66 }
67 }
68}
69
70#[inline]
72#[target_feature(enable = "sse2")]
73pub(crate) unsafe fn sigmoid(a: &[f32], result: &mut [f32]) {
75 unsafe {
76 let len = a.len();
78 let mut i = 0;
79 let one = _mm_set1_ps(1.0);
80 let neg_one = _mm_set1_ps(-1.0);
81 let ln2 = _mm_set1_ps(std::f32::consts::LN_2);
82 let inv_ln2 = _mm_set1_ps(1.0 / std::f32::consts::LN_2);
83 let c2 = _mm_set1_ps(0.5);
84 let c3 = _mm_set1_ps(0.166_666_67);
85 let c4 = _mm_set1_ps(0.041_666_668);
86 let c5 = _mm_set1_ps(0.008_333_334);
87 while i + 4 <= len {
88 let x = _mm_loadu_ps(a.as_ptr().add(i));
89 let neg_x = _mm_mul_ps(x, neg_one);
90 let k = _mm_cvtps_epi32(_mm_mul_ps(neg_x, inv_ln2));
91 let kf = _mm_cvtepi32_ps(k);
92 let r = _mm_sub_ps(neg_x, _mm_mul_ps(kf, ln2));
93 let mut poly = _mm_add_ps(one, _mm_mul_ps(r, c5));
94 poly = _mm_add_ps(one, _mm_mul_ps(r, _mm_add_ps(c4, _mm_mul_ps(r, poly))));
95 poly = _mm_add_ps(one, _mm_mul_ps(r, _mm_add_ps(c3, _mm_mul_ps(r, poly))));
96 poly = _mm_add_ps(one, _mm_mul_ps(r, _mm_add_ps(c2, _mm_mul_ps(r, poly))));
97 poly = _mm_add_ps(one, _mm_mul_ps(r, poly));
98 let exp_k = _mm_castsi128_ps(_mm_slli_epi32(_mm_add_epi32(k, _mm_set1_epi32(127)), 23));
99 let exp_neg_x = _mm_mul_ps(poly, exp_k);
100 _mm_storeu_ps(result.as_mut_ptr().add(i), _mm_div_ps(one, _mm_add_ps(one, exp_neg_x)));
101 i += 4;
102 }
103 for j in i..len {
104 result[j] = 1.0 / (1.0 + (-a[j]).exp());
105 }
106 }
107}
108
109#[inline]
111#[target_feature(enable = "sse2")]
112pub(crate) unsafe fn gelu(a: &[f32], result: &mut [f32]) {
114 unsafe {
115 let len = a.len();
117 let mut i = 0;
118 let half = _mm_set1_ps(0.5);
119 let one = _mm_set1_ps(1.0);
120 let sqrt_2_pi = _mm_set1_ps(0.797_884_56);
121 let coeff = _mm_set1_ps(0.044_715);
122 while i + 4 <= len {
123 let x = _mm_loadu_ps(a.as_ptr().add(i));
124 let x3 = _mm_mul_ps(_mm_mul_ps(x, x), x);
125 let inner = _mm_mul_ps(sqrt_2_pi, _mm_add_ps(x, _mm_mul_ps(coeff, x3)));
126 let two_inner = _mm_add_ps(inner, inner);
128 let exp_2x = exp_approx_sse2(two_inner);
129 let tanh_val = _mm_div_ps(_mm_sub_ps(exp_2x, one), _mm_add_ps(exp_2x, one));
130 _mm_storeu_ps(
131 result.as_mut_ptr().add(i),
132 _mm_mul_ps(half, _mm_mul_ps(x, _mm_add_ps(one, tanh_val))),
133 );
134 i += 4;
135 }
136 for j in i..len {
137 let x = a[j];
138 result[j] = 0.5
139 * x
140 * (1.0 + ((0.797_884_56 * (x + 0.044_715 * x * x * x)) as f64).tanh() as f32);
141 }
142 }
143}
144
145#[inline]
147#[target_feature(enable = "sse2")]
148pub(crate) unsafe fn swish(a: &[f32], result: &mut [f32]) {
150 unsafe {
151 let len = a.len();
153 let mut i = 0;
154 let one = _mm_set1_ps(1.0);
155 let neg_one = _mm_set1_ps(-1.0);
156 let ln2 = _mm_set1_ps(std::f32::consts::LN_2);
157 let inv_ln2 = _mm_set1_ps(1.0 / std::f32::consts::LN_2);
158 let c2 = _mm_set1_ps(0.5);
159 let c3 = _mm_set1_ps(0.166_666_67);
160 let c4 = _mm_set1_ps(0.041_666_668);
161 let c5 = _mm_set1_ps(0.008_333_334);
162 while i + 4 <= len {
163 let x = _mm_loadu_ps(a.as_ptr().add(i));
164 let neg_x = _mm_mul_ps(x, neg_one);
165 let k = _mm_cvtps_epi32(_mm_mul_ps(neg_x, inv_ln2));
166 let kf = _mm_cvtepi32_ps(k);
167 let r = _mm_sub_ps(neg_x, _mm_mul_ps(kf, ln2));
168 let mut poly = _mm_add_ps(one, _mm_mul_ps(r, c5));
169 poly = _mm_add_ps(one, _mm_mul_ps(r, _mm_add_ps(c4, _mm_mul_ps(r, poly))));
170 poly = _mm_add_ps(one, _mm_mul_ps(r, _mm_add_ps(c3, _mm_mul_ps(r, poly))));
171 poly = _mm_add_ps(one, _mm_mul_ps(r, _mm_add_ps(c2, _mm_mul_ps(r, poly))));
172 poly = _mm_add_ps(one, _mm_mul_ps(r, poly));
173 let exp_k = _mm_castsi128_ps(_mm_slli_epi32(_mm_add_epi32(k, _mm_set1_epi32(127)), 23));
174 let exp_neg_x = _mm_mul_ps(poly, exp_k);
175 let sigmoid = _mm_div_ps(one, _mm_add_ps(one, exp_neg_x));
176 _mm_storeu_ps(result.as_mut_ptr().add(i), _mm_mul_ps(x, sigmoid));
177 i += 4;
178 }
179 for j in i..len {
180 result[j] = a[j] / (1.0 + (-a[j]).exp());
181 }
182 }
183}
184
185#[inline]
187#[target_feature(enable = "sse2")]
188pub(crate) unsafe fn tanh(a: &[f32], result: &mut [f32]) {
190 unsafe {
191 let len = a.len();
193 let mut i = 0;
194 let one = _mm_set1_ps(1.0);
195 let two = _mm_set1_ps(2.0);
196 while i + 4 <= len {
197 let x = _mm_loadu_ps(a.as_ptr().add(i));
198 let exp_2x = exp_approx_sse2(_mm_mul_ps(two, x));
199 _mm_storeu_ps(
200 result.as_mut_ptr().add(i),
201 _mm_div_ps(_mm_sub_ps(exp_2x, one), _mm_add_ps(exp_2x, one)),
202 );
203 i += 4;
204 }
205 for j in i..len {
206 let exp_2x = (2.0 * a[j]).exp();
207 result[j] = (exp_2x - 1.0) / (exp_2x + 1.0);
208 }
209 }
210}