burn_core/nn/
hard_sigmoid.rs1use burn_tensor::activation::hard_sigmoid;
2
3use crate as burn;
4use crate::config::Config;
5use crate::module::Module;
6use crate::module::{Content, DisplaySettings, ModuleDisplay};
7use crate::tensor::Tensor;
8use crate::tensor::backend::Backend;
9
10#[derive(Module, Clone, Debug)]
14#[module(custom_display)]
15pub struct HardSigmoid {
16 pub alpha: f64,
18 pub beta: f64,
20}
21#[derive(Config, Debug)]
23pub struct HardSigmoidConfig {
24 #[config(default = "0.2")]
26 pub alpha: f64,
27 #[config(default = "0.5")]
29 pub beta: f64,
30}
31impl HardSigmoidConfig {
32 pub fn init(&self) -> HardSigmoid {
34 HardSigmoid {
35 alpha: self.alpha,
36 beta: self.beta,
37 }
38 }
39}
40
41impl ModuleDisplay for HardSigmoid {
42 fn custom_settings(&self) -> Option<DisplaySettings> {
43 DisplaySettings::new()
44 .with_new_line_after_attribute(false)
45 .optional()
46 }
47
48 fn custom_content(&self, content: Content) -> Option<Content> {
49 content
50 .add("alpha", &self.alpha)
51 .add("beta", &self.beta)
52 .optional()
53 }
54}
55
56impl HardSigmoid {
57 pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
65 hard_sigmoid(input, self.alpha, self.beta)
66 }
67}
68
69#[cfg(test)]
70mod tests {
71 use super::*;
72 use crate::TestBackend;
73 use crate::tensor::TensorData;
74 use burn_tensor::{Tolerance, ops::FloatElem};
75 type FT = FloatElem<TestBackend>;
76
77 #[test]
78 fn test_hard_sigmoid_forward() {
79 let device = <TestBackend as Backend>::Device::default();
80 let model: HardSigmoid = HardSigmoidConfig::new().init();
81 let input =
82 Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.4410, -0.2507]]), &device);
83 let out = model.forward(input);
84 let expected = TensorData::from([[0.5882, 0.44986]]);
85 out.to_data()
86 .assert_approx_eq::<FT>(&expected, Tolerance::default());
87 }
88
89 #[test]
90 fn display() {
91 let config = HardSigmoidConfig::new().init();
92 assert_eq!(
93 alloc::format!("{config}"),
94 "HardSigmoid {alpha: 0.2, beta: 0.5}"
95 );
96 }
97}