burn_nn/activation/
hard_swish.rs1use burn_core as burn;
2
3use burn::module::Module;
4use burn::tensor::Tensor;
5use burn::tensor::activation::hard_swish;
6use burn::tensor::backend::Backend;
7
8#[derive(Module, Clone, Debug, Default)]
10pub struct HardSwish;
11
12impl HardSwish {
13 pub fn new() -> Self {
15 Self
16 }
17
18 pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
26 hard_swish(input)
27 }
28}
29
30#[cfg(test)]
31mod tests {
32 use super::*;
33 use crate::TestBackend;
34 use burn::tensor::TensorData;
35 use burn::tensor::{Tolerance, ops::FloatElem};
36 type FT = FloatElem<TestBackend>;
37
38 #[test]
39 fn test_hard_swish_forward() {
40 let device = <TestBackend as Backend>::Device::default();
41 let model = HardSwish::new();
42
43 let input = Tensor::<TestBackend, 2>::from_data(
44 TensorData::from([[3.0f32, -3.0], [0.0, 1.0]]),
45 &device,
46 );
47 let out = model.forward(input);
48 let expected = TensorData::from([[3.0f32, 0.0], [0.0, 0.6666667]]);
49 out.to_data()
50 .assert_approx_eq::<FT>(&expected, Tolerance::default());
51 }
52
53 #[test]
54 fn display() {
55 let layer = HardSwish::new();
56 assert_eq!(alloc::format!("{layer}"), "HardSwish");
57 }
58}