use scirs2_core::ndarray::{Array, Axis, Dimension, Ix3, ScalarOperand};
use scirs2_core::numeric::Float;
use scirs2_core::random::{thread_rng, Rng};
use std::fmt::Debug;
use crate::error::{OptimError, Result};
use crate::regularizers::Regularizer;
#[derive(Debug, Clone)]
pub struct SpatialDropout<A: Float> {
dropprob: A,
feature_dim: Axis,
}
impl<A: Float + Debug + ScalarOperand + Send + Sync> SpatialDropout<A> {
pub fn new(dropprob: A) -> Result<Self> {
if dropprob < A::zero() || dropprob > A::one() {
return Err(OptimError::InvalidConfig(
"Drop probability must be between 0.0 and 1.0".to_string(),
));
}
Ok(Self {
dropprob,
feature_dim: Axis(1), })
}
pub fn with_feature_dim(mut self, dim: usize) -> Self {
self.feature_dim = Axis(dim);
self
}
pub fn apply<D>(&self, features: &Array<A, D>, training: bool) -> Array<A, D>
where
D: Dimension + scirs2_core::ndarray::RemoveAxis,
{
if !training || self.dropprob == A::zero() {
return features.clone();
}
let keep_prob = A::one() - self.dropprob;
let feature_size = features.shape()[self.feature_dim.0];
let keep_prob_f64 = keep_prob.to_f64().expect("unwrap failed");
let mut rng = thread_rng();
let feature_mask: Vec<bool> = (0..feature_size)
.map(|_| rng.random_bool(keep_prob_f64))
.collect();
let mut result = features.clone();
for (idx, &keep) in feature_mask.iter().enumerate() {
if !keep {
let mut axis_slice = result.index_axis_mut(self.feature_dim, idx);
axis_slice.fill(A::zero());
} else {
let mut axis_slice = result.index_axis_mut(self.feature_dim, idx);
axis_slice.mapv_inplace(|x| x / keep_prob);
}
}
result
}
}
#[derive(Debug, Clone)]
pub struct FeatureDropout<A: Float> {
dropprob: A,
feature_dim: Axis,
}
impl<A: Float + Debug + ScalarOperand + Send + Sync> FeatureDropout<A> {
pub fn new(dropprob: A) -> Result<Self> {
if dropprob < A::zero() || dropprob > A::one() {
return Err(OptimError::InvalidConfig(
"Drop probability must be between 0.0 and 1.0".to_string(),
));
}
Ok(Self {
dropprob,
feature_dim: Axis(1), })
}
pub fn with_feature_dim(mut self, dim: usize) -> Self {
self.feature_dim = Axis(dim);
self
}
pub fn apply<D>(&self, features: &Array<A, D>, training: bool) -> Array<A, D>
where
D: Dimension + scirs2_core::ndarray::RemoveAxis,
{
if !training || self.dropprob == A::zero() {
return features.clone();
}
let keep_prob = A::one() - self.dropprob;
let feature_size = features.shape()[self.feature_dim.0];
let keep_prob_f64 = keep_prob.to_f64().expect("unwrap failed");
let mut rng = thread_rng();
let feature_mask: Vec<bool> = (0..feature_size)
.map(|_| rng.random_bool(keep_prob_f64))
.collect();
let mut result = features.clone();
for (idx, &keep) in feature_mask.iter().enumerate() {
if !keep {
let mut axis_slice = result.index_axis_mut(self.feature_dim, idx);
axis_slice.fill(A::zero());
} else {
let mut axis_slice = result.index_axis_mut(self.feature_dim, idx);
axis_slice.mapv_inplace(|x| x / keep_prob);
}
}
result
}
}
impl<
A: Float + Debug + ScalarOperand + Send + Sync,
D: Dimension + scirs2_core::ndarray::RemoveAxis + Send + Sync,
> Regularizer<A, D> for SpatialDropout<A>
{
fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
let masked_gradients = SpatialDropout::apply(self, gradients, true);
*gradients = masked_gradients;
Ok(A::zero())
}
fn penalty(&self, params: &Array<A, D>) -> Result<A> {
Ok(A::zero())
}
}
impl<
A: Float + Debug + ScalarOperand + Send + Sync,
D: Dimension + scirs2_core::ndarray::RemoveAxis + Send + Sync,
> Regularizer<A, D> for FeatureDropout<A>
{
fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
let masked_gradients = FeatureDropout::apply(self, gradients, true);
*gradients = masked_gradients;
Ok(A::zero())
}
fn penalty(&self, params: &Array<A, D>) -> Result<A> {
Ok(A::zero())
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_spatial_dropout_creation() {
let sd = SpatialDropout::<f64>::new(0.3).expect("unwrap failed");
assert_eq!(sd.dropprob, 0.3);
assert!(SpatialDropout::<f64>::new(-0.1).is_err());
assert!(SpatialDropout::<f64>::new(1.1).is_err());
}
#[test]
fn test_spatial_dropout_4d() {
let sd = SpatialDropout::new(0.5).expect("unwrap failed");
let features = Array::from_shape_fn((2, 4, 3, 3), |(b, c, h, w)| {
1.0 + b as f64 + c as f64 * 10.0 + h as f64 * 0.1 + w as f64 * 0.01
});
let masked = sd.apply(&features, true);
for b in 0..2 {
for c in 0..4 {
let masked_batch = masked.index_axis(Axis(0), b);
let channel = masked_batch.index_axis(Axis(0), c);
let channel_clone = channel.to_owned();
let is_dropped = channel_clone.iter().all(|&x| x.abs() < 1e-10);
let is_kept = channel_clone.iter().all(|&x| x.abs() > 1e-10);
if is_dropped {
for &val in channel_clone.iter() {
assert_eq!(val, 0.0);
}
} else if is_kept {
let original_batch = features.index_axis(Axis(0), b);
let original_channel = original_batch.index_axis(Axis(0), c);
for ((i, j), &val) in channel_clone.indexed_iter() {
assert_relative_eq!(val, original_channel[[i, j]] * 2.0, epsilon = 1e-10);
}
} else {
println!("Channel {c} in batch {b} has mixed values:");
for val in channel_clone.iter() {
println!(" Value: {val}");
}
panic!("Channel should be entirely dropped or kept");
}
}
}
}
#[test]
fn test_feature_dropout_creation() {
let fd = FeatureDropout::<f64>::new(0.4).expect("unwrap failed");
assert_eq!(fd.dropprob, 0.4);
assert!(FeatureDropout::<f64>::new(-0.1).is_err());
assert!(FeatureDropout::<f64>::new(1.1).is_err());
}
#[test]
fn test_feature_dropout_3d() {
let fd = FeatureDropout::new(0.5).expect("unwrap failed");
let features = Array::from_shape_fn((2, 5, 10), |(_b, f, s)| f as f64 + s as f64);
let masked = fd.apply(&features, true);
for f in 0..5 {
let first_batch = masked.index_axis(Axis(0), 0);
let first_batch_feature = first_batch.index_axis(Axis(0), f);
let first_batch_clone = first_batch_feature.to_owned();
let is_dropped = first_batch_clone.iter().all(|&x| x == 0.0);
for b in 0..2 {
let batch = masked.index_axis(Axis(0), b);
let feature_slice = batch.index_axis(Axis(0), f);
let feature_clone = feature_slice.to_owned();
let all_dropped = feature_clone.iter().all(|&x| x == 0.0);
assert_eq!(
is_dropped, all_dropped,
"Feature dropout should be consistent"
);
if !all_dropped {
let original_batch = features.index_axis(Axis(0), b);
let original_slice = original_batch.index_axis(Axis(0), f);
for (i, &val) in feature_clone.iter().enumerate() {
assert_relative_eq!(val, original_slice[i] * 2.0, epsilon = 1e-10);
}
}
}
}
}
#[test]
fn test_inference_mode() {
let sd = SpatialDropout::new(0.5).expect("unwrap failed");
let fd = FeatureDropout::new(0.5).expect("unwrap failed");
let features = array![[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]];
let sd_inference = sd.apply(&features, false);
let fd_inference = fd.apply(&features, false);
assert_eq!(features, sd_inference);
assert_eq!(features, fd_inference);
}
#[test]
fn test_regularizer_trait() {
let sd = SpatialDropout::new(0.3).expect("unwrap failed");
let params = array![[[1.0, 2.0], [3.0, 4.0]]];
let mut gradient = array![[[0.1, 0.2], [0.3, 0.4]]];
let penalty = sd.penalty(¶ms).expect("unwrap failed");
assert_eq!(penalty, 0.0);
let _penalty_apply = sd.apply(¶ms, true);
let penalty_reg =
<SpatialDropout<f64> as Regularizer<f64, Ix3>>::apply(&sd, ¶ms, &mut gradient)
.expect("unwrap failed");
assert_eq!(penalty_reg, 0.0);
let is_modified = gradient != array![[[0.1, 0.2], [0.3, 0.4]]];
assert!(is_modified || gradient == array![[[0.1, 0.2], [0.3, 0.4]]]);
}
}