burn_core/nn/
prelu.rs

1use crate as burn;
2use crate::config::Config;
3use crate::module::Param;
4use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
5use crate::nn::Initializer;
6use crate::tensor::Tensor;
7use crate::tensor::backend::Backend;
8/// Parametric Relu layer.
9///
10/// Should be created using [PReluConfig]
11#[derive(Module, Debug)]
12#[module(custom_display)]
13pub struct PRelu<B: Backend> {
14    /// the weights learnt for PReLu. can be of shape \[1\] or \[num_parameters\] in which case it must
15    /// be the same as number of channels in the input tensor
16    pub alpha: Param<Tensor<B, 1>>,
17
18    /// Alpha value for the PRelu layer
19    pub alpha_value: f64,
20}
21
22impl<B: Backend> ModuleDisplay for PRelu<B> {
23    fn custom_settings(&self) -> Option<DisplaySettings> {
24        DisplaySettings::new()
25            .with_new_line_after_attribute(false)
26            .optional()
27    }
28
29    fn custom_content(&self, content: Content) -> Option<Content> {
30        let [num_parameters] = self.alpha.shape().dims();
31
32        content
33            .add("num_parameters", &num_parameters)
34            .add("alpha_value", &self.alpha_value)
35            .optional()
36    }
37}
38
39/// Configuration to create a [Parametric Relu](PRelu) layer using the [init function](PReluConfig::init).
40#[derive(Config, Debug)]
41pub struct PReluConfig {
42    /// The number of parameters.
43    #[config(default = "1")]
44    pub num_parameters: usize,
45    /// The learnable weight alpha. Default is 0.25
46    #[config(default = "0.25")]
47    pub alpha: f64,
48}
49
50impl PReluConfig {
51    /// Initialize a new [Parametric Relu](PRelu) Layer
52    pub fn init<B: Backend>(&self, device: &B::Device) -> PRelu<B> {
53        PRelu {
54            // alpha is a tensor of length num_parameters
55            alpha: Initializer::Constant { value: self.alpha }.init([self.num_parameters], device),
56            alpha_value: self.alpha,
57        }
58    }
59}
60
61impl<B: Backend> PRelu<B> {
62    /// Applies the forward pass on the input tensor.
63    ///
64    /// # Shapes
65    ///
66    /// - input: `[..., any]`
67    /// - output: `[..., any]`
68    ///
69    /// See also [prelu](crate::tensor::activation::prelu) for more information.
70    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
71        crate::tensor::activation::prelu(input, self.alpha.val())
72    }
73}
74
75#[cfg(test)]
76mod tests {
77    use super::*;
78    use crate::TestBackend;
79
80    #[test]
81    fn display() {
82        let layer = PReluConfig::new().init::<TestBackend>(&Default::default());
83
84        assert_eq!(
85            alloc::format!("{layer}"),
86            "PRelu {num_parameters: 1, alpha_value: 0.25, params: 1}"
87        );
88    }
89}