burn_nn/activation/
hard_swish.rs

1use burn_core as burn;
2
3use burn::module::Module;
4use burn::tensor::Tensor;
5use burn::tensor::activation::hard_swish;
6use burn::tensor::backend::Backend;
7
8/// Hard Swish layer.
9#[derive(Module, Clone, Debug, Default)]
10pub struct HardSwish;
11
12impl HardSwish {
13    /// Create the module.
14    pub fn new() -> Self {
15        Self
16    }
17
18    /// Forward pass for the Hard Swish layer.
19    ///
20    /// See [hard_swish](burn::tensor::activation::hard_swish) for more information.
21    ///
22    /// # Shapes
23    /// - input: `[..., any]`
24    /// - output: `[..., any]`
25    pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
26        hard_swish(input)
27    }
28}
29
30#[cfg(test)]
31mod tests {
32    use super::*;
33    use crate::TestBackend;
34    use burn::tensor::TensorData;
35    use burn::tensor::{Tolerance, ops::FloatElem};
36    type FT = FloatElem<TestBackend>;
37
38    #[test]
39    fn test_hard_swish_forward() {
40        let device = <TestBackend as Backend>::Device::default();
41        let model = HardSwish::new();
42
43        let input = Tensor::<TestBackend, 2>::from_data(
44            TensorData::from([[3.0f32, -3.0], [0.0, 1.0]]),
45            &device,
46        );
47        let out = model.forward(input);
48        let expected = TensorData::from([[3.0f32, 0.0], [0.0, 0.6666667]]);
49        out.to_data()
50            .assert_approx_eq::<FT>(&expected, Tolerance::default());
51    }
52
53    #[test]
54    fn display() {
55        let layer = HardSwish::new();
56        assert_eq!(alloc::format!("{layer}"), "HardSwish");
57    }
58}