rten_vecmath/
exp.rs

1//! Vectorized version of the exponential and closely related functions.
2
3#![allow(clippy::excessive_precision)]
4
5use rten_simd::ops::{FloatOps, IntOps, NumOps};
6use rten_simd::{Isa, Simd, SimdUnaryOp};
7
8const INV_LOG2: f32 = std::f32::consts::LOG2_E; // aka. 1 / ln2
9const ROUNDING_MAGIC: f32 = 12582912.; // 0x3 << 22
10
11// `log(2)` split into large and small parts for Cody-Waite range reduction.
12const LOG2_HI: f32 = -6.93145752e-1;
13const LOG2_LO: f32 = -1.42860677e-6;
14
15// Coefficients of polynomial used to approximate `exp(x)` in `[0, ln2/2]`.
16//
17// These values are very close to, but not exactly the same as the coefficients
18// of the Taylor series around 0 (1, 1/1!, 1/2!, 1/3!, 1/4! ...).
19const EXP_POLY_0: f32 = 1.0;
20const EXP_POLY_1: f32 = 1.0;
21const EXP_POLY_2: f32 = 4.99999851e-1; // ~ 1/2!
22const EXP_POLY_3: f32 = 1.66664720e-1; // ~ 1/3! or 1/6
23const EXP_POLY_4: f32 = 4.16695364e-2; // ~ 1/4! or 1/24
24const EXP_POLY_5: f32 = 8.37312452e-3; // ~ 1/5! or 1/120
25const EXP_POLY_6: f32 = 1.37805939e-3; // ~ 1/6! or 1/720
26
27/// Vectorized exponential function.
28///
29/// This has a maximum error of 1 ULP compared to `f32::exp` in the Rust standard
30/// library.
31#[derive(Default)]
32pub struct Exp {}
33
34// Implementation based on work by Norbert Juffa in
35// https://forums.developer.nvidia.com/t/a-more-accurate-performance-competitive-implementation-of-expf/47528.
36//
37// See also
38// https://justinwillmert.com/articles/2020/numerically-computing-the-exponential-function-with-polynomial-approximations/.
39//
40// Method outline:
41//
42//  1. Use the identity `exp(a + b) = exp(a) * exp(b)` to reduce the range for
43//     which a polynomial approximation needs to be valid:
44//
45//     ```text
46//        exp(x) = exp(ln2 * k) * exp(r);
47//               = 2**k * exp(r)
48//     ```
49//
50//     Such that `k` is an integer and `|r| <= 1/2 ln 2`.
51//
52//     ```text
53//             k = rintf(x / ln2)
54//             r = x - k * ln 2
55//     ```
56//
57//  2. Compute `exp(r)` using a polynomial approximation.
58//
59//  3. Compute result as `exp(x) = exp(r) * 2**k`. The reconstruction is split
60//     into multiple steps to extend the domain.
61impl SimdUnaryOp<f32> for Exp {
62    #[inline(always)]
63    fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
64        let ops = isa.f32();
65        let int_ops = isa.i32();
66
67        // Load constants
68        let inv_log_2 = ops.splat(INV_LOG2);
69        let rounding_magic = ops.splat(ROUNDING_MAGIC);
70        let ln2_hi = ops.splat(LOG2_HI);
71        let ln2_lo = ops.splat(LOG2_LO);
72
73        let p6 = ops.splat(EXP_POLY_6);
74        let p5 = ops.splat(EXP_POLY_5);
75        let p4 = ops.splat(EXP_POLY_4);
76        let p3 = ops.splat(EXP_POLY_3);
77        let p2 = ops.splat(EXP_POLY_2);
78        let p1 = ops.splat(EXP_POLY_1);
79        let p0 = ops.splat(EXP_POLY_0);
80
81        // Compute `k = rintf(x / ln2), r = x - k * ln2`.
82        let j = ops.mul_add(x, inv_log_2, rounding_magic);
83        let j = ops.sub(j, rounding_magic);
84        let r = ops.mul_add(j, ln2_hi, x);
85        let r = ops.mul_add(j, ln2_lo, r);
86        let k = ops.to_int_trunc(j);
87
88        // Approximate `exp(r)` on interval [-ln2 / 2, +ln2 / 2]
89        let mut tmp = p6;
90        tmp = ops.mul_add(tmp, r, p5);
91        tmp = ops.mul_add(tmp, r, p4);
92        tmp = ops.mul_add(tmp, r, p3);
93        tmp = ops.mul_add(tmp, r, p2);
94        tmp = ops.mul_add(tmp, r, p1);
95        let r = ops.mul_add(tmp, r, p0);
96
97        // Reconstruct `exp(x) = 2**k * exp(r`).
98        //
99        // Reconstruction is split into steps to extend the input domain of the
100        // function. The split reconstruction is effectively:
101        //
102        //   When k > 0:  exp(r) * exp2(127) * exp2(k - 127)
103        //   When k <= 0: exp(r) * exp2(-123) * exp2(k + 123)
104        //
105        // Where 127 is the exponent bias for f32.
106        let ia = int_ops.gt(k, int_ops.zero());
107        let x7f = int_ops.splat(0x7f000000);
108        #[allow(overflowing_literals)]
109        let x83 = int_ops.splat(0x83000000);
110        let ia = int_ops.select(int_ops.zero(), x83, ia);
111        let is = int_ops.add(ia, x7f);
112
113        let it = int_ops.shift_left::<23>(k);
114        let it = int_ops.sub(it, ia);
115
116        let s: I::F32 = is.reinterpret_cast();
117        let t: I::F32 = it.reinterpret_cast();
118        let r = ops.mul(r, s);
119        let r = ops.mul(r, t);
120
121        // Handle overflow and underflow when `x.abs() >= 104.`
122        let overflow_mask = ops.ge(x, ops.splat(104.0));
123        let underflow_mask = ops.le(x, ops.splat(-104.0));
124        let r = ops.select(ops.splat(f32::INFINITY), r, overflow_mask);
125        ops.select(ops.zero(), r, underflow_mask)
126    }
127}
128
129/// Cutoff value chosen such that if `k = round(x / ln2)`, `2**k` is a normal
130/// number.
131const EXP_LOWER_CUTOFF: f32 = -126.5 * std::f32::consts::LN_2 + 0.01; // ~87.67
132
133/// A simplified and faster version of [`Exp`] with a reduced domain and range.
134///
135/// 1. The input value must be <= 0
136/// 2. The lower cutoff for which `exp(x)` returns 0 is higher (~87.67 instead of ~104).
137#[derive(Default)]
138pub struct ReducedRangeExp {}
139
140impl SimdUnaryOp<f32> for ReducedRangeExp {
141    #[inline(always)]
142    fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
143        let ops = isa.f32();
144        let int_ops = isa.i32();
145
146        // Load constants
147        let inv_log_2 = ops.splat(INV_LOG2);
148        let rounding_magic = ops.splat(ROUNDING_MAGIC);
149        let ln2_hi = ops.splat(LOG2_HI);
150        let ln2_lo = ops.splat(LOG2_LO);
151
152        let p6 = ops.splat(EXP_POLY_6);
153        let p5 = ops.splat(EXP_POLY_5);
154        let p4 = ops.splat(EXP_POLY_4);
155        let p3 = ops.splat(EXP_POLY_3);
156        let p2 = ops.splat(EXP_POLY_2);
157        let p1 = ops.splat(EXP_POLY_1);
158        let p0 = ops.splat(EXP_POLY_0);
159
160        // Compute `k = rintf(x / ln2), r = x - k * ln2`.
161        //
162        // Since x <= 0, also k <= 0.
163        let j = ops.mul_add(x, inv_log_2, rounding_magic);
164        let j = ops.sub(j, rounding_magic);
165        let r = ops.mul_add(j, ln2_hi, x);
166        let r = ops.mul_add(j, ln2_lo, r);
167        let k = ops.to_int_trunc(j);
168
169        // Approximate `exp(r)` on interval [-ln2 / 2, +ln2 / 2]
170        let mut tmp = p6;
171        tmp = ops.mul_add(tmp, r, p5);
172        tmp = ops.mul_add(tmp, r, p4);
173        tmp = ops.mul_add(tmp, r, p3);
174        tmp = ops.mul_add(tmp, r, p2);
175        tmp = ops.mul_add(tmp, r, p1);
176        let r = ops.mul_add(tmp, r, p0);
177
178        // Reconstruct `exp(x) = 2**k * exp(r)`.
179        //
180        // This is valid as long as `k >= -126`, so that `2**k` as f32 is a
181        // normal number.
182        let exponent_bias = int_ops.splat(127);
183        let k_pow2 = int_ops.shift_left::<23>(int_ops.add(k, exponent_bias));
184        let k_pow2: I::F32 = k_pow2.reinterpret_cast();
185        let r = ops.mul(r, k_pow2);
186
187        // Handle underflow. We don't need to handle overflow since x <= 0.
188        let underflow_mask = ops.lt(x, ops.splat(EXP_LOWER_CUTOFF));
189        ops.select(ops.zero(), r, underflow_mask)
190    }
191}
192
193/// Computes the [sigmoid function][sigmoid], aka. the standard logistic function, `1. /
194/// (1. + (-x).exp())`.
195///
196/// This has a maximum error of 4 ULPs compared to a reference implementation
197/// using `1. / (1. + (-x).exp())`.
198///
199/// [sigmoid]: https://en.wikipedia.org/wiki/Logistic_function#Mathematical_properties
200#[derive(Default)]
201pub struct Sigmoid {}
202
203impl SimdUnaryOp<f32> for Sigmoid {
204    #[inline(always)]
205    fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
206        let ops = isa.f32();
207
208        // 1. + exp(-x)
209        let denom = ops.add(ops.one(), Exp::apply(isa, ops.neg(x)));
210        ops.reciprocal(denom)
211    }
212}
213
214/// Vectorized Sigmoid Linear Unit (SiLU) function.
215///
216/// This computes `x * sigmoid(x)` for all lanes in `x`.
217pub struct Silu {}
218
219impl SimdUnaryOp<f32> for Silu {
220    #[inline(always)]
221    fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
222        let ops = isa.f32();
223
224        // 1. + exp(-x)
225        let denom = ops.add(ops.one(), Exp::apply(isa, ops.neg(x)));
226        ops.div(x, denom)
227    }
228}
229
230/// Vectorized Swish function.
231///
232/// This computes `x * sigmoid(beta * x)` for each element.
233pub struct Swish {
234    pub beta: f32,
235}
236
237impl SimdUnaryOp<f32> for Swish {
238    #[inline(always)]
239    fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
240        let ops = isa.f32();
241
242        let beta = ops.splat(self.beta);
243        ops.mul(x, Sigmoid::apply(isa, ops.mul(x, beta)))
244    }
245}
246
247/// Computes the Exponential Linear Unit function.
248///
249/// Computes `if x >= 0 { x } else { alpha * (exp(x) - 1) }`.
250pub struct Elu {
251    pub alpha: f32,
252}
253
254impl SimdUnaryOp<f32> for Elu {
255    #[inline(always)]
256    fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
257        // The ONNX spec and the original paper [1] define Elu in slightly
258        // different, but equivalent ways:
259        //
260        // Original: `f(x) = x if x > 0 else alpha * (exp(x) - 1)`
261        // ONNX: `f(x) = x if x >= 0 else alpha * (exp(x) - 1)`
262        //
263        // [1] https://arxiv.org/pdf/1511.07289
264
265        let ops = isa.f32();
266        let x_pos = ops.ge(x, ops.zero());
267        let x_exp = ops.mul(
268            ops.splat(self.alpha),
269            ops.sub(Exp::apply(isa, x), ops.splat(1.)),
270        );
271        ops.select(x, x_exp, x_pos)
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use rten_simd::SimdUnaryOp;
278
279    use super::{EXP_LOWER_CUTOFF, ReducedRangeExp};
280    use crate::testing::{AllF32s, Tolerance, UnaryOpTester, arange, benchmark_op};
281    use crate::{Elu, Exp, Sigmoid, Silu, Swish};
282
283    // Maximum error of `Exp` compared to Rust standard library implementation.
284    const MAX_EXP_ERROR_ULPS: f32 = 1.0;
285
286    // Maximum error of `Sigmoid` compared to reference implementation below.
287    const MAX_SIGMOID_ERROR_ULPS: f32 = 4.0;
288
289    fn reference_elu(x: f32, alpha: f32) -> f32 {
290        if x >= 0. { x } else { alpha * (x.exp() - 1.) }
291    }
292
293    fn reference_sigmoid(x: f32) -> f32 {
294        1. / (1. + (-x).exp())
295    }
296
297    fn reference_silu(x: f32) -> f32 {
298        x * reference_sigmoid(x)
299    }
300
301    fn reference_swish(x: f32, beta: f32) -> f32 {
302        x * reference_sigmoid(beta * x)
303    }
304
305    #[test]
306    fn test_exp_basic() {
307        // A few simple test cases, including "typical" +/-ve inputs with
308        // |x| above/below ln2, zero and values below/above min/max cutoffs.
309        let cases = [-2.0f32, -1., -0.5, 0.1, 0., 0.1, 0.5, 1., 2., -105., 105.];
310
311        let exp_op = Exp {};
312        for case in cases {
313            let expected = case.exp();
314            let actual = exp_op.scalar_eval(case);
315            let diff = (expected - actual).abs();
316
317            if actual.is_infinite() || expected.is_infinite() {
318                assert_eq!(actual, expected);
319            } else {
320                // The expected precision is less than 1 ULP, so the diff should
321                // be exactly zero.
322                assert_eq!(diff, 0.);
323            };
324        }
325    }
326
327    #[test]
328    fn test_exp() {
329        let test = UnaryOpTester {
330            reference: f32::exp,
331            simd: Exp {},
332            range: arange(-6., 6., 0.001),
333            tolerance: Tolerance::Ulp(MAX_EXP_ERROR_ULPS),
334        };
335        test.run();
336    }
337
338    #[test]
339    fn test_reduced_range_exp() {
340        let test = UnaryOpTester {
341            reference: f32::exp,
342            simd: ReducedRangeExp {},
343            range: arange(EXP_LOWER_CUTOFF, 0., 0.015),
344            tolerance: Tolerance::Ulp(MAX_EXP_ERROR_ULPS),
345        };
346        test.run();
347    }
348
349    #[test]
350    fn test_elu() {
351        let alpha = 0.5;
352        let test = UnaryOpTester {
353            reference: |x| reference_elu(x, alpha),
354            simd: Elu { alpha },
355            range: [-2., -1., 0., 1., 2.].into_iter(),
356            tolerance: Tolerance::Ulp(1.0),
357        };
358        test.run();
359    }
360
361    #[test]
362    #[ignore] // Ignored by default due to long runtime
363    fn test_exp_exhaustive() {
364        let test = UnaryOpTester {
365            reference: f32::exp,
366            simd: Exp {},
367            range: AllF32s::new(),
368            tolerance: Tolerance::Ulp(MAX_EXP_ERROR_ULPS),
369        };
370        test.run_with_progress();
371    }
372
373    #[test]
374    fn test_sigmoid() {
375        let test = UnaryOpTester {
376            reference: reference_sigmoid,
377            simd: Sigmoid {},
378            range: arange(-6., 6., 0.001),
379            tolerance: Tolerance::Ulp(MAX_SIGMOID_ERROR_ULPS),
380        };
381        test.run();
382    }
383
384    #[test]
385    #[ignore] // Ignored by default due to long runtime
386    fn test_sigmoid_exhaustive() {
387        let test = UnaryOpTester {
388            reference: reference_sigmoid,
389            simd: Sigmoid {},
390            range: AllF32s::new(),
391            tolerance: Tolerance::Ulp(MAX_SIGMOID_ERROR_ULPS),
392        };
393        test.run_with_progress();
394    }
395
396    #[test]
397    fn test_silu() {
398        let test = UnaryOpTester {
399            reference: reference_silu,
400            simd: Silu {},
401            range: arange(-6., 6., 0.001),
402            tolerance: Tolerance::Ulp(MAX_SIGMOID_ERROR_ULPS),
403        };
404        test.run();
405    }
406
407    #[test]
408    fn test_swish() {
409        let beta = 1.7;
410        let test = UnaryOpTester {
411            reference: |x| reference_swish(x, beta),
412            simd: Swish { beta },
413            range: arange(-6., 6., 0.001),
414            tolerance: Tolerance::Ulp(MAX_SIGMOID_ERROR_ULPS),
415        };
416        test.run();
417    }
418
419    #[test]
420    #[ignore]
421    fn bench_elu() {
422        let alpha = 0.5;
423        benchmark_op(
424            |xs, ys| {
425                xs.iter()
426                    .zip(ys.iter_mut())
427                    .for_each(|(x, y)| *y = reference_elu(*x, alpha))
428            },
429            |xs, ys| {
430                Elu { alpha }.map(xs, ys);
431            },
432        );
433    }
434
435    #[test]
436    #[ignore]
437    fn bench_exp() {
438        benchmark_op(
439            |xs, ys| xs.iter().zip(ys.iter_mut()).for_each(|(x, y)| *y = x.exp()),
440            |xs, ys| {
441                Exp {}.map(xs, ys);
442            },
443        );
444    }
445
446    #[test]
447    #[ignore]
448    fn bench_sigmoid() {
449        benchmark_op(
450            |xs, ys| {
451                xs.iter()
452                    .zip(ys.iter_mut())
453                    .for_each(|(x, y)| *y = reference_sigmoid(*x))
454            },
455            |xs, ys| {
456                Sigmoid {}.map(xs, ys);
457            },
458        );
459    }
460}