1#![allow(clippy::excessive_precision)]
4
5use std::f32::consts::SQRT_2;
6
7use rten_simd::ops::{FloatOps, NumOps};
8use rten_simd::{Isa, SimdUnaryOp};
9
10use crate::exp::ReducedRangeExp;
11use crate::tanh::Tanh;
12
13#[derive(Default)]
21pub struct Erf {}
22
23impl SimdUnaryOp<f32> for Erf {
24 #[inline(always)]
25 fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
26 let ops = isa.f32();
27
28 let neg_mask = ops.lt(x, ops.zero());
29
30 let x = ops.abs(x);
31
32 let p = ops.splat(0.3275911);
33
34 let a0 = ops.splat(0.254829592);
36 let a1 = ops.splat(-0.284496736);
37 let a2 = ops.splat(1.421413741);
38 let a3 = ops.splat(-1.453152027);
39 let a4 = ops.splat(1.061405429);
40
41 let t = ops.reciprocal(ops.mul_add(x, p, ops.one()));
43 let at = ops.poly_eval(t, &[a0, a1, a2, a3, a4]);
44
45 let x_m2 = ops.neg(ops.mul(x, x));
48 let exp_mx2 = ReducedRangeExp::apply(isa, x_m2);
49
50 let y = ops.sub(ops.one(), ops.mul(at, exp_mx2));
52
53 ops.select(ops.neg(y), y, neg_mask)
56 }
57}
58
59const SQRT_2_RCP: f32 = 1.0 / SQRT_2;
60
61pub struct Gelu {}
64
65impl SimdUnaryOp<f32> for Gelu {
66 #[inline(always)]
67 fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
68 let ops = isa.f32();
69
70 let half_x = ops.mul(x, ops.splat(0.5));
71 let sqrt_2_rcp = ops.splat(SQRT_2_RCP);
72 let y = ops.mul(x, sqrt_2_rcp);
73 let y = ops.add(Erf::apply(isa, y), ops.splat(1.0));
74 ops.mul(half_x, y)
75 }
76}
77
78const SQRT_2_PI: f32 = 0.7978845608028654;
80
81pub struct ApproxGelu {}
85
86impl SimdUnaryOp<f32> for ApproxGelu {
87 #[inline(always)]
88 fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
89 let ops = isa.f32();
90
91 let half_x = ops.mul(x, ops.splat(0.5));
93 let x_cubed = ops.mul(ops.mul(x, x), x);
94 let y = ops.mul_add(x_cubed, ops.splat(0.044715), x);
95 let y = ops.mul(y, ops.splat(SQRT_2_PI));
96 let y = Tanh::apply(isa, y);
97 let y = ops.add(y, ops.splat(1.));
98 ops.mul(half_x, y)
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use rten_simd::SimdUnaryOp;
105
106 use super::{ApproxGelu, Erf, Gelu};
107 use crate::testing::{AllF32s, Tolerance, UnaryOpTester, arange, benchmark_op};
108
109 fn reference_gelu(x: f32) -> f32 {
110 0.5 * x * (1. + libm::erff(x / (2.0f32).sqrt()))
111 }
112
113 fn reference_approx_gelu(x: f32) -> f32 {
114 let x_cubed = x * x * x;
115 let approx_erf = ((2.0f32 / std::f32::consts::PI).sqrt() * (x + 0.044715 * x_cubed)).tanh();
116 0.5 * x * (1. + approx_erf)
117 }
118
119 const MAX_EXPECTED_DIFF: f32 = 6.631017e-7;
127
128 #[test]
129 fn test_erf() {
130 let test = UnaryOpTester {
131 reference: libm::erff,
132 simd: Erf {},
133 range: arange(-6., 6., 0.001),
136 tolerance: Tolerance::Absolute(MAX_EXPECTED_DIFF),
137 };
138 test.run();
139 }
140
141 #[test]
142 #[ignore] fn test_erf_exhaustive() {
144 let test = UnaryOpTester {
145 reference: libm::erff,
146 simd: Erf {},
147 range: AllF32s::new(),
148 tolerance: Tolerance::Absolute(MAX_EXPECTED_DIFF),
149 };
150 test.run_with_progress();
151 }
152
153 #[test]
154 fn test_gelu() {
155 let test = UnaryOpTester {
156 reference: reference_gelu,
157 simd: Gelu {},
158 range: arange(-6., 6., 0.001),
159 tolerance: Tolerance::Absolute(MAX_EXPECTED_DIFF),
160 };
161 test.run();
162 }
163
164 #[test]
165 fn test_approx_gelu() {
166 let test = UnaryOpTester {
167 reference: reference_approx_gelu,
168 simd: ApproxGelu {},
169 range: arange(-6., 6., 0.001),
170 tolerance: Tolerance::Absolute(5e-7),
171 };
172 test.run();
173 }
174
175 #[test]
176 #[ignore]
177 fn bench_erf() {
178 benchmark_op(
179 |xs, ys| {
180 xs.iter()
181 .zip(ys.iter_mut())
182 .for_each(|(x, y)| *y = libm::erff(*x))
183 },
184 |xs, ys| {
185 Erf {}.map(xs, ys);
186 },
187 );
188 }
189}