Skip to main content

burn_nn/activation/
celu.rs

1use burn_core as burn;
2
3use burn::config::Config;
4use burn::module::Module;
5use burn::module::{Content, DisplaySettings, ModuleDisplay};
6use burn::tensor::Tensor;
7use burn::tensor::activation::celu;
8use burn::tensor::backend::Backend;
9
10/// CELU (Continuously Differentiable Exponential Linear Unit) layer.
11///
12/// Applies the CELU function element-wise:
13/// `celu(x) = max(0, x) + min(0, alpha * (exp(x / alpha) - 1))`
14///
15/// Should be created with [CeluConfig](CeluConfig).
16#[derive(Module, Clone, Debug)]
17#[module(custom_display)]
18pub struct Celu {
19    /// The alpha value for the CELU formulation.
20    pub alpha: f64,
21}
22
23/// Configuration to create a [Celu](Celu) layer using the [init function](CeluConfig::init).
24#[derive(Config, Debug)]
25pub struct CeluConfig {
26    /// The alpha value for the CELU formulation. Default is 1.0
27    #[config(default = "1.0")]
28    pub alpha: f64,
29}
30
31impl CeluConfig {
32    /// Initialize a new [Celu](Celu) Layer
33    pub fn init(&self) -> Celu {
34        Celu { alpha: self.alpha }
35    }
36}
37
38impl ModuleDisplay for Celu {
39    fn custom_settings(&self) -> Option<DisplaySettings> {
40        DisplaySettings::new()
41            .with_new_line_after_attribute(false)
42            .optional()
43    }
44
45    fn custom_content(&self, content: Content) -> Option<Content> {
46        content.add("alpha", &self.alpha).optional()
47    }
48}
49
50impl Celu {
51    /// Forward pass for the Celu layer.
52    ///
53    /// See [celu](burn::tensor::activation::celu) for more information.
54    ///
55    /// # Shapes
56    /// - input: `[..., any]`
57    /// - output: `[..., any]`
58    pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
59        celu(input, self.alpha)
60    }
61}
62
63#[cfg(test)]
64mod tests {
65    use super::*;
66    use crate::TestBackend;
67    use burn::tensor::TensorData;
68    use burn::tensor::{Tolerance, ops::FloatElem};
69    type FT = FloatElem<TestBackend>;
70
71    #[test]
72    fn test_celu_forward() {
73        let device = Default::default();
74        let model: Celu = CeluConfig::new().init();
75        let input =
76            Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.5, -0.5, -1.0]]), &device);
77        let out = model.forward(input);
78        // celu(0.5, 1) = 0.5
79        // celu(-0.5, 1) = 1 * (exp(-0.5) - 1) = -0.393469
80        // celu(-1.0, 1) = 1 * (exp(-1) - 1) = -0.632121
81        let expected = TensorData::from([[0.5, -0.393469, -0.632121]]);
82        out.to_data()
83            .assert_approx_eq::<FT>(&expected, Tolerance::default());
84    }
85
86    #[test]
87    fn test_celu_with_alpha() {
88        let device = Default::default();
89        let model: Celu = CeluConfig::new().with_alpha(2.0).init();
90        let input = Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0, -2.0]]), &device);
91        let out = model.forward(input);
92        // celu(0, 2) = 0
93        // celu(-2, 2) = 2 * (exp(-1) - 1) = -1.264241
94        let expected = TensorData::from([[0.0, -1.264241]]);
95        out.to_data()
96            .assert_approx_eq::<FT>(&expected, Tolerance::default());
97    }
98
99    #[test]
100    fn display() {
101        let config = CeluConfig::new().init();
102        assert_eq!(alloc::format!("{config}"), "Celu {alpha: 1}");
103    }
104}