rten_vecmath/
exp.rs

1//! Vectorized version of the exponential and closely related functions.
2
3#![allow(clippy::excessive_precision)]
4
5use rten_simd::ops::{FloatOps, IntOps, NumOps};
6use rten_simd::{Isa, Simd, SimdUnaryOp};
7
8const INV_LOG2: f32 = std::f32::consts::LOG2_E; // aka. 1 / ln2
9const ROUNDING_MAGIC: f32 = 12582912.; // 0x3 << 22
10
11// `log(2)` split into large and small parts for Cody-Waite range reduction.
12const LOG2_HI: f32 = -6.93145752e-1;
13const LOG2_LO: f32 = -1.42860677e-6;
14
15// Coefficients of polynomial used to approximate `exp(x)` in `[0, ln2/2]`.
16//
17// These values are very close to, but not exactly the same as the coefficients
18// of the Taylor series around 0 (1, 1/1!, 1/2!, 1/3!, 1/4! ...).
19const EXP_POLY_0: f32 = 1.0;
20const EXP_POLY_1: f32 = 1.0;
21const EXP_POLY_2: f32 = 4.99999851e-1; // ~ 1/2!
22const EXP_POLY_3: f32 = 1.66664720e-1; // ~ 1/3! or 1/6
23const EXP_POLY_4: f32 = 4.16695364e-2; // ~ 1/4! or 1/24
24const EXP_POLY_5: f32 = 8.37312452e-3; // ~ 1/5! or 1/120
25const EXP_POLY_6: f32 = 1.37805939e-3; // ~ 1/6! or 1/720
26
27/// Vectorized exponential function.
28///
29/// This has a maximum error of 1 ULP compared to `f32::exp` in the Rust standard
30/// library.
31#[derive(Default)]
32pub struct Exp {}
33
34// Implementation based on work by Norbert Juffa in
35// https://forums.developer.nvidia.com/t/a-more-accurate-performance-competitive-implementation-of-expf/47528.
36//
37// See also
38// https://justinwillmert.com/articles/2020/numerically-computing-the-exponential-function-with-polynomial-approximations/.
39//
40// Method outline:
41//
42//  1. Use the identity `exp(a + b) = exp(a) * exp(b)` to reduce the range for
43//     which a polynomial approximation needs to be valid:
44//
45//     ```text
46//        exp(x) = exp(ln2 * k) * exp(r);
47//               = 2**k * exp(r)
48//     ```
49//
50//     Such that `k` is an integer and `|r| <= 1/2 ln 2`.
51//
52//     ```text
53//             k = rintf(x / ln2)
54//             r = x - k * ln 2
55//     ```
56//
57//  2. Compute `exp(r)` using a polynomial approximation.
58//
59//  3. Compute result as `exp(x) = exp(r) * 2**k`. The reconstruction is split
60//     into multiple steps to extend the domain.
61impl SimdUnaryOp<f32> for Exp {
62    #[inline(always)]
63    fn eval<I: Isa, S: Simd<Elem = f32, Isa = I>>(&self, isa: I, x: S) -> S {
64        let ops = isa.f32();
65        let int_ops = isa.i32();
66
67        let x = x.same_cast();
68
69        // Load constants
70        let inv_log_2 = ops.splat(INV_LOG2);
71        let rounding_magic = ops.splat(ROUNDING_MAGIC);
72        let ln2_hi = ops.splat(LOG2_HI);
73        let ln2_lo = ops.splat(LOG2_LO);
74
75        let p6 = ops.splat(EXP_POLY_6);
76        let p5 = ops.splat(EXP_POLY_5);
77        let p4 = ops.splat(EXP_POLY_4);
78        let p3 = ops.splat(EXP_POLY_3);
79        let p2 = ops.splat(EXP_POLY_2);
80        let p1 = ops.splat(EXP_POLY_1);
81        let p0 = ops.splat(EXP_POLY_0);
82
83        // Compute `k = rintf(x / ln2), r = x - k * ln2`.
84        let j = ops.mul_add(x, inv_log_2, rounding_magic);
85        let j = ops.sub(j, rounding_magic);
86        let r = ops.mul_add(j, ln2_hi, x);
87        let r = ops.mul_add(j, ln2_lo, r);
88        let k = ops.to_int_trunc(j);
89
90        // Approximate `exp(r)` on interval [-ln2 / 2, +ln2 / 2]
91        let mut tmp = p6;
92        tmp = ops.mul_add(tmp, r, p5);
93        tmp = ops.mul_add(tmp, r, p4);
94        tmp = ops.mul_add(tmp, r, p3);
95        tmp = ops.mul_add(tmp, r, p2);
96        tmp = ops.mul_add(tmp, r, p1);
97        let r = ops.mul_add(tmp, r, p0);
98
99        // Reconstruct `exp(x) = 2**k * exp(r`).
100        //
101        // Reconstruction is split into steps to extend the input domain of the
102        // function. The split reconstruction is effectively:
103        //
104        //   When k > 0:  exp(r) * exp2(127) * exp2(k - 127)
105        //   When k <= 0: exp(r) * exp2(-123) * exp2(k + 123)
106        //
107        // Where 127 is the exponent bias for f32.
108        let ia = int_ops.gt(k, int_ops.zero());
109        let x7f = int_ops.splat(0x7f000000);
110        #[allow(overflowing_literals)]
111        let x83 = int_ops.splat(0x83000000);
112        let ia = int_ops.select(int_ops.zero(), x83, ia);
113        let is = int_ops.add(ia, x7f);
114
115        let it = int_ops.shift_left::<23>(k);
116        let it = int_ops.sub(it, ia);
117
118        let s: I::F32 = is.reinterpret_cast();
119        let t: I::F32 = it.reinterpret_cast();
120        let r = ops.mul(r, s);
121        let r = ops.mul(r, t);
122
123        // Handle overflow and underflow when `x.abs() >= 104.`
124        let overflow_mask = ops.ge(x, ops.splat(104.0));
125        let underflow_mask = ops.le(x, ops.splat(-104.0));
126        let r = ops.select(ops.splat(f32::INFINITY), r, overflow_mask);
127        ops.select(ops.zero(), r, underflow_mask).same_cast()
128    }
129}
130
131/// Cutoff value chosen such that if `k = round(x / ln2)`, `2**k` is a normal
132/// number.
133const EXP_LOWER_CUTOFF: f32 = -126.5 * std::f32::consts::LN_2 + 0.01; // ~87.67
134
135/// A simplified and faster version of [`Exp`] with a reduced domain and range.
136///
137/// 1. The input value must be <= 0
138/// 2. The lower cutoff for which `exp(x)` returns 0 is higher (~87.67 instead of ~104).
139#[derive(Default)]
140pub struct ReducedRangeExp {}
141
142impl SimdUnaryOp<f32> for ReducedRangeExp {
143    #[inline(always)]
144    fn eval<I: Isa, S: Simd<Elem = f32, Isa = I>>(&self, isa: I, x: S) -> S {
145        let ops = isa.f32();
146        let int_ops = isa.i32();
147
148        let x = x.same_cast();
149
150        // Load constants
151        let inv_log_2 = ops.splat(INV_LOG2);
152        let rounding_magic = ops.splat(ROUNDING_MAGIC);
153        let ln2_hi = ops.splat(LOG2_HI);
154        let ln2_lo = ops.splat(LOG2_LO);
155
156        let p6 = ops.splat(EXP_POLY_6);
157        let p5 = ops.splat(EXP_POLY_5);
158        let p4 = ops.splat(EXP_POLY_4);
159        let p3 = ops.splat(EXP_POLY_3);
160        let p2 = ops.splat(EXP_POLY_2);
161        let p1 = ops.splat(EXP_POLY_1);
162        let p0 = ops.splat(EXP_POLY_0);
163
164        // Compute `k = rintf(x / ln2), r = x - k * ln2`.
165        //
166        // Since x <= 0, also k <= 0.
167        let j = ops.mul_add(x, inv_log_2, rounding_magic);
168        let j = ops.sub(j, rounding_magic);
169        let r = ops.mul_add(j, ln2_hi, x);
170        let r = ops.mul_add(j, ln2_lo, r);
171        let k = ops.to_int_trunc(j);
172
173        // Approximate `exp(r)` on interval [-ln2 / 2, +ln2 / 2]
174        let mut tmp = p6;
175        tmp = ops.mul_add(tmp, r, p5);
176        tmp = ops.mul_add(tmp, r, p4);
177        tmp = ops.mul_add(tmp, r, p3);
178        tmp = ops.mul_add(tmp, r, p2);
179        tmp = ops.mul_add(tmp, r, p1);
180        let r = ops.mul_add(tmp, r, p0);
181
182        // Reconstruct `exp(x) = 2**k * exp(r)`.
183        //
184        // This is valid as long as `k >= -126`, so that `2**k` as f32 is a
185        // normal number.
186        let exponent_bias = int_ops.splat(127);
187        let k_pow2 = int_ops.shift_left::<23>(int_ops.add(k, exponent_bias));
188        let k_pow2: I::F32 = k_pow2.reinterpret_cast();
189        let r = ops.mul(r, k_pow2);
190
191        // Handle underflow. We don't need to handle overflow since x <= 0.
192        let underflow_mask = ops.lt(x, ops.splat(EXP_LOWER_CUTOFF));
193        ops.select(ops.zero(), r, underflow_mask).same_cast()
194    }
195}
196
197/// Computes the [sigmoid function][sigmoid], aka. the standard logistic function, `1. /
198/// (1. + (-x).exp())`.
199///
200/// This has a maximum error of 4 ULPs compared to a reference implementation
201/// using `1. / (1. + (-x).exp())`.
202///
203/// [sigmoid]: https://en.wikipedia.org/wiki/Logistic_function#Mathematical_properties
204#[derive(Default)]
205pub struct Sigmoid {}
206
207impl SimdUnaryOp<f32> for Sigmoid {
208    #[inline(always)]
209    fn eval<I: Isa, S: Simd<Elem = f32, Isa = I>>(&self, isa: I, x: S) -> S {
210        let ops = isa.f32();
211        let x = x.same_cast();
212
213        // 1. + exp(-x)
214        let denom = ops.add(ops.one(), Exp::apply(isa, ops.neg(x)));
215        ops.reciprocal(denom).same_cast()
216    }
217}
218
219/// Vectorized Sigmoid Linear Unit (SiLU) function.
220///
221/// This computes `x * sigmoid(x)` for all lanes in `x`.
222pub struct Silu {}
223
224impl SimdUnaryOp<f32> for Silu {
225    #[inline(always)]
226    fn eval<I: Isa, S: Simd<Elem = f32, Isa = I>>(&self, isa: I, x: S) -> S {
227        let ops = isa.f32();
228        let x = x.same_cast();
229
230        ops.mul(x, Sigmoid::apply(isa, x)).same_cast()
231    }
232}
233
234/// Vectorized Swish function.
235///
236/// This computes `x * sigmoid(beta * x)` for each element.
237pub struct Swish {
238    pub beta: f32,
239}
240
241impl SimdUnaryOp<f32> for Swish {
242    #[inline(always)]
243    fn eval<I: Isa, S: Simd<Elem = f32, Isa = I>>(&self, isa: I, x: S) -> S {
244        let ops = isa.f32();
245        let x = x.same_cast();
246
247        let beta = ops.splat(self.beta);
248        ops.mul(x, Sigmoid::apply(isa, ops.mul(x, beta)))
249            .same_cast()
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use std::mem::MaybeUninit;
256
257    use rten_simd::SimdUnaryOp;
258
259    use super::{ReducedRangeExp, EXP_LOWER_CUTOFF};
260    use crate::testing::{
261        arange, benchmark_op, check_f32s_are_equal_ulps, check_with_all_f32s, AsUninit,
262    };
263    use crate::{Exp, Sigmoid, Silu, Swish};
264
265    // Maximum error of `Exp` compared to Rust standard library implementation.
266    const MAX_EXP_ERROR_ULPS: f32 = 1.0;
267
268    // Maximum error of `Sigmoid` compared to reference implementation below.
269    const MAX_SIGMOID_ERROR_ULPS: f32 = 4.0;
270
271    fn reference_sigmoid(x: f32) -> f32 {
272        1. / (1. + (-x).exp())
273    }
274
275    fn reference_silu(x: f32) -> f32 {
276        x * reference_sigmoid(x)
277    }
278
279    fn reference_swish(x: f32, beta: f32) -> f32 {
280        x * reference_sigmoid(beta * x)
281    }
282
283    /// Check the results of a SIMD implementation of a unary operator against
284    /// a reference implementation.
285    fn check_simd_vs_reference<
286        F: Fn(&[f32], &mut [MaybeUninit<f32>]),
287        R: Fn(f32) -> f32,
288        I: Iterator<Item = f32>,
289    >(
290        simd_op: F,
291        reference_op: R,
292        max_error_ulps: f32,
293        values: I,
294    ) {
295        let cases: Vec<_> = values.collect();
296        let expected: Vec<_> = cases.iter().copied().map(reference_op).collect();
297        let mut actual = cases.clone();
298
299        simd_op(&cases, actual.as_mut_slice().as_uninit());
300
301        let results = cases
302            .iter()
303            .zip(actual.iter().zip(expected.iter()))
304            .map(|(x, (actual, expected))| (*x, *actual, *expected));
305        check_f32s_are_equal_ulps(results, max_error_ulps);
306    }
307
308    #[test]
309    fn test_exp_basic() {
310        // A few simple test cases, including "typical" +/-ve inputs with
311        // |x| above/below ln2, zero and values below/above min/max cutoffs.
312        let cases = [-2.0f32, -1., -0.5, 0.1, 0., 0.1, 0.5, 1., 2., -105., 105.];
313
314        let exp_op = Exp {};
315        for case in cases {
316            let expected = case.exp();
317            let actual = exp_op.scalar_eval(case);
318            let diff = (expected - actual).abs();
319
320            if actual.is_infinite() || expected.is_infinite() {
321                assert_eq!(actual, expected);
322            } else {
323                // The expected precision is less than 1 ULP, so the diff should
324                // be exactly zero.
325                assert_eq!(diff, 0.);
326            };
327        }
328    }
329
330    #[test]
331    fn test_exp() {
332        check_simd_vs_reference(
333            |src, dest| Exp {}.map(src, dest),
334            f32::exp,
335            MAX_EXP_ERROR_ULPS,
336            arange(-6., 6., 0.001f32),
337        );
338    }
339
340    #[test]
341    fn test_reduced_range_exp() {
342        check_simd_vs_reference(
343            |src, dest| ReducedRangeExp {}.map(src, dest),
344            f32::exp,
345            MAX_EXP_ERROR_ULPS,
346            arange(EXP_LOWER_CUTOFF, 0., 0.015f32),
347        );
348    }
349
350    #[test]
351    #[ignore] // Ignored by default due to long runtime
352    fn test_exp_exhaustive() {
353        let exp_op = Exp {};
354        check_with_all_f32s(
355            |x| (exp_op.scalar_eval(x), x.exp()),
356            MAX_EXP_ERROR_ULPS,
357            "testing exp",
358        );
359        check_with_all_f32s(
360            |x| {
361                let mut y = [0.; 1];
362                exp_op.map(&[x], y.as_mut().as_uninit());
363                (y[0], x.exp())
364            },
365            MAX_EXP_ERROR_ULPS,
366            "testing vec_expf",
367        );
368    }
369
370    #[test]
371    fn test_sigmoid() {
372        check_simd_vs_reference(
373            |src, dest| Sigmoid {}.map(src, dest),
374            reference_sigmoid,
375            MAX_SIGMOID_ERROR_ULPS,
376            arange(-6., 6., 0.001f32),
377        );
378    }
379
380    #[test]
381    #[ignore] // Ignored by default due to long runtime
382    fn test_sigmoid_exhaustive() {
383        check_with_all_f32s(
384            |x| {
385                let mut y = [0.; 1];
386                Sigmoid {}.map(&[x], y.as_mut().as_uninit());
387                (y[0], reference_sigmoid(x))
388            },
389            MAX_SIGMOID_ERROR_ULPS,
390            "testing vec_sigmoid",
391        );
392    }
393
394    #[test]
395    fn test_silu() {
396        check_simd_vs_reference(
397            |src, dest| Silu {}.map(src, dest),
398            reference_silu,
399            MAX_SIGMOID_ERROR_ULPS,
400            arange(-6., 6., 0.001f32),
401        );
402    }
403
404    #[test]
405    fn test_swish() {
406        let beta = 1.7;
407        check_simd_vs_reference(
408            |src, dest| Swish { beta }.map(src, dest),
409            |x| reference_swish(x, beta),
410            MAX_SIGMOID_ERROR_ULPS,
411            arange(-6., 6., 0.001f32),
412        )
413    }
414
415    #[test]
416    #[ignore]
417    fn bench_exp() {
418        benchmark_op(
419            |xs, ys| xs.iter().zip(ys.iter_mut()).for_each(|(x, y)| *y = x.exp()),
420            |xs, ys| Exp {}.map(xs, ys),
421        );
422    }
423
424    #[test]
425    #[ignore]
426    fn bench_sigmoid() {
427        benchmark_op(
428            |xs, ys| {
429                xs.iter()
430                    .zip(ys.iter_mut())
431                    .for_each(|(x, y)| *y = reference_sigmoid(*x))
432            },
433            |xs, ys| Sigmoid {}.map(xs, ys),
434        );
435    }
436}