optirs_core/regularizers/stochastic_depth.rs
1// Stochastic Depth regularization
2//
3// Stochastic Depth is a regularization technique that randomly skips
4// certain layers during training, which helps prevent overfitting and
5// improves gradient flow in very deep networks.
6
7use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use std::fmt::Debug;
10
11use crate::error::Result;
12use crate::regularizers::Regularizer;
13
14/// Stochastic Depth regularization
15///
16/// Implements stochastic depth by randomly skipping layers during training.
17/// During inference, all layers are used with a scaling factor.
18///
19/// # Example
20///
21/// ```
22/// use scirs2_core::ndarray::array;
23/// use optirs_core::regularizers::StochasticDepth;
24///
25/// let stochastic_depth = StochasticDepth::new(0.2, 10, 50);
26/// let features = array![[1.0, 2.0], [3.0, 4.0]];
27///
28/// // Apply stochastic depth for layer 5 during training
29/// let output = stochastic_depth.apply_layer(5, &features, true);
30/// ```
31#[derive(Debug, Clone)]
32pub struct StochasticDepth<A: Float> {
33 /// Probability of dropping a layer
34 drop_prob: A,
35 /// Current layer index
36 layer_idx: usize,
37 /// Total number of layers
38 num_layers: usize,
39 /// Random state for drop decision
40 rng_state: u64,
41}
42
43impl<A: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> StochasticDepth<A> {
44 /// Create a new stochastic depth regularization
45 ///
46 /// # Arguments
47 ///
48 /// * `drop_prob` - The base probability of dropping a layer
49 /// * `layer_idx` - The index of the current layer
50 /// * `num_layers` - The total number of layers in the network
51 pub fn new(drop_prob: A, layer_idx: usize, numlayers: usize) -> Self {
52 Self {
53 drop_prob,
54 layer_idx,
55 num_layers: numlayers,
56 rng_state: 0,
57 }
58 }
59
60 /// Set layer index
61 ///
62 /// # Arguments
63 ///
64 /// * `layer_idx` - New layer index
65 pub fn set_layer(&mut self, layeridx: usize) {
66 self.layer_idx = layeridx;
67 }
68
69 /// Set the RNG state for deterministic behavior
70 pub fn set_rng_state(&mut self, state: u64) {
71 self.rng_state = state;
72 }
73
74 /// Get the survival probability for the current layer
75 ///
76 /// The survival probability typically decreases for deeper layers,
77 /// following a linear decay schedule.
78 fn survival_probability(&self) -> A {
79 // Linear decay of survival probability with depth
80 let layer_ratio =
81 A::from_usize(self.layer_idx).unwrap() / A::from_usize(self.num_layers).unwrap();
82 A::one() - (self.drop_prob * layer_ratio)
83 }
84
85 /// Decide whether to drop the current layer
86 fn should_drop(&self) -> bool {
87 // Use simple random hash function for reproducibility
88 let hash = (self
89 .rng_state
90 .wrapping_mul(0x7fffffff)
91 .wrapping_add(self.layer_idx as u64))
92 % 10000;
93 let random_val = A::from_f64(hash as f64 / 10000.0).unwrap();
94
95 random_val > self.survival_probability()
96 }
97
98 /// Apply stochastic depth to a layer
99 ///
100 /// # Arguments
101 ///
102 /// * `layer_idx` - Index of the layer
103 /// * `features` - Input features
104 /// * `training` - Whether in training mode
105 ///
106 /// # Returns
107 ///
108 /// The output features, which are either:
109 /// - The identity (input) if the layer is dropped during training
110 /// - The input scaled by the survival probability during inference
111 /// - The input if not dropped during training
112 pub fn apply_layer<D>(
113 &self,
114 layer_idx: usize,
115 features: &Array<A, D>,
116 training: bool,
117 ) -> Array<A, D>
118 where
119 D: Dimension,
120 {
121 let survival_prob = self.survival_probability();
122
123 if training {
124 let mut sd = self.clone();
125 sd.set_layer(layer_idx);
126
127 if sd.should_drop() {
128 // Skip this layer
129 features.clone()
130 } else {
131 // Use this layer normally
132 features.clone()
133 }
134 } else {
135 // During inference, scale by survival probability
136 features * survival_prob
137 }
138 }
139}
140
141// Implement Regularizer trait (although the main functionality is in apply_layer)
142impl<A: Float + Debug + ScalarOperand + FromPrimitive, D: Dimension + Send + Sync> Regularizer<A, D>
143 for StochasticDepth<A>
144{
145 fn apply(&self, _params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
146 // This method is not the primary way to use stochastic depth,
147 // prefer apply_layer for layer-wise applications
148 Ok(A::zero())
149 }
150
151 fn penalty(&self, params: &Array<A, D>) -> Result<A> {
152 // Stochastic depth doesn't add a direct penalty term
153 Ok(A::zero())
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use scirs2_core::ndarray::array;
161
162 #[test]
163 fn test_stochastic_depth_creation() {
164 let sd = StochasticDepth::<f64>::new(0.2, 5, 10);
165 assert_eq!(sd.drop_prob, 0.2);
166 assert_eq!(sd.layer_idx, 5);
167 assert_eq!(sd.num_layers, 10);
168 }
169
170 #[test]
171 fn test_survival_probability() {
172 // For layer 0 of 10 with drop_prob 0.5, survival prob is 1.0
173 let sd1 = StochasticDepth::<f64>::new(0.5, 0, 10);
174 assert_eq!(sd1.survival_probability(), 1.0);
175
176 // For layer 10 of 10 with drop_prob 0.5, survival prob is 0.5
177 let sd2 = StochasticDepth::<f64>::new(0.5, 10, 10);
178 assert_eq!(sd2.survival_probability(), 0.5);
179
180 // For layer 5 of 10 with drop_prob 0.5, survival prob is 0.75
181 let sd3 = StochasticDepth::<f64>::new(0.5, 5, 10);
182 assert_eq!(sd3.survival_probability(), 0.75);
183 }
184
185 #[test]
186 fn test_should_drop() {
187 // With fixed RNG states, we can test deterministic behavior
188 let mut sd = StochasticDepth::<f64>::new(0.5, 5, 10);
189
190 // Try different RNG states
191 sd.set_rng_state(12345);
192 let _result1 = sd.should_drop();
193
194 sd.set_rng_state(54321);
195 let _result2 = sd.should_drop();
196
197 // The results should be deterministic for given RNG states
198 // result1 is already a boolean, no need to assert
199 // result2 is already a boolean, no need to assert
200 }
201
202 #[test]
203 fn test_apply_layer_training() {
204 let sd = StochasticDepth::<f64>::new(0.5, 5, 10);
205 let features = array![[1.0, 2.0], [3.0, 4.0]];
206
207 // In training mode, the output is either features or modified features
208 let output = sd.apply_layer(5, &features, true);
209
210 // Output should be 2D array with same shape
211 assert_eq!(output.shape(), features.shape());
212 }
213
214 #[test]
215 fn test_apply_layer_inference() {
216 let sd = StochasticDepth::<f64>::new(0.5, 5, 10);
217 let features = array![[1.0, 2.0], [3.0, 4.0]];
218
219 // In inference mode, output is always scaled by survival probability
220 let output = sd.apply_layer(5, &features, false);
221 let survival_prob = sd.survival_probability();
222
223 // Check that each element is scaled by survival probability
224 for (i, j) in output.indexed_iter() {
225 assert_eq!(*j, features[i] * survival_prob);
226 }
227 }
228
229 #[test]
230 fn test_regularizer_trait() {
231 let sd = StochasticDepth::<f64>::new(0.5, 5, 10);
232 let params = array![[1.0, 2.0], [3.0, 4.0]];
233 let mut gradients = array![[0.1, 0.2], [0.3, 0.4]];
234 let original_gradients = gradients.clone();
235
236 let penalty = sd.apply(¶ms, &mut gradients).unwrap();
237
238 // Penalty should be zero
239 assert_eq!(penalty, 0.0);
240
241 // Gradients should be unchanged
242 assert_eq!(gradients, original_gradients);
243 }
244}