optirs_core/regularizers/
shakedrop.rs

1use scirs2_core::ndarray::{Array, ArrayBase, Data, Dimension, ScalarOperand};
2use scirs2_core::numeric::{Float, FromPrimitive};
3use scirs2_core::random::Rng;
4use std::fmt::Debug;
5
6use crate::error::{OptimError, Result};
7use crate::regularizers::Regularizer;
8
9/// ShakeDrop regularization
10///
11/// ShakeDrop is a regularization method that extends Stochastic Depth and
12/// is often used in very deep neural networks. It randomly scales activations
13/// during training.
14///
15/// # Parameters
16///
17/// * `p` - The probability of activating ShakeDrop (probability of activating the forward
18///   pass transformation), value between 0 and 1.
19/// * `alpha_range` - The range for the alpha parameter used in forward pass (default: [-1.0, 1.0]).
20/// * `beta_range` - The range for the beta parameter used in backward pass (default: [0.0, 1.0]).
21///
22/// # References
23///
24/// * Yamada, Y., Iwamura, M., & Kise, K. (2018). ShakeDrop regularization.
25///   arXiv preprint arXiv:1802.02375.
26///
27#[derive(Debug, Clone)]
28pub struct ShakeDrop<A: Float + FromPrimitive + Debug> {
29    /// Probability of applying ShakeDrop
30    pub p: A,
31    /// Range for the alpha parameter
32    pub alpha_range: (A, A),
33    /// Range for the beta parameter
34    pub beta_range: (A, A),
35    /// Random number generator
36    rng: scirs2_core::random::Random<scirs2_core::random::rngs::StdRng>,
37}
38
39impl<A: Float + FromPrimitive + Debug + Send + Sync> ShakeDrop<A> {
40    /// Create a new ShakeDrop regularizer
41    ///
42    /// # Arguments
43    ///
44    /// * `p` - Probability of applying ShakeDrop, between 0 and 1
45    /// * `alpha_range` - Range for the alpha parameter (default: [-1.0, 1.0])
46    /// * `beta_range` - Range for the beta parameter (default: [0.0, 1.0])
47    ///
48    /// # Returns
49    ///
50    /// A ShakeDrop regularizer
51    pub fn new(p: A) -> Self {
52        let zero = A::zero();
53        let one = A::one();
54        let neg_one = zero - one;
55
56        Self {
57            p,
58            alpha_range: (neg_one, one),
59            beta_range: (zero, one),
60            rng: scirs2_core::random::Random::seed(42),
61        }
62    }
63
64    /// Create a new ShakeDrop regularizer with custom ranges
65    ///
66    /// # Arguments
67    ///
68    /// * `p` - Probability of applying ShakeDrop, between 0 and 1
69    /// * `alpha_range` - Range for the alpha parameter
70    /// * `beta_range` - Range for the beta parameter
71    ///
72    /// # Returns
73    ///
74    /// A ShakeDrop regularizer
75    pub fn new_with_ranges(p: A, alpharange: (A, A), beta_range: (A, A)) -> Self {
76        Self {
77            p,
78            alpha_range: alpharange,
79            beta_range,
80            rng: scirs2_core::random::Random::seed(42),
81        }
82    }
83
84    /// Get a random value between the given range
85    fn random_in_range(&mut self, range: (A, A)) -> A {
86        let (min, max) = range;
87        let min_f = min.to_f64().unwrap();
88        let max_f = max.to_f64().unwrap();
89
90        // Handle equal min and max to avoid "empty range" error
91        if (max_f - min_f).abs() < 1e-10 {
92            return min;
93        }
94
95        let random_val = self.rng.gen_range(min_f..max_f);
96        A::from_f64(random_val).unwrap()
97    }
98
99    /// Get a forward pass gate for the ShakeDrop
100    ///
101    /// # Returns
102    ///
103    /// A tuple (b, alpha, beta):
104    /// - b: Binary gate (1 or 0) based on the probability p
105    /// - alpha: Random value within alpha_range if b is 1, otherwise 0
106    /// - beta: Random value within beta_range
107    fn get_gate(&mut self) -> (A, A, A) {
108        let zero = A::zero();
109        let one = A::one();
110
111        // Determine if the gate is active
112        let u: f64 = self.rng.gen_range(0.0..1.0);
113        let b = if u < self.p.to_f64().unwrap() {
114            one
115        } else {
116            zero
117        };
118
119        // Get random alpha if gate is active..otherwise 0
120        let alpha = if b > zero {
121            self.random_in_range(self.alpha_range)
122        } else {
123            zero
124        };
125
126        // Get random beta regardless of gate
127        let beta = self.random_in_range(self.beta_range);
128
129        (b, alpha, beta)
130    }
131
132    /// Apply ShakeDrop to input activations
133    ///
134    /// # Arguments
135    ///
136    /// * `x` - Input activation tensor
137    ///
138    /// # Returns
139    ///
140    /// The transformed activations and gate parameters for use in backward pass
141    pub fn forward<S, D>(&mut self, x: &ArrayBase<S, D>) -> (Array<A, D>, (A, A, A))
142    where
143        S: Data<Elem = A>,
144        D: Dimension,
145    {
146        // Get the gate values
147        let (b, alpha, beta) = self.get_gate();
148
149        // Apply ShakeDrop transformation
150        // During forward pass: x' = x * (b + alpha - b*alpha)
151        let factor = b + alpha - b * alpha;
152        let result = x.mapv(|v| v * factor);
153
154        (result, (b, alpha, beta))
155    }
156
157    /// Backward pass for ShakeDrop
158    ///
159    /// # Arguments
160    ///
161    /// * `grad_output` - Gradient from the next layer
162    /// * `gate_params` - The gate parameters (b, alpha, beta) from the forward pass
163    ///
164    /// # Returns
165    ///
166    /// The modified gradients
167    pub fn backward<S, D>(
168        &self,
169        grad_output: &ArrayBase<S, D>,
170        gate_params: (A, A, A),
171    ) -> Array<A, D>
172    where
173        S: Data<Elem = A>,
174        D: Dimension,
175    {
176        let (b, alpha, beta) = gate_params;
177
178        // During backward pass: grad_x = grad_output * (b + beta - b*beta)
179        let factor = b + beta - b * beta;
180        grad_output.mapv(|g| g * factor)
181    }
182}
183
184impl<A: Float + FromPrimitive + Debug + ScalarOperand, D: Dimension + Send + Sync> Regularizer<A, D>
185    for ShakeDrop<A>
186{
187    fn apply(&self, _params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
188        // ShakeDrop is typically applied to activations, not parameters
189        // In this implementation, apply() isn't the primary usage pattern
190        // Instead, users would call forward() during the forward pass
191        // and backward() during the backward pass
192        Err(OptimError::InvalidConfig(
193            "ShakeDrop should be applied to activations during forward/backward passes, \
194             not through the Regularizer trait's apply method"
195                .to_string(),
196        ))
197    }
198
199    fn penalty(&self, params: &Array<A, D>) -> Result<A> {
200        // ShakeDrop doesn't add a penalty term to the loss function
201        Ok(A::zero())
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use approx::assert_abs_diff_eq;
209    use scirs2_core::ndarray::{Array1, Array2};
210
211    #[test]
212    fn test_shakedrop_new() {
213        let sd = ShakeDrop::new(0.5f64);
214        assert_eq!(sd.p, 0.5);
215        assert_eq!(sd.alpha_range, (-1.0, 1.0));
216        assert_eq!(sd.beta_range, (0.0, 1.0));
217    }
218
219    #[test]
220    fn test_shakedrop_new_with_ranges() {
221        let sd = ShakeDrop::new_with_ranges(0.7f64, (-0.5, 0.5), (0.2, 0.8));
222        assert_eq!(sd.p, 0.7);
223        assert_eq!(sd.alpha_range, (-0.5, 0.5));
224        assert_eq!(sd.beta_range, (0.2, 0.8));
225    }
226
227    #[test]
228    fn test_shakedrop_forward_backward() {
229        // Create a simple 2D array
230        let x = Array2::from_elem((2, 3), 1.0f64);
231
232        // Initialize ShakeDrop with p=1.0 to ensure gate is always active
233        // Use slightly different values for min and max to avoid empty range error
234        let mut sd = ShakeDrop::new_with_ranges(1.0f64, (0.5, 0.500001), (0.5, 0.500001));
235
236        // Forward pass
237        let (output, gate_params) = sd.forward(&x);
238
239        // Verify the gate parameters
240        assert_eq!(gate_params.0, 1.0); // b should be 1 since p=1.0
241        assert_abs_diff_eq!(gate_params.1, 0.5, epsilon = 1e-5); // alpha should be approximately 0.5
242        assert_abs_diff_eq!(gate_params.2, 0.5, epsilon = 1e-5); // beta should be approximately 0.5
243
244        // The expected output is x * (b + alpha - b*alpha) = x * (1 + 0.5 - 1*0.5) = x * 1
245        for &val in output.iter() {
246            assert_abs_diff_eq!(val, 1.0, epsilon = 1e-5);
247        }
248
249        // Backward pass
250        let grad_output = Array2::from_elem((2, 3), 2.0f64);
251        let grad_input = sd.backward(&grad_output, gate_params);
252
253        // The expected gradient is grad_output * (b + beta - b*beta) = grad_output * (1 + 0.5 - 1*0.5) = grad_output * 1
254        for &val in grad_input.iter() {
255            assert_abs_diff_eq!(val, 2.0, epsilon = 1e-5);
256        }
257    }
258
259    #[test]
260    fn test_shakedrop_forward_inactive() {
261        // Create a simple 1D array
262        let x = Array1::from_vec(vec![1.0f64, 2.0, 3.0]);
263
264        // Initialize ShakeDrop with p=0.0 to ensure gate is always inactive
265        // Use slightly different values for min and max to avoid empty range error
266        let mut sd = ShakeDrop::new_with_ranges(0.0f64, (-0.5, -0.499999), (0.5, 0.500001));
267
268        // Forward pass - gate should be inactive
269        let (output, gate_params) = sd.forward(&x);
270
271        // Verify the gate parameters
272        assert_eq!(gate_params.0, 0.0); // b should be 0 since p=0.0
273        assert_eq!(gate_params.1, 0.0); // alpha should be 0 when gate is inactive
274        assert_abs_diff_eq!(gate_params.2, 0.5, epsilon = 1e-5); // beta should be approximately 0.5
275
276        // The expected output is x * (b + alpha - b*alpha) = x * (0 + 0 - 0*0) = x * 0
277        for &val in output.iter() {
278            assert_abs_diff_eq!(val, 0.0, epsilon = 1e-10);
279        }
280    }
281
282    #[test]
283    fn test_shakedrop_gen_range() {
284        let mut sd = ShakeDrop::new(0.5f64);
285
286        // Test random value generation within range
287        for _ in 0..100 {
288            let value = sd.random_in_range((-0.5, 0.5));
289            assert!((-0.5..=0.5).contains(&value));
290        }
291
292        // Test with very small range (should not panic)
293        let value = sd.random_in_range((0.5, 0.5));
294        assert_eq!(value, 0.5);
295    }
296
297    #[test]
298    fn test_shakedrop_get_gate() {
299        // Test with p=1.0 - gate should always be active
300        let mut sd = ShakeDrop::new(1.0f64);
301        for _ in 0..10 {
302            let (b, alpha, beta) = sd.get_gate();
303            assert_eq!(b, 1.0);
304            assert!((-1.0..=1.0).contains(&alpha));
305            assert!((0.0..=1.0).contains(&beta));
306        }
307
308        // Test with p=0.0 - gate should always be inactive
309        let mut sd = ShakeDrop::new(0.0f64);
310        for _ in 0..10 {
311            let (b, alpha, beta) = sd.get_gate();
312            assert_eq!(b, 0.0);
313            assert_eq!(alpha, 0.0);
314            assert!((0.0..=1.0).contains(&beta));
315        }
316    }
317
318    #[test]
319    fn test_regularizer_trait() {
320        let sd = ShakeDrop::new(0.5f64);
321        let params = Array2::from_elem((2, 3), 1.0f64);
322        let mut grads = Array2::from_elem((2, 3), 1.0f64);
323
324        // apply() should return an error for ShakeDrop
325        assert!(sd.apply(&params, &mut grads).is_err());
326
327        // penalty() should return zero
328        let penalty = sd.penalty(&params).unwrap();
329        assert_eq!(penalty, 0.0);
330    }
331}