#![allow(clippy::excessive_precision)]
use rten_simd::ops::{FloatOps, IntOps, NumOps};
use rten_simd::{Isa, Simd, SimdUnaryOp};
const INV_LOG2: f32 = std::f32::consts::LOG2_E; const ROUNDING_MAGIC: f32 = 12582912.;
const LOG2_HI: f32 = -6.93145752e-1;
const LOG2_LO: f32 = -1.42860677e-6;
const EXP_POLY_0: f32 = 1.0;
const EXP_POLY_1: f32 = 1.0;
const EXP_POLY_2: f32 = 4.99999851e-1; const EXP_POLY_3: f32 = 1.66664720e-1; const EXP_POLY_4: f32 = 4.16695364e-2; const EXP_POLY_5: f32 = 8.37312452e-3; const EXP_POLY_6: f32 = 1.37805939e-3;
#[derive(Default)]
pub struct Exp {}
impl SimdUnaryOp<f32> for Exp {
#[inline(always)]
fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
let ops = isa.f32();
let int_ops = isa.i32();
let inv_log_2 = ops.splat(INV_LOG2);
let rounding_magic = ops.splat(ROUNDING_MAGIC);
let ln2_hi = ops.splat(LOG2_HI);
let ln2_lo = ops.splat(LOG2_LO);
let p6 = ops.splat(EXP_POLY_6);
let p5 = ops.splat(EXP_POLY_5);
let p4 = ops.splat(EXP_POLY_4);
let p3 = ops.splat(EXP_POLY_3);
let p2 = ops.splat(EXP_POLY_2);
let p1 = ops.splat(EXP_POLY_1);
let p0 = ops.splat(EXP_POLY_0);
let j = ops.mul_add(x, inv_log_2, rounding_magic);
let j = ops.sub(j, rounding_magic);
let r = ops.mul_add(j, ln2_hi, x);
let r = ops.mul_add(j, ln2_lo, r);
let k = ops.to_int_trunc(j);
let mut tmp = p6;
tmp = ops.mul_add(tmp, r, p5);
tmp = ops.mul_add(tmp, r, p4);
tmp = ops.mul_add(tmp, r, p3);
tmp = ops.mul_add(tmp, r, p2);
tmp = ops.mul_add(tmp, r, p1);
let r = ops.mul_add(tmp, r, p0);
let ia = int_ops.gt(k, int_ops.zero());
let x7f = int_ops.splat(0x7f000000);
#[allow(overflowing_literals)]
let x83 = int_ops.splat(0x83000000);
let ia = int_ops.select(int_ops.zero(), x83, ia);
let is = int_ops.add(ia, x7f);
let it = int_ops.shift_left::<23>(k);
let it = int_ops.sub(it, ia);
let s: I::F32 = is.reinterpret_cast();
let t: I::F32 = it.reinterpret_cast();
let r = ops.mul(r, s);
let r = ops.mul(r, t);
let overflow_mask = ops.ge(x, ops.splat(104.0));
let underflow_mask = ops.le(x, ops.splat(-104.0));
let r = ops.select(ops.splat(f32::INFINITY), r, overflow_mask);
ops.select(ops.zero(), r, underflow_mask)
}
}
const EXP_LOWER_CUTOFF: f32 = -126.5 * std::f32::consts::LN_2 + 0.01;
#[derive(Default)]
pub struct ReducedRangeExp {}
impl SimdUnaryOp<f32> for ReducedRangeExp {
#[inline(always)]
fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
let ops = isa.f32();
let int_ops = isa.i32();
let inv_log_2 = ops.splat(INV_LOG2);
let rounding_magic = ops.splat(ROUNDING_MAGIC);
let ln2_hi = ops.splat(LOG2_HI);
let ln2_lo = ops.splat(LOG2_LO);
let p6 = ops.splat(EXP_POLY_6);
let p5 = ops.splat(EXP_POLY_5);
let p4 = ops.splat(EXP_POLY_4);
let p3 = ops.splat(EXP_POLY_3);
let p2 = ops.splat(EXP_POLY_2);
let p1 = ops.splat(EXP_POLY_1);
let p0 = ops.splat(EXP_POLY_0);
let j = ops.mul_add(x, inv_log_2, rounding_magic);
let j = ops.sub(j, rounding_magic);
let r = ops.mul_add(j, ln2_hi, x);
let r = ops.mul_add(j, ln2_lo, r);
let k = ops.to_int_trunc(j);
let mut tmp = p6;
tmp = ops.mul_add(tmp, r, p5);
tmp = ops.mul_add(tmp, r, p4);
tmp = ops.mul_add(tmp, r, p3);
tmp = ops.mul_add(tmp, r, p2);
tmp = ops.mul_add(tmp, r, p1);
let r = ops.mul_add(tmp, r, p0);
let exponent_bias = int_ops.splat(127);
let k_pow2 = int_ops.shift_left::<23>(int_ops.add(k, exponent_bias));
let k_pow2: I::F32 = k_pow2.reinterpret_cast();
let r = ops.mul(r, k_pow2);
let underflow_mask = ops.lt(x, ops.splat(EXP_LOWER_CUTOFF));
ops.select(ops.zero(), r, underflow_mask)
}
}
#[derive(Default)]
pub struct Sigmoid {}
impl SimdUnaryOp<f32> for Sigmoid {
#[inline(always)]
fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
let ops = isa.f32();
let denom = ops.add(ops.one(), Exp::apply(isa, ops.neg(x)));
ops.reciprocal(denom)
}
}
pub struct Silu {}
impl SimdUnaryOp<f32> for Silu {
#[inline(always)]
fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
let ops = isa.f32();
let denom = ops.add(ops.one(), Exp::apply(isa, ops.neg(x)));
ops.div(x, denom)
}
}
pub struct Swish {
pub beta: f32,
}
impl SimdUnaryOp<f32> for Swish {
#[inline(always)]
fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
let ops = isa.f32();
let beta = ops.splat(self.beta);
ops.mul(x, Sigmoid::apply(isa, ops.mul(x, beta)))
}
}
pub struct Elu {
pub alpha: f32,
}
impl SimdUnaryOp<f32> for Elu {
#[inline(always)]
fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
let ops = isa.f32();
let x_pos = ops.ge(x, ops.zero());
let x_exp = ops.mul(
ops.splat(self.alpha),
ops.sub(Exp::apply(isa, x), ops.splat(1.)),
);
ops.select(x, x_exp, x_pos)
}
}
#[cfg(test)]
mod tests {
use rten_simd::SimdUnaryOp;
use super::{EXP_LOWER_CUTOFF, ReducedRangeExp};
use crate::testing::{AllF32s, Tolerance, UnaryOpTester, arange, benchmark_op};
use crate::{Elu, Exp, Sigmoid, Silu, Swish};
const MAX_EXP_ERROR_ULPS: f32 = 1.0;
const MAX_SIGMOID_ERROR_ULPS: f32 = 4.0;
fn reference_elu(x: f32, alpha: f32) -> f32 {
if x >= 0. { x } else { alpha * (x.exp() - 1.) }
}
fn reference_sigmoid(x: f32) -> f32 {
1. / (1. + (-x).exp())
}
fn reference_silu(x: f32) -> f32 {
x * reference_sigmoid(x)
}
fn reference_swish(x: f32, beta: f32) -> f32 {
x * reference_sigmoid(beta * x)
}
#[test]
fn test_exp_basic() {
let cases = [-2.0f32, -1., -0.5, 0.1, 0., 0.1, 0.5, 1., 2., -105., 105.];
let exp_op = Exp {};
for case in cases {
let expected = case.exp();
let actual = exp_op.scalar_eval(case);
let diff = (expected - actual).abs();
if actual.is_infinite() || expected.is_infinite() {
assert_eq!(actual, expected);
} else {
assert_eq!(diff, 0.);
};
}
}
#[test]
fn test_exp() {
let test = UnaryOpTester {
reference: f32::exp,
simd: Exp {},
range: arange(-6., 6., 0.001),
tolerance: Tolerance::Ulp(MAX_EXP_ERROR_ULPS),
};
test.run();
}
#[test]
fn test_reduced_range_exp() {
let test = UnaryOpTester {
reference: f32::exp,
simd: ReducedRangeExp {},
range: arange(EXP_LOWER_CUTOFF, 0., 0.015),
tolerance: Tolerance::Ulp(MAX_EXP_ERROR_ULPS),
};
test.run();
}
#[test]
fn test_elu() {
let alpha = 0.5;
let test = UnaryOpTester {
reference: |x| reference_elu(x, alpha),
simd: Elu { alpha },
range: [-2., -1., 0., 1., 2.].into_iter(),
tolerance: Tolerance::Ulp(1.0),
};
test.run();
}
#[test]
#[ignore] fn test_exp_exhaustive() {
let test = UnaryOpTester {
reference: f32::exp,
simd: Exp {},
range: AllF32s::new(),
tolerance: Tolerance::Ulp(MAX_EXP_ERROR_ULPS),
};
test.run_with_progress();
}
#[test]
fn test_sigmoid() {
let test = UnaryOpTester {
reference: reference_sigmoid,
simd: Sigmoid {},
range: arange(-6., 6., 0.001),
tolerance: Tolerance::Ulp(MAX_SIGMOID_ERROR_ULPS),
};
test.run();
}
#[test]
#[ignore] fn test_sigmoid_exhaustive() {
let test = UnaryOpTester {
reference: reference_sigmoid,
simd: Sigmoid {},
range: AllF32s::new(),
tolerance: Tolerance::Ulp(MAX_SIGMOID_ERROR_ULPS),
};
test.run_with_progress();
}
#[test]
fn test_silu() {
let test = UnaryOpTester {
reference: reference_silu,
simd: Silu {},
range: arange(-6., 6., 0.001),
tolerance: Tolerance::Ulp(MAX_SIGMOID_ERROR_ULPS),
};
test.run();
}
#[test]
fn test_swish() {
let beta = 1.7;
let test = UnaryOpTester {
reference: |x| reference_swish(x, beta),
simd: Swish { beta },
range: arange(-6., 6., 0.001),
tolerance: Tolerance::Ulp(MAX_SIGMOID_ERROR_ULPS),
};
test.run();
}
#[test]
#[ignore]
fn bench_elu() {
let alpha = 0.5;
benchmark_op(
|xs, ys| {
xs.iter()
.zip(ys.iter_mut())
.for_each(|(x, y)| *y = reference_elu(*x, alpha))
},
|xs, ys| {
Elu { alpha }.map(xs, ys);
},
);
}
#[test]
#[ignore]
fn bench_exp() {
benchmark_op(
|xs, ys| xs.iter().zip(ys.iter_mut()).for_each(|(x, y)| *y = x.exp()),
|xs, ys| {
Exp {}.map(xs, ys);
},
);
}
#[test]
#[ignore]
fn bench_sigmoid() {
benchmark_op(
|xs, ys| {
xs.iter()
.zip(ys.iter_mut())
.for_each(|(x, y)| *y = reference_sigmoid(*x))
},
|xs, ys| {
Sigmoid {}.map(xs, ys);
},
);
}
}