burn_core/nn/
hard_sigmoid.rs

1use burn_tensor::activation::hard_sigmoid;
2
3use crate as burn;
4use crate::config::Config;
5use crate::module::Module;
6use crate::module::{Content, DisplaySettings, ModuleDisplay};
7use crate::tensor::Tensor;
8use crate::tensor::backend::Backend;
9
10/// Hard Sigmoid layer.
11///
12/// Should be created with [HardSigmoidConfig](HardSigmoidConfig).
13#[derive(Module, Clone, Debug)]
14#[module(custom_display)]
15pub struct HardSigmoid {
16    /// The alpha value.
17    pub alpha: f64,
18    /// The beta value.
19    pub beta: f64,
20}
21/// Configuration to create a [Hard Sigmoid](HardSigmoid) layer using the [init function](HardSigmoidConfig::init).
22#[derive(Config, Debug)]
23pub struct HardSigmoidConfig {
24    /// The alpha value. Default is 0.2
25    #[config(default = "0.2")]
26    pub alpha: f64,
27    /// The beta value. Default is 0.5
28    #[config(default = "0.5")]
29    pub beta: f64,
30}
31impl HardSigmoidConfig {
32    /// Initialize a new [Hard Sigmoid](HardSigmoid) Layer
33    pub fn init(&self) -> HardSigmoid {
34        HardSigmoid {
35            alpha: self.alpha,
36            beta: self.beta,
37        }
38    }
39}
40
41impl ModuleDisplay for HardSigmoid {
42    fn custom_settings(&self) -> Option<DisplaySettings> {
43        DisplaySettings::new()
44            .with_new_line_after_attribute(false)
45            .optional()
46    }
47
48    fn custom_content(&self, content: Content) -> Option<Content> {
49        content
50            .add("alpha", &self.alpha)
51            .add("beta", &self.beta)
52            .optional()
53    }
54}
55
56impl HardSigmoid {
57    /// Forward pass for the Hard Sigmoid layer.
58    ///
59    /// See [hard_sigmoid](crate::tensor::activation::hard_sigmoid) for more information.
60    ///
61    /// # Shapes
62    /// - input: `[..., any]`
63    /// - output: `[..., any]`
64    pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
65        hard_sigmoid(input, self.alpha, self.beta)
66    }
67}
68
69#[cfg(test)]
70mod tests {
71    use super::*;
72    use crate::TestBackend;
73    use crate::tensor::TensorData;
74    use burn_tensor::{Tolerance, ops::FloatElem};
75    type FT = FloatElem<TestBackend>;
76
77    #[test]
78    fn test_hard_sigmoid_forward() {
79        let device = <TestBackend as Backend>::Device::default();
80        let model: HardSigmoid = HardSigmoidConfig::new().init();
81        let input =
82            Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.4410, -0.2507]]), &device);
83        let out = model.forward(input);
84        let expected = TensorData::from([[0.5882, 0.44986]]);
85        out.to_data()
86            .assert_approx_eq::<FT>(&expected, Tolerance::default());
87    }
88
89    #[test]
90    fn display() {
91        let config = HardSigmoidConfig::new().init();
92        assert_eq!(
93            alloc::format!("{config}"),
94            "HardSigmoid {alpha: 0.2, beta: 0.5}"
95        );
96    }
97}