optirs_core/regularizers/
shakedrop.rs1use scirs2_core::ndarray::{Array, ArrayBase, Data, Dimension, ScalarOperand};
2use scirs2_core::numeric::{Float, FromPrimitive};
3use scirs2_core::random::Rng;
4use std::cell::RefCell;
5use std::fmt::Debug;
6
7use crate::error::{OptimError, Result};
8use crate::regularizers::Regularizer;
9
10type ShakeDropResult<A, D> = (Array<A, D>, (A, A, A));
12
13#[derive(Debug)]
32pub struct ShakeDrop<A: Float + FromPrimitive + Debug> {
33 pub p: A,
35 pub alpha_range: (A, A),
37 pub beta_range: (A, A),
39 rng: RefCell<scirs2_core::random::Random<scirs2_core::random::rngs::StdRng>>,
41}
42
43impl<A: Float + FromPrimitive + Debug + Send + Sync> ShakeDrop<A> {
44 pub fn new(p: A) -> Self {
56 let zero = A::zero();
57 let one = A::one();
58 let neg_one = zero - one;
59
60 Self {
61 p,
62 alpha_range: (neg_one, one),
63 beta_range: (zero, one),
64 rng: RefCell::new(scirs2_core::random::Random::seed(42)),
65 }
66 }
67
68 pub fn new_with_ranges(p: A, alpharange: (A, A), beta_range: (A, A)) -> Self {
80 Self {
81 p,
82 alpha_range: alpharange,
83 beta_range,
84 rng: RefCell::new(scirs2_core::random::Random::seed(42)),
85 }
86 }
87
88 fn random_in_range(&self, range: (A, A)) -> Result<A> {
90 let (min, max) = range;
91 let min_f = min
92 .to_f64()
93 .ok_or_else(|| OptimError::InvalidConfig("Failed to convert min to f64".to_string()))?;
94 let max_f = max
95 .to_f64()
96 .ok_or_else(|| OptimError::InvalidConfig("Failed to convert max to f64".to_string()))?;
97
98 if (max_f - min_f).abs() < 1e-10 {
100 return Ok(min);
101 }
102
103 let random_val = self.rng.borrow_mut().gen_range(min_f..max_f);
104 A::from_f64(random_val).ok_or_else(|| {
105 OptimError::InvalidConfig("Failed to convert random value from f64".to_string())
106 })
107 }
108
109 fn get_gate(&self) -> Result<(A, A, A)> {
118 let zero = A::zero();
119 let one = A::one();
120
121 let u: f64 = self.rng.borrow_mut().gen_range(0.0..1.0);
123 let p_f64 = self
124 .p
125 .to_f64()
126 .ok_or_else(|| OptimError::InvalidConfig("Failed to convert p to f64".to_string()))?;
127 let b = if u < p_f64 { one } else { zero };
128
129 let alpha = if b > zero {
131 self.random_in_range(self.alpha_range)?
132 } else {
133 zero
134 };
135
136 let beta = self.random_in_range(self.beta_range)?;
138
139 Ok((b, alpha, beta))
140 }
141
142 pub fn forward<S, D>(&self, x: &ArrayBase<S, D>) -> Result<ShakeDropResult<A, D>>
152 where
153 S: Data<Elem = A>,
154 D: Dimension,
155 {
156 let (b, alpha, beta) = self.get_gate()?;
158
159 let factor = b + alpha - b * alpha;
162 let result = x.mapv(|v| v * factor);
163
164 Ok((result, (b, alpha, beta)))
165 }
166
167 pub fn backward<S, D>(
178 &self,
179 grad_output: &ArrayBase<S, D>,
180 gate_params: (A, A, A),
181 ) -> Array<A, D>
182 where
183 S: Data<Elem = A>,
184 D: Dimension,
185 {
186 let (b, alpha, beta) = gate_params;
187
188 let factor = b + beta - b * beta;
190 grad_output.mapv(|g| g * factor)
191 }
192}
193
194impl<A: Float + FromPrimitive + Debug + ScalarOperand, D: Dimension + Send + Sync> Regularizer<A, D>
195 for ShakeDrop<A>
196{
197 fn apply(&self, _params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
198 Err(OptimError::InvalidConfig(
203 "ShakeDrop should be applied to activations during forward/backward passes, \
204 not through the Regularizer trait's apply method"
205 .to_string(),
206 ))
207 }
208
209 fn penalty(&self, params: &Array<A, D>) -> Result<A> {
210 Ok(A::zero())
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218 use approx::assert_abs_diff_eq;
219 use scirs2_core::ndarray::{Array1, Array2};
220
221 #[test]
222 fn test_shakedrop_new() {
223 let sd = ShakeDrop::new(0.5f64);
224 assert_eq!(sd.p, 0.5);
225 assert_eq!(sd.alpha_range, (-1.0, 1.0));
226 assert_eq!(sd.beta_range, (0.0, 1.0));
227 }
228
229 #[test]
230 fn test_shakedrop_new_with_ranges() {
231 let sd = ShakeDrop::new_with_ranges(0.7f64, (-0.5, 0.5), (0.2, 0.8));
232 assert_eq!(sd.p, 0.7);
233 assert_eq!(sd.alpha_range, (-0.5, 0.5));
234 assert_eq!(sd.beta_range, (0.2, 0.8));
235 }
236
237 #[test]
238 fn test_shakedrop_forward_backward() {
239 let x = Array2::from_elem((2, 3), 1.0f64);
241
242 let sd = ShakeDrop::new_with_ranges(1.0f64, (0.5, 0.500001), (0.5, 0.500001));
245
246 let (output, gate_params) = sd.forward(&x).expect("forward failed");
248
249 assert_eq!(gate_params.0, 1.0); assert_abs_diff_eq!(gate_params.1, 0.5, epsilon = 1e-5); assert_abs_diff_eq!(gate_params.2, 0.5, epsilon = 1e-5); for &val in output.iter() {
256 assert_abs_diff_eq!(val, 1.0, epsilon = 1e-5);
257 }
258
259 let grad_output = Array2::from_elem((2, 3), 2.0f64);
261 let grad_input = sd.backward(&grad_output, gate_params);
262
263 for &val in grad_input.iter() {
265 assert_abs_diff_eq!(val, 2.0, epsilon = 1e-5);
266 }
267 }
268
269 #[test]
270 fn test_shakedrop_forward_inactive() {
271 let x = Array1::from_vec(vec![1.0f64, 2.0, 3.0]);
273
274 let sd = ShakeDrop::new_with_ranges(0.0f64, (-0.5, -0.499999), (0.5, 0.500001));
277
278 let (output, gate_params) = sd.forward(&x).expect("forward failed");
280
281 assert_eq!(gate_params.0, 0.0); assert_eq!(gate_params.1, 0.0); assert_abs_diff_eq!(gate_params.2, 0.5, epsilon = 1e-5); for &val in output.iter() {
288 assert_abs_diff_eq!(val, 0.0, epsilon = 1e-10);
289 }
290 }
291
292 #[test]
293 fn test_shakedrop_gen_range() {
294 let sd = ShakeDrop::new(0.5f64);
295
296 for _ in 0..100 {
298 let value = sd
299 .random_in_range((-0.5, 0.5))
300 .expect("random_in_range failed");
301 assert!((-0.5..=0.5).contains(&value));
302 }
303
304 let value = sd
306 .random_in_range((0.5, 0.5))
307 .expect("random_in_range failed");
308 assert_eq!(value, 0.5);
309 }
310
311 #[test]
312 fn test_shakedrop_get_gate() {
313 let sd = ShakeDrop::new(1.0f64);
315 for _ in 0..10 {
316 let (b, alpha, beta) = sd.get_gate().expect("get_gate failed");
317 assert_eq!(b, 1.0);
318 assert!((-1.0..=1.0).contains(&alpha));
319 assert!((0.0..=1.0).contains(&beta));
320 }
321
322 let sd = ShakeDrop::new(0.0f64);
324 for _ in 0..10 {
325 let (b, alpha, beta) = sd.get_gate().expect("get_gate failed");
326 assert_eq!(b, 0.0);
327 assert_eq!(alpha, 0.0);
328 assert!((0.0..=1.0).contains(&beta));
329 }
330 }
331
332 #[test]
333 fn test_regularizer_trait() {
334 let sd = ShakeDrop::new(0.5f64);
335 let params = Array2::from_elem((2, 3), 1.0f64);
336 let mut grads = Array2::from_elem((2, 3), 1.0f64);
337
338 assert!(sd.apply(¶ms, &mut grads).is_err());
340
341 let penalty = sd.penalty(¶ms).expect("penalty failed");
343 assert_eq!(penalty, 0.0);
344 }
345}