burn_core/nn/
hard_sigmoid.rs1use burn_tensor::activation::hard_sigmoid;
2
3use crate as burn;
4use crate::config::Config;
5use crate::module::Module;
6use crate::module::{Content, DisplaySettings, ModuleDisplay};
7use crate::tensor::backend::Backend;
8use crate::tensor::Tensor;
9
10#[derive(Module, Clone, Debug)]
14#[module(custom_display)]
15pub struct HardSigmoid {
16 pub alpha: f64,
18 pub beta: f64,
20}
21#[derive(Config, Debug)]
23pub struct HardSigmoidConfig {
24 #[config(default = "0.2")]
26 pub alpha: f64,
27 #[config(default = "0.5")]
29 pub beta: f64,
30}
31impl HardSigmoidConfig {
32 pub fn init(&self) -> HardSigmoid {
34 HardSigmoid {
35 alpha: self.alpha,
36 beta: self.beta,
37 }
38 }
39}
40
41impl ModuleDisplay for HardSigmoid {
42 fn custom_settings(&self) -> Option<DisplaySettings> {
43 DisplaySettings::new()
44 .with_new_line_after_attribute(false)
45 .optional()
46 }
47
48 fn custom_content(&self, content: Content) -> Option<Content> {
49 content
50 .add("alpha", &self.alpha)
51 .add("beta", &self.beta)
52 .optional()
53 }
54}
55
56impl HardSigmoid {
57 pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
65 hard_sigmoid(input, self.alpha, self.beta)
66 }
67}
68
69#[cfg(test)]
70mod tests {
71 use super::*;
72 use crate::tensor::TensorData;
73 use crate::TestBackend;
74
75 #[test]
76 fn test_hard_sigmoid_forward() {
77 let device = <TestBackend as Backend>::Device::default();
78 let model: HardSigmoid = HardSigmoidConfig::new().init();
79 let input =
80 Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.4410, -0.2507]]), &device);
81 let out = model.forward(input);
82 let expected = TensorData::from([[0.5882, 0.44986]]);
83 out.to_data().assert_approx_eq(&expected, 4);
84 }
85
86 #[test]
87 fn display() {
88 let config = HardSigmoidConfig::new().init();
89 assert_eq!(
90 alloc::format!("{}", config),
91 "HardSigmoid {alpha: 0.2, beta: 0.5}"
92 );
93 }
94}