optirs_core/regularizers/
spatial_dropout.rs

1// Spatial and Feature Dropout regularization
2//
3// This module provides specialized dropout variants that preserve spatial or feature connectivity:
4// - Spatial Dropout: drops entire feature maps (useful for CNNs)
5// - Feature Dropout: drops specific features across all spatial locations
6
7use scirs2_core::ndarray::{Array, Axis, Dimension, Ix3, ScalarOperand};
8use scirs2_core::numeric::Float;
9use scirs2_core::random::{thread_rng, Rng};
10use std::fmt::Debug;
11
12use crate::error::{OptimError, Result};
13use crate::regularizers::Regularizer;
14
15/// Spatial Dropout regularizer
16///
17/// Drops entire feature maps instead of individual units. This helps
18/// preserve spatial structure in convolutional neural networks.
19///
20/// # Example
21///
22/// ```
23/// use scirs2_core::ndarray::Array4;
24/// use optirs_core::regularizers::SpatialDropout;
25///
26/// let spatial_dropout = SpatialDropout::new(0.3).unwrap(); // 30% dropout rate
27///
28/// // 4D tensor (batch, channels, height, width)
29/// let features = Array4::<f64>::ones((2, 3, 4, 4));
30///
31/// // During training - drops entire channels
32/// let masked_features = spatial_dropout.apply(&features, true);
33/// ```
34#[derive(Debug, Clone)]
35pub struct SpatialDropout<A: Float> {
36    /// Probability of dropping a channel/feature map
37    dropprob: A,
38    /// Dimension along which to drop (default is 1 for channels)
39    feature_dim: Axis,
40}
41
42impl<A: Float + Debug + ScalarOperand + Send + Sync> SpatialDropout<A> {
43    /// Create a new SpatialDropout regularizer
44    ///
45    /// # Arguments
46    ///
47    /// * `dropprob` - Probability of dropping each feature map (0.0 to 1.0)
48    pub fn new(dropprob: A) -> Result<Self> {
49        if dropprob < A::zero() || dropprob > A::one() {
50            return Err(OptimError::InvalidConfig(
51                "Drop probability must be between 0.0 and 1.0".to_string(),
52            ));
53        }
54
55        Ok(Self {
56            dropprob,
57            feature_dim: Axis(1), // Default to channel dimension
58        })
59    }
60
61    /// Set the dimension along which to drop features
62    pub fn with_feature_dim(mut self, dim: usize) -> Self {
63        self.feature_dim = Axis(dim);
64        self
65    }
66
67    /// Apply spatial dropout to a tensor
68    pub fn apply<D>(&self, features: &Array<A, D>, training: bool) -> Array<A, D>
69    where
70        D: Dimension + scirs2_core::ndarray::RemoveAxis,
71    {
72        if !training || self.dropprob == A::zero() {
73            return features.clone();
74        }
75
76        let keep_prob = A::one() - self.dropprob;
77
78        // Get the size of the feature dimension
79        let feature_size = features.shape()[self.feature_dim.0];
80
81        // Create a mask for each feature map
82        let keep_prob_f64 = keep_prob.to_f64().unwrap();
83        let mut rng = thread_rng();
84        let feature_mask: Vec<bool> = (0..feature_size)
85            .map(|_| rng.random_bool(keep_prob_f64))
86            .collect();
87
88        // Apply mask to each feature map
89        let mut result = features.clone();
90        for (idx, &keep) in feature_mask.iter().enumerate() {
91            if !keep {
92                // Drop the entire feature map
93                let mut axis_slice = result.index_axis_mut(self.feature_dim, idx);
94                axis_slice.fill(A::zero());
95            } else {
96                // Scale kept features
97                let mut axis_slice = result.index_axis_mut(self.feature_dim, idx);
98                axis_slice.mapv_inplace(|x| x / keep_prob);
99            }
100        }
101
102        result
103    }
104}
105
106/// Feature Dropout regularizer
107///
108/// Drops specific features across all spatial locations. This is useful when
109/// you want to maintain spatial consistency while dropping features.
110///
111/// # Example
112///
113/// ```
114/// use scirs2_core::ndarray::Array3;
115/// use optirs_core::regularizers::FeatureDropout;
116///
117/// let feature_dropout = FeatureDropout::new(0.5).unwrap(); // 50% dropout rate
118///
119/// // 3D tensor (batch, features, sequence_length)
120/// let features = Array3::<f64>::ones((2, 10, 20));
121///
122/// // During training - drops specific features across all positions
123/// let masked_features = feature_dropout.apply(&features, true);
124/// ```
125#[derive(Debug, Clone)]
126pub struct FeatureDropout<A: Float> {
127    /// Probability of dropping each feature
128    dropprob: A,
129    /// Dimension along which features are located (default is 1)
130    feature_dim: Axis,
131}
132
133impl<A: Float + Debug + ScalarOperand + Send + Sync> FeatureDropout<A> {
134    /// Create a new FeatureDropout regularizer
135    ///
136    /// # Arguments
137    ///
138    /// * `dropprob` - Probability of dropping each feature (0.0 to 1.0)
139    pub fn new(dropprob: A) -> Result<Self> {
140        if dropprob < A::zero() || dropprob > A::one() {
141            return Err(OptimError::InvalidConfig(
142                "Drop probability must be between 0.0 and 1.0".to_string(),
143            ));
144        }
145
146        Ok(Self {
147            dropprob,
148            feature_dim: Axis(1), // Default to feature dimension
149        })
150    }
151
152    /// Set the dimension along which features are located
153    pub fn with_feature_dim(mut self, dim: usize) -> Self {
154        self.feature_dim = Axis(dim);
155        self
156    }
157
158    /// Apply feature dropout to a tensor
159    pub fn apply<D>(&self, features: &Array<A, D>, training: bool) -> Array<A, D>
160    where
161        D: Dimension + scirs2_core::ndarray::RemoveAxis,
162    {
163        if !training || self.dropprob == A::zero() {
164            return features.clone();
165        }
166
167        let keep_prob = A::one() - self.dropprob;
168
169        // Get the size of the feature dimension
170        let feature_size = features.shape()[self.feature_dim.0];
171
172        // Create a consistent mask for each feature
173        let keep_prob_f64 = keep_prob.to_f64().unwrap();
174        let mut rng = thread_rng();
175        let feature_mask: Vec<bool> = (0..feature_size)
176            .map(|_| rng.random_bool(keep_prob_f64))
177            .collect();
178
179        // Apply the same mask across all spatial/temporal locations
180        let mut result = features.clone();
181        for (idx, &keep) in feature_mask.iter().enumerate() {
182            if !keep {
183                // Drop this feature everywhere
184                let mut axis_slice = result.index_axis_mut(self.feature_dim, idx);
185                axis_slice.fill(A::zero());
186            } else {
187                // Scale kept features
188                let mut axis_slice = result.index_axis_mut(self.feature_dim, idx);
189                axis_slice.mapv_inplace(|x| x / keep_prob);
190            }
191        }
192
193        result
194    }
195}
196
197// Implement Regularizer trait for SpatialDropout - only for dimensions that support RemoveAxis
198impl<
199        A: Float + Debug + ScalarOperand + Send + Sync,
200        D: Dimension + scirs2_core::ndarray::RemoveAxis + Send + Sync,
201    > Regularizer<A, D> for SpatialDropout<A>
202{
203    fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
204        // Apply spatial dropout to gradients during training
205        let masked_gradients = SpatialDropout::apply(self, gradients, true);
206        *gradients = masked_gradients;
207        Ok(A::zero())
208    }
209
210    fn penalty(&self, params: &Array<A, D>) -> Result<A> {
211        // Spatial dropout doesn't add a penalty term
212        Ok(A::zero())
213    }
214}
215
216// Implement Regularizer trait for FeatureDropout - only for dimensions that support RemoveAxis
217impl<
218        A: Float + Debug + ScalarOperand + Send + Sync,
219        D: Dimension + scirs2_core::ndarray::RemoveAxis + Send + Sync,
220    > Regularizer<A, D> for FeatureDropout<A>
221{
222    fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
223        // Apply feature dropout to gradients during training
224        let masked_gradients = FeatureDropout::apply(self, gradients, true);
225        *gradients = masked_gradients;
226        Ok(A::zero())
227    }
228
229    fn penalty(&self, params: &Array<A, D>) -> Result<A> {
230        // Feature dropout doesn't add a penalty term
231        Ok(A::zero())
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238    use approx::assert_relative_eq;
239    use scirs2_core::ndarray::array;
240
241    #[test]
242    fn test_spatial_dropout_creation() {
243        // Valid creation
244        let sd = SpatialDropout::<f64>::new(0.3).unwrap();
245        assert_eq!(sd.dropprob, 0.3);
246
247        // Invalid probabilities
248        assert!(SpatialDropout::<f64>::new(-0.1).is_err());
249        assert!(SpatialDropout::<f64>::new(1.1).is_err());
250    }
251
252    #[test]
253    fn test_spatial_dropout_4d() {
254        let sd = SpatialDropout::new(0.5).unwrap();
255
256        // Create a 4D tensor (batch, channels, height, width)
257        // Use values that are always non-zero to better test dropout
258        let features = Array::from_shape_fn((2, 4, 3, 3), |(b, c, h, w)| {
259            1.0 + b as f64 + c as f64 * 10.0 + h as f64 * 0.1 + w as f64 * 0.01
260        });
261
262        // Apply spatial dropout
263        let masked = sd.apply(&features, true);
264
265        // Check that entire channels are either kept or dropped
266        for b in 0..2 {
267            for c in 0..4 {
268                let masked_batch = masked.index_axis(Axis(0), b);
269                let channel = masked_batch.index_axis(Axis(0), c);
270                let channel_clone = channel.to_owned();
271                let is_dropped = channel_clone.iter().all(|&x| x.abs() < 1e-10);
272                let is_kept = channel_clone.iter().all(|&x| x.abs() > 1e-10);
273
274                // For dropped channels, all values should be 0
275                // For kept channels, all values should be scaled by 1/keep_prob
276                if is_dropped {
277                    for &val in channel_clone.iter() {
278                        assert_eq!(val, 0.0);
279                    }
280                } else if is_kept {
281                    // Check scaling
282                    let original_batch = features.index_axis(Axis(0), b);
283                    let original_channel = original_batch.index_axis(Axis(0), c);
284                    for ((i, j), &val) in channel_clone.indexed_iter() {
285                        assert_relative_eq!(val, original_channel[[i, j]] * 2.0, epsilon = 1e-10);
286                    }
287                } else {
288                    // Mixed values - this shouldn't happen
289                    println!("Channel {c} in batch {b} has mixed values:");
290                    for val in channel_clone.iter() {
291                        println!("  Value: {val}");
292                    }
293                    panic!("Channel should be entirely dropped or kept");
294                }
295            }
296        }
297    }
298
299    #[test]
300    fn test_feature_dropout_creation() {
301        // Valid creation
302        let fd = FeatureDropout::<f64>::new(0.4).unwrap();
303        assert_eq!(fd.dropprob, 0.4);
304
305        // Invalid probabilities
306        assert!(FeatureDropout::<f64>::new(-0.1).is_err());
307        assert!(FeatureDropout::<f64>::new(1.1).is_err());
308    }
309
310    #[test]
311    fn test_feature_dropout_3d() {
312        let fd = FeatureDropout::new(0.5).unwrap();
313
314        // Create a 3D tensor (batch, features, sequence)
315        let features = Array::from_shape_fn((2, 5, 10), |(_b, f, s)| f as f64 + s as f64);
316
317        // Apply feature dropout
318        let masked = fd.apply(&features, true);
319
320        // Check that features are consistently dropped across all positions
321        for f in 0..5 {
322            let first_batch = masked.index_axis(Axis(0), 0);
323            let first_batch_feature = first_batch.index_axis(Axis(0), f);
324            let first_batch_clone = first_batch_feature.to_owned();
325            let is_dropped = first_batch_clone.iter().all(|&x| x == 0.0);
326
327            // Check consistency across batches and positions
328            for b in 0..2 {
329                let batch = masked.index_axis(Axis(0), b);
330                let feature_slice = batch.index_axis(Axis(0), f);
331                let feature_clone = feature_slice.to_owned();
332                let all_dropped = feature_clone.iter().all(|&x| x == 0.0);
333                assert_eq!(
334                    is_dropped, all_dropped,
335                    "Feature dropout should be consistent"
336                );
337
338                if !all_dropped {
339                    // Check scaling
340                    let original_batch = features.index_axis(Axis(0), b);
341                    let original_slice = original_batch.index_axis(Axis(0), f);
342                    for (i, &val) in feature_clone.iter().enumerate() {
343                        assert_relative_eq!(val, original_slice[i] * 2.0, epsilon = 1e-10);
344                    }
345                }
346            }
347        }
348    }
349
350    #[test]
351    fn test_inference_mode() {
352        let sd = SpatialDropout::new(0.5).unwrap();
353        let fd = FeatureDropout::new(0.5).unwrap();
354
355        let features = array![[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]];
356
357        // During inference, features should remain unchanged
358        let sd_inference = sd.apply(&features, false);
359        let fd_inference = fd.apply(&features, false);
360
361        assert_eq!(features, sd_inference);
362        assert_eq!(features, fd_inference);
363    }
364
365    #[test]
366    fn test_regularizer_trait() {
367        let sd = SpatialDropout::new(0.3).unwrap();
368        let params = array![[[1.0, 2.0], [3.0, 4.0]]];
369        let mut gradient = array![[[0.1, 0.2], [0.3, 0.4]]];
370
371        // Test Regularizer trait
372        let penalty = sd.penalty(&params).unwrap();
373        assert_eq!(penalty, 0.0);
374
375        let _penalty_apply = sd.apply(&params, true);
376        let penalty_reg =
377            <SpatialDropout<f64> as Regularizer<f64, Ix3>>::apply(&sd, &params, &mut gradient)
378                .unwrap();
379        assert_eq!(penalty_reg, 0.0);
380
381        // Gradient should be modified
382        let is_modified = gradient != array![[[0.1, 0.2], [0.3, 0.4]]];
383        assert!(is_modified || gradient == array![[[0.1, 0.2], [0.3, 0.4]]]);
384    }
385}