burn_nn/activation/
celu.rs1use burn_core as burn;
2
3use burn::config::Config;
4use burn::module::Module;
5use burn::module::{Content, DisplaySettings, ModuleDisplay};
6use burn::tensor::Tensor;
7use burn::tensor::activation::celu;
8use burn::tensor::backend::Backend;
9
10#[derive(Module, Clone, Debug)]
17#[module(custom_display)]
18pub struct Celu {
19 pub alpha: f64,
21}
22
23#[derive(Config, Debug)]
25pub struct CeluConfig {
26 #[config(default = "1.0")]
28 pub alpha: f64,
29}
30
31impl CeluConfig {
32 pub fn init(&self) -> Celu {
34 Celu { alpha: self.alpha }
35 }
36}
37
38impl ModuleDisplay for Celu {
39 fn custom_settings(&self) -> Option<DisplaySettings> {
40 DisplaySettings::new()
41 .with_new_line_after_attribute(false)
42 .optional()
43 }
44
45 fn custom_content(&self, content: Content) -> Option<Content> {
46 content.add("alpha", &self.alpha).optional()
47 }
48}
49
50impl Celu {
51 pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
59 celu(input, self.alpha)
60 }
61}
62
63#[cfg(test)]
64mod tests {
65 use super::*;
66 use crate::TestBackend;
67 use burn::tensor::TensorData;
68 use burn::tensor::{Tolerance, ops::FloatElem};
69 type FT = FloatElem<TestBackend>;
70
71 #[test]
72 fn test_celu_forward() {
73 let device = Default::default();
74 let model: Celu = CeluConfig::new().init();
75 let input =
76 Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.5, -0.5, -1.0]]), &device);
77 let out = model.forward(input);
78 let expected = TensorData::from([[0.5, -0.393469, -0.632121]]);
82 out.to_data()
83 .assert_approx_eq::<FT>(&expected, Tolerance::default());
84 }
85
86 #[test]
87 fn test_celu_with_alpha() {
88 let device = Default::default();
89 let model: Celu = CeluConfig::new().with_alpha(2.0).init();
90 let input = Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0, -2.0]]), &device);
91 let out = model.forward(input);
92 let expected = TensorData::from([[0.0, -1.264241]]);
95 out.to_data()
96 .assert_approx_eq::<FT>(&expected, Tolerance::default());
97 }
98
99 #[test]
100 fn display() {
101 let config = CeluConfig::new().init();
102 assert_eq!(alloc::format!("{config}"), "Celu {alpha: 1}");
103 }
104}