rten_vecmath/
tanh.rs

1#![allow(clippy::excessive_precision)]
2
3use rten_simd::ops::{FloatOps, NumOps};
4use rten_simd::{Isa, SimdUnaryOp};
5
6use crate::Exp;
7
8/// Computes the hyperbolic tangent function.
9#[derive(Default)]
10pub struct Tanh {}
11
12impl SimdUnaryOp<f32> for Tanh {
13    #[inline(always)]
14    fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
15        let ops = isa.f32();
16
17        let x_negative = ops.le(x, ops.zero());
18        let abs_x = ops.abs(x);
19
20        // Cutoff beyond which `f32::tanh(x)` saturates at +/- 1.0.
21        let x_cutoff = ops.ge(abs_x, ops.splat(9.02));
22
23        // tanh(x) ~ x when |x| is very small.
24        let x_tiny = ops.le(abs_x, ops.splat(0.0004));
25
26        // Threshold below which `tanh(x) = (exp(2x) - 1) / (exp(2x) + 1)` method
27        // produces errors >= 2 ULP.
28        let x_small = ops.le(abs_x, ops.splat(0.55));
29
30        // For small x, use polynomial approximation. Computed using Sollya with
31        // `P = fpminimax(f, [|1, 3, 5, 7, 9|], [|SG...|], [0, 0.6])`.
32        const P1: f32 = 0.999999940395355224609375;
33        const P3: f32 = -0.33332359790802001953125;
34        const P5: f32 = 0.13310669362545013427734375;
35        const P7: f32 = -5.21197654306888580322265625e-2;
36        const P9: f32 = 1.5497927553951740264892578125e-2;
37
38        let p1 = ops.splat(P1);
39        let p3 = ops.splat(P3);
40        let p5 = ops.splat(P5);
41        let p7 = ops.splat(P7);
42        let p9 = ops.splat(P9);
43
44        let x_sqr = ops.mul(x, x);
45        let y_small = ops.mul_add(p9, x_sqr, p7);
46        let y_small = ops.mul_add(y_small, x_sqr, p5);
47        let y_small = ops.mul_add(y_small, x_sqr, p3);
48        let y_small = ops.mul_add(y_small, x_sqr, p1);
49        let y_small = ops.mul(y_small, abs_x);
50
51        // For medium x, compute `tanh(x) = (exp(2x) - 1) / (exp(2x) + 1)`.
52        let x2 = ops.mul(abs_x, ops.splat(2.0));
53        let exp_2x = Exp::apply(isa, x2);
54        let exp_2x_m1 = ops.sub(exp_2x, ops.one());
55        let exp_2x_p1 = ops.add(exp_2x, ops.one());
56        let y_medium = ops.div(exp_2x_m1, exp_2x_p1);
57
58        // Select output to use depending on |x|.
59        let y = ops.select(ops.one(), y_medium, x_cutoff);
60        let y = ops.select(y_small, y, x_small);
61        let y = ops.select(abs_x, y, x_tiny);
62
63        // Flip sign if input was negative.
64        ops.select(ops.neg(y), y, x_negative)
65    }
66}
67
68#[cfg(test)]
69mod tests {
70    use rten_simd::SimdUnaryOp;
71
72    use crate::Tanh;
73    use crate::testing::{AllF32s, Tolerance, UnaryOpTester, arange, benchmark_op};
74
75    // Maximum error of `Tanh` compared to `f32::tanh`.
76    const MAX_TANH_ERROR_ULPS: f32 = 3.0;
77
78    #[test]
79    #[ignore] // Ignored by default due to long runtime
80    fn test_tanh_exhaustive() {
81        let test = UnaryOpTester {
82            reference: f32::tanh,
83            simd: Tanh {},
84            range: AllF32s::new(),
85            tolerance: Tolerance::Ulp(MAX_TANH_ERROR_ULPS),
86        };
87        test.run_with_progress();
88    }
89
90    #[test]
91    fn test_tanh() {
92        let test = UnaryOpTester {
93            reference: f32::tanh,
94            simd: Tanh {},
95            range: arange(-8., 8., 0.001),
96            tolerance: Tolerance::Ulp(MAX_TANH_ERROR_ULPS),
97        };
98        test.run();
99    }
100
101    #[test]
102    #[ignore]
103    fn bench_tanh() {
104        benchmark_op(
105            |xs, ys| {
106                xs.iter()
107                    .zip(ys.iter_mut())
108                    .for_each(|(x, y)| *y = x.tanh())
109            },
110            |xs, ys| {
111                Tanh {}.map(xs, ys);
112            },
113        );
114    }
115}