Skip to main content

optirs_core/regularizers/
dropout.rs

1// Dropout regularization
2
3use scirs2_core::ndarray::{Array, Dimension, ScalarOperand, Zip};
4use scirs2_core::numeric::Float;
5use scirs2_core::random::Rng;
6use scirs2_core::Random;
7use std::cell::RefCell;
8use std::fmt::Debug;
9
10use crate::error::Result;
11use crate::regularizers::Regularizer;
12
13/// Dropout regularization
14///
15/// Randomly sets a fraction of the input units to 0 at each update during training,
16/// which helps prevent overfitting. During inference, all units are used with appropriate
17/// scaling to maintain the same expected output.
18///
19/// # Examples
20///
21/// ```
22/// use scirs2_core::ndarray::Array1;
23/// use optirs_core::regularizers::Dropout;
24/// use scirs2_core::random::SeedableRng;
25/// use scirs2_core::random::rngs::SmallRng;
26///
27/// // Create a dropout regularizer with 0.5 dropout rate
28/// let seed = [0u8; 32];
29/// let mut rng = SmallRng::from_seed(seed);
30/// let mut dropout = Dropout::new(0.5f64, &mut rng);
31///
32/// // Set to training mode
33/// dropout.train();
34///
35/// // Check the dropout rate
36/// assert_eq!(dropout.rate(), 0.5);
37///
38/// // Set to evaluation mode
39/// dropout.eval();
40/// assert!(!dropout.is_training());
41/// ```
42#[derive(Debug)]
43pub struct Dropout<A: Float + Debug> {
44    /// Dropout rate (fraction of units that are dropped)
45    rate: A,
46    /// Random number generator
47    rng: RefCell<Random<scirs2_core::random::rngs::StdRng>>,
48    /// Boolean indicating whether in training mode
49    training: bool,
50    /// Cached dropout mask
51    mask: RefCell<Option<Array<A, scirs2_core::ndarray::IxDyn>>>,
52}
53
54impl<A: Float + Debug + Send + Sync> Dropout<A> {
55    /// Create a new dropout regularizer
56    ///
57    /// # Arguments
58    ///
59    /// * `rate` - Dropout rate (0.0 to 1.0, fraction of units that are dropped)
60    /// * `rng` - Random number generator
61    pub fn new<R: Rng>(rate: A, rng: &mut R) -> Self {
62        // Ensure _rate is between 0 and 1
63        let rate = rate.max(A::zero()).min(A::one());
64
65        // Create a new RNG from the provided one
66        let mut seed_bytes = [0u8; 8];
67        rng.fill_bytes(&mut seed_bytes);
68        let seed = u64::from_ne_bytes(seed_bytes);
69        let rng = Random::seed(seed);
70
71        Self {
72            rate,
73            rng: RefCell::new(rng),
74            training: true,
75            mask: RefCell::new(None),
76        }
77    }
78
79    /// Get the dropout rate
80    pub fn rate(&self) -> A {
81        self.rate
82    }
83
84    /// Set the dropout rate
85    ///
86    /// # Arguments
87    ///
88    /// * `rate` - Dropout rate (0.0 to 1.0, fraction of units that are dropped)
89    pub fn set_rate(&mut self, rate: A) -> &mut Self {
90        // Ensure rate is between 0 and 1
91        self.rate = rate.max(A::zero()).min(A::one());
92        // Clear the mask cache
93        *self.mask.borrow_mut() = None;
94        self
95    }
96
97    /// Set to training mode (apply dropout)
98    pub fn train(&mut self) -> &mut Self {
99        self.training = true;
100        self
101    }
102
103    /// Set to inference mode (no dropout, scale outputs)
104    pub fn eval(&mut self) -> &mut Self {
105        self.training = false;
106        self
107    }
108
109    /// Get the training mode
110    pub fn is_training(&self) -> bool {
111        self.training
112    }
113
114    /// Create a new dropout mask for the given shape
115    ///
116    /// During training, randomly sets units to 0 with probability `rate`,
117    /// and scales the remaining by 1/(1-rate) to maintain the same expected output.
118    fn create_mask<D: Dimension>(&self, shape: D) -> Array<A, D> {
119        if !self.training || self.rate <= A::zero() {
120            // In eval mode or with 0 dropout rate, no masking is applied
121            return Array::ones(shape);
122        }
123
124        // The scale factor for the kept units is 1/(1-rate)
125        // This maintains the expected sum of the layer outputs
126        let keep_prob = A::one() - self.rate;
127        let scale = A::one() / keep_prob;
128
129        // Create a mask where units are kept with probability (1-rate)
130        // and scaled by 1/(1-rate)
131        let mut rng = self.rng.borrow_mut();
132        let mut mask = Array::zeros(shape);
133        for elem in mask.iter_mut() {
134            let rand_val =
135                A::from(rng.gen_range(0.0..1.0)).expect("failed to convert random value");
136            if rand_val > self.rate {
137                *elem = scale;
138            }
139        }
140
141        mask
142    }
143}
144
145impl<A, D> Regularizer<A, D> for Dropout<A>
146where
147    A: Float + ScalarOperand + Debug + Send + Sync,
148    D: Dimension<Pattern = D>,
149{
150    fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
151        if !self.training || self.rate <= A::zero() {
152            // In eval mode or with 0 dropout rate, no dropout is applied
153            return Ok(A::zero());
154        }
155
156        // Create or get the dropout mask
157        let mask = {
158            let mask_ref = self.mask.borrow();
159            match &*mask_ref {
160                Some(m) if m.shape() == gradients.shape() => {
161                    // Use cached mask if shapes match
162                    m.clone()
163                        .into_dimensionality::<D>()
164                        .expect("mask dimensionality conversion failed")
165                }
166                _ => {
167                    // Drop the borrow before calling create_mask which also borrows
168                    drop(mask_ref);
169                    // Create a new mask
170                    self.create_mask(gradients.dim())
171                }
172            }
173        };
174
175        // Apply the mask to the gradients
176        Zip::from(gradients).and(&mask).for_each(|grad, &mask_val| {
177            *grad = *grad * mask_val;
178        });
179
180        // Dropout doesn't add a penalty term to the loss
181        Ok(A::zero())
182    }
183
184    fn penalty(&self, params: &Array<A, D>) -> Result<A> {
185        // Dropout doesn't add a penalty term to the loss
186        Ok(A::zero())
187    }
188}