optirs_core/regularizers/dropout.rs
1// Dropout regularization
2
3use scirs2_core::ndarray::{Array, Dimension, ScalarOperand, Zip};
4use scirs2_core::numeric::Float;
5use scirs2_core::random::Rng;
6use scirs2_core::Random;
7use std::cell::RefCell;
8use std::fmt::Debug;
9
10use crate::error::Result;
11use crate::regularizers::Regularizer;
12
13/// Dropout regularization
14///
15/// Randomly sets a fraction of the input units to 0 at each update during training,
16/// which helps prevent overfitting. During inference, all units are used with appropriate
17/// scaling to maintain the same expected output.
18///
19/// # Examples
20///
21/// ```
22/// use scirs2_core::ndarray::Array1;
23/// use optirs_core::regularizers::Dropout;
24/// use scirs2_core::random::SeedableRng;
25/// use scirs2_core::random::rngs::SmallRng;
26///
27/// // Create a dropout regularizer with 0.5 dropout rate
28/// let seed = [0u8; 32];
29/// let mut rng = SmallRng::from_seed(seed);
30/// let mut dropout = Dropout::new(0.5f64, &mut rng);
31///
32/// // Set to training mode
33/// dropout.train();
34///
35/// // Check the dropout rate
36/// assert_eq!(dropout.rate(), 0.5);
37///
38/// // Set to evaluation mode
39/// dropout.eval();
40/// assert!(!dropout.is_training());
41/// ```
42#[derive(Debug)]
43pub struct Dropout<A: Float + Debug> {
44 /// Dropout rate (fraction of units that are dropped)
45 rate: A,
46 /// Random number generator
47 rng: RefCell<Random<scirs2_core::random::rngs::StdRng>>,
48 /// Boolean indicating whether in training mode
49 training: bool,
50 /// Cached dropout mask
51 mask: RefCell<Option<Array<A, scirs2_core::ndarray::IxDyn>>>,
52}
53
54impl<A: Float + Debug + Send + Sync> Dropout<A> {
55 /// Create a new dropout regularizer
56 ///
57 /// # Arguments
58 ///
59 /// * `rate` - Dropout rate (0.0 to 1.0, fraction of units that are dropped)
60 /// * `rng` - Random number generator
61 pub fn new<R: Rng>(rate: A, rng: &mut R) -> Self {
62 // Ensure _rate is between 0 and 1
63 let rate = rate.max(A::zero()).min(A::one());
64
65 // Create a new RNG from the provided one
66 let mut seed_bytes = [0u8; 8];
67 rng.fill_bytes(&mut seed_bytes);
68 let seed = u64::from_ne_bytes(seed_bytes);
69 let rng = Random::seed(seed);
70
71 Self {
72 rate,
73 rng: RefCell::new(rng),
74 training: true,
75 mask: RefCell::new(None),
76 }
77 }
78
79 /// Get the dropout rate
80 pub fn rate(&self) -> A {
81 self.rate
82 }
83
84 /// Set the dropout rate
85 ///
86 /// # Arguments
87 ///
88 /// * `rate` - Dropout rate (0.0 to 1.0, fraction of units that are dropped)
89 pub fn set_rate(&mut self, rate: A) -> &mut Self {
90 // Ensure rate is between 0 and 1
91 self.rate = rate.max(A::zero()).min(A::one());
92 // Clear the mask cache
93 *self.mask.borrow_mut() = None;
94 self
95 }
96
97 /// Set to training mode (apply dropout)
98 pub fn train(&mut self) -> &mut Self {
99 self.training = true;
100 self
101 }
102
103 /// Set to inference mode (no dropout, scale outputs)
104 pub fn eval(&mut self) -> &mut Self {
105 self.training = false;
106 self
107 }
108
109 /// Get the training mode
110 pub fn is_training(&self) -> bool {
111 self.training
112 }
113
114 /// Create a new dropout mask for the given shape
115 ///
116 /// During training, randomly sets units to 0 with probability `rate`,
117 /// and scales the remaining by 1/(1-rate) to maintain the same expected output.
118 fn create_mask<D: Dimension>(&self, shape: D) -> Array<A, D> {
119 if !self.training || self.rate <= A::zero() {
120 // In eval mode or with 0 dropout rate, no masking is applied
121 return Array::ones(shape);
122 }
123
124 // The scale factor for the kept units is 1/(1-rate)
125 // This maintains the expected sum of the layer outputs
126 let keep_prob = A::one() - self.rate;
127 let scale = A::one() / keep_prob;
128
129 // Create a mask where units are kept with probability (1-rate)
130 // and scaled by 1/(1-rate)
131 let mut rng = self.rng.borrow_mut();
132 let mut mask = Array::zeros(shape);
133 for elem in mask.iter_mut() {
134 let rand_val =
135 A::from(rng.gen_range(0.0..1.0)).expect("failed to convert random value");
136 if rand_val > self.rate {
137 *elem = scale;
138 }
139 }
140
141 mask
142 }
143}
144
145impl<A, D> Regularizer<A, D> for Dropout<A>
146where
147 A: Float + ScalarOperand + Debug + Send + Sync,
148 D: Dimension<Pattern = D>,
149{
150 fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
151 if !self.training || self.rate <= A::zero() {
152 // In eval mode or with 0 dropout rate, no dropout is applied
153 return Ok(A::zero());
154 }
155
156 // Create or get the dropout mask
157 let mask = {
158 let mask_ref = self.mask.borrow();
159 match &*mask_ref {
160 Some(m) if m.shape() == gradients.shape() => {
161 // Use cached mask if shapes match
162 m.clone()
163 .into_dimensionality::<D>()
164 .expect("mask dimensionality conversion failed")
165 }
166 _ => {
167 // Drop the borrow before calling create_mask which also borrows
168 drop(mask_ref);
169 // Create a new mask
170 self.create_mask(gradients.dim())
171 }
172 }
173 };
174
175 // Apply the mask to the gradients
176 Zip::from(gradients).and(&mask).for_each(|grad, &mask_val| {
177 *grad = *grad * mask_val;
178 });
179
180 // Dropout doesn't add a penalty term to the loss
181 Ok(A::zero())
182 }
183
184 fn penalty(&self, params: &Array<A, D>) -> Result<A> {
185 // Dropout doesn't add a penalty term to the loss
186 Ok(A::zero())
187 }
188}