use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::Debug;
use crate::error::Result;
use crate::regularizers::Regularizer;
#[derive(Debug, Clone)]
pub struct StochasticDepth<A: Float> {
drop_prob: A,
layer_idx: usize,
num_layers: usize,
rng_state: u64,
}
impl<A: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> StochasticDepth<A> {
pub fn new(drop_prob: A, layer_idx: usize, numlayers: usize) -> Self {
Self {
drop_prob,
layer_idx,
num_layers: numlayers,
rng_state: 0,
}
}
pub fn set_layer(&mut self, layeridx: usize) {
self.layer_idx = layeridx;
}
pub fn set_rng_state(&mut self, state: u64) {
self.rng_state = state;
}
fn survival_probability(&self) -> A {
let layer_ratio = A::from_usize(self.layer_idx).expect("unwrap failed")
/ A::from_usize(self.num_layers).expect("unwrap failed");
A::one() - (self.drop_prob * layer_ratio)
}
fn should_drop(&self) -> bool {
let hash = (self
.rng_state
.wrapping_mul(0x7fffffff)
.wrapping_add(self.layer_idx as u64))
% 10000;
let random_val = A::from_f64(hash as f64 / 10000.0).expect("unwrap failed");
random_val > self.survival_probability()
}
pub fn apply_layer<D>(
&self,
layer_idx: usize,
features: &Array<A, D>,
training: bool,
) -> Array<A, D>
where
D: Dimension,
{
let survival_prob = self.survival_probability();
if training {
let mut sd = self.clone();
sd.set_layer(layer_idx);
if sd.should_drop() {
features.clone()
} else {
features.clone()
}
} else {
features * survival_prob
}
}
}
impl<A: Float + Debug + ScalarOperand + FromPrimitive, D: Dimension + Send + Sync> Regularizer<A, D>
for StochasticDepth<A>
{
fn apply(&self, _params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
Ok(A::zero())
}
fn penalty(&self, params: &Array<A, D>) -> Result<A> {
Ok(A::zero())
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_stochastic_depth_creation() {
let sd = StochasticDepth::<f64>::new(0.2, 5, 10);
assert_eq!(sd.drop_prob, 0.2);
assert_eq!(sd.layer_idx, 5);
assert_eq!(sd.num_layers, 10);
}
#[test]
fn test_survival_probability() {
let sd1 = StochasticDepth::<f64>::new(0.5, 0, 10);
assert_eq!(sd1.survival_probability(), 1.0);
let sd2 = StochasticDepth::<f64>::new(0.5, 10, 10);
assert_eq!(sd2.survival_probability(), 0.5);
let sd3 = StochasticDepth::<f64>::new(0.5, 5, 10);
assert_eq!(sd3.survival_probability(), 0.75);
}
#[test]
fn test_should_drop() {
let mut sd = StochasticDepth::<f64>::new(0.5, 5, 10);
sd.set_rng_state(12345);
let _result1 = sd.should_drop();
sd.set_rng_state(54321);
let _result2 = sd.should_drop();
}
#[test]
fn test_apply_layer_training() {
let sd = StochasticDepth::<f64>::new(0.5, 5, 10);
let features = array![[1.0, 2.0], [3.0, 4.0]];
let output = sd.apply_layer(5, &features, true);
assert_eq!(output.shape(), features.shape());
}
#[test]
fn test_apply_layer_inference() {
let sd = StochasticDepth::<f64>::new(0.5, 5, 10);
let features = array![[1.0, 2.0], [3.0, 4.0]];
let output = sd.apply_layer(5, &features, false);
let survival_prob = sd.survival_probability();
for (i, j) in output.indexed_iter() {
assert_eq!(*j, features[i] * survival_prob);
}
}
#[test]
fn test_regularizer_trait() {
let sd = StochasticDepth::<f64>::new(0.5, 5, 10);
let params = array![[1.0, 2.0], [3.0, 4.0]];
let mut gradients = array![[0.1, 0.2], [0.3, 0.4]];
let original_gradients = gradients.clone();
let penalty = sd.apply(¶ms, &mut gradients).expect("unwrap failed");
assert_eq!(penalty, 0.0);
assert_eq!(gradients, original_gradients);
}
}