1use burn_core as burn;
2
3use burn::config::Config;
4use burn::module::{Content, DisplaySettings, Module, ModuleDisplay};
5use burn::tensor::backend::Backend;
6use burn::tensor::{Distribution, Tensor};
7
8#[derive(Config, Debug)]
10pub struct GaussianNoiseConfig {
11 pub std: f64,
13}
14
15#[derive(Module, Clone, Debug)]
23#[module(custom_display)]
24pub struct GaussianNoise {
25 pub std: f64,
27}
28
29impl GaussianNoiseConfig {
30 pub fn init(&self) -> GaussianNoise {
32 if self.std.is_sign_negative() {
33 panic!(
34 "Standard deviation is required to be non-negative, but got {}",
35 self.std
36 );
37 }
38 GaussianNoise { std: self.std }
39 }
40}
41
42impl GaussianNoise {
43 pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
52 if B::ad_enabled() && self.std != 0.0 {
53 let noise = Tensor::random(
54 input.shape(),
55 Distribution::Normal(0.0, self.std),
56 &input.device(),
57 );
58 input + noise
59 } else {
60 input
61 }
62 }
63}
64
65impl ModuleDisplay for GaussianNoise {
66 fn custom_settings(&self) -> Option<DisplaySettings> {
67 DisplaySettings::new()
68 .with_new_line_after_attribute(false)
69 .optional()
70 }
71
72 fn custom_content(&self, content: Content) -> Option<Content> {
73 content.add("std", &self.std).optional()
74 }
75}
76
77#[cfg(test)]
78mod tests {
79 use super::*;
80 use burn::tensor::Shape;
81
82 #[cfg(feature = "std")]
83 use crate::{TestAutodiffBackend, TestBackend};
84
85 #[cfg(not(feature = "std"))]
86 use crate::TestBackend;
87
88 #[cfg(feature = "std")]
89 #[test]
90 fn with_ad_backend_should_mark_input() {
91 let tensor =
92 Tensor::<TestAutodiffBackend, 2>::ones(Shape::new([100, 100]), &Default::default());
93 let noise = GaussianNoiseConfig::new(0.5).init();
94
95 let output = noise.forward(tensor.clone());
96
97 assert_ne!(tensor.to_data(), output.to_data());
98 }
99
100 #[test]
101 fn without_ad_backend_should_not_change_input() {
102 let tensor = Tensor::<TestBackend, 2>::ones(Shape::new([100, 100]), &Default::default());
103 let noise = GaussianNoiseConfig::new(0.5).init();
104
105 let output = noise.forward(tensor.clone());
106
107 assert_eq!(tensor.to_data(), output.to_data());
108 }
109
110 #[test]
111 #[should_panic(expected = "Standard deviation is required to be non-negative")]
112 fn negative_std_should_panic() {
113 GaussianNoiseConfig { std: -0.5 }.init();
114 }
115
116 #[test]
117 fn display() {
118 let config = GaussianNoiseConfig::new(0.5);
119 let layer = config.init();
120
121 assert_eq!(alloc::format!("{layer}"), "GaussianNoise {std: 0.5}");
122 }
123}