Skip to main content

burn_nn/activation/
softplus.rs

1use burn_core as burn;
2
3use burn::config::Config;
4use burn::module::Module;
5use burn::module::{Content, DisplaySettings, ModuleDisplay};
6use burn::tensor::Tensor;
7use burn::tensor::activation::softplus;
8use burn::tensor::backend::Backend;
9
10/// Softplus layer.
11///
12/// Applies the softplus function element-wise:
13/// `softplus(x) = (1/beta) * log(1 + exp(beta * x))`
14///
15/// Should be created with [SoftplusConfig](SoftplusConfig).
16#[derive(Module, Clone, Debug)]
17#[module(custom_display)]
18pub struct Softplus {
19    /// The beta value.
20    pub beta: f64,
21}
22
23/// Configuration to create a [Softplus](Softplus) layer using the [init function](SoftplusConfig::init).
24#[derive(Config, Debug)]
25pub struct SoftplusConfig {
26    /// The beta value. Default is 1.0
27    #[config(default = "1.0")]
28    pub beta: f64,
29}
30
31impl SoftplusConfig {
32    /// Initialize a new [Softplus](Softplus) Layer
33    pub fn init(&self) -> Softplus {
34        Softplus { beta: self.beta }
35    }
36}
37
38impl ModuleDisplay for Softplus {
39    fn custom_settings(&self) -> Option<DisplaySettings> {
40        DisplaySettings::new()
41            .with_new_line_after_attribute(false)
42            .optional()
43    }
44
45    fn custom_content(&self, content: Content) -> Option<Content> {
46        content.add("beta", &self.beta).optional()
47    }
48}
49
50impl Softplus {
51    /// Forward pass for the Softplus layer.
52    ///
53    /// See [softplus](burn::tensor::activation::softplus) for more information.
54    ///
55    /// # Shapes
56    /// - input: `[..., any]`
57    /// - output: `[..., any]`
58    pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
59        softplus(input, self.beta)
60    }
61}
62
63#[cfg(test)]
64#[allow(clippy::approx_constant)]
65mod tests {
66    use super::*;
67    use crate::TestBackend;
68    use burn::tensor::TensorData;
69    use burn::tensor::{Tolerance, ops::FloatElem};
70    type FT = FloatElem<TestBackend>;
71
72    #[test]
73    fn test_softplus_forward() {
74        let device = <TestBackend as Backend>::Device::default();
75        let model: Softplus = SoftplusConfig::new().init();
76        let input =
77            Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0, 1.0, -1.0]]), &device);
78        let out = model.forward(input);
79        // softplus(0) = log(2) ≈ 0.6931
80        // softplus(1) = log(1 + e) ≈ 1.3133
81        // softplus(-1) = log(1 + e^-1) ≈ 0.3133
82        let expected = TensorData::from([[0.6931, 1.3133, 0.3133]]);
83        out.to_data()
84            .assert_approx_eq::<FT>(&expected, Tolerance::default());
85    }
86
87    #[test]
88    fn test_softplus_with_beta() {
89        let device = <TestBackend as Backend>::Device::default();
90        let model: Softplus = SoftplusConfig::new().with_beta(2.0).init();
91        let input = Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0, 1.0]]), &device);
92        let out = model.forward(input);
93        // softplus(0, beta=2) = (1/2) * log(1 + exp(0)) = 0.5 * log(2) ≈ 0.3466
94        // softplus(1, beta=2) = (1/2) * log(1 + exp(2)) = 0.5 * log(8.389) ≈ 1.0635
95        let expected = TensorData::from([[0.3466, 1.0635]]);
96        out.to_data()
97            .assert_approx_eq::<FT>(&expected, Tolerance::default());
98    }
99
100    #[test]
101    fn display() {
102        let config = SoftplusConfig::new().init();
103        assert_eq!(alloc::format!("{config}"), "Softplus {beta: 1}");
104    }
105}