use std::cell::Cell;
use crate::autograd::Variable;
use crate::tensor::Result;
use super::Module;
pub struct Dropout {
p: f64,
training: Cell<bool>,
}
impl Dropout {
pub fn new(p: f64) -> Self {
Dropout {
p,
training: Cell::new(true),
}
}
}
impl Module for Dropout {
fn name(&self) -> &str { "dropout" }
fn forward(&self, input: &Variable) -> Result<Variable> {
if !self.training.get() || self.p == 0.0 {
return Ok(input.clone());
}
let result = input.data().dropout(self.p, true)?;
Ok(Variable::wrap(result))
}
fn set_training(&self, training: bool) {
self.training.set(training);
}
}
pub struct Dropout2d {
p: f64,
training: Cell<bool>,
}
impl Dropout2d {
pub fn new(p: f64) -> Self {
Dropout2d {
p,
training: Cell::new(true),
}
}
}
impl Module for Dropout2d {
fn name(&self) -> &str { "dropout2d" }
fn forward(&self, input: &Variable) -> Result<Variable> {
if !self.training.get() || self.p == 0.0 {
return Ok(input.clone());
}
let result = input.data().feature_dropout(self.p, true)?;
Ok(Variable::wrap(result))
}
fn set_training(&self, training: bool) {
self.training.set(training);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::{DType, Tensor, TensorOptions};
#[test]
fn test_dropout2d_whole_channels_zeroed() {
let d = Dropout2d::new(0.5);
let opts = TensorOptions { dtype: DType::Float32, device: crate::tensor::test_device() };
let input = Variable::new(Tensor::ones(&[2, 8, 4, 4], opts).unwrap(), false);
let output = d.forward(&input).unwrap();
let data = output.data().to_f32_vec().unwrap();
let h = 4_usize;
let w = 4_usize;
let scale = 1.0 / 0.5;
for b in 0..2_usize {
for c in 0..8_usize {
let start = b * 8 * h * w + c * h * w;
let channel: Vec<f32> = data[start..start + h * w].to_vec();
let first = channel[0];
for &v in &channel {
assert!((v - first).abs() < 1e-5,
"channel [{},{}] not uniform: {} vs {}", b, c, v, first);
}
assert!(first.abs() < 1e-5 || (first - scale as f32).abs() < 1e-5,
"channel value should be 0 or {}: got {}", scale, first);
}
}
}
#[test]
fn test_dropout2d_eval_identity() {
let d = Dropout2d::new(0.5);
d.set_training(false);
let opts = TensorOptions { dtype: DType::Float32, device: crate::tensor::test_device() };
let input = Variable::new(Tensor::ones(&[1, 3, 4, 4], opts).unwrap(), false);
let output = d.forward(&input).unwrap();
let data = output.data().to_f32_vec().unwrap();
assert!(data.iter().all(|&v| (v - 1.0).abs() < 1e-5));
}
}