use crate::{TrainError, TrainResult};
use scirs2_core::ndarray::{Array2, ArrayView2};
use scirs2_core::random::{RngExt, StdRng};
#[derive(Debug, Clone)]
pub struct DropPath {
pub drop_prob: f64,
keep_prob: f64,
}
impl DropPath {
pub fn new(drop_prob: f64) -> TrainResult<Self> {
if !(0.0..=1.0).contains(&drop_prob) {
return Err(TrainError::InvalidParameter(
"drop_prob must be in [0, 1]".to_string(),
));
}
Ok(Self {
drop_prob,
keep_prob: 1.0 - drop_prob,
})
}
pub fn apply(
&self,
path: &ArrayView2<f64>,
training: bool,
rng: &mut StdRng,
) -> TrainResult<Array2<f64>> {
if !training || self.drop_prob == 0.0 {
return Ok(path.to_owned());
}
if self.drop_prob == 1.0 {
return Ok(Array2::zeros(path.raw_dim()));
}
let should_drop = rng.random::<f64>() < self.drop_prob;
if should_drop {
Ok(Array2::zeros(path.raw_dim()))
} else {
Ok(path.mapv(|x| x / self.keep_prob))
}
}
pub fn apply_batch(
&self,
paths: &ArrayView2<f64>,
training: bool,
rng: &mut StdRng,
) -> TrainResult<Array2<f64>> {
if !training || self.drop_prob == 0.0 {
return Ok(paths.to_owned());
}
let (batch_size, _) = paths.dim();
let mut output = paths.to_owned();
for i in 0..batch_size {
let should_drop = rng.random::<f64>() < self.drop_prob;
if should_drop {
for j in 0..output.ncols() {
output[[i, j]] = 0.0;
}
} else {
for j in 0..output.ncols() {
output[[i, j]] /= self.keep_prob;
}
}
}
Ok(output)
}
pub fn keep_probability(&self) -> f64 {
self.keep_prob
}
pub fn set_drop_prob(&mut self, drop_prob: f64) -> TrainResult<()> {
if !(0.0..=1.0).contains(&drop_prob) {
return Err(TrainError::InvalidParameter(
"drop_prob must be in [0, 1]".to_string(),
));
}
self.drop_prob = drop_prob;
self.keep_prob = 1.0 - drop_prob;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct LinearStochasticDepth {
pub num_layers: usize,
pub drop_prob_min: f64,
pub drop_prob_max: f64,
}
impl LinearStochasticDepth {
pub fn new(num_layers: usize, drop_prob_min: f64, drop_prob_max: f64) -> TrainResult<Self> {
if num_layers == 0 {
return Err(TrainError::InvalidParameter(
"num_layers must be > 0".to_string(),
));
}
if !(0.0..=1.0).contains(&drop_prob_min) || !(0.0..=1.0).contains(&drop_prob_max) {
return Err(TrainError::InvalidParameter(
"drop probabilities must be in [0, 1]".to_string(),
));
}
if drop_prob_min > drop_prob_max {
return Err(TrainError::InvalidParameter(
"drop_prob_min must be <= drop_prob_max".to_string(),
));
}
Ok(Self {
num_layers,
drop_prob_min,
drop_prob_max,
})
}
pub fn get_drop_prob(&self, layer_idx: usize) -> f64 {
if layer_idx >= self.num_layers {
return self.drop_prob_max;
}
if self.num_layers == 1 {
return self.drop_prob_min;
}
let ratio = layer_idx as f64 / (self.num_layers - 1) as f64;
self.drop_prob_min + (self.drop_prob_max - self.drop_prob_min) * ratio
}
pub fn create_drop_paths(&self) -> TrainResult<Vec<DropPath>> {
let mut drop_paths = Vec::with_capacity(self.num_layers);
for i in 0..self.num_layers {
let drop_prob = self.get_drop_prob(i);
drop_paths.push(DropPath::new(drop_prob)?);
}
Ok(drop_paths)
}
}
#[derive(Debug, Clone)]
pub struct ExponentialStochasticDepth {
pub num_layers: usize,
pub drop_prob_min: f64,
pub drop_prob_max: f64,
}
impl ExponentialStochasticDepth {
pub fn new(num_layers: usize, drop_prob_min: f64, drop_prob_max: f64) -> TrainResult<Self> {
if num_layers == 0 {
return Err(TrainError::InvalidParameter(
"num_layers must be > 0".to_string(),
));
}
if !(0.0..=1.0).contains(&drop_prob_min) || !(0.0..=1.0).contains(&drop_prob_max) {
return Err(TrainError::InvalidParameter(
"drop probabilities must be in [0, 1]".to_string(),
));
}
if drop_prob_min > drop_prob_max {
return Err(TrainError::InvalidParameter(
"drop_prob_min must be <= drop_prob_max".to_string(),
));
}
Ok(Self {
num_layers,
drop_prob_min,
drop_prob_max,
})
}
pub fn get_drop_prob(&self, layer_idx: usize) -> f64 {
if layer_idx >= self.num_layers {
return self.drop_prob_max;
}
if self.num_layers == 1 {
return self.drop_prob_min;
}
let ratio = layer_idx as f64 / (self.num_layers - 1) as f64;
let exp_ratio = ratio * ratio;
self.drop_prob_min + (self.drop_prob_max - self.drop_prob_min) * exp_ratio
}
pub fn create_drop_paths(&self) -> TrainResult<Vec<DropPath>> {
let mut drop_paths = Vec::with_capacity(self.num_layers);
for i in 0..self.num_layers {
let drop_prob = self.get_drop_prob(i);
drop_paths.push(DropPath::new(drop_prob)?);
}
Ok(drop_paths)
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
use scirs2_core::random::SeedableRng;
fn create_test_rng() -> StdRng {
StdRng::seed_from_u64(42)
}
#[test]
fn test_drop_path_creation() {
let dp = DropPath::new(0.2).expect("unwrap");
assert_eq!(dp.drop_prob, 0.2);
assert!((dp.keep_prob - 0.8).abs() < 1e-10);
}
#[test]
fn test_drop_path_invalid_prob() {
assert!(DropPath::new(-0.1).is_err());
assert!(DropPath::new(1.5).is_err());
}
#[test]
fn test_drop_path_zero_prob() {
let dp = DropPath::new(0.0).expect("unwrap");
let mut rng = create_test_rng();
let path = array![[1.0, 2.0], [3.0, 4.0]];
let output = dp.apply(&path.view(), true, &mut rng).expect("unwrap");
assert_eq!(output, path);
}
#[test]
fn test_drop_path_full_prob() {
let dp = DropPath::new(1.0).expect("unwrap");
let mut rng = create_test_rng();
let path = array![[1.0, 2.0], [3.0, 4.0]];
let output = dp.apply(&path.view(), true, &mut rng).expect("unwrap");
assert_eq!(output, Array2::<f64>::zeros((2, 2)));
}
#[test]
fn test_drop_path_inference_mode() {
let dp = DropPath::new(0.5).expect("unwrap");
let mut rng = create_test_rng();
let path = array![[1.0, 2.0], [3.0, 4.0]];
let output = dp.apply(&path.view(), false, &mut rng).expect("unwrap");
assert_eq!(output, path);
}
#[test]
fn test_drop_path_training_mode() {
let dp = DropPath::new(0.5).expect("unwrap");
let mut rng = create_test_rng();
let path = array![[1.0, 2.0]];
let mut dropped_count = 0;
let mut kept_count = 0;
for _ in 0..100 {
let output = dp.apply(&path.view(), true, &mut rng).expect("unwrap");
if output[[0, 0]] == 0.0 {
dropped_count += 1;
} else {
kept_count += 1;
assert!((output[[0, 0]] - 2.0).abs() < 1e-10);
}
}
assert!(dropped_count > 30 && dropped_count < 70);
assert!(kept_count > 30 && kept_count < 70);
}
#[test]
fn test_drop_path_batch() {
let dp = DropPath::new(0.5).expect("unwrap");
let mut rng = create_test_rng();
let paths = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
let output = dp
.apply_batch(&paths.view(), true, &mut rng)
.expect("unwrap");
assert_eq!(output.shape(), paths.shape());
let mut dropped_rows = 0;
for i in 0..output.nrows() {
if output[[i, 0]] == 0.0 && output[[i, 1]] == 0.0 {
dropped_rows += 1;
}
}
assert!(dropped_rows > 0);
}
#[test]
fn test_drop_path_set_prob() {
let mut dp = DropPath::new(0.2).expect("unwrap");
assert_eq!(dp.drop_prob, 0.2);
dp.set_drop_prob(0.5).expect("unwrap");
assert_eq!(dp.drop_prob, 0.5);
assert!((dp.keep_prob - 0.5).abs() < 1e-10);
assert!(dp.set_drop_prob(1.5).is_err());
}
#[test]
fn test_linear_stochastic_depth_creation() {
let scheduler = LinearStochasticDepth::new(10, 0.0, 0.5).expect("unwrap");
assert_eq!(scheduler.num_layers, 10);
assert_eq!(scheduler.drop_prob_min, 0.0);
assert_eq!(scheduler.drop_prob_max, 0.5);
}
#[test]
fn test_linear_stochastic_depth_invalid() {
assert!(LinearStochasticDepth::new(0, 0.0, 0.5).is_err());
assert!(LinearStochasticDepth::new(10, -0.1, 0.5).is_err());
assert!(LinearStochasticDepth::new(10, 0.0, 1.5).is_err());
assert!(LinearStochasticDepth::new(10, 0.6, 0.3).is_err());
}
#[test]
fn test_linear_stochastic_depth_interpolation() {
let scheduler = LinearStochasticDepth::new(10, 0.0, 0.9).expect("unwrap");
assert!((scheduler.get_drop_prob(0) - 0.0).abs() < 1e-10);
assert!((scheduler.get_drop_prob(5) - 0.5).abs() < 1e-6);
assert!((scheduler.get_drop_prob(9) - 0.9).abs() < 1e-10);
}
#[test]
fn test_linear_stochastic_depth_create_paths() {
let scheduler = LinearStochasticDepth::new(5, 0.0, 0.4).expect("unwrap");
let paths = scheduler.create_drop_paths().expect("unwrap");
assert_eq!(paths.len(), 5);
assert!((paths[0].drop_prob - 0.0).abs() < 1e-10);
assert!((paths[2].drop_prob - 0.2).abs() < 1e-10);
assert!((paths[4].drop_prob - 0.4).abs() < 1e-10);
}
#[test]
fn test_exponential_stochastic_depth() {
let scheduler = ExponentialStochasticDepth::new(10, 0.0, 0.8).expect("unwrap");
assert!((scheduler.get_drop_prob(0) - 0.0).abs() < 1e-10);
assert!((scheduler.get_drop_prob(9) - 0.8).abs() < 1e-10);
let mid_prob = scheduler.get_drop_prob(5);
let linear_mid = 0.4;
assert!(mid_prob < linear_mid + 0.1);
}
#[test]
fn test_exponential_create_paths() {
let scheduler = ExponentialStochasticDepth::new(5, 0.0, 0.4).expect("unwrap");
let paths = scheduler.create_drop_paths().expect("unwrap");
assert_eq!(paths.len(), 5);
for i in 0..paths.len() - 1 {
assert!(paths[i].drop_prob <= paths[i + 1].drop_prob);
}
}
}