use burn::{
config::Config,
module::Module,
prelude::{
Backend,
Tensor,
},
tensor::Distribution,
};
use crate::utility::probability;
#[must_use]
pub fn drop_path<B: Backend, const D: usize>(
x: Tensor<B, D>,
drop_prob: f64,
training: bool,
scale_by_keep: bool,
) -> Tensor<B, D> {
_drop_path_sample(
x,
drop_prob,
training,
scale_by_keep,
|shape, keep_prob, device| {
Tensor::<B, D>::random(shape, Distribution::Bernoulli(keep_prob), device)
},
)
}
#[inline(always)]
#[must_use]
fn _drop_path_sample<B: Backend, const D: usize>(
x: Tensor<B, D>,
drop_prob: f64,
training: bool,
scale_by_keep: bool,
sample: fn([usize; D], f64, &B::Device) -> Tensor<B, D>,
) -> Tensor<B, D> {
probability::expect_probability(drop_prob);
if !training || drop_prob == 0.0 {
return x;
}
let keep_prob = 1.0 - drop_prob;
let mut shape = [1; D];
shape[0] = x.dims()[0];
let random_tensor = sample(shape, keep_prob, &x.device());
let random_tensor = if keep_prob > 0.0 && scale_by_keep {
random_tensor.div_scalar(keep_prob)
} else {
random_tensor
};
x * random_tensor
}
pub trait DropPathMeta {
fn drop_prob(&self) -> f64;
fn keep_prob(&self) -> f64 {
1.0 - self.drop_prob()
}
fn scale_by_keep(&self) -> bool;
}
#[derive(Config, Debug)]
pub struct DropPathConfig {
#[config(default = 0.0)]
pub drop_prob: f64,
#[config(default = true)]
pub scale_by_keep: bool,
}
impl DropPathMeta for DropPathConfig {
fn drop_prob(&self) -> f64 {
self.drop_prob
}
fn scale_by_keep(&self) -> bool {
self.scale_by_keep
}
}
impl DropPathConfig {
#[inline(always)]
#[must_use]
pub fn init(&self) -> DropPath {
DropPath {
drop_prob: probability::expect_probability(self.drop_prob),
scale_by_keep: self.scale_by_keep,
}
}
}
#[derive(Module, Clone, Debug)]
pub struct DropPath {
pub drop_prob: f64,
pub scale_by_keep: bool,
}
impl DropPathMeta for DropPath {
fn drop_prob(&self) -> f64 {
self.drop_prob
}
fn scale_by_keep(&self) -> bool {
self.scale_by_keep
}
}
impl DropPath {
#[must_use]
pub fn forward<B: Backend, const D: usize>(
&self,
input: Tensor<B, D>,
) -> Tensor<B, D> {
let training = B::ad_enabled();
drop_path(input, self.drop_prob, training, self.scale_by_keep)
}
#[inline]
#[must_use]
pub fn with_skip<B: Backend, const D: usize, F>(
&self,
x: Tensor<B, D>,
f: F,
) -> Tensor<B, D>
where
F: FnOnce(Tensor<B, D>) -> Tensor<B, D>,
{
x.clone() + self.forward(f(x))
}
}
#[cfg(test)]
mod tests {
use burn::{
backend::NdArray,
prelude::Tensor,
tensor::Distribution,
};
use super::*;
#[test]
fn test_drop_path() {
let device = Default::default();
let drop_prob = 0.5;
let scale_by_keep = true;
let config = DropPathConfig {
drop_prob,
scale_by_keep,
};
let module = config.init();
let input =
Tensor::<NdArray, 4>::random([2, 3, 4, 5], Distribution::Uniform(0.0, 1.0), &device);
let output = module.forward(input.clone());
assert_eq!(input.dims(), output.dims());
}
#[test]
fn test_drop_path_wrapper() {
let device = Default::default();
let n = 3;
let shape = [n, 2, 4];
let x = Tensor::<NdArray, 3>::random(shape, Distribution::Uniform(0.0, 1.0), &device);
let training = false;
let drop_prob = 0.0;
let scale_by_keep = false;
let res = drop_path(x.clone(), drop_prob, training, scale_by_keep);
assert_eq!(res.dims(), x.dims());
}
#[test]
fn test_drop_path_sample() {
let device = Default::default();
let n = 3;
let shape = [n, 2, 4];
let x = Tensor::<NdArray, 3>::random(shape, Distribution::Uniform(0.0, 1.0), &device);
let training = false;
let drop_prob = 0.0;
let scale_by_keep = false;
let res = _drop_path_sample(
x.clone(),
drop_prob,
training,
scale_by_keep,
|shape, keep_prob, device| {
assert_eq!(shape, [3, 1, 1]);
assert_eq!(keep_prob, 1.0);
Tensor::<NdArray, 3>::from_data([[[1.0]], [[0.0]], [[1.0]]], device)
},
);
res.to_data().assert_eq(&x.clone().to_data(), true);
let training = true;
let drop_prob = 0.0;
let scale_by_keep = false;
let res = _drop_path_sample(
x.clone(),
drop_prob,
training,
scale_by_keep,
|shape, keep_prob, device| {
assert_eq!(shape, [3, 1, 1]);
assert_eq!(keep_prob, 1.0);
Tensor::<NdArray, 3>::from_data([[[1.0]], [[0.0]], [[1.0]]], device)
},
);
res.to_data().assert_eq(&x.clone().to_data(), true);
let training = true;
let drop_prob = 0.5;
let scale_by_keep = false;
let res = _drop_path_sample(
x.clone(),
drop_prob,
training,
scale_by_keep,
|shape, keep_prob, device| {
assert_eq!(shape, [3, 1, 1]);
assert_eq!(keep_prob, 0.5);
Tensor::<NdArray, 3>::from_data([[[1.0]], [[0.0]], [[1.0]]], device)
},
);
res.to_data().assert_eq(
&(x.clone() * Tensor::<NdArray, 3>::from_data([[[1.0]], [[0.0]], [[1.0]]], &device))
.to_data(),
true,
);
let training = true;
let drop_prob = 0.5;
let keep_prob = 1.0 - drop_prob;
let scale_by_keep = true;
let res = _drop_path_sample(
x.clone(),
drop_prob,
training,
scale_by_keep,
|shape, keep_prob, device| {
assert_eq!(shape, [3, 1, 1]);
assert_eq!(keep_prob, 0.5);
Tensor::<NdArray, 3>::from_data([[[1.0]], [[0.0]], [[1.0]]], device)
},
);
res.to_data().assert_eq(
&(x.clone() * Tensor::<NdArray, 3>::from_data([[[1.0]], [[0.0]], [[1.0]]], &device))
.div_scalar(keep_prob)
.to_data(),
true,
);
}
#[test]
fn test_droppath_module() {
let drop_prob = 0.2;
let config = DropPathConfig::new().with_drop_prob(drop_prob);
assert_eq!(config.drop_prob(), 0.2);
assert_eq!(config.keep_prob(), 1.0 - drop_prob);
assert!(config.scale_by_keep());
let module = config.init();
assert_eq!(module.drop_prob(), 0.2);
assert_eq!(module.keep_prob(), 1.0 - drop_prob);
assert!(module.scale_by_keep());
let device = Default::default();
let shape = [2, 3, 4];
let x = Tensor::<NdArray, 3>::random(shape, Distribution::Uniform(0.0, 1.0), &device);
let output = module.forward(x.clone());
assert_eq!(x.dims(), output.dims());
}
}