optirs_core/regularizers/shakedrop.rs
1use scirs2_core::ndarray::{Array, ArrayBase, Data, Dimension, ScalarOperand};
2use scirs2_core::numeric::{Float, FromPrimitive};
3use scirs2_core::random::Rng;
4use std::fmt::Debug;
5
6use crate::error::{OptimError, Result};
7use crate::regularizers::Regularizer;
8
9/// ShakeDrop regularization
10///
11/// ShakeDrop is a regularization method that extends Stochastic Depth and
12/// is often used in very deep neural networks. It randomly scales activations
13/// during training.
14///
15/// # Parameters
16///
17/// * `p` - The probability of activating ShakeDrop (probability of activating the forward
18/// pass transformation), value between 0 and 1.
19/// * `alpha_range` - The range for the alpha parameter used in forward pass (default: [-1.0, 1.0]).
20/// * `beta_range` - The range for the beta parameter used in backward pass (default: [0.0, 1.0]).
21///
22/// # References
23///
24/// * Yamada, Y., Iwamura, M., & Kise, K. (2018). ShakeDrop regularization.
25/// arXiv preprint arXiv:1802.02375.
26///
27#[derive(Debug, Clone)]
28pub struct ShakeDrop<A: Float + FromPrimitive + Debug> {
29 /// Probability of applying ShakeDrop
30 pub p: A,
31 /// Range for the alpha parameter
32 pub alpha_range: (A, A),
33 /// Range for the beta parameter
34 pub beta_range: (A, A),
35 /// Random number generator
36 rng: scirs2_core::random::Random<scirs2_core::random::rngs::StdRng>,
37}
38
39impl<A: Float + FromPrimitive + Debug + Send + Sync> ShakeDrop<A> {
40 /// Create a new ShakeDrop regularizer
41 ///
42 /// # Arguments
43 ///
44 /// * `p` - Probability of applying ShakeDrop, between 0 and 1
45 /// * `alpha_range` - Range for the alpha parameter (default: [-1.0, 1.0])
46 /// * `beta_range` - Range for the beta parameter (default: [0.0, 1.0])
47 ///
48 /// # Returns
49 ///
50 /// A ShakeDrop regularizer
51 pub fn new(p: A) -> Self {
52 let zero = A::zero();
53 let one = A::one();
54 let neg_one = zero - one;
55
56 Self {
57 p,
58 alpha_range: (neg_one, one),
59 beta_range: (zero, one),
60 rng: scirs2_core::random::Random::seed(42),
61 }
62 }
63
64 /// Create a new ShakeDrop regularizer with custom ranges
65 ///
66 /// # Arguments
67 ///
68 /// * `p` - Probability of applying ShakeDrop, between 0 and 1
69 /// * `alpha_range` - Range for the alpha parameter
70 /// * `beta_range` - Range for the beta parameter
71 ///
72 /// # Returns
73 ///
74 /// A ShakeDrop regularizer
75 pub fn new_with_ranges(p: A, alpharange: (A, A), beta_range: (A, A)) -> Self {
76 Self {
77 p,
78 alpha_range: alpharange,
79 beta_range,
80 rng: scirs2_core::random::Random::seed(42),
81 }
82 }
83
84 /// Get a random value between the given range
85 fn random_in_range(&mut self, range: (A, A)) -> A {
86 let (min, max) = range;
87 let min_f = min.to_f64().unwrap();
88 let max_f = max.to_f64().unwrap();
89
90 // Handle equal min and max to avoid "empty range" error
91 if (max_f - min_f).abs() < 1e-10 {
92 return min;
93 }
94
95 let random_val = self.rng.gen_range(min_f..max_f);
96 A::from_f64(random_val).unwrap()
97 }
98
99 /// Get a forward pass gate for the ShakeDrop
100 ///
101 /// # Returns
102 ///
103 /// A tuple (b, alpha, beta):
104 /// - b: Binary gate (1 or 0) based on the probability p
105 /// - alpha: Random value within alpha_range if b is 1, otherwise 0
106 /// - beta: Random value within beta_range
107 fn get_gate(&mut self) -> (A, A, A) {
108 let zero = A::zero();
109 let one = A::one();
110
111 // Determine if the gate is active
112 let u: f64 = self.rng.gen_range(0.0..1.0);
113 let b = if u < self.p.to_f64().unwrap() {
114 one
115 } else {
116 zero
117 };
118
119 // Get random alpha if gate is active..otherwise 0
120 let alpha = if b > zero {
121 self.random_in_range(self.alpha_range)
122 } else {
123 zero
124 };
125
126 // Get random beta regardless of gate
127 let beta = self.random_in_range(self.beta_range);
128
129 (b, alpha, beta)
130 }
131
132 /// Apply ShakeDrop to input activations
133 ///
134 /// # Arguments
135 ///
136 /// * `x` - Input activation tensor
137 ///
138 /// # Returns
139 ///
140 /// The transformed activations and gate parameters for use in backward pass
141 pub fn forward<S, D>(&mut self, x: &ArrayBase<S, D>) -> (Array<A, D>, (A, A, A))
142 where
143 S: Data<Elem = A>,
144 D: Dimension,
145 {
146 // Get the gate values
147 let (b, alpha, beta) = self.get_gate();
148
149 // Apply ShakeDrop transformation
150 // During forward pass: x' = x * (b + alpha - b*alpha)
151 let factor = b + alpha - b * alpha;
152 let result = x.mapv(|v| v * factor);
153
154 (result, (b, alpha, beta))
155 }
156
157 /// Backward pass for ShakeDrop
158 ///
159 /// # Arguments
160 ///
161 /// * `grad_output` - Gradient from the next layer
162 /// * `gate_params` - The gate parameters (b, alpha, beta) from the forward pass
163 ///
164 /// # Returns
165 ///
166 /// The modified gradients
167 pub fn backward<S, D>(
168 &self,
169 grad_output: &ArrayBase<S, D>,
170 gate_params: (A, A, A),
171 ) -> Array<A, D>
172 where
173 S: Data<Elem = A>,
174 D: Dimension,
175 {
176 let (b, alpha, beta) = gate_params;
177
178 // During backward pass: grad_x = grad_output * (b + beta - b*beta)
179 let factor = b + beta - b * beta;
180 grad_output.mapv(|g| g * factor)
181 }
182}
183
184impl<A: Float + FromPrimitive + Debug + ScalarOperand, D: Dimension + Send + Sync> Regularizer<A, D>
185 for ShakeDrop<A>
186{
187 fn apply(&self, _params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
188 // ShakeDrop is typically applied to activations, not parameters
189 // In this implementation, apply() isn't the primary usage pattern
190 // Instead, users would call forward() during the forward pass
191 // and backward() during the backward pass
192 Err(OptimError::InvalidConfig(
193 "ShakeDrop should be applied to activations during forward/backward passes, \
194 not through the Regularizer trait's apply method"
195 .to_string(),
196 ))
197 }
198
199 fn penalty(&self, params: &Array<A, D>) -> Result<A> {
200 // ShakeDrop doesn't add a penalty term to the loss function
201 Ok(A::zero())
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use approx::assert_abs_diff_eq;
209 use scirs2_core::ndarray::{Array1, Array2};
210
211 #[test]
212 fn test_shakedrop_new() {
213 let sd = ShakeDrop::new(0.5f64);
214 assert_eq!(sd.p, 0.5);
215 assert_eq!(sd.alpha_range, (-1.0, 1.0));
216 assert_eq!(sd.beta_range, (0.0, 1.0));
217 }
218
219 #[test]
220 fn test_shakedrop_new_with_ranges() {
221 let sd = ShakeDrop::new_with_ranges(0.7f64, (-0.5, 0.5), (0.2, 0.8));
222 assert_eq!(sd.p, 0.7);
223 assert_eq!(sd.alpha_range, (-0.5, 0.5));
224 assert_eq!(sd.beta_range, (0.2, 0.8));
225 }
226
227 #[test]
228 fn test_shakedrop_forward_backward() {
229 // Create a simple 2D array
230 let x = Array2::from_elem((2, 3), 1.0f64);
231
232 // Initialize ShakeDrop with p=1.0 to ensure gate is always active
233 // Use slightly different values for min and max to avoid empty range error
234 let mut sd = ShakeDrop::new_with_ranges(1.0f64, (0.5, 0.500001), (0.5, 0.500001));
235
236 // Forward pass
237 let (output, gate_params) = sd.forward(&x);
238
239 // Verify the gate parameters
240 assert_eq!(gate_params.0, 1.0); // b should be 1 since p=1.0
241 assert_abs_diff_eq!(gate_params.1, 0.5, epsilon = 1e-5); // alpha should be approximately 0.5
242 assert_abs_diff_eq!(gate_params.2, 0.5, epsilon = 1e-5); // beta should be approximately 0.5
243
244 // The expected output is x * (b + alpha - b*alpha) = x * (1 + 0.5 - 1*0.5) = x * 1
245 for &val in output.iter() {
246 assert_abs_diff_eq!(val, 1.0, epsilon = 1e-5);
247 }
248
249 // Backward pass
250 let grad_output = Array2::from_elem((2, 3), 2.0f64);
251 let grad_input = sd.backward(&grad_output, gate_params);
252
253 // The expected gradient is grad_output * (b + beta - b*beta) = grad_output * (1 + 0.5 - 1*0.5) = grad_output * 1
254 for &val in grad_input.iter() {
255 assert_abs_diff_eq!(val, 2.0, epsilon = 1e-5);
256 }
257 }
258
259 #[test]
260 fn test_shakedrop_forward_inactive() {
261 // Create a simple 1D array
262 let x = Array1::from_vec(vec![1.0f64, 2.0, 3.0]);
263
264 // Initialize ShakeDrop with p=0.0 to ensure gate is always inactive
265 // Use slightly different values for min and max to avoid empty range error
266 let mut sd = ShakeDrop::new_with_ranges(0.0f64, (-0.5, -0.499999), (0.5, 0.500001));
267
268 // Forward pass - gate should be inactive
269 let (output, gate_params) = sd.forward(&x);
270
271 // Verify the gate parameters
272 assert_eq!(gate_params.0, 0.0); // b should be 0 since p=0.0
273 assert_eq!(gate_params.1, 0.0); // alpha should be 0 when gate is inactive
274 assert_abs_diff_eq!(gate_params.2, 0.5, epsilon = 1e-5); // beta should be approximately 0.5
275
276 // The expected output is x * (b + alpha - b*alpha) = x * (0 + 0 - 0*0) = x * 0
277 for &val in output.iter() {
278 assert_abs_diff_eq!(val, 0.0, epsilon = 1e-10);
279 }
280 }
281
282 #[test]
283 fn test_shakedrop_gen_range() {
284 let mut sd = ShakeDrop::new(0.5f64);
285
286 // Test random value generation within range
287 for _ in 0..100 {
288 let value = sd.random_in_range((-0.5, 0.5));
289 assert!((-0.5..=0.5).contains(&value));
290 }
291
292 // Test with very small range (should not panic)
293 let value = sd.random_in_range((0.5, 0.5));
294 assert_eq!(value, 0.5);
295 }
296
297 #[test]
298 fn test_shakedrop_get_gate() {
299 // Test with p=1.0 - gate should always be active
300 let mut sd = ShakeDrop::new(1.0f64);
301 for _ in 0..10 {
302 let (b, alpha, beta) = sd.get_gate();
303 assert_eq!(b, 1.0);
304 assert!((-1.0..=1.0).contains(&alpha));
305 assert!((0.0..=1.0).contains(&beta));
306 }
307
308 // Test with p=0.0 - gate should always be inactive
309 let mut sd = ShakeDrop::new(0.0f64);
310 for _ in 0..10 {
311 let (b, alpha, beta) = sd.get_gate();
312 assert_eq!(b, 0.0);
313 assert_eq!(alpha, 0.0);
314 assert!((0.0..=1.0).contains(&beta));
315 }
316 }
317
318 #[test]
319 fn test_regularizer_trait() {
320 let sd = ShakeDrop::new(0.5f64);
321 let params = Array2::from_elem((2, 3), 1.0f64);
322 let mut grads = Array2::from_elem((2, 3), 1.0f64);
323
324 // apply() should return an error for ShakeDrop
325 assert!(sd.apply(¶ms, &mut grads).is_err());
326
327 // penalty() should return zero
328 let penalty = sd.penalty(¶ms).unwrap();
329 assert_eq!(penalty, 0.0);
330 }
331}