1use rten_simd::ops::NumOps;
8use rten_simd::{Isa, SimdUnaryOp};
9
10pub struct LeakyRelu {
14 pub alpha: f32,
15}
16
17impl SimdUnaryOp<f32> for LeakyRelu {
18 fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
19 let ops = isa.f32();
20 let alpha = ops.splat(self.alpha);
21 let x_neg = ops.lt(x, ops.zero());
22 let x_mul_alpha = ops.mul(x, alpha);
23 ops.select(x_mul_alpha, x, x_neg)
24 }
25}
26
27#[cfg(test)]
28mod tests {
29 use crate::testing::{Tolerance, UnaryOpTester};
30
31 use super::LeakyRelu;
32
33 fn reference_leaky_relu(x: f32, alpha: f32) -> f32 {
34 if x < 0. { alpha * x } else { x }
35 }
36
37 #[test]
38 fn test_leaky_relu() {
39 let alpha = 0.5;
40 let test = UnaryOpTester {
41 reference: |x: f32| reference_leaky_relu(x, alpha),
42 simd: LeakyRelu { alpha },
43 range: [-2., -1., 0., 1., 2.].iter().copied(),
44 tolerance: Tolerance::Ulp(1.0),
45 };
46 test.run();
47 }
48}