optirs_core/regularizers/
dropconnect.rs

1// DropConnect regularization
2//
3// DropConnect is a regularization technique that randomly drops connections between layers
4// during training. Unlike Dropout which drops units, DropConnect drops individual weights.
5
6use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
7use scirs2_core::numeric::Float;
8use scirs2_core::random::{thread_rng, Rng};
9use std::fmt::Debug;
10
11use crate::error::{OptimError, Result};
12use crate::regularizers::Regularizer;
13
14/// DropConnect regularizer
15///
16/// Randomly drops connections (weights) during training to prevent overfitting.
17///
18/// # Example
19///
20/// ```
21/// use scirs2_core::ndarray::Array2;
22/// use scirs2_core::ndarray::array;
23/// use optirs_core::regularizers::DropConnect;
24///
25/// let dropconnect = DropConnect::new(0.5).unwrap(); // 50% connection dropout
26/// let weights = array![[1.0, 2.0], [3.0, 4.0]];
27///
28/// // During training
29/// let masked_weights = dropconnect.apply_to_weights(&weights, true);
30/// // Some connections will be zeroed out randomly
31///
32/// // During inference
33/// let inference_weights = dropconnect.apply_to_weights(&weights, false);
34/// // No dropout during inference - weights are scaled appropriately
35/// ```
36#[derive(Debug, Clone)]
37pub struct DropConnect<A: Float> {
38    /// Probability of dropping a connection
39    drop_prob: A,
40}
41
42impl<A: Float + Debug + ScalarOperand + Send + Sync> DropConnect<A> {
43    /// Create a new DropConnect regularizer
44    ///
45    /// # Arguments
46    ///
47    /// * `drop_prob` - Probability of dropping each connection (0.0 to 1.0)
48    ///
49    /// # Returns
50    ///
51    /// A new DropConnect instance or error if probability is invalid
52    pub fn new(dropprob: A) -> Result<Self> {
53        if dropprob < A::zero() || dropprob > A::one() {
54            return Err(OptimError::InvalidConfig(
55                "Drop probability must be between 0.0 and 1.0".to_string(),
56            ));
57        }
58
59        Ok(Self {
60            drop_prob: dropprob,
61        })
62    }
63
64    /// Apply DropConnect to weights
65    ///
66    /// # Arguments
67    ///
68    /// * `weights` - The weight matrix to apply DropConnect to
69    /// * `training` - Whether we're in training mode (applies dropout) or inference mode
70    pub fn apply_to_weights<D: Dimension>(
71        &self,
72        weights: &Array<A, D>,
73        training: bool,
74    ) -> Array<A, D> {
75        if !training || self.drop_prob == A::zero() {
76            // During inference or if no dropout, return weights as-is
77            return weights.clone();
78        }
79
80        // Create keep probability for sampling
81        let keep_prob = A::one() - self.drop_prob;
82        let keep_prob_f64 = keep_prob.to_f64().unwrap();
83
84        // Sample mask
85        let mut rng = thread_rng();
86        let mask = Array::from_shape_fn(weights.raw_dim(), |_| rng.random_bool(keep_prob_f64));
87
88        // Apply mask and scale by keep probability
89        let mut result = weights.clone();
90        for (r, &m) in result.iter_mut().zip(mask.iter()) {
91            if !m {
92                *r = A::zero();
93            } else {
94                // Scale the kept weights to maintain expected value
95                *r = *r / keep_prob;
96            }
97        }
98
99        result
100    }
101
102    /// Apply DropConnect during gradient computation
103    ///
104    /// This method should be called during backpropagation to ensure
105    /// gradients are only computed for non-dropped connections
106    pub fn apply_to_gradients<D: Dimension>(
107        &self,
108        gradients: &Array<A, D>,
109        weightsshape: D,
110        training: bool,
111    ) -> Array<A, D> {
112        if !training || self.drop_prob == A::zero() {
113            return gradients.clone();
114        }
115
116        // Use the same mask for gradients
117        let keep_prob = A::one() - self.drop_prob;
118        let keep_prob_f64 = keep_prob.to_f64().unwrap();
119
120        // Create mask with same shape as weights
121        let mut rng = thread_rng();
122        let mask = Array::from_shape_fn(weightsshape, |_| rng.random_bool(keep_prob_f64));
123
124        // Apply mask to gradients
125        let mut result = gradients.clone();
126        for (g, &m) in result.iter_mut().zip(mask.iter()) {
127            if !m {
128                *g = A::zero();
129            } else {
130                // Scale gradients by keep probability
131                *g = *g / keep_prob;
132            }
133        }
134
135        result
136    }
137}
138
139impl<A: Float + Debug + ScalarOperand + Send + Sync, D: Dimension + Send + Sync> Regularizer<A, D>
140    for DropConnect<A>
141{
142    fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
143        // Apply DropConnect mask to gradients
144        let masked_gradients = self.apply_to_gradients(gradients, params.raw_dim(), true);
145
146        // Update gradients in place
147        gradients.assign(&masked_gradients);
148
149        // DropConnect doesn't add a penalty term
150        Ok(A::zero())
151    }
152
153    fn penalty(&self, params: &Array<A, D>) -> Result<A> {
154        // DropConnect doesn't add a penalty term to the loss
155        Ok(A::zero())
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162    use approx::assert_relative_eq;
163    use scirs2_core::ndarray::array;
164
165    #[test]
166    fn test_dropconnect_creation() {
167        // Valid creation
168        let dc = DropConnect::<f64>::new(0.5).unwrap();
169        assert_eq!(dc.drop_prob, 0.5);
170
171        // Invalid probabilities
172        assert!(DropConnect::<f64>::new(-0.1).is_err());
173        assert!(DropConnect::<f64>::new(1.1).is_err());
174    }
175
176    #[test]
177    fn test_dropconnect_training_mode() {
178        let dc = DropConnect::new(0.5).unwrap();
179        let weights = array![[1.0, 2.0], [3.0, 4.0]];
180
181        // During training, some connections should be dropped
182        let masked_weights = dc.apply_to_weights(&weights, true);
183
184        // Check that some but not all values are zero (statistically)
185        let _zeros = masked_weights.iter().filter(|&&x| x == 0.0).count();
186
187        // The masked weights should have approximately scaled values
188        for (&original, &masked) in weights.iter().zip(masked_weights.iter()) {
189            if masked != 0.0 {
190                // Non-zero values should be scaled by 1/keep_prob = 2.0
191                assert_relative_eq!(masked, original * 2.0, epsilon = 1e-10);
192            }
193        }
194    }
195
196    #[test]
197    fn test_dropconnect_inference_mode() {
198        let dc = DropConnect::new(0.5).unwrap();
199        let weights = array![[1.0, 2.0], [3.0, 4.0]];
200
201        // During inference, weights should remain unchanged
202        let inference_weights = dc.apply_to_weights(&weights, false);
203        assert_eq!(weights, inference_weights);
204    }
205
206    #[test]
207    fn test_dropconnect_zero_probability() {
208        let dc = DropConnect::new(0.0).unwrap();
209        let weights = array![[1.0, 2.0], [3.0, 4.0]];
210
211        // With 0% dropout, weights should remain unchanged
212        let result = dc.apply_to_weights(&weights, true);
213        assert_eq!(weights, result);
214    }
215
216    #[test]
217    fn test_dropconnect_gradients() {
218        let dc = DropConnect::new(0.5).unwrap();
219        let gradients = array![[1.0, 1.0], [1.0, 1.0]];
220        let weightsshape = gradients.raw_dim();
221
222        // Apply to gradients
223        let masked_grads = dc.apply_to_gradients(&gradients, weightsshape, true);
224
225        // Check scaling
226        for &grad in masked_grads.iter() {
227            if grad != 0.0 {
228                assert_relative_eq!(grad, 2.0, epsilon = 1e-10);
229            }
230        }
231    }
232
233    #[test]
234    fn test_regularizer_trait() {
235        let dc = DropConnect::new(0.3).unwrap();
236        let params = array![[1.0, 2.0], [3.0, 4.0]];
237        let mut gradient = array![[0.1, 0.2], [0.3, 0.4]];
238
239        // Test Regularizer trait methods
240        let penalty = dc.penalty(&params).unwrap();
241        assert_eq!(penalty, 0.0); // DropConnect has no penalty term
242
243        // Test gradient computation
244        let penalty_from_apply = dc.apply(&params, &mut gradient).unwrap();
245        assert_eq!(penalty_from_apply, 0.0);
246
247        // Gradient should be modified with dropout
248        let zeros = gradient.iter().filter(|&&x| x == 0.0).count();
249        assert!(zeros <= 4); // Some elements may be dropped
250    }
251}