Skip to main content

optirs_core/regularizers/
shakedrop.rs

1use scirs2_core::ndarray::{Array, ArrayBase, Data, Dimension, ScalarOperand};
2use scirs2_core::numeric::{Float, FromPrimitive};
3use scirs2_core::random::Rng;
4use std::cell::RefCell;
5use std::fmt::Debug;
6
7use crate::error::{OptimError, Result};
8use crate::regularizers::Regularizer;
9
10/// Result type for ShakeDrop forward pass: transformed activations and gate parameters (b, alpha, beta).
11type ShakeDropResult<A, D> = (Array<A, D>, (A, A, A));
12
13/// ShakeDrop regularization
14///
15/// ShakeDrop is a regularization method that extends Stochastic Depth and
16/// is often used in very deep neural networks. It randomly scales activations
17/// during training.
18///
19/// # Parameters
20///
21/// * `p` - The probability of activating ShakeDrop (probability of activating the forward
22///   pass transformation), value between 0 and 1.
23/// * `alpha_range` - The range for the alpha parameter used in forward pass (default: [-1.0, 1.0]).
24/// * `beta_range` - The range for the beta parameter used in backward pass (default: [0.0, 1.0]).
25///
26/// # References
27///
28/// * Yamada, Y., Iwamura, M., & Kise, K. (2018). ShakeDrop regularization.
29///   arXiv preprint arXiv:1802.02375.
30///
31#[derive(Debug)]
32pub struct ShakeDrop<A: Float + FromPrimitive + Debug> {
33    /// Probability of applying ShakeDrop
34    pub p: A,
35    /// Range for the alpha parameter
36    pub alpha_range: (A, A),
37    /// Range for the beta parameter
38    pub beta_range: (A, A),
39    /// Random number generator (wrapped in RefCell for interior mutability)
40    rng: RefCell<scirs2_core::random::Random<scirs2_core::random::rngs::StdRng>>,
41}
42
43impl<A: Float + FromPrimitive + Debug + Send + Sync> ShakeDrop<A> {
44    /// Create a new ShakeDrop regularizer
45    ///
46    /// # Arguments
47    ///
48    /// * `p` - Probability of applying ShakeDrop, between 0 and 1
49    /// * `alpha_range` - Range for the alpha parameter (default: [-1.0, 1.0])
50    /// * `beta_range` - Range for the beta parameter (default: [0.0, 1.0])
51    ///
52    /// # Returns
53    ///
54    /// A ShakeDrop regularizer
55    pub fn new(p: A) -> Self {
56        let zero = A::zero();
57        let one = A::one();
58        let neg_one = zero - one;
59
60        Self {
61            p,
62            alpha_range: (neg_one, one),
63            beta_range: (zero, one),
64            rng: RefCell::new(scirs2_core::random::Random::seed(42)),
65        }
66    }
67
68    /// Create a new ShakeDrop regularizer with custom ranges
69    ///
70    /// # Arguments
71    ///
72    /// * `p` - Probability of applying ShakeDrop, between 0 and 1
73    /// * `alpha_range` - Range for the alpha parameter
74    /// * `beta_range` - Range for the beta parameter
75    ///
76    /// # Returns
77    ///
78    /// A ShakeDrop regularizer
79    pub fn new_with_ranges(p: A, alpharange: (A, A), beta_range: (A, A)) -> Self {
80        Self {
81            p,
82            alpha_range: alpharange,
83            beta_range,
84            rng: RefCell::new(scirs2_core::random::Random::seed(42)),
85        }
86    }
87
88    /// Get a random value between the given range
89    fn random_in_range(&self, range: (A, A)) -> Result<A> {
90        let (min, max) = range;
91        let min_f = min
92            .to_f64()
93            .ok_or_else(|| OptimError::InvalidConfig("Failed to convert min to f64".to_string()))?;
94        let max_f = max
95            .to_f64()
96            .ok_or_else(|| OptimError::InvalidConfig("Failed to convert max to f64".to_string()))?;
97
98        // Handle equal min and max to avoid "empty range" error
99        if (max_f - min_f).abs() < 1e-10 {
100            return Ok(min);
101        }
102
103        let random_val = self.rng.borrow_mut().gen_range(min_f..max_f);
104        A::from_f64(random_val).ok_or_else(|| {
105            OptimError::InvalidConfig("Failed to convert random value from f64".to_string())
106        })
107    }
108
109    /// Get a forward pass gate for the ShakeDrop
110    ///
111    /// # Returns
112    ///
113    /// A tuple (b, alpha, beta):
114    /// - b: Binary gate (1 or 0) based on the probability p
115    /// - alpha: Random value within alpha_range if b is 1, otherwise 0
116    /// - beta: Random value within beta_range
117    fn get_gate(&self) -> Result<(A, A, A)> {
118        let zero = A::zero();
119        let one = A::one();
120
121        // Determine if the gate is active
122        let u: f64 = self.rng.borrow_mut().gen_range(0.0..1.0);
123        let p_f64 = self
124            .p
125            .to_f64()
126            .ok_or_else(|| OptimError::InvalidConfig("Failed to convert p to f64".to_string()))?;
127        let b = if u < p_f64 { one } else { zero };
128
129        // Get random alpha if gate is active..otherwise 0
130        let alpha = if b > zero {
131            self.random_in_range(self.alpha_range)?
132        } else {
133            zero
134        };
135
136        // Get random beta regardless of gate
137        let beta = self.random_in_range(self.beta_range)?;
138
139        Ok((b, alpha, beta))
140    }
141
142    /// Apply ShakeDrop to input activations
143    ///
144    /// # Arguments
145    ///
146    /// * `x` - Input activation tensor
147    ///
148    /// # Returns
149    ///
150    /// The transformed activations and gate parameters for use in backward pass
151    pub fn forward<S, D>(&self, x: &ArrayBase<S, D>) -> Result<ShakeDropResult<A, D>>
152    where
153        S: Data<Elem = A>,
154        D: Dimension,
155    {
156        // Get the gate values
157        let (b, alpha, beta) = self.get_gate()?;
158
159        // Apply ShakeDrop transformation
160        // During forward pass: x' = x * (b + alpha - b*alpha)
161        let factor = b + alpha - b * alpha;
162        let result = x.mapv(|v| v * factor);
163
164        Ok((result, (b, alpha, beta)))
165    }
166
167    /// Backward pass for ShakeDrop
168    ///
169    /// # Arguments
170    ///
171    /// * `grad_output` - Gradient from the next layer
172    /// * `gate_params` - The gate parameters (b, alpha, beta) from the forward pass
173    ///
174    /// # Returns
175    ///
176    /// The modified gradients
177    pub fn backward<S, D>(
178        &self,
179        grad_output: &ArrayBase<S, D>,
180        gate_params: (A, A, A),
181    ) -> Array<A, D>
182    where
183        S: Data<Elem = A>,
184        D: Dimension,
185    {
186        let (b, alpha, beta) = gate_params;
187
188        // During backward pass: grad_x = grad_output * (b + beta - b*beta)
189        let factor = b + beta - b * beta;
190        grad_output.mapv(|g| g * factor)
191    }
192}
193
194impl<A: Float + FromPrimitive + Debug + ScalarOperand, D: Dimension + Send + Sync> Regularizer<A, D>
195    for ShakeDrop<A>
196{
197    fn apply(&self, _params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
198        // ShakeDrop is typically applied to activations, not parameters
199        // In this implementation, apply() isn't the primary usage pattern
200        // Instead, users would call forward() during the forward pass
201        // and backward() during the backward pass
202        Err(OptimError::InvalidConfig(
203            "ShakeDrop should be applied to activations during forward/backward passes, \
204             not through the Regularizer trait's apply method"
205                .to_string(),
206        ))
207    }
208
209    fn penalty(&self, params: &Array<A, D>) -> Result<A> {
210        // ShakeDrop doesn't add a penalty term to the loss function
211        Ok(A::zero())
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218    use approx::assert_abs_diff_eq;
219    use scirs2_core::ndarray::{Array1, Array2};
220
221    #[test]
222    fn test_shakedrop_new() {
223        let sd = ShakeDrop::new(0.5f64);
224        assert_eq!(sd.p, 0.5);
225        assert_eq!(sd.alpha_range, (-1.0, 1.0));
226        assert_eq!(sd.beta_range, (0.0, 1.0));
227    }
228
229    #[test]
230    fn test_shakedrop_new_with_ranges() {
231        let sd = ShakeDrop::new_with_ranges(0.7f64, (-0.5, 0.5), (0.2, 0.8));
232        assert_eq!(sd.p, 0.7);
233        assert_eq!(sd.alpha_range, (-0.5, 0.5));
234        assert_eq!(sd.beta_range, (0.2, 0.8));
235    }
236
237    #[test]
238    fn test_shakedrop_forward_backward() {
239        // Create a simple 2D array
240        let x = Array2::from_elem((2, 3), 1.0f64);
241
242        // Initialize ShakeDrop with p=1.0 to ensure gate is always active
243        // Use slightly different values for min and max to avoid empty range error
244        let sd = ShakeDrop::new_with_ranges(1.0f64, (0.5, 0.500001), (0.5, 0.500001));
245
246        // Forward pass
247        let (output, gate_params) = sd.forward(&x).expect("forward failed");
248
249        // Verify the gate parameters
250        assert_eq!(gate_params.0, 1.0); // b should be 1 since p=1.0
251        assert_abs_diff_eq!(gate_params.1, 0.5, epsilon = 1e-5); // alpha should be approximately 0.5
252        assert_abs_diff_eq!(gate_params.2, 0.5, epsilon = 1e-5); // beta should be approximately 0.5
253
254        // The expected output is x * (b + alpha - b*alpha) = x * (1 + 0.5 - 1*0.5) = x * 1
255        for &val in output.iter() {
256            assert_abs_diff_eq!(val, 1.0, epsilon = 1e-5);
257        }
258
259        // Backward pass
260        let grad_output = Array2::from_elem((2, 3), 2.0f64);
261        let grad_input = sd.backward(&grad_output, gate_params);
262
263        // The expected gradient is grad_output * (b + beta - b*beta) = grad_output * (1 + 0.5 - 1*0.5) = grad_output * 1
264        for &val in grad_input.iter() {
265            assert_abs_diff_eq!(val, 2.0, epsilon = 1e-5);
266        }
267    }
268
269    #[test]
270    fn test_shakedrop_forward_inactive() {
271        // Create a simple 1D array
272        let x = Array1::from_vec(vec![1.0f64, 2.0, 3.0]);
273
274        // Initialize ShakeDrop with p=0.0 to ensure gate is always inactive
275        // Use slightly different values for min and max to avoid empty range error
276        let sd = ShakeDrop::new_with_ranges(0.0f64, (-0.5, -0.499999), (0.5, 0.500001));
277
278        // Forward pass - gate should be inactive
279        let (output, gate_params) = sd.forward(&x).expect("forward failed");
280
281        // Verify the gate parameters
282        assert_eq!(gate_params.0, 0.0); // b should be 0 since p=0.0
283        assert_eq!(gate_params.1, 0.0); // alpha should be 0 when gate is inactive
284        assert_abs_diff_eq!(gate_params.2, 0.5, epsilon = 1e-5); // beta should be approximately 0.5
285
286        // The expected output is x * (b + alpha - b*alpha) = x * (0 + 0 - 0*0) = x * 0
287        for &val in output.iter() {
288            assert_abs_diff_eq!(val, 0.0, epsilon = 1e-10);
289        }
290    }
291
292    #[test]
293    fn test_shakedrop_gen_range() {
294        let sd = ShakeDrop::new(0.5f64);
295
296        // Test random value generation within range
297        for _ in 0..100 {
298            let value = sd
299                .random_in_range((-0.5, 0.5))
300                .expect("random_in_range failed");
301            assert!((-0.5..=0.5).contains(&value));
302        }
303
304        // Test with very small range (should not panic)
305        let value = sd
306            .random_in_range((0.5, 0.5))
307            .expect("random_in_range failed");
308        assert_eq!(value, 0.5);
309    }
310
311    #[test]
312    fn test_shakedrop_get_gate() {
313        // Test with p=1.0 - gate should always be active
314        let sd = ShakeDrop::new(1.0f64);
315        for _ in 0..10 {
316            let (b, alpha, beta) = sd.get_gate().expect("get_gate failed");
317            assert_eq!(b, 1.0);
318            assert!((-1.0..=1.0).contains(&alpha));
319            assert!((0.0..=1.0).contains(&beta));
320        }
321
322        // Test with p=0.0 - gate should always be inactive
323        let sd = ShakeDrop::new(0.0f64);
324        for _ in 0..10 {
325            let (b, alpha, beta) = sd.get_gate().expect("get_gate failed");
326            assert_eq!(b, 0.0);
327            assert_eq!(alpha, 0.0);
328            assert!((0.0..=1.0).contains(&beta));
329        }
330    }
331
332    #[test]
333    fn test_regularizer_trait() {
334        let sd = ShakeDrop::new(0.5f64);
335        let params = Array2::from_elem((2, 3), 1.0f64);
336        let mut grads = Array2::from_elem((2, 3), 1.0f64);
337
338        // apply() should return an error for ShakeDrop
339        assert!(sd.apply(&params, &mut grads).is_err());
340
341        // penalty() should return zero
342        let penalty = sd.penalty(&params).expect("penalty failed");
343        assert_eq!(penalty, 0.0);
344    }
345}