optirs_core/regularizers/dropout.rs
1// Dropout regularization
2
3use scirs2_core::ndarray::{Array, Dimension, ScalarOperand, Zip};
4use scirs2_core::numeric::Float;
5use scirs2_core::random::Rng;
6use scirs2_core::Random;
7use std::fmt::Debug;
8
9use crate::error::Result;
10use crate::regularizers::Regularizer;
11
12/// Dropout regularization
13///
14/// Randomly sets a fraction of the input units to 0 at each update during training,
15/// which helps prevent overfitting. During inference, all units are used with appropriate
16/// scaling to maintain the same expected output.
17///
18/// # Examples
19///
20/// ```
21/// use scirs2_core::ndarray::Array1;
22/// use optirs_core::regularizers::Dropout;
23/// use scirs2_core::random::SeedableRng;
24/// use scirs2_core::random::rngs::SmallRng;
25///
26/// // Create a dropout regularizer with 0.5 dropout rate
27/// let seed = [0u8; 32];
28/// let mut rng = SmallRng::from_seed(seed);
29/// let mut dropout = Dropout::new(0.5f64, &mut rng);
30///
31/// // Set to training mode
32/// dropout.train();
33///
34/// // Check the dropout rate
35/// assert_eq!(dropout.rate(), 0.5);
36///
37/// // Set to evaluation mode
38/// dropout.eval();
39/// assert!(!dropout.is_training());
40/// ```
41#[derive(Debug, Clone)]
42pub struct Dropout<A: Float + Debug> {
43 /// Dropout rate (fraction of units that are dropped)
44 rate: A,
45 /// Random number generator
46 rng: Random<scirs2_core::random::rngs::StdRng>,
47 /// Boolean indicating whether in training mode
48 training: bool,
49 /// Cached dropout mask
50 mask: Option<Array<A, scirs2_core::ndarray::IxDyn>>,
51}
52
53impl<A: Float + Debug + Send + Sync> Dropout<A> {
54 /// Create a new dropout regularizer
55 ///
56 /// # Arguments
57 ///
58 /// * `rate` - Dropout rate (0.0 to 1.0, fraction of units that are dropped)
59 /// * `rng` - Random number generator
60 pub fn new<R: Rng>(rate: A, rng: &mut R) -> Self {
61 // Ensure _rate is between 0 and 1
62 let rate = rate.max(A::zero()).min(A::one());
63
64 // Create a new RNG from the provided one
65 let mut seed_bytes = [0u8; 8];
66 rng.fill_bytes(&mut seed_bytes);
67 let seed = u64::from_ne_bytes(seed_bytes);
68 let rng = Random::seed(seed);
69
70 Self {
71 rate,
72 rng,
73 training: true,
74 mask: None,
75 }
76 }
77
78 /// Get the dropout rate
79 pub fn rate(&self) -> A {
80 self.rate
81 }
82
83 /// Set the dropout rate
84 ///
85 /// # Arguments
86 ///
87 /// * `rate` - Dropout rate (0.0 to 1.0, fraction of units that are dropped)
88 pub fn set_rate(&mut self, rate: A) -> &mut Self {
89 // Ensure rate is between 0 and 1
90 self.rate = rate.max(A::zero()).min(A::one());
91 // Clear the mask cache
92 self.mask = None;
93 self
94 }
95
96 /// Set to training mode (apply dropout)
97 pub fn train(&mut self) -> &mut Self {
98 self.training = true;
99 self
100 }
101
102 /// Set to inference mode (no dropout, scale outputs)
103 pub fn eval(&mut self) -> &mut Self {
104 self.training = false;
105 self
106 }
107
108 /// Get the training mode
109 pub fn is_training(&self) -> bool {
110 self.training
111 }
112
113 /// Create a new dropout mask for the given shape
114 ///
115 /// During training, randomly sets units to 0 with probability `rate`,
116 /// and scales the remaining by 1/(1-rate) to maintain the same expected output.
117 fn create_mask<D: Dimension>(&mut self, shape: D) -> Array<A, D> {
118 if !self.training || self.rate <= A::zero() {
119 // In eval mode or with 0 dropout rate, no masking is applied
120 return Array::ones(shape);
121 }
122
123 // The scale factor for the kept units is 1/(1-rate)
124 // This maintains the expected sum of the layer outputs
125 let keep_prob = A::one() - self.rate;
126 let scale = A::one() / keep_prob;
127
128 // Create a mask where units are kept with probability (1-rate)
129 // and scaled by 1/(1-rate)
130 let mut mask = Array::zeros(shape);
131 for elem in mask.iter_mut() {
132 let rand_val = A::from(self.rng.gen_range(0.0..1.0)).unwrap();
133 if rand_val > self.rate {
134 *elem = scale;
135 }
136 }
137
138 mask
139 }
140}
141
142impl<A, D> Regularizer<A, D> for Dropout<A>
143where
144 A: Float + ScalarOperand + Debug + Send + Sync,
145 D: Dimension<Pattern = D>,
146{
147 fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
148 if !self.training || self.rate <= A::zero() {
149 // In eval mode or with 0 dropout rate, no dropout is applied
150 return Ok(A::zero());
151 }
152
153 // Create or get the dropout mask
154 let mask = match &self.mask {
155 Some(m) if m.shape() == gradients.shape() => {
156 // Use cached mask if shapes match
157 m.clone().into_dimensionality::<D>().unwrap()
158 }
159 _ => {
160 // Create a new mask
161 let mut dropout = self.clone();
162 // We would cache the mask here in a mutable context
163 dropout.create_mask(gradients.dim())
164 }
165 };
166
167 // Apply the mask to the gradients
168 Zip::from(gradients).and(&mask).for_each(|grad, &mask_val| {
169 *grad = *grad * mask_val;
170 });
171
172 // Dropout doesn't add a penalty term to the loss
173 Ok(A::zero())
174 }
175
176 fn penalty(&self, params: &Array<A, D>) -> Result<A> {
177 // Dropout doesn't add a penalty term to the loss
178 Ok(A::zero())
179 }
180}