1#![allow(clippy::excessive_precision)]
2
3use rten_simd::ops::{FloatOps, NumOps};
4use rten_simd::{Isa, SimdUnaryOp};
5
6use crate::Exp;
7
8#[derive(Default)]
10pub struct Tanh {}
11
12impl SimdUnaryOp<f32> for Tanh {
13 #[inline(always)]
14 fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
15 let ops = isa.f32();
16
17 let x_negative = ops.le(x, ops.zero());
18 let abs_x = ops.abs(x);
19
20 let x_cutoff = ops.ge(abs_x, ops.splat(9.02));
22
23 let x_tiny = ops.le(abs_x, ops.splat(0.0004));
25
26 let x_small = ops.le(abs_x, ops.splat(0.55));
29
30 const P1: f32 = 0.999999940395355224609375;
33 const P3: f32 = -0.33332359790802001953125;
34 const P5: f32 = 0.13310669362545013427734375;
35 const P7: f32 = -5.21197654306888580322265625e-2;
36 const P9: f32 = 1.5497927553951740264892578125e-2;
37
38 let p1 = ops.splat(P1);
39 let p3 = ops.splat(P3);
40 let p5 = ops.splat(P5);
41 let p7 = ops.splat(P7);
42 let p9 = ops.splat(P9);
43
44 let x_sqr = ops.mul(x, x);
45 let y_small = ops.mul_add(p9, x_sqr, p7);
46 let y_small = ops.mul_add(y_small, x_sqr, p5);
47 let y_small = ops.mul_add(y_small, x_sqr, p3);
48 let y_small = ops.mul_add(y_small, x_sqr, p1);
49 let y_small = ops.mul(y_small, abs_x);
50
51 let x2 = ops.mul(abs_x, ops.splat(2.0));
53 let exp_2x = Exp::apply(isa, x2);
54 let exp_2x_m1 = ops.sub(exp_2x, ops.one());
55 let exp_2x_p1 = ops.add(exp_2x, ops.one());
56 let y_medium = ops.div(exp_2x_m1, exp_2x_p1);
57
58 let y = ops.select(ops.one(), y_medium, x_cutoff);
60 let y = ops.select(y_small, y, x_small);
61 let y = ops.select(abs_x, y, x_tiny);
62
63 ops.select(ops.neg(y), y, x_negative)
65 }
66}
67
68#[cfg(test)]
69mod tests {
70 use rten_simd::SimdUnaryOp;
71
72 use crate::Tanh;
73 use crate::testing::{AllF32s, Tolerance, UnaryOpTester, arange, benchmark_op};
74
75 const MAX_TANH_ERROR_ULPS: f32 = 3.0;
77
78 #[test]
79 #[ignore] fn test_tanh_exhaustive() {
81 let test = UnaryOpTester {
82 reference: f32::tanh,
83 simd: Tanh {},
84 range: AllF32s::new(),
85 tolerance: Tolerance::Ulp(MAX_TANH_ERROR_ULPS),
86 };
87 test.run_with_progress();
88 }
89
90 #[test]
91 fn test_tanh() {
92 let test = UnaryOpTester {
93 reference: f32::tanh,
94 simd: Tanh {},
95 range: arange(-8., 8., 0.001),
96 tolerance: Tolerance::Ulp(MAX_TANH_ERROR_ULPS),
97 };
98 test.run();
99 }
100
101 #[test]
102 #[ignore]
103 fn bench_tanh() {
104 benchmark_op(
105 |xs, ys| {
106 xs.iter()
107 .zip(ys.iter_mut())
108 .for_each(|(x, y)| *y = x.tanh())
109 },
110 |xs, ys| {
111 Tanh {}.map(xs, ys);
112 },
113 );
114 }
115}