use crate::error::Result;
use crate::nn::module::TrainMode;
use numr::autograd::Var;
use numr::ops::{BinaryOps, RandomOps, ScalarOps, TensorOps};
use numr::runtime::{Runtime, RuntimeClient};
pub struct StochasticDepth {
drop_prob: f64,
training: bool,
}
impl StochasticDepth {
pub fn new(drop_prob: f64) -> Self {
assert!(
(0.0..=1.0).contains(&drop_prob),
"drop probability must be in [0, 1], got {drop_prob}"
);
Self {
drop_prob,
training: true,
}
}
pub fn drop_prob(&self) -> f64 {
self.drop_prob
}
pub fn forward<R, C>(&self, client: &C, input: &Var<R>) -> Result<Var<R>>
where
R: Runtime<DType = numr::dtype::DType>,
C: RuntimeClient<R> + TensorOps<R> + RandomOps<R> + ScalarOps<R> + BinaryOps<R>,
R::Client: TensorOps<R> + ScalarOps<R> + BinaryOps<R>,
{
if !self.training || self.drop_prob == 0.0 {
return Ok(input.clone());
}
if self.drop_prob >= 1.0 {
let zeros = numr::tensor::Tensor::<R>::zeros(
input.tensor().shape(),
input.tensor().dtype(),
input.tensor().device(),
);
return Ok(Var::new(zeros, false));
}
let shape = input.tensor().shape().to_vec();
let mut mask_shape = vec![1usize; shape.len()];
if !shape.is_empty() {
mask_shape[0] = shape[0];
}
let keep_prob = 1.0 - self.drop_prob;
let mask = client.bernoulli(keep_prob, &mask_shape, input.tensor().dtype())?;
let scale = 1.0 / keep_prob;
let scaled_mask = client.mul_scalar(&mask, scale)?;
let output = client.mul(input.tensor(), &scaled_mask)?;
Ok(Var::new(output, false))
}
}
impl TrainMode for StochasticDepth {
fn set_training(&mut self, training: bool) {
self.training = training;
}
fn is_training(&self) -> bool {
self.training
}
}
#[cfg(test)]
mod tests {
use super::*;
use numr::runtime::cpu::{CpuDevice, CpuRuntime};
use numr::tensor::Tensor;
#[test]
fn test_eval_mode_is_identity() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let data = vec![1.0f32; 12];
let input = Var::new(
Tensor::<CpuRuntime>::from_slice(&data, &[3, 4], &device),
false,
);
let mut sd = StochasticDepth::new(0.5);
sd.set_training(false);
let output = sd.forward(&client, &input).unwrap();
let out: Vec<f32> = output.tensor().to_vec();
assert_eq!(out, data);
}
#[test]
fn test_zero_drop_prob_is_identity() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let input = Var::new(
Tensor::<CpuRuntime>::from_slice(&data, &[2, 3], &device),
false,
);
let sd = StochasticDepth::new(0.0);
let output = sd.forward(&client, &input).unwrap();
let out: Vec<f32> = output.tensor().to_vec();
assert_eq!(out, data);
}
#[test]
fn test_drops_entire_samples() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let data = vec![1.0f32; 400];
let input = Var::new(
Tensor::<CpuRuntime>::from_slice(&data, &[100, 4], &device),
false,
);
let sd = StochasticDepth::new(0.5);
let output = sd.forward(&client, &input).unwrap();
let out: Vec<f32> = output.tensor().to_vec();
let mut dropped = 0;
let mut kept = 0;
for sample in 0..100 {
let row: Vec<f32> = out[sample * 4..(sample + 1) * 4].to_vec();
if row.iter().all(|&v| v == 0.0) {
dropped += 1;
} else {
assert!(
row.iter().all(|&v| (v - 2.0).abs() < 1e-5),
"sample {sample} has inconsistent values: {row:?}"
);
kept += 1;
}
}
assert_eq!(dropped + kept, 100);
assert!(dropped > 20 && dropped < 80, "dropped: {dropped}");
}
#[test]
#[should_panic(expected = "drop probability must be in [0, 1]")]
fn test_invalid_prob() {
StochasticDepth::new(1.5);
}
}