#![allow(clippy::excessive_precision)]
use std::f32::consts::SQRT_2;
use rten_simd::ops::{FloatOps, NumOps};
use rten_simd::{Isa, SimdUnaryOp};
use crate::exp::ReducedRangeExp;
use crate::tanh::Tanh;
#[derive(Default)]
pub struct Erf {}
impl SimdUnaryOp<f32> for Erf {
#[inline(always)]
fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
let ops = isa.f32();
let neg_mask = ops.lt(x, ops.zero());
let x = ops.abs(x);
let p = ops.splat(0.3275911);
let a0 = ops.splat(0.254829592);
let a1 = ops.splat(-0.284496736);
let a2 = ops.splat(1.421413741);
let a3 = ops.splat(-1.453152027);
let a4 = ops.splat(1.061405429);
let t = ops.reciprocal(ops.mul_add(x, p, ops.one()));
let at = ops.poly_eval(t, &[a0, a1, a2, a3, a4]);
let x_m2 = ops.neg(ops.mul(x, x));
let exp_mx2 = ReducedRangeExp::apply(isa, x_m2);
let y = ops.sub(ops.one(), ops.mul(at, exp_mx2));
ops.select(ops.neg(y), y, neg_mask)
}
}
const SQRT_2_RCP: f32 = 1.0 / SQRT_2;
pub struct Gelu {}
impl SimdUnaryOp<f32> for Gelu {
#[inline(always)]
fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
let ops = isa.f32();
let half_x = ops.mul(x, ops.splat(0.5));
let sqrt_2_rcp = ops.splat(SQRT_2_RCP);
let y = ops.mul(x, sqrt_2_rcp);
let y = ops.add(Erf::apply(isa, y), ops.splat(1.0));
ops.mul(half_x, y)
}
}
const SQRT_2_PI: f32 = 0.7978845608028654;
pub struct ApproxGelu {}
impl SimdUnaryOp<f32> for ApproxGelu {
#[inline(always)]
fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
let ops = isa.f32();
let half_x = ops.mul(x, ops.splat(0.5));
let x_cubed = ops.mul(ops.mul(x, x), x);
let y = ops.mul_add(x_cubed, ops.splat(0.044715), x);
let y = ops.mul(y, ops.splat(SQRT_2_PI));
let y = Tanh::apply(isa, y);
let y = ops.add(y, ops.splat(1.));
ops.mul(half_x, y)
}
}
#[cfg(test)]
mod tests {
use rten_simd::SimdUnaryOp;
use super::{ApproxGelu, Erf, Gelu};
use crate::testing::{AllF32s, Tolerance, UnaryOpTester, arange, benchmark_op};
fn reference_gelu(x: f32) -> f32 {
0.5 * x * (1. + libm::erff(x / (2.0f32).sqrt()))
}
fn reference_approx_gelu(x: f32) -> f32 {
let x_cubed = x * x * x;
let approx_erf = ((2.0f32 / std::f32::consts::PI).sqrt() * (x + 0.044715 * x_cubed)).tanh();
0.5 * x * (1. + approx_erf)
}
const MAX_EXPECTED_DIFF: f32 = 6.631017e-7;
#[test]
fn test_erf() {
let test = UnaryOpTester {
reference: libm::erff,
simd: Erf {},
range: arange(-6., 6., 0.001),
tolerance: Tolerance::Absolute(MAX_EXPECTED_DIFF),
};
test.run();
}
#[test]
#[ignore] fn test_erf_exhaustive() {
let test = UnaryOpTester {
reference: libm::erff,
simd: Erf {},
range: AllF32s::new(),
tolerance: Tolerance::Absolute(MAX_EXPECTED_DIFF),
};
test.run_with_progress();
}
#[test]
fn test_gelu() {
let test = UnaryOpTester {
reference: reference_gelu,
simd: Gelu {},
range: arange(-6., 6., 0.001),
tolerance: Tolerance::Absolute(MAX_EXPECTED_DIFF),
};
test.run();
}
#[test]
fn test_approx_gelu() {
let test = UnaryOpTester {
reference: reference_approx_gelu,
simd: ApproxGelu {},
range: arange(-6., 6., 0.001),
tolerance: Tolerance::Absolute(5e-7),
};
test.run();
}
#[test]
#[ignore]
fn bench_erf() {
benchmark_op(
|xs, ys| {
xs.iter()
.zip(ys.iter_mut())
.for_each(|(x, y)| *y = libm::erff(*x))
},
|xs, ys| {
Erf {}.map(xs, ys);
},
);
}
}