Skip to main content

trueno/backends/sse2/ops/
activations.rs

1//! SSE2 activation functions (exp, sigmoid, gelu, swish, tanh).
2
3#[cfg(target_arch = "x86_64")]
4use std::arch::x86_64::*;
5
6/// SSE2 exp approximation helper (polynomial range reduction).
7///
8/// # Safety
9///
10/// Caller must ensure SSE2 is available on the current CPU.
11#[inline]
12#[target_feature(enable = "sse2")]
13// SAFETY: caller verifies SSE2 support, input slices meet alignment/length requirements
14pub(crate) unsafe fn exp_approx_sse2(x: __m128) -> __m128 {
15    let ln2 = _mm_set1_ps(std::f32::consts::LN_2);
16    let inv_ln2 = _mm_set1_ps(1.0 / std::f32::consts::LN_2);
17    let one = _mm_set1_ps(1.0);
18    let c2 = _mm_set1_ps(0.5);
19    let c3 = _mm_set1_ps(0.166_666_67);
20    let c4 = _mm_set1_ps(0.041_666_668);
21    let c5 = _mm_set1_ps(0.008_333_334);
22    let k = _mm_cvtps_epi32(_mm_mul_ps(x, inv_ln2));
23    let kf = _mm_cvtepi32_ps(k);
24    let r = _mm_sub_ps(x, _mm_mul_ps(kf, ln2));
25    let mut poly = _mm_add_ps(one, _mm_mul_ps(r, c5));
26    poly = _mm_add_ps(one, _mm_mul_ps(r, _mm_add_ps(c4, _mm_mul_ps(r, poly))));
27    poly = _mm_add_ps(one, _mm_mul_ps(r, _mm_add_ps(c3, _mm_mul_ps(r, poly))));
28    poly = _mm_add_ps(one, _mm_mul_ps(r, _mm_add_ps(c2, _mm_mul_ps(r, poly))));
29    poly = _mm_add_ps(one, _mm_mul_ps(r, poly));
30    let exp_k = _mm_castsi128_ps(_mm_slli_epi32(_mm_add_epi32(k, _mm_set1_epi32(127)), 23));
31    _mm_mul_ps(poly, exp_k)
32}
33
34/// SSE2 exp (element-wise).
35#[inline]
36#[target_feature(enable = "sse2")]
37// SAFETY: caller ensures preconditions are met for this unsafe function
38pub(crate) unsafe fn exp(a: &[f32], result: &mut [f32]) {
39    unsafe {
40        // Polynomial approximation for exp - range reduction + polynomial
41        let len = a.len();
42        let mut i = 0;
43        let ln2 = _mm_set1_ps(std::f32::consts::LN_2);
44        let inv_ln2 = _mm_set1_ps(1.0 / std::f32::consts::LN_2);
45        let c1 = _mm_set1_ps(1.0);
46        let c2 = _mm_set1_ps(0.5);
47        let c3 = _mm_set1_ps(0.166_666_67);
48        let c4 = _mm_set1_ps(0.041_666_668);
49        let c5 = _mm_set1_ps(0.008_333_334);
50        while i + 4 <= len {
51            let x = _mm_loadu_ps(a.as_ptr().add(i));
52            let k = _mm_cvtps_epi32(_mm_mul_ps(x, inv_ln2));
53            let kf = _mm_cvtepi32_ps(k);
54            let r = _mm_sub_ps(x, _mm_mul_ps(kf, ln2));
55            let mut poly = _mm_add_ps(c1, _mm_mul_ps(r, c5));
56            poly = _mm_add_ps(c1, _mm_mul_ps(r, _mm_add_ps(c4, _mm_mul_ps(r, poly))));
57            poly = _mm_add_ps(c1, _mm_mul_ps(r, _mm_add_ps(c3, _mm_mul_ps(r, poly))));
58            poly = _mm_add_ps(c1, _mm_mul_ps(r, _mm_add_ps(c2, _mm_mul_ps(r, poly))));
59            poly = _mm_add_ps(c1, _mm_mul_ps(r, poly));
60            let exp_k = _mm_castsi128_ps(_mm_slli_epi32(_mm_add_epi32(k, _mm_set1_epi32(127)), 23));
61            _mm_storeu_ps(result.as_mut_ptr().add(i), _mm_mul_ps(poly, exp_k));
62            i += 4;
63        }
64        for j in i..len {
65            result[j] = a[j].exp();
66        }
67    }
68}
69
70/// SSE2 sigmoid activation.
71#[inline]
72#[target_feature(enable = "sse2")]
73// SAFETY: caller ensures preconditions are met for this unsafe function
74pub(crate) unsafe fn sigmoid(a: &[f32], result: &mut [f32]) {
75    unsafe {
76        // sigmoid(x) = 1 / (1 + exp(-x))
77        let len = a.len();
78        let mut i = 0;
79        let one = _mm_set1_ps(1.0);
80        let neg_one = _mm_set1_ps(-1.0);
81        let ln2 = _mm_set1_ps(std::f32::consts::LN_2);
82        let inv_ln2 = _mm_set1_ps(1.0 / std::f32::consts::LN_2);
83        let c2 = _mm_set1_ps(0.5);
84        let c3 = _mm_set1_ps(0.166_666_67);
85        let c4 = _mm_set1_ps(0.041_666_668);
86        let c5 = _mm_set1_ps(0.008_333_334);
87        while i + 4 <= len {
88            let x = _mm_loadu_ps(a.as_ptr().add(i));
89            let neg_x = _mm_mul_ps(x, neg_one);
90            let k = _mm_cvtps_epi32(_mm_mul_ps(neg_x, inv_ln2));
91            let kf = _mm_cvtepi32_ps(k);
92            let r = _mm_sub_ps(neg_x, _mm_mul_ps(kf, ln2));
93            let mut poly = _mm_add_ps(one, _mm_mul_ps(r, c5));
94            poly = _mm_add_ps(one, _mm_mul_ps(r, _mm_add_ps(c4, _mm_mul_ps(r, poly))));
95            poly = _mm_add_ps(one, _mm_mul_ps(r, _mm_add_ps(c3, _mm_mul_ps(r, poly))));
96            poly = _mm_add_ps(one, _mm_mul_ps(r, _mm_add_ps(c2, _mm_mul_ps(r, poly))));
97            poly = _mm_add_ps(one, _mm_mul_ps(r, poly));
98            let exp_k = _mm_castsi128_ps(_mm_slli_epi32(_mm_add_epi32(k, _mm_set1_epi32(127)), 23));
99            let exp_neg_x = _mm_mul_ps(poly, exp_k);
100            _mm_storeu_ps(result.as_mut_ptr().add(i), _mm_div_ps(one, _mm_add_ps(one, exp_neg_x)));
101            i += 4;
102        }
103        for j in i..len {
104            result[j] = 1.0 / (1.0 + (-a[j]).exp());
105        }
106    }
107}
108
109/// SSE2 GELU activation.
110#[inline]
111#[target_feature(enable = "sse2")]
112// SAFETY: caller ensures preconditions are met for this unsafe function
113pub(crate) unsafe fn gelu(a: &[f32], result: &mut [f32]) {
114    unsafe {
115        // GELU(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
116        let len = a.len();
117        let mut i = 0;
118        let half = _mm_set1_ps(0.5);
119        let one = _mm_set1_ps(1.0);
120        let sqrt_2_pi = _mm_set1_ps(0.797_884_56);
121        let coeff = _mm_set1_ps(0.044_715);
122        while i + 4 <= len {
123            let x = _mm_loadu_ps(a.as_ptr().add(i));
124            let x3 = _mm_mul_ps(_mm_mul_ps(x, x), x);
125            let inner = _mm_mul_ps(sqrt_2_pi, _mm_add_ps(x, _mm_mul_ps(coeff, x3)));
126            // tanh approximation: (e^2x - 1) / (e^2x + 1)
127            let two_inner = _mm_add_ps(inner, inner);
128            let exp_2x = exp_approx_sse2(two_inner);
129            let tanh_val = _mm_div_ps(_mm_sub_ps(exp_2x, one), _mm_add_ps(exp_2x, one));
130            _mm_storeu_ps(
131                result.as_mut_ptr().add(i),
132                _mm_mul_ps(half, _mm_mul_ps(x, _mm_add_ps(one, tanh_val))),
133            );
134            i += 4;
135        }
136        for j in i..len {
137            let x = a[j];
138            result[j] = 0.5
139                * x
140                * (1.0 + ((0.797_884_56 * (x + 0.044_715 * x * x * x)) as f64).tanh() as f32);
141        }
142    }
143}
144
145/// SSE2 swish activation.
146#[inline]
147#[target_feature(enable = "sse2")]
148// SAFETY: caller ensures preconditions are met for this unsafe function
149pub(crate) unsafe fn swish(a: &[f32], result: &mut [f32]) {
150    unsafe {
151        // swish(x) = x * sigmoid(x)
152        let len = a.len();
153        let mut i = 0;
154        let one = _mm_set1_ps(1.0);
155        let neg_one = _mm_set1_ps(-1.0);
156        let ln2 = _mm_set1_ps(std::f32::consts::LN_2);
157        let inv_ln2 = _mm_set1_ps(1.0 / std::f32::consts::LN_2);
158        let c2 = _mm_set1_ps(0.5);
159        let c3 = _mm_set1_ps(0.166_666_67);
160        let c4 = _mm_set1_ps(0.041_666_668);
161        let c5 = _mm_set1_ps(0.008_333_334);
162        while i + 4 <= len {
163            let x = _mm_loadu_ps(a.as_ptr().add(i));
164            let neg_x = _mm_mul_ps(x, neg_one);
165            let k = _mm_cvtps_epi32(_mm_mul_ps(neg_x, inv_ln2));
166            let kf = _mm_cvtepi32_ps(k);
167            let r = _mm_sub_ps(neg_x, _mm_mul_ps(kf, ln2));
168            let mut poly = _mm_add_ps(one, _mm_mul_ps(r, c5));
169            poly = _mm_add_ps(one, _mm_mul_ps(r, _mm_add_ps(c4, _mm_mul_ps(r, poly))));
170            poly = _mm_add_ps(one, _mm_mul_ps(r, _mm_add_ps(c3, _mm_mul_ps(r, poly))));
171            poly = _mm_add_ps(one, _mm_mul_ps(r, _mm_add_ps(c2, _mm_mul_ps(r, poly))));
172            poly = _mm_add_ps(one, _mm_mul_ps(r, poly));
173            let exp_k = _mm_castsi128_ps(_mm_slli_epi32(_mm_add_epi32(k, _mm_set1_epi32(127)), 23));
174            let exp_neg_x = _mm_mul_ps(poly, exp_k);
175            let sigmoid = _mm_div_ps(one, _mm_add_ps(one, exp_neg_x));
176            _mm_storeu_ps(result.as_mut_ptr().add(i), _mm_mul_ps(x, sigmoid));
177            i += 4;
178        }
179        for j in i..len {
180            result[j] = a[j] / (1.0 + (-a[j]).exp());
181        }
182    }
183}
184
185/// SSE2 tanh activation.
186#[inline]
187#[target_feature(enable = "sse2")]
188// SAFETY: caller ensures preconditions are met for this unsafe function
189pub(crate) unsafe fn tanh(a: &[f32], result: &mut [f32]) {
190    unsafe {
191        // tanh(x) = (e^2x - 1) / (e^2x + 1)
192        let len = a.len();
193        let mut i = 0;
194        let one = _mm_set1_ps(1.0);
195        let two = _mm_set1_ps(2.0);
196        while i + 4 <= len {
197            let x = _mm_loadu_ps(a.as_ptr().add(i));
198            let exp_2x = exp_approx_sse2(_mm_mul_ps(two, x));
199            _mm_storeu_ps(
200                result.as_mut_ptr().add(i),
201                _mm_div_ps(_mm_sub_ps(exp_2x, one), _mm_add_ps(exp_2x, one)),
202            );
203            i += 4;
204        }
205        for j in i..len {
206            let exp_2x = (2.0 * a[j]).exp();
207            result[j] = (exp_2x - 1.0) / (exp_2x + 1.0);
208        }
209    }
210}