optirs_core/regularizers/
spatial_dropout.rs1use scirs2_core::ndarray::{Array, Axis, Dimension, Ix3, ScalarOperand};
8use scirs2_core::numeric::Float;
9use scirs2_core::random::{thread_rng, Rng};
10use std::fmt::Debug;
11
12use crate::error::{OptimError, Result};
13use crate::regularizers::Regularizer;
14
15#[derive(Debug, Clone)]
35pub struct SpatialDropout<A: Float> {
36 dropprob: A,
38 feature_dim: Axis,
40}
41
42impl<A: Float + Debug + ScalarOperand + Send + Sync> SpatialDropout<A> {
43 pub fn new(dropprob: A) -> Result<Self> {
49 if dropprob < A::zero() || dropprob > A::one() {
50 return Err(OptimError::InvalidConfig(
51 "Drop probability must be between 0.0 and 1.0".to_string(),
52 ));
53 }
54
55 Ok(Self {
56 dropprob,
57 feature_dim: Axis(1), })
59 }
60
61 pub fn with_feature_dim(mut self, dim: usize) -> Self {
63 self.feature_dim = Axis(dim);
64 self
65 }
66
67 pub fn apply<D>(&self, features: &Array<A, D>, training: bool) -> Array<A, D>
69 where
70 D: Dimension + scirs2_core::ndarray::RemoveAxis,
71 {
72 if !training || self.dropprob == A::zero() {
73 return features.clone();
74 }
75
76 let keep_prob = A::one() - self.dropprob;
77
78 let feature_size = features.shape()[self.feature_dim.0];
80
81 let keep_prob_f64 = keep_prob.to_f64().unwrap();
83 let mut rng = thread_rng();
84 let feature_mask: Vec<bool> = (0..feature_size)
85 .map(|_| rng.random_bool(keep_prob_f64))
86 .collect();
87
88 let mut result = features.clone();
90 for (idx, &keep) in feature_mask.iter().enumerate() {
91 if !keep {
92 let mut axis_slice = result.index_axis_mut(self.feature_dim, idx);
94 axis_slice.fill(A::zero());
95 } else {
96 let mut axis_slice = result.index_axis_mut(self.feature_dim, idx);
98 axis_slice.mapv_inplace(|x| x / keep_prob);
99 }
100 }
101
102 result
103 }
104}
105
106#[derive(Debug, Clone)]
126pub struct FeatureDropout<A: Float> {
127 dropprob: A,
129 feature_dim: Axis,
131}
132
133impl<A: Float + Debug + ScalarOperand + Send + Sync> FeatureDropout<A> {
134 pub fn new(dropprob: A) -> Result<Self> {
140 if dropprob < A::zero() || dropprob > A::one() {
141 return Err(OptimError::InvalidConfig(
142 "Drop probability must be between 0.0 and 1.0".to_string(),
143 ));
144 }
145
146 Ok(Self {
147 dropprob,
148 feature_dim: Axis(1), })
150 }
151
152 pub fn with_feature_dim(mut self, dim: usize) -> Self {
154 self.feature_dim = Axis(dim);
155 self
156 }
157
158 pub fn apply<D>(&self, features: &Array<A, D>, training: bool) -> Array<A, D>
160 where
161 D: Dimension + scirs2_core::ndarray::RemoveAxis,
162 {
163 if !training || self.dropprob == A::zero() {
164 return features.clone();
165 }
166
167 let keep_prob = A::one() - self.dropprob;
168
169 let feature_size = features.shape()[self.feature_dim.0];
171
172 let keep_prob_f64 = keep_prob.to_f64().unwrap();
174 let mut rng = thread_rng();
175 let feature_mask: Vec<bool> = (0..feature_size)
176 .map(|_| rng.random_bool(keep_prob_f64))
177 .collect();
178
179 let mut result = features.clone();
181 for (idx, &keep) in feature_mask.iter().enumerate() {
182 if !keep {
183 let mut axis_slice = result.index_axis_mut(self.feature_dim, idx);
185 axis_slice.fill(A::zero());
186 } else {
187 let mut axis_slice = result.index_axis_mut(self.feature_dim, idx);
189 axis_slice.mapv_inplace(|x| x / keep_prob);
190 }
191 }
192
193 result
194 }
195}
196
197impl<
199 A: Float + Debug + ScalarOperand + Send + Sync,
200 D: Dimension + scirs2_core::ndarray::RemoveAxis + Send + Sync,
201 > Regularizer<A, D> for SpatialDropout<A>
202{
203 fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
204 let masked_gradients = SpatialDropout::apply(self, gradients, true);
206 *gradients = masked_gradients;
207 Ok(A::zero())
208 }
209
210 fn penalty(&self, params: &Array<A, D>) -> Result<A> {
211 Ok(A::zero())
213 }
214}
215
216impl<
218 A: Float + Debug + ScalarOperand + Send + Sync,
219 D: Dimension + scirs2_core::ndarray::RemoveAxis + Send + Sync,
220 > Regularizer<A, D> for FeatureDropout<A>
221{
222 fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
223 let masked_gradients = FeatureDropout::apply(self, gradients, true);
225 *gradients = masked_gradients;
226 Ok(A::zero())
227 }
228
229 fn penalty(&self, params: &Array<A, D>) -> Result<A> {
230 Ok(A::zero())
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238 use approx::assert_relative_eq;
239 use scirs2_core::ndarray::array;
240
241 #[test]
242 fn test_spatial_dropout_creation() {
243 let sd = SpatialDropout::<f64>::new(0.3).unwrap();
245 assert_eq!(sd.dropprob, 0.3);
246
247 assert!(SpatialDropout::<f64>::new(-0.1).is_err());
249 assert!(SpatialDropout::<f64>::new(1.1).is_err());
250 }
251
252 #[test]
253 fn test_spatial_dropout_4d() {
254 let sd = SpatialDropout::new(0.5).unwrap();
255
256 let features = Array::from_shape_fn((2, 4, 3, 3), |(b, c, h, w)| {
259 1.0 + b as f64 + c as f64 * 10.0 + h as f64 * 0.1 + w as f64 * 0.01
260 });
261
262 let masked = sd.apply(&features, true);
264
265 for b in 0..2 {
267 for c in 0..4 {
268 let masked_batch = masked.index_axis(Axis(0), b);
269 let channel = masked_batch.index_axis(Axis(0), c);
270 let channel_clone = channel.to_owned();
271 let is_dropped = channel_clone.iter().all(|&x| x.abs() < 1e-10);
272 let is_kept = channel_clone.iter().all(|&x| x.abs() > 1e-10);
273
274 if is_dropped {
277 for &val in channel_clone.iter() {
278 assert_eq!(val, 0.0);
279 }
280 } else if is_kept {
281 let original_batch = features.index_axis(Axis(0), b);
283 let original_channel = original_batch.index_axis(Axis(0), c);
284 for ((i, j), &val) in channel_clone.indexed_iter() {
285 assert_relative_eq!(val, original_channel[[i, j]] * 2.0, epsilon = 1e-10);
286 }
287 } else {
288 println!("Channel {c} in batch {b} has mixed values:");
290 for val in channel_clone.iter() {
291 println!(" Value: {val}");
292 }
293 panic!("Channel should be entirely dropped or kept");
294 }
295 }
296 }
297 }
298
299 #[test]
300 fn test_feature_dropout_creation() {
301 let fd = FeatureDropout::<f64>::new(0.4).unwrap();
303 assert_eq!(fd.dropprob, 0.4);
304
305 assert!(FeatureDropout::<f64>::new(-0.1).is_err());
307 assert!(FeatureDropout::<f64>::new(1.1).is_err());
308 }
309
310 #[test]
311 fn test_feature_dropout_3d() {
312 let fd = FeatureDropout::new(0.5).unwrap();
313
314 let features = Array::from_shape_fn((2, 5, 10), |(_b, f, s)| f as f64 + s as f64);
316
317 let masked = fd.apply(&features, true);
319
320 for f in 0..5 {
322 let first_batch = masked.index_axis(Axis(0), 0);
323 let first_batch_feature = first_batch.index_axis(Axis(0), f);
324 let first_batch_clone = first_batch_feature.to_owned();
325 let is_dropped = first_batch_clone.iter().all(|&x| x == 0.0);
326
327 for b in 0..2 {
329 let batch = masked.index_axis(Axis(0), b);
330 let feature_slice = batch.index_axis(Axis(0), f);
331 let feature_clone = feature_slice.to_owned();
332 let all_dropped = feature_clone.iter().all(|&x| x == 0.0);
333 assert_eq!(
334 is_dropped, all_dropped,
335 "Feature dropout should be consistent"
336 );
337
338 if !all_dropped {
339 let original_batch = features.index_axis(Axis(0), b);
341 let original_slice = original_batch.index_axis(Axis(0), f);
342 for (i, &val) in feature_clone.iter().enumerate() {
343 assert_relative_eq!(val, original_slice[i] * 2.0, epsilon = 1e-10);
344 }
345 }
346 }
347 }
348 }
349
350 #[test]
351 fn test_inference_mode() {
352 let sd = SpatialDropout::new(0.5).unwrap();
353 let fd = FeatureDropout::new(0.5).unwrap();
354
355 let features = array![[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]];
356
357 let sd_inference = sd.apply(&features, false);
359 let fd_inference = fd.apply(&features, false);
360
361 assert_eq!(features, sd_inference);
362 assert_eq!(features, fd_inference);
363 }
364
365 #[test]
366 fn test_regularizer_trait() {
367 let sd = SpatialDropout::new(0.3).unwrap();
368 let params = array![[[1.0, 2.0], [3.0, 4.0]]];
369 let mut gradient = array![[[0.1, 0.2], [0.3, 0.4]]];
370
371 let penalty = sd.penalty(¶ms).unwrap();
373 assert_eq!(penalty, 0.0);
374
375 let _penalty_apply = sd.apply(¶ms, true);
376 let penalty_reg =
377 <SpatialDropout<f64> as Regularizer<f64, Ix3>>::apply(&sd, ¶ms, &mut gradient)
378 .unwrap();
379 assert_eq!(penalty_reg, 0.0);
380
381 let is_modified = gradient != array![[[0.1, 0.2], [0.3, 0.4]]];
383 assert!(is_modified || gradient == array![[[0.1, 0.2], [0.3, 0.4]]]);
384 }
385}