1use crate as burn;
2use crate::config::Config;
3use crate::module::Param;
4use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
5use crate::nn::Initializer;
6use crate::tensor::Tensor;
7use crate::tensor::backend::Backend;
8#[derive(Module, Debug)]
12#[module(custom_display)]
13pub struct PRelu<B: Backend> {
14 pub alpha: Param<Tensor<B, 1>>,
17
18 pub alpha_value: f64,
20}
21
22impl<B: Backend> ModuleDisplay for PRelu<B> {
23 fn custom_settings(&self) -> Option<DisplaySettings> {
24 DisplaySettings::new()
25 .with_new_line_after_attribute(false)
26 .optional()
27 }
28
29 fn custom_content(&self, content: Content) -> Option<Content> {
30 let [num_parameters] = self.alpha.shape().dims();
31
32 content
33 .add("num_parameters", &num_parameters)
34 .add("alpha_value", &self.alpha_value)
35 .optional()
36 }
37}
38
39#[derive(Config, Debug)]
41pub struct PReluConfig {
42 #[config(default = "1")]
44 pub num_parameters: usize,
45 #[config(default = "0.25")]
47 pub alpha: f64,
48}
49
50impl PReluConfig {
51 pub fn init<B: Backend>(&self, device: &B::Device) -> PRelu<B> {
53 PRelu {
54 alpha: Initializer::Constant { value: self.alpha }.init([self.num_parameters], device),
56 alpha_value: self.alpha,
57 }
58 }
59}
60
61impl<B: Backend> PRelu<B> {
62 pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
71 crate::tensor::activation::prelu(input, self.alpha.val())
72 }
73}
74
75#[cfg(test)]
76mod tests {
77 use super::*;
78 use crate::TestBackend;
79
80 #[test]
81 fn display() {
82 let layer = PReluConfig::new().init::<TestBackend>(&Default::default());
83
84 assert_eq!(
85 alloc::format!("{layer}"),
86 "PRelu {num_parameters: 1, alpha_value: 0.25, params: 1}"
87 );
88 }
89}