rten_vecmath/
erf.rs

1//! Error function ("erf") and closely related operations.
2
3#![allow(clippy::excessive_precision)]
4
5use std::f32::consts::SQRT_2;
6
7use rten_simd::ops::{FloatOps, NumOps};
8use rten_simd::{Isa, SimdUnaryOp};
9
10use crate::exp::ReducedRangeExp;
11use crate::tanh::Tanh;
12
13/// Vectorized error function (erf).
14///
15/// The implementation uses an approximation from Abramowitz and Stegun,
16/// see <https://en.wikipedia.org/wiki/Error_function#Approximation_with_elementary_functions>.
17///
18/// This has a maximum absolute error of 6.631017e-7 when comparing to
19/// `libm::erff` as a source of truth.
20#[derive(Default)]
21pub struct Erf {}
22
23impl SimdUnaryOp<f32> for Erf {
24    #[inline(always)]
25    fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
26        let ops = isa.f32();
27
28        let neg_mask = ops.lt(x, ops.zero());
29
30        let x = ops.abs(x);
31
32        let p = ops.splat(0.3275911);
33
34        // Coefficients for polynomial approximation.
35        let a0 = ops.splat(0.254829592);
36        let a1 = ops.splat(-0.284496736);
37        let a2 = ops.splat(1.421413741);
38        let a3 = ops.splat(-1.453152027);
39        let a4 = ops.splat(1.061405429);
40
41        // t = 1. / (1. + p * x);
42        let t = ops.reciprocal(ops.mul_add(x, p, ops.one()));
43        let at = ops.poly_eval(t, &[a0, a1, a2, a3, a4]);
44
45        // exp_mx2 = e^(-x^2). `-(x^2)` is always <= 0, so we can use
46        // reduced-range exp.
47        let x_m2 = ops.neg(ops.mul(x, x));
48        let exp_mx2 = ReducedRangeExp::apply(isa, x_m2);
49
50        // y = 1. - at * exp_mx2;
51        let y = ops.sub(ops.one(), ops.mul(at, exp_mx2));
52
53        // Approximation is valid only for x >= 0. For negative values approximation
54        // can be computed as -erf(-x).
55        ops.select(ops.neg(y), y, neg_mask)
56    }
57}
58
59const SQRT_2_RCP: f32 = 1.0 / SQRT_2;
60
61/// Computes the [GELU](https://onnx.ai/onnx/operators/onnx__Gelu.html)
62/// function.
63pub struct Gelu {}
64
65impl SimdUnaryOp<f32> for Gelu {
66    #[inline(always)]
67    fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
68        let ops = isa.f32();
69
70        let half_x = ops.mul(x, ops.splat(0.5));
71        let sqrt_2_rcp = ops.splat(SQRT_2_RCP);
72        let y = ops.mul(x, sqrt_2_rcp);
73        let y = ops.add(Erf::apply(isa, y), ops.splat(1.0));
74        ops.mul(half_x, y)
75    }
76}
77
78// sqrt(2 / pi)
79const SQRT_2_PI: f32 = 0.7978845608028654;
80
81/// Approximate Gelu function.
82///
83/// See <https://onnx.ai/onnx/operators/onnx__Gelu.html>.
84pub struct ApproxGelu {}
85
86impl SimdUnaryOp<f32> for ApproxGelu {
87    #[inline(always)]
88    fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
89        let ops = isa.f32();
90
91        // 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
92        let half_x = ops.mul(x, ops.splat(0.5));
93        let x_cubed = ops.mul(ops.mul(x, x), x);
94        let y = ops.mul_add(x_cubed, ops.splat(0.044715), x);
95        let y = ops.mul(y, ops.splat(SQRT_2_PI));
96        let y = Tanh::apply(isa, y);
97        let y = ops.add(y, ops.splat(1.));
98        ops.mul(half_x, y)
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use rten_simd::SimdUnaryOp;
105
106    use super::{ApproxGelu, Erf, Gelu};
107    use crate::testing::{AllF32s, Tolerance, UnaryOpTester, arange, benchmark_op};
108
109    fn reference_gelu(x: f32) -> f32 {
110        0.5 * x * (1. + libm::erff(x / (2.0f32).sqrt()))
111    }
112
113    fn reference_approx_gelu(x: f32) -> f32 {
114        let x_cubed = x * x * x;
115        let approx_erf = ((2.0f32 / std::f32::consts::PI).sqrt() * (x + 0.044715 * x_cubed)).tanh();
116        0.5 * x * (1. + approx_erf)
117    }
118
119    // Maximum difference between our erf function and `libm::erf` found
120    // through an exhaustive test.
121    //
122    // We use a max-difference test rather than comparing ULPs because the ULP
123    // difference is large when the input is near zero, but the absolute
124    // difference is still small enough to be acceptable for the practical uses
125    // this library is most concerned with.
126    const MAX_EXPECTED_DIFF: f32 = 6.631017e-7;
127
128    #[test]
129    fn test_erf() {
130        let test = UnaryOpTester {
131            reference: libm::erff,
132            simd: Erf {},
133            // This range is sufficient to cover the regions where the function
134            // is not saturated and where it is saturated at +/- 1.
135            range: arange(-6., 6., 0.001),
136            tolerance: Tolerance::Absolute(MAX_EXPECTED_DIFF),
137        };
138        test.run();
139    }
140
141    #[test]
142    #[ignore] // Ignored by default due to long runtime
143    fn test_erf_exhaustive() {
144        let test = UnaryOpTester {
145            reference: libm::erff,
146            simd: Erf {},
147            range: AllF32s::new(),
148            tolerance: Tolerance::Absolute(MAX_EXPECTED_DIFF),
149        };
150        test.run_with_progress();
151    }
152
153    #[test]
154    fn test_gelu() {
155        let test = UnaryOpTester {
156            reference: reference_gelu,
157            simd: Gelu {},
158            range: arange(-6., 6., 0.001),
159            tolerance: Tolerance::Absolute(MAX_EXPECTED_DIFF),
160        };
161        test.run();
162    }
163
164    #[test]
165    fn test_approx_gelu() {
166        let test = UnaryOpTester {
167            reference: reference_approx_gelu,
168            simd: ApproxGelu {},
169            range: arange(-6., 6., 0.001),
170            tolerance: Tolerance::Absolute(5e-7),
171        };
172        test.run();
173    }
174
175    #[test]
176    #[ignore]
177    fn bench_erf() {
178        benchmark_op(
179            |xs, ys| {
180                xs.iter()
181                    .zip(ys.iter_mut())
182                    .for_each(|(x, y)| *y = libm::erff(*x))
183            },
184            |xs, ys| {
185                Erf {}.map(xs, ys);
186            },
187        );
188    }
189}