1#![allow(clippy::excessive_precision)]
2
3use rten_simd::ops::{FloatOps, NumOps};
4use rten_simd::{Isa, Simd, SimdUnaryOp};
5
6use crate::Exp;
7
8pub struct Tanh {}
10
11impl SimdUnaryOp<f32> for Tanh {
12 #[inline(always)]
13 fn eval<I: Isa, S: Simd<Elem = f32, Isa = I>>(&self, isa: I, x: S) -> S {
14 let ops = isa.f32();
15 let x = x.same_cast();
16
17 let x_negative = ops.le(x, ops.zero());
18 let abs_x = ops.abs(x);
19
20 let x_cutoff = ops.ge(abs_x, ops.splat(9.02));
22
23 let x_tiny = ops.le(abs_x, ops.splat(0.0004));
25
26 let x_small = ops.le(abs_x, ops.splat(0.55));
29
30 const P1: f32 = 0.999999940395355224609375;
33 const P3: f32 = -0.33332359790802001953125;
34 const P5: f32 = 0.13310669362545013427734375;
35 const P7: f32 = -5.21197654306888580322265625e-2;
36 const P9: f32 = 1.5497927553951740264892578125e-2;
37
38 let p1 = ops.splat(P1);
39 let p3 = ops.splat(P3);
40 let p5 = ops.splat(P5);
41 let p7 = ops.splat(P7);
42 let p9 = ops.splat(P9);
43
44 let x_sqr = ops.mul(x, x);
45 let y_small = ops.mul_add(p9, x_sqr, p7);
46 let y_small = ops.mul_add(y_small, x_sqr, p5);
47 let y_small = ops.mul_add(y_small, x_sqr, p3);
48 let y_small = ops.mul_add(y_small, x_sqr, p1);
49 let y_small = ops.mul(y_small, abs_x);
50
51 let x2 = ops.mul(abs_x, ops.splat(2.0));
53 let exp_2x = Exp::apply(isa, x2);
54 let exp_2x_m1 = ops.sub(exp_2x, ops.one());
55 let exp_2x_p1 = ops.add(exp_2x, ops.one());
56 let y_medium = ops.div(exp_2x_m1, exp_2x_p1);
57
58 let y = ops.select(ops.one(), y_medium, x_cutoff);
60 let y = ops.select(y_small, y, x_small);
61 let y = ops.select(abs_x, y, x_tiny);
62
63 ops.select(ops.neg(y), y, x_negative).same_cast()
65 }
66}
67
68#[cfg(test)]
69mod tests {
70 use rten_simd::SimdUnaryOp;
71
72 use crate::testing::{
73 arange, benchmark_op, check_f32s_are_equal_ulps, check_with_all_f32s, AsUninit,
74 };
75 use crate::Tanh;
76
77 const MAX_TANH_ERROR_ULPS: f32 = 3.0;
79
80 #[test]
81 #[ignore] fn test_tanh_exhaustive() {
83 check_with_all_f32s(
84 |x| {
85 let mut y = [0.; 1];
86 Tanh {}.map(&[x], y.as_mut().as_uninit());
87 (y[0], x.tanh())
88 },
89 MAX_TANH_ERROR_ULPS,
90 "testing vec_tanh",
91 );
92 }
93
94 #[test]
95 fn test_tanh() {
96 let cases: Vec<f32> = arange(-8., 8., 0.001f32).collect();
97 let expected: Vec<_> = cases.iter().copied().map(|x| x.tanh()).collect();
98 let mut actual = cases.clone();
99 Tanh {}.map(&cases, actual.as_mut_slice().as_uninit());
100
101 let results = cases
102 .iter()
103 .zip(actual.iter().zip(expected.iter()))
104 .map(|(x, (actual, expected))| (*x, *actual, *expected));
105 check_f32s_are_equal_ulps(results, MAX_TANH_ERROR_ULPS);
106 }
107
108 #[test]
109 #[ignore]
110 fn bench_tanh() {
111 benchmark_op(
112 |xs, ys| {
113 xs.iter()
114 .zip(ys.iter_mut())
115 .for_each(|(x, y)| *y = x.tanh())
116 },
117 |xs, ys| Tanh {}.map(xs, ys),
118 );
119 }
120}