use burn_core as burn;
use burn::config::Config;
use burn::module::{Content, DisplaySettings, Module, ModuleDisplay};
use burn::tensor::backend::Backend;
use burn::tensor::{Distribution, Tensor};
#[derive(Config, Debug)]
pub struct GaussianNoiseConfig {
pub std: f64,
}
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct GaussianNoise {
pub std: f64,
}
impl GaussianNoiseConfig {
pub fn init(&self) -> GaussianNoise {
if self.std.is_sign_negative() {
panic!(
"Standard deviation is required to be non-negative, but got {}",
self.std
);
}
GaussianNoise { std: self.std }
}
}
impl GaussianNoise {
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
if B::ad_enabled() && self.std != 0.0 {
let noise = Tensor::random(
input.shape(),
Distribution::Normal(0.0, self.std),
&input.device(),
);
input + noise
} else {
input
}
}
}
impl ModuleDisplay for GaussianNoise {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content.add("std", &self.std).optional()
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::tensor::Shape;
#[cfg(feature = "std")]
use crate::{TestAutodiffBackend, TestBackend};
#[cfg(not(feature = "std"))]
use crate::TestBackend;
#[cfg(feature = "std")]
#[test]
fn with_ad_backend_should_mark_input() {
let tensor =
Tensor::<TestAutodiffBackend, 2>::ones(Shape::new([100, 100]), &Default::default());
let noise = GaussianNoiseConfig::new(0.5).init();
let output = noise.forward(tensor.clone());
assert_ne!(tensor.to_data(), output.to_data());
}
#[test]
fn without_ad_backend_should_not_change_input() {
let tensor = Tensor::<TestBackend, 2>::ones(Shape::new([100, 100]), &Default::default());
let noise = GaussianNoiseConfig::new(0.5).init();
let output = noise.forward(tensor.clone());
assert_eq!(tensor.to_data(), output.to_data());
}
#[test]
#[should_panic(expected = "Standard deviation is required to be non-negative")]
fn negative_std_should_panic() {
GaussianNoiseConfig { std: -0.5 }.init();
}
#[test]
fn display() {
let config = GaussianNoiseConfig::new(0.5);
let layer = config.init();
assert_eq!(alloc::format!("{layer}"), "GaussianNoise {std: 0.5}");
}
}