burn_nn/modules/
noise.rs

1use burn_core as burn;
2
3use burn::config::Config;
4use burn::module::{Content, DisplaySettings, Module, ModuleDisplay};
5use burn::tensor::backend::Backend;
6use burn::tensor::{Distribution, Tensor};
7
8/// Configuration to create a [GaussianNoise](GaussianNoise) layer using the [init function](GaussianNoiseConfig::init).
9#[derive(Config, Debug)]
10pub struct GaussianNoiseConfig {
11    /// Standard deviation of the normal noise distribution.
12    pub std: f64,
13}
14
15/// Add pseudorandom Gaussian noise to an arbitrarily shaped tensor.
16///
17/// This is an effective regularization technique that also contributes to data augmentation.
18/// Please keep in mind that the value of [std](GaussianNoise::std) should be chosen with care in order to avoid
19/// distortion.
20///
21/// Should be created with [GaussianNoiseConfig].
22#[derive(Module, Clone, Debug)]
23#[module(custom_display)]
24pub struct GaussianNoise {
25    /// Standard deviation of the normal noise distribution.
26    pub std: f64,
27}
28
29impl GaussianNoiseConfig {
30    /// Initialize a new [Gaussian noise](GaussianNoise) module.
31    pub fn init(&self) -> GaussianNoise {
32        if self.std.is_sign_negative() {
33            panic!(
34                "Standard deviation is required to be non-negative, but got {}",
35                self.std
36            );
37        }
38        GaussianNoise { std: self.std }
39    }
40}
41
42impl GaussianNoise {
43    /// Applies the forward pass on the input tensor.
44    ///
45    /// See [GaussianNoise](GaussianNoise) for more information.
46    ///
47    /// # Shapes
48    ///
49    /// - input: `[..., any]`
50    /// - output: `[..., any]`
51    pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
52        if B::ad_enabled() && self.std != 0.0 {
53            let noise = Tensor::random(
54                input.shape(),
55                Distribution::Normal(0.0, self.std),
56                &input.device(),
57            );
58            input + noise
59        } else {
60            input
61        }
62    }
63}
64
65impl ModuleDisplay for GaussianNoise {
66    fn custom_settings(&self) -> Option<DisplaySettings> {
67        DisplaySettings::new()
68            .with_new_line_after_attribute(false)
69            .optional()
70    }
71
72    fn custom_content(&self, content: Content) -> Option<Content> {
73        content.add("std", &self.std).optional()
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80    use burn::tensor::Shape;
81
82    #[cfg(feature = "std")]
83    use crate::{TestAutodiffBackend, TestBackend};
84
85    #[cfg(not(feature = "std"))]
86    use crate::TestBackend;
87
88    #[cfg(feature = "std")]
89    #[test]
90    fn with_ad_backend_should_mark_input() {
91        let tensor =
92            Tensor::<TestAutodiffBackend, 2>::ones(Shape::new([100, 100]), &Default::default());
93        let noise = GaussianNoiseConfig::new(0.5).init();
94
95        let output = noise.forward(tensor.clone());
96
97        assert_ne!(tensor.to_data(), output.to_data());
98    }
99
100    #[test]
101    fn without_ad_backend_should_not_change_input() {
102        let tensor = Tensor::<TestBackend, 2>::ones(Shape::new([100, 100]), &Default::default());
103        let noise = GaussianNoiseConfig::new(0.5).init();
104
105        let output = noise.forward(tensor.clone());
106
107        assert_eq!(tensor.to_data(), output.to_data());
108    }
109
110    #[test]
111    #[should_panic(expected = "Standard deviation is required to be non-negative")]
112    fn negative_std_should_panic() {
113        GaussianNoiseConfig { std: -0.5 }.init();
114    }
115
116    #[test]
117    fn display() {
118        let config = GaussianNoiseConfig::new(0.5);
119        let layer = config.init();
120
121        assert_eq!(alloc::format!("{layer}"), "GaussianNoise {std: 0.5}");
122    }
123}