use scirs2_core::ndarray::{Array, ArrayBase, Data, Dimension, ScalarOperand};
use scirs2_core::numeric::{Float, FromPrimitive};
use scirs2_core::random::Rng;
use std::cell::RefCell;
use std::fmt::Debug;
use crate::error::{OptimError, Result};
use crate::regularizers::Regularizer;
type ShakeDropResult<A, D> = (Array<A, D>, (A, A, A));
#[derive(Debug)]
pub struct ShakeDrop<A: Float + FromPrimitive + Debug> {
pub p: A,
pub alpha_range: (A, A),
pub beta_range: (A, A),
rng: RefCell<scirs2_core::random::Random<scirs2_core::random::rngs::StdRng>>,
}
impl<A: Float + FromPrimitive + Debug + Send + Sync> ShakeDrop<A> {
pub fn new(p: A) -> Self {
let zero = A::zero();
let one = A::one();
let neg_one = zero - one;
Self {
p,
alpha_range: (neg_one, one),
beta_range: (zero, one),
rng: RefCell::new(scirs2_core::random::Random::seed(42)),
}
}
pub fn new_with_ranges(p: A, alpharange: (A, A), beta_range: (A, A)) -> Self {
Self {
p,
alpha_range: alpharange,
beta_range,
rng: RefCell::new(scirs2_core::random::Random::seed(42)),
}
}
fn random_in_range(&self, range: (A, A)) -> Result<A> {
let (min, max) = range;
let min_f = min
.to_f64()
.ok_or_else(|| OptimError::InvalidConfig("Failed to convert min to f64".to_string()))?;
let max_f = max
.to_f64()
.ok_or_else(|| OptimError::InvalidConfig("Failed to convert max to f64".to_string()))?;
if (max_f - min_f).abs() < 1e-10 {
return Ok(min);
}
let random_val = self.rng.borrow_mut().gen_range(min_f..max_f);
A::from_f64(random_val).ok_or_else(|| {
OptimError::InvalidConfig("Failed to convert random value from f64".to_string())
})
}
fn get_gate(&self) -> Result<(A, A, A)> {
let zero = A::zero();
let one = A::one();
let u: f64 = self.rng.borrow_mut().gen_range(0.0..1.0);
let p_f64 = self
.p
.to_f64()
.ok_or_else(|| OptimError::InvalidConfig("Failed to convert p to f64".to_string()))?;
let b = if u < p_f64 { one } else { zero };
let alpha = if b > zero {
self.random_in_range(self.alpha_range)?
} else {
zero
};
let beta = self.random_in_range(self.beta_range)?;
Ok((b, alpha, beta))
}
pub fn forward<S, D>(&self, x: &ArrayBase<S, D>) -> Result<ShakeDropResult<A, D>>
where
S: Data<Elem = A>,
D: Dimension,
{
let (b, alpha, beta) = self.get_gate()?;
let factor = b + alpha - b * alpha;
let result = x.mapv(|v| v * factor);
Ok((result, (b, alpha, beta)))
}
pub fn backward<S, D>(
&self,
grad_output: &ArrayBase<S, D>,
gate_params: (A, A, A),
) -> Array<A, D>
where
S: Data<Elem = A>,
D: Dimension,
{
let (b, alpha, beta) = gate_params;
let factor = b + beta - b * beta;
grad_output.mapv(|g| g * factor)
}
}
impl<A: Float + FromPrimitive + Debug + ScalarOperand, D: Dimension + Send + Sync> Regularizer<A, D>
for ShakeDrop<A>
{
fn apply(&self, _params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
Err(OptimError::InvalidConfig(
"ShakeDrop should be applied to activations during forward/backward passes, \
not through the Regularizer trait's apply method"
.to_string(),
))
}
fn penalty(&self, params: &Array<A, D>) -> Result<A> {
Ok(A::zero())
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::{Array1, Array2};
#[test]
fn test_shakedrop_new() {
let sd = ShakeDrop::new(0.5f64);
assert_eq!(sd.p, 0.5);
assert_eq!(sd.alpha_range, (-1.0, 1.0));
assert_eq!(sd.beta_range, (0.0, 1.0));
}
#[test]
fn test_shakedrop_new_with_ranges() {
let sd = ShakeDrop::new_with_ranges(0.7f64, (-0.5, 0.5), (0.2, 0.8));
assert_eq!(sd.p, 0.7);
assert_eq!(sd.alpha_range, (-0.5, 0.5));
assert_eq!(sd.beta_range, (0.2, 0.8));
}
#[test]
fn test_shakedrop_forward_backward() {
let x = Array2::from_elem((2, 3), 1.0f64);
let sd = ShakeDrop::new_with_ranges(1.0f64, (0.5, 0.500001), (0.5, 0.500001));
let (output, gate_params) = sd.forward(&x).expect("forward failed");
assert_eq!(gate_params.0, 1.0); assert_abs_diff_eq!(gate_params.1, 0.5, epsilon = 1e-5); assert_abs_diff_eq!(gate_params.2, 0.5, epsilon = 1e-5);
for &val in output.iter() {
assert_abs_diff_eq!(val, 1.0, epsilon = 1e-5);
}
let grad_output = Array2::from_elem((2, 3), 2.0f64);
let grad_input = sd.backward(&grad_output, gate_params);
for &val in grad_input.iter() {
assert_abs_diff_eq!(val, 2.0, epsilon = 1e-5);
}
}
#[test]
fn test_shakedrop_forward_inactive() {
let x = Array1::from_vec(vec![1.0f64, 2.0, 3.0]);
let sd = ShakeDrop::new_with_ranges(0.0f64, (-0.5, -0.499999), (0.5, 0.500001));
let (output, gate_params) = sd.forward(&x).expect("forward failed");
assert_eq!(gate_params.0, 0.0); assert_eq!(gate_params.1, 0.0); assert_abs_diff_eq!(gate_params.2, 0.5, epsilon = 1e-5);
for &val in output.iter() {
assert_abs_diff_eq!(val, 0.0, epsilon = 1e-10);
}
}
#[test]
fn test_shakedrop_gen_range() {
let sd = ShakeDrop::new(0.5f64);
for _ in 0..100 {
let value = sd
.random_in_range((-0.5, 0.5))
.expect("random_in_range failed");
assert!((-0.5..=0.5).contains(&value));
}
let value = sd
.random_in_range((0.5, 0.5))
.expect("random_in_range failed");
assert_eq!(value, 0.5);
}
#[test]
fn test_shakedrop_get_gate() {
let sd = ShakeDrop::new(1.0f64);
for _ in 0..10 {
let (b, alpha, beta) = sd.get_gate().expect("get_gate failed");
assert_eq!(b, 1.0);
assert!((-1.0..=1.0).contains(&alpha));
assert!((0.0..=1.0).contains(&beta));
}
let sd = ShakeDrop::new(0.0f64);
for _ in 0..10 {
let (b, alpha, beta) = sd.get_gate().expect("get_gate failed");
assert_eq!(b, 0.0);
assert_eq!(alpha, 0.0);
assert!((0.0..=1.0).contains(&beta));
}
}
#[test]
fn test_regularizer_trait() {
let sd = ShakeDrop::new(0.5f64);
let params = Array2::from_elem((2, 3), 1.0f64);
let mut grads = Array2::from_elem((2, 3), 1.0f64);
assert!(sd.apply(¶ms, &mut grads).is_err());
let penalty = sd.penalty(¶ms).expect("penalty failed");
assert_eq!(penalty, 0.0);
}
}