use burn::{
module::{
Content,
ModuleDisplay,
ModuleDisplayDefault,
},
prelude::{
Backend,
Shape,
Tensor,
},
tensor::Distribution,
};
use serde::{
Deserialize,
Serialize,
};
use crate::utility::burn::{
clamp::ClampConfig,
distribution::DistributionDisplayAdapter,
};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct NoiseConfig {
pub distribution: Distribution,
pub clamp: Option<ClampConfig>,
}
impl ModuleDisplay for NoiseConfig {}
impl ModuleDisplayDefault for NoiseConfig {
fn content(
&self,
content: Content,
) -> Option<Content> {
Some(
content
.add(
"distribution",
&DistributionDisplayAdapter::new(self.distribution),
)
.add("clamp", &self.clamp),
)
}
}
impl Default for NoiseConfig {
fn default() -> Self {
Self {
distribution: Distribution::Default,
clamp: None,
}
}
}
impl NoiseConfig {
pub fn with_distribution(
self,
distribution: Distribution,
) -> Self {
Self {
distribution,
..self
}
}
pub fn with_clamp<C>(
self,
clamp: C,
) -> Self
where
C: Into<Option<ClampConfig>>,
{
Self {
clamp: clamp.into(),
..self
}
}
pub fn noise<B: Backend, S, const D: usize>(
&self,
shape: S,
device: &B::Device,
) -> Tensor<B, D>
where
S: Into<Shape>,
{
let noise = Tensor::random(shape.into(), self.distribution, device);
match &self.clamp {
None => noise,
Some(clamp_cfg) => clamp_cfg.clamp(noise),
}
}
pub fn noise_like<B: Backend, const D: usize>(
&self,
tensor: &Tensor<B, D>,
) -> Tensor<B, D> {
self.noise(tensor.shape(), &tensor.device())
}
}
#[cfg(test)]
mod tests {
use burn::{
backend::NdArray,
module::DisplaySettings,
};
use num_traits::real::Real;
use super::*;
#[test]
fn test_noise_config_display() {
let config = NoiseConfig::default().with_clamp(ClampConfig::min_max(0.5, 1.0));
let settings = DisplaySettings::default();
assert_eq!(
config.format(settings),
indoc::indoc! {r#"
NoiseConfig {
distribution: Distribution::Default
clamp: ClampConfig {
min: 0.5
max: 1
}
}"#
}
)
}
#[test]
fn test_noise_default() {
let cfg = NoiseConfig::default();
assert_eq!(
cfg,
NoiseConfig {
distribution: Distribution::Default,
clamp: None
}
);
let cfg = NoiseConfig::default()
.with_distribution(Distribution::Bernoulli(0.3))
.with_clamp(ClampConfig::default());
assert_eq!(
cfg,
NoiseConfig {
distribution: Distribution::Bernoulli(0.3),
clamp: Some(ClampConfig::default())
}
);
let cfg = NoiseConfig::default().with_clamp(Some(ClampConfig::default()));
assert_eq!(
cfg,
NoiseConfig {
distribution: Distribution::Default,
clamp: Some(ClampConfig::default())
}
);
let cfg = NoiseConfig::default()
.with_clamp(ClampConfig::default())
.with_clamp(None);
assert_eq!(
cfg,
NoiseConfig {
distribution: Distribution::Default,
clamp: None,
}
);
}
#[test]
fn test_noise_like_default_clamp() {
type B = NdArray;
let device = Default::default();
let reference: Tensor<B, 2> = Tensor::ones([20, 20], &device);
let numel = reference.shape().num_elements() as f64;
let p = 0.1;
let noise = NoiseConfig::default()
.with_clamp(ClampConfig::default().with_min(0.5))
.noise_like(&reference);
assert_eq!(noise.shape(), reference.shape());
assert_eq!(noise.device(), reference.device());
let count_05 = noise.clone().equal_elem(0.5).int().sum().into_scalar() as f64;
assert!((0.5 - (count_05 / numel)).abs() < 0.15);
let count_ge_1 = noise
.clone()
.greater_equal_elem(1.0)
.int()
.sum()
.into_scalar();
assert_eq!(count_ge_1, 0);
}
#[test]
fn test_noise_like_bernoulli() {
type B = NdArray;
let device = Default::default();
let reference: Tensor<B, 2> = Tensor::ones([20, 20], &device);
let p = 0.1;
let noise = NoiseConfig::default()
.with_distribution(Distribution::Bernoulli(p))
.noise_like(&reference);
assert_eq!(noise.shape(), reference.shape());
assert_eq!(noise.device(), reference.device());
let ratio =
(noise.clone().sum().into_scalar() as f64) / (noise.shape().num_elements() as f64);
assert!((ratio - p).abs() < 0.05);
}
}