use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
use scirs2_core::numeric::Float;
use scirs2_core::random::{thread_rng, Rng};
use std::fmt::Debug;
use crate::error::{OptimError, Result};
use crate::regularizers::Regularizer;
#[derive(Debug, Clone)]
pub struct DropConnect<A: Float> {
drop_prob: A,
}
impl<A: Float + Debug + ScalarOperand + Send + Sync> DropConnect<A> {
pub fn new(dropprob: A) -> Result<Self> {
if dropprob < A::zero() || dropprob > A::one() {
return Err(OptimError::InvalidConfig(
"Drop probability must be between 0.0 and 1.0".to_string(),
));
}
Ok(Self {
drop_prob: dropprob,
})
}
pub fn apply_to_weights<D: Dimension>(
&self,
weights: &Array<A, D>,
training: bool,
) -> Array<A, D> {
if !training || self.drop_prob == A::zero() {
return weights.clone();
}
let keep_prob = A::one() - self.drop_prob;
let keep_prob_f64 = keep_prob.to_f64().expect("unwrap failed");
let mut rng = thread_rng();
let mask = Array::from_shape_fn(weights.raw_dim(), |_| rng.random_bool(keep_prob_f64));
let mut result = weights.clone();
for (r, &m) in result.iter_mut().zip(mask.iter()) {
if !m {
*r = A::zero();
} else {
*r = *r / keep_prob;
}
}
result
}
pub fn apply_to_gradients<D: Dimension>(
&self,
gradients: &Array<A, D>,
weightsshape: D,
training: bool,
) -> Array<A, D> {
if !training || self.drop_prob == A::zero() {
return gradients.clone();
}
let keep_prob = A::one() - self.drop_prob;
let keep_prob_f64 = keep_prob.to_f64().expect("unwrap failed");
let mut rng = thread_rng();
let mask = Array::from_shape_fn(weightsshape, |_| rng.random_bool(keep_prob_f64));
let mut result = gradients.clone();
for (g, &m) in result.iter_mut().zip(mask.iter()) {
if !m {
*g = A::zero();
} else {
*g = *g / keep_prob;
}
}
result
}
}
impl<A: Float + Debug + ScalarOperand + Send + Sync, D: Dimension + Send + Sync> Regularizer<A, D>
for DropConnect<A>
{
fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
let masked_gradients = self.apply_to_gradients(gradients, params.raw_dim(), true);
gradients.assign(&masked_gradients);
Ok(A::zero())
}
fn penalty(&self, params: &Array<A, D>) -> Result<A> {
Ok(A::zero())
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_dropconnect_creation() {
let dc = DropConnect::<f64>::new(0.5).expect("unwrap failed");
assert_eq!(dc.drop_prob, 0.5);
assert!(DropConnect::<f64>::new(-0.1).is_err());
assert!(DropConnect::<f64>::new(1.1).is_err());
}
#[test]
fn test_dropconnect_training_mode() {
let dc = DropConnect::new(0.5).expect("unwrap failed");
let weights = array![[1.0, 2.0], [3.0, 4.0]];
let masked_weights = dc.apply_to_weights(&weights, true);
let _zeros = masked_weights.iter().filter(|&&x| x == 0.0).count();
for (&original, &masked) in weights.iter().zip(masked_weights.iter()) {
if masked != 0.0 {
assert_relative_eq!(masked, original * 2.0, epsilon = 1e-10);
}
}
}
#[test]
fn test_dropconnect_inference_mode() {
let dc = DropConnect::new(0.5).expect("unwrap failed");
let weights = array![[1.0, 2.0], [3.0, 4.0]];
let inference_weights = dc.apply_to_weights(&weights, false);
assert_eq!(weights, inference_weights);
}
#[test]
fn test_dropconnect_zero_probability() {
let dc = DropConnect::new(0.0).expect("unwrap failed");
let weights = array![[1.0, 2.0], [3.0, 4.0]];
let result = dc.apply_to_weights(&weights, true);
assert_eq!(weights, result);
}
#[test]
fn test_dropconnect_gradients() {
let dc = DropConnect::new(0.5).expect("unwrap failed");
let gradients = array![[1.0, 1.0], [1.0, 1.0]];
let weightsshape = gradients.raw_dim();
let masked_grads = dc.apply_to_gradients(&gradients, weightsshape, true);
for &grad in masked_grads.iter() {
if grad != 0.0 {
assert_relative_eq!(grad, 2.0, epsilon = 1e-10);
}
}
}
#[test]
fn test_regularizer_trait() {
let dc = DropConnect::new(0.3).expect("unwrap failed");
let params = array![[1.0, 2.0], [3.0, 4.0]];
let mut gradient = array![[0.1, 0.2], [0.3, 0.4]];
let penalty = dc.penalty(¶ms).expect("unwrap failed");
assert_eq!(penalty, 0.0);
let penalty_from_apply = dc.apply(¶ms, &mut gradient).expect("unwrap failed");
assert_eq!(penalty_from_apply, 0.0);
let zeros = gradient.iter().filter(|&&x| x == 0.0).count();
assert!(zeros <= 4); }
}