use super::module::Module;
use crate::autograd::Tensor;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use std::sync::Mutex;
pub struct Dropout {
p: f32,
training: bool,
rng: Mutex<StdRng>,
}
impl Dropout {
#[must_use]
pub fn new(p: f32) -> Self {
assert!(
(0.0..1.0).contains(&p),
"Dropout probability must be in [0, 1), got {p}",
);
Self {
p,
training: true,
rng: Mutex::new(StdRng::from_os_rng()),
}
}
#[must_use]
pub fn with_seed(p: f32, seed: u64) -> Self {
assert!(
(0.0..1.0).contains(&p),
"Dropout probability must be in [0, 1), got {p}",
);
Self {
p,
training: true,
rng: Mutex::new(StdRng::seed_from_u64(seed)),
}
}
pub fn probability(&self) -> f32 {
self.p
}
}
impl Module for Dropout {
#[allow(clippy::expect_used)]
fn forward(&self, input: &Tensor) -> Tensor {
if !self.training || self.p == 0.0 {
return input.clone();
}
let mut rng = self.rng.lock().expect("Dropout RNG lock poisoned");
let scale = 1.0 / (1.0 - self.p);
let data: Vec<f32> = input
.data()
.iter()
.map(|&x| {
if rng.random::<f32>() < self.p {
0.0
} else {
x * scale
}
})
.collect();
Tensor::new(&data, input.shape())
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn training(&self) -> bool {
self.training
}
}
impl std::fmt::Debug for Dropout {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Dropout")
.field("p", &self.p)
.field("training", &self.training)
.finish_non_exhaustive()
}
}
pub struct Dropout2d {
p: f32,
training: bool,
rng: Mutex<StdRng>,
}
impl Dropout2d {
#[must_use]
pub fn new(p: f32) -> Self {
assert!(
(0.0..1.0).contains(&p),
"Dropout probability must be in [0, 1), got {p}",
);
Self {
p,
training: true,
rng: Mutex::new(StdRng::from_os_rng()),
}
}
#[must_use]
pub fn with_seed(p: f32, seed: u64) -> Self {
assert!(
(0.0..1.0).contains(&p),
"Dropout probability must be in [0, 1), got {p}",
);
Self {
p,
training: true,
rng: Mutex::new(StdRng::seed_from_u64(seed)),
}
}
pub fn probability(&self) -> f32 {
self.p
}
}
impl Module for Dropout2d {
#[allow(clippy::expect_used)]
fn forward(&self, input: &Tensor) -> Tensor {
if !self.training || self.p == 0.0 {
return input.clone();
}
let shape = input.shape();
assert!(
shape.len() >= 3,
"Dropout2d expects at least 3D input [N, C, ...], got {}D",
shape.len()
);
let batch_size = shape[0];
let num_channels = shape[1];
let spatial_size: usize = shape[2..].iter().product();
let mut rng = self.rng.lock().expect("Dropout2d RNG lock poisoned");
let scale = 1.0 / (1.0 - self.p);
let mut channel_masks: Vec<bool> = Vec::with_capacity(batch_size * num_channels);
for _ in 0..(batch_size * num_channels) {
channel_masks.push(rng.random::<f32>() >= self.p);
}
let input_data = input.data();
let mut output = vec![0.0; input_data.len()];
for n in 0..batch_size {
for c in 0..num_channels {
let keep = channel_masks[n * num_channels + c];
for s in 0..spatial_size {
let idx = n * num_channels * spatial_size + c * spatial_size + s;
output[idx] = if keep { input_data[idx] * scale } else { 0.0 };
}
}
}
Tensor::new(&output, shape)
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn training(&self) -> bool {
self.training
}
}
impl std::fmt::Debug for Dropout2d {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Dropout2d")
.field("p", &self.p)
.field("training", &self.training)
.finish_non_exhaustive()
}
}
pub struct AlphaDropout {
p: f32,
training: bool,
rng: Mutex<StdRng>,
}
const ALPHA: f32 = 1.673_263_2;
const SCALE: f32 = 1.050_701;
impl AlphaDropout {
#[must_use]
pub fn new(p: f32) -> Self {
assert!(
(0.0..1.0).contains(&p),
"Dropout probability must be in [0, 1), got {p}",
);
Self {
p,
training: true,
rng: Mutex::new(StdRng::from_os_rng()),
}
}
#[must_use]
pub fn with_seed(p: f32, seed: u64) -> Self {
assert!(
(0.0..1.0).contains(&p),
"Dropout probability must be in [0, 1), got {p}",
);
Self {
p,
training: true,
rng: Mutex::new(StdRng::seed_from_u64(seed)),
}
}
}
impl Module for AlphaDropout {
#[allow(clippy::expect_used)]
fn forward(&self, input: &Tensor) -> Tensor {
if !self.training || self.p == 0.0 {
return input.clone();
}
let mut rng = self.rng.lock().expect("AlphaDropout RNG lock poisoned");
let alpha_p = -ALPHA * SCALE;
let a = ((1.0 - self.p) * (1.0 + self.p * alpha_p.powi(2))).powf(-0.5);
let b = -a * alpha_p * self.p;
let data: Vec<f32> = input
.data()
.iter()
.map(|&x| {
if rng.random::<f32>() < self.p {
a * alpha_p + b
} else {
a * x + b
}
})
.collect();
Tensor::new(&data, input.shape())
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn training(&self) -> bool {
self.training
}
}
impl std::fmt::Debug for AlphaDropout {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AlphaDropout")
.field("p", &self.p)
.field("training", &self.training)
.finish_non_exhaustive()
}
}
pub struct DropBlock {
block_size: usize,
p: f32,
training: bool,
rng: Mutex<StdRng>,
}
mod drop_connect;
pub use drop_connect::*;