burn_nn/activation/
softplus.rs1use burn_core as burn;
2
3use burn::config::Config;
4use burn::module::Module;
5use burn::module::{Content, DisplaySettings, ModuleDisplay};
6use burn::tensor::Tensor;
7use burn::tensor::activation::softplus;
8use burn::tensor::backend::Backend;
9
10#[derive(Module, Clone, Debug)]
17#[module(custom_display)]
18pub struct Softplus {
19 pub beta: f64,
21}
22
23#[derive(Config, Debug)]
25pub struct SoftplusConfig {
26 #[config(default = "1.0")]
28 pub beta: f64,
29}
30
31impl SoftplusConfig {
32 pub fn init(&self) -> Softplus {
34 Softplus { beta: self.beta }
35 }
36}
37
38impl ModuleDisplay for Softplus {
39 fn custom_settings(&self) -> Option<DisplaySettings> {
40 DisplaySettings::new()
41 .with_new_line_after_attribute(false)
42 .optional()
43 }
44
45 fn custom_content(&self, content: Content) -> Option<Content> {
46 content.add("beta", &self.beta).optional()
47 }
48}
49
50impl Softplus {
51 pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
59 softplus(input, self.beta)
60 }
61}
62
63#[cfg(test)]
64#[allow(clippy::approx_constant)]
65mod tests {
66 use super::*;
67 use crate::TestBackend;
68 use burn::tensor::TensorData;
69 use burn::tensor::{Tolerance, ops::FloatElem};
70 type FT = FloatElem<TestBackend>;
71
72 #[test]
73 fn test_softplus_forward() {
74 let device = <TestBackend as Backend>::Device::default();
75 let model: Softplus = SoftplusConfig::new().init();
76 let input =
77 Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0, 1.0, -1.0]]), &device);
78 let out = model.forward(input);
79 let expected = TensorData::from([[0.6931, 1.3133, 0.3133]]);
83 out.to_data()
84 .assert_approx_eq::<FT>(&expected, Tolerance::default());
85 }
86
87 #[test]
88 fn test_softplus_with_beta() {
89 let device = <TestBackend as Backend>::Device::default();
90 let model: Softplus = SoftplusConfig::new().with_beta(2.0).init();
91 let input = Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0, 1.0]]), &device);
92 let out = model.forward(input);
93 let expected = TensorData::from([[0.3466, 1.0635]]);
96 out.to_data()
97 .assert_approx_eq::<FT>(&expected, Tolerance::default());
98 }
99
100 #[test]
101 fn display() {
102 let config = SoftplusConfig::new().init();
103 assert_eq!(alloc::format!("{config}"), "Softplus {beta: 1}");
104 }
105}