rten_vecmath/
tanh.rs

1#![allow(clippy::excessive_precision)]
2
3use rten_simd::ops::{FloatOps, NumOps};
4use rten_simd::{Isa, Simd, SimdUnaryOp};
5
6use crate::Exp;
7
8/// Vectorized tanh implementation.
9pub struct Tanh {}
10
11impl SimdUnaryOp<f32> for Tanh {
12    #[inline(always)]
13    fn eval<I: Isa, S: Simd<Elem = f32, Isa = I>>(&self, isa: I, x: S) -> S {
14        let ops = isa.f32();
15        let x = x.same_cast();
16
17        let x_negative = ops.le(x, ops.zero());
18        let abs_x = ops.abs(x);
19
20        // Cutoff beyond which `f32::tanh(x)` saturates at +/- 1.0.
21        let x_cutoff = ops.ge(abs_x, ops.splat(9.02));
22
23        // tanh(x) ~ x when |x| is very small.
24        let x_tiny = ops.le(abs_x, ops.splat(0.0004));
25
26        // Threshold below which `tanh(x) = (exp(2x) - 1) / (exp(2x) + 1)` method
27        // produces errors >= 2 ULP.
28        let x_small = ops.le(abs_x, ops.splat(0.55));
29
30        // For small x, use polynomial approximation. Computed using Sollya with
31        // `P = fpminimax(f, [|1, 3, 5, 7, 9|], [|SG...|], [0, 0.6])`.
32        const P1: f32 = 0.999999940395355224609375;
33        const P3: f32 = -0.33332359790802001953125;
34        const P5: f32 = 0.13310669362545013427734375;
35        const P7: f32 = -5.21197654306888580322265625e-2;
36        const P9: f32 = 1.5497927553951740264892578125e-2;
37
38        let p1 = ops.splat(P1);
39        let p3 = ops.splat(P3);
40        let p5 = ops.splat(P5);
41        let p7 = ops.splat(P7);
42        let p9 = ops.splat(P9);
43
44        let x_sqr = ops.mul(x, x);
45        let y_small = ops.mul_add(p9, x_sqr, p7);
46        let y_small = ops.mul_add(y_small, x_sqr, p5);
47        let y_small = ops.mul_add(y_small, x_sqr, p3);
48        let y_small = ops.mul_add(y_small, x_sqr, p1);
49        let y_small = ops.mul(y_small, abs_x);
50
51        // For medium x, compute `tanh(x) = (exp(2x) - 1) / (exp(2x) + 1)`.
52        let x2 = ops.mul(abs_x, ops.splat(2.0));
53        let exp_2x = Exp::apply(isa, x2);
54        let exp_2x_m1 = ops.sub(exp_2x, ops.one());
55        let exp_2x_p1 = ops.add(exp_2x, ops.one());
56        let y_medium = ops.div(exp_2x_m1, exp_2x_p1);
57
58        // Select output to use depending on |x|.
59        let y = ops.select(ops.one(), y_medium, x_cutoff);
60        let y = ops.select(y_small, y, x_small);
61        let y = ops.select(abs_x, y, x_tiny);
62
63        // Flip sign if input was negative.
64        ops.select(ops.neg(y), y, x_negative).same_cast()
65    }
66}
67
68#[cfg(test)]
69mod tests {
70    use rten_simd::SimdUnaryOp;
71
72    use crate::testing::{
73        arange, benchmark_op, check_f32s_are_equal_ulps, check_with_all_f32s, AsUninit,
74    };
75    use crate::Tanh;
76
77    // Maximum error of `vec_tanh` compared to `f32::tanh`.
78    const MAX_TANH_ERROR_ULPS: f32 = 3.0;
79
80    #[test]
81    #[ignore] // Ignored by default due to long runtime
82    fn test_tanh_exhaustive() {
83        check_with_all_f32s(
84            |x| {
85                let mut y = [0.; 1];
86                Tanh {}.map(&[x], y.as_mut().as_uninit());
87                (y[0], x.tanh())
88            },
89            MAX_TANH_ERROR_ULPS,
90            "testing vec_tanh",
91        );
92    }
93
94    #[test]
95    fn test_tanh() {
96        let cases: Vec<f32> = arange(-8., 8., 0.001f32).collect();
97        let expected: Vec<_> = cases.iter().copied().map(|x| x.tanh()).collect();
98        let mut actual = cases.clone();
99        Tanh {}.map(&cases, actual.as_mut_slice().as_uninit());
100
101        let results = cases
102            .iter()
103            .zip(actual.iter().zip(expected.iter()))
104            .map(|(x, (actual, expected))| (*x, *actual, *expected));
105        check_f32s_are_equal_ulps(results, MAX_TANH_ERROR_ULPS);
106    }
107
108    #[test]
109    #[ignore]
110    fn bench_tanh() {
111        benchmark_op(
112            |xs, ys| {
113                xs.iter()
114                    .zip(ys.iter_mut())
115                    .for_each(|(x, y)| *y = x.tanh())
116            },
117            |xs, ys| Tanh {}.map(xs, ys),
118        );
119    }
120}