Skip to main content

burn_nn/activation/
selu.rs

1use burn_core as burn;
2
3use burn::module::Module;
4use burn::tensor::Tensor;
5use burn::tensor::backend::Backend;
6
7/// Applies the Scaled Exponential Linear Unit function element-wise.
8/// See also [selu](burn::tensor::activation::selu)
9#[derive(Module, Clone, Debug, Default)]
10pub struct Selu;
11
12impl Selu {
13    /// Create the module.
14    pub fn new() -> Self {
15        Self {}
16    }
17    /// Applies the forward pass on the input tensor.
18    ///
19    /// # Shapes
20    ///
21    /// - input: `[..., any]`
22    /// - output: `[..., any]`
23    pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
24        burn::tensor::activation::selu(input)
25    }
26}
27
28#[cfg(test)]
29mod tests {
30    use super::*;
31
32    #[test]
33    fn display() {
34        let layer = Selu::new();
35
36        assert_eq!(alloc::format!("{layer}"), "Selu");
37    }
38}