rten_vecmath/
relu.rs

1//! Activations related to the ReLU activation.
2//!
3//! Vanilla ReLU doesn't really need an explicitly vectorized kernel because it
4//! is just `x.max(0)` which is easy for compilers to auto-vectorize. Variants
5//! such as leaky ReLU however do benefit.
6
7use rten_simd::ops::NumOps;
8use rten_simd::{Isa, SimdUnaryOp};
9
10/// Computes the leaky ReLU activation function.
11///
12/// This evaluates `if x < 0. { alpha * x } else { x }`.
13pub struct LeakyRelu {
14    pub alpha: f32,
15}
16
17impl SimdUnaryOp<f32> for LeakyRelu {
18    fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
19        let ops = isa.f32();
20        let alpha = ops.splat(self.alpha);
21        let x_neg = ops.lt(x, ops.zero());
22        let x_mul_alpha = ops.mul(x, alpha);
23        ops.select(x_mul_alpha, x, x_neg)
24    }
25}
26
27#[cfg(test)]
28mod tests {
29    use crate::testing::{Tolerance, UnaryOpTester};
30
31    use super::LeakyRelu;
32
33    fn reference_leaky_relu(x: f32, alpha: f32) -> f32 {
34        if x < 0. { alpha * x } else { x }
35    }
36
37    #[test]
38    fn test_leaky_relu() {
39        let alpha = 0.5;
40        let test = UnaryOpTester {
41            reference: |x: f32| reference_leaky_relu(x, alpha),
42            simd: LeakyRelu { alpha },
43            range: [-2., -1., 0., 1., 2.].iter().copied(),
44            tolerance: Tolerance::Ulp(1.0),
45        };
46        test.run();
47    }
48}