optirs_core/regularizers/
dropout.rs

1// Dropout regularization
2
3use scirs2_core::ndarray::{Array, Dimension, ScalarOperand, Zip};
4use scirs2_core::numeric::Float;
5use scirs2_core::random::Rng;
6use scirs2_core::Random;
7use std::fmt::Debug;
8
9use crate::error::Result;
10use crate::regularizers::Regularizer;
11
12/// Dropout regularization
13///
14/// Randomly sets a fraction of the input units to 0 at each update during training,
15/// which helps prevent overfitting. During inference, all units are used with appropriate
16/// scaling to maintain the same expected output.
17///
18/// # Examples
19///
20/// ```
21/// use scirs2_core::ndarray::Array1;
22/// use optirs_core::regularizers::Dropout;
23/// use scirs2_core::random::SeedableRng;
24/// use scirs2_core::random::rngs::SmallRng;
25///
26/// // Create a dropout regularizer with 0.5 dropout rate
27/// let seed = [0u8; 32];
28/// let mut rng = SmallRng::from_seed(seed);
29/// let mut dropout = Dropout::new(0.5f64, &mut rng);
30///
31/// // Set to training mode
32/// dropout.train();
33///
34/// // Check the dropout rate
35/// assert_eq!(dropout.rate(), 0.5);
36///
37/// // Set to evaluation mode
38/// dropout.eval();
39/// assert!(!dropout.is_training());
40/// ```
41#[derive(Debug, Clone)]
42pub struct Dropout<A: Float + Debug> {
43    /// Dropout rate (fraction of units that are dropped)
44    rate: A,
45    /// Random number generator
46    rng: Random<scirs2_core::random::rngs::StdRng>,
47    /// Boolean indicating whether in training mode
48    training: bool,
49    /// Cached dropout mask
50    mask: Option<Array<A, scirs2_core::ndarray::IxDyn>>,
51}
52
53impl<A: Float + Debug + Send + Sync> Dropout<A> {
54    /// Create a new dropout regularizer
55    ///
56    /// # Arguments
57    ///
58    /// * `rate` - Dropout rate (0.0 to 1.0, fraction of units that are dropped)
59    /// * `rng` - Random number generator
60    pub fn new<R: Rng>(rate: A, rng: &mut R) -> Self {
61        // Ensure _rate is between 0 and 1
62        let rate = rate.max(A::zero()).min(A::one());
63
64        // Create a new RNG from the provided one
65        let mut seed_bytes = [0u8; 8];
66        rng.fill_bytes(&mut seed_bytes);
67        let seed = u64::from_ne_bytes(seed_bytes);
68        let rng = Random::seed(seed);
69
70        Self {
71            rate,
72            rng,
73            training: true,
74            mask: None,
75        }
76    }
77
78    /// Get the dropout rate
79    pub fn rate(&self) -> A {
80        self.rate
81    }
82
83    /// Set the dropout rate
84    ///
85    /// # Arguments
86    ///
87    /// * `rate` - Dropout rate (0.0 to 1.0, fraction of units that are dropped)
88    pub fn set_rate(&mut self, rate: A) -> &mut Self {
89        // Ensure rate is between 0 and 1
90        self.rate = rate.max(A::zero()).min(A::one());
91        // Clear the mask cache
92        self.mask = None;
93        self
94    }
95
96    /// Set to training mode (apply dropout)
97    pub fn train(&mut self) -> &mut Self {
98        self.training = true;
99        self
100    }
101
102    /// Set to inference mode (no dropout, scale outputs)
103    pub fn eval(&mut self) -> &mut Self {
104        self.training = false;
105        self
106    }
107
108    /// Get the training mode
109    pub fn is_training(&self) -> bool {
110        self.training
111    }
112
113    /// Create a new dropout mask for the given shape
114    ///
115    /// During training, randomly sets units to 0 with probability `rate`,
116    /// and scales the remaining by 1/(1-rate) to maintain the same expected output.
117    fn create_mask<D: Dimension>(&mut self, shape: D) -> Array<A, D> {
118        if !self.training || self.rate <= A::zero() {
119            // In eval mode or with 0 dropout rate, no masking is applied
120            return Array::ones(shape);
121        }
122
123        // The scale factor for the kept units is 1/(1-rate)
124        // This maintains the expected sum of the layer outputs
125        let keep_prob = A::one() - self.rate;
126        let scale = A::one() / keep_prob;
127
128        // Create a mask where units are kept with probability (1-rate)
129        // and scaled by 1/(1-rate)
130        let mut mask = Array::zeros(shape);
131        for elem in mask.iter_mut() {
132            let rand_val = A::from(self.rng.gen_range(0.0..1.0)).unwrap();
133            if rand_val > self.rate {
134                *elem = scale;
135            }
136        }
137
138        mask
139    }
140}
141
142impl<A, D> Regularizer<A, D> for Dropout<A>
143where
144    A: Float + ScalarOperand + Debug + Send + Sync,
145    D: Dimension<Pattern = D>,
146{
147    fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
148        if !self.training || self.rate <= A::zero() {
149            // In eval mode or with 0 dropout rate, no dropout is applied
150            return Ok(A::zero());
151        }
152
153        // Create or get the dropout mask
154        let mask = match &self.mask {
155            Some(m) if m.shape() == gradients.shape() => {
156                // Use cached mask if shapes match
157                m.clone().into_dimensionality::<D>().unwrap()
158            }
159            _ => {
160                // Create a new mask
161                let mut dropout = self.clone();
162                // We would cache the mask here in a mutable context
163                dropout.create_mask(gradients.dim())
164            }
165        };
166
167        // Apply the mask to the gradients
168        Zip::from(gradients).and(&mask).for_each(|grad, &mask_val| {
169            *grad = *grad * mask_val;
170        });
171
172        // Dropout doesn't add a penalty term to the loss
173        Ok(A::zero())
174    }
175
176    fn penalty(&self, params: &Array<A, D>) -> Result<A> {
177        // Dropout doesn't add a penalty term to the loss
178        Ok(A::zero())
179    }
180}