optirs_core/regularizers/
stochastic_depth.rs

1// Stochastic Depth regularization
2//
3// Stochastic Depth is a regularization technique that randomly skips
4// certain layers during training, which helps prevent overfitting and
5// improves gradient flow in very deep networks.
6
7use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use std::fmt::Debug;
10
11use crate::error::Result;
12use crate::regularizers::Regularizer;
13
14/// Stochastic Depth regularization
15///
16/// Implements stochastic depth by randomly skipping layers during training.
17/// During inference, all layers are used with a scaling factor.
18///
19/// # Example
20///
21/// ```
22/// use scirs2_core::ndarray::array;
23/// use optirs_core::regularizers::StochasticDepth;
24///
25/// let stochastic_depth = StochasticDepth::new(0.2, 10, 50);
26/// let features = array![[1.0, 2.0], [3.0, 4.0]];
27///
28/// // Apply stochastic depth for layer 5 during training
29/// let output = stochastic_depth.apply_layer(5, &features, true);
30/// ```
31#[derive(Debug, Clone)]
32pub struct StochasticDepth<A: Float> {
33    /// Probability of dropping a layer
34    drop_prob: A,
35    /// Current layer index
36    layer_idx: usize,
37    /// Total number of layers
38    num_layers: usize,
39    /// Random state for drop decision
40    rng_state: u64,
41}
42
43impl<A: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> StochasticDepth<A> {
44    /// Create a new stochastic depth regularization
45    ///
46    /// # Arguments
47    ///
48    /// * `drop_prob` - The base probability of dropping a layer
49    /// * `layer_idx` - The index of the current layer
50    /// * `num_layers` - The total number of layers in the network
51    pub fn new(drop_prob: A, layer_idx: usize, numlayers: usize) -> Self {
52        Self {
53            drop_prob,
54            layer_idx,
55            num_layers: numlayers,
56            rng_state: 0,
57        }
58    }
59
60    /// Set layer index
61    ///
62    /// # Arguments
63    ///
64    /// * `layer_idx` - New layer index
65    pub fn set_layer(&mut self, layeridx: usize) {
66        self.layer_idx = layeridx;
67    }
68
69    /// Set the RNG state for deterministic behavior
70    pub fn set_rng_state(&mut self, state: u64) {
71        self.rng_state = state;
72    }
73
74    /// Get the survival probability for the current layer
75    ///
76    /// The survival probability typically decreases for deeper layers,
77    /// following a linear decay schedule.
78    fn survival_probability(&self) -> A {
79        // Linear decay of survival probability with depth
80        let layer_ratio =
81            A::from_usize(self.layer_idx).unwrap() / A::from_usize(self.num_layers).unwrap();
82        A::one() - (self.drop_prob * layer_ratio)
83    }
84
85    /// Decide whether to drop the current layer
86    fn should_drop(&self) -> bool {
87        // Use simple random hash function for reproducibility
88        let hash = (self
89            .rng_state
90            .wrapping_mul(0x7fffffff)
91            .wrapping_add(self.layer_idx as u64))
92            % 10000;
93        let random_val = A::from_f64(hash as f64 / 10000.0).unwrap();
94
95        random_val > self.survival_probability()
96    }
97
98    /// Apply stochastic depth to a layer
99    ///
100    /// # Arguments
101    ///
102    /// * `layer_idx` - Index of the layer
103    /// * `features` - Input features
104    /// * `training` - Whether in training mode
105    ///
106    /// # Returns
107    ///
108    /// The output features, which are either:
109    /// - The identity (input) if the layer is dropped during training
110    /// - The input scaled by the survival probability during inference
111    /// - The input if not dropped during training
112    pub fn apply_layer<D>(
113        &self,
114        layer_idx: usize,
115        features: &Array<A, D>,
116        training: bool,
117    ) -> Array<A, D>
118    where
119        D: Dimension,
120    {
121        let survival_prob = self.survival_probability();
122
123        if training {
124            let mut sd = self.clone();
125            sd.set_layer(layer_idx);
126
127            if sd.should_drop() {
128                // Skip this layer
129                features.clone()
130            } else {
131                // Use this layer normally
132                features.clone()
133            }
134        } else {
135            // During inference, scale by survival probability
136            features * survival_prob
137        }
138    }
139}
140
141// Implement Regularizer trait (although the main functionality is in apply_layer)
142impl<A: Float + Debug + ScalarOperand + FromPrimitive, D: Dimension + Send + Sync> Regularizer<A, D>
143    for StochasticDepth<A>
144{
145    fn apply(&self, _params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
146        // This method is not the primary way to use stochastic depth,
147        // prefer apply_layer for layer-wise applications
148        Ok(A::zero())
149    }
150
151    fn penalty(&self, params: &Array<A, D>) -> Result<A> {
152        // Stochastic depth doesn't add a direct penalty term
153        Ok(A::zero())
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use scirs2_core::ndarray::array;
161
162    #[test]
163    fn test_stochastic_depth_creation() {
164        let sd = StochasticDepth::<f64>::new(0.2, 5, 10);
165        assert_eq!(sd.drop_prob, 0.2);
166        assert_eq!(sd.layer_idx, 5);
167        assert_eq!(sd.num_layers, 10);
168    }
169
170    #[test]
171    fn test_survival_probability() {
172        // For layer 0 of 10 with drop_prob 0.5, survival prob is 1.0
173        let sd1 = StochasticDepth::<f64>::new(0.5, 0, 10);
174        assert_eq!(sd1.survival_probability(), 1.0);
175
176        // For layer 10 of 10 with drop_prob 0.5, survival prob is 0.5
177        let sd2 = StochasticDepth::<f64>::new(0.5, 10, 10);
178        assert_eq!(sd2.survival_probability(), 0.5);
179
180        // For layer 5 of 10 with drop_prob 0.5, survival prob is 0.75
181        let sd3 = StochasticDepth::<f64>::new(0.5, 5, 10);
182        assert_eq!(sd3.survival_probability(), 0.75);
183    }
184
185    #[test]
186    fn test_should_drop() {
187        // With fixed RNG states, we can test deterministic behavior
188        let mut sd = StochasticDepth::<f64>::new(0.5, 5, 10);
189
190        // Try different RNG states
191        sd.set_rng_state(12345);
192        let _result1 = sd.should_drop();
193
194        sd.set_rng_state(54321);
195        let _result2 = sd.should_drop();
196
197        // The results should be deterministic for given RNG states
198        // result1 is already a boolean, no need to assert
199        // result2 is already a boolean, no need to assert
200    }
201
202    #[test]
203    fn test_apply_layer_training() {
204        let sd = StochasticDepth::<f64>::new(0.5, 5, 10);
205        let features = array![[1.0, 2.0], [3.0, 4.0]];
206
207        // In training mode, the output is either features or modified features
208        let output = sd.apply_layer(5, &features, true);
209
210        // Output should be 2D array with same shape
211        assert_eq!(output.shape(), features.shape());
212    }
213
214    #[test]
215    fn test_apply_layer_inference() {
216        let sd = StochasticDepth::<f64>::new(0.5, 5, 10);
217        let features = array![[1.0, 2.0], [3.0, 4.0]];
218
219        // In inference mode, output is always scaled by survival probability
220        let output = sd.apply_layer(5, &features, false);
221        let survival_prob = sd.survival_probability();
222
223        // Check that each element is scaled by survival probability
224        for (i, j) in output.indexed_iter() {
225            assert_eq!(*j, features[i] * survival_prob);
226        }
227    }
228
229    #[test]
230    fn test_regularizer_trait() {
231        let sd = StochasticDepth::<f64>::new(0.5, 5, 10);
232        let params = array![[1.0, 2.0], [3.0, 4.0]];
233        let mut gradients = array![[0.1, 0.2], [0.3, 0.4]];
234        let original_gradients = gradients.clone();
235
236        let penalty = sd.apply(&params, &mut gradients).unwrap();
237
238        // Penalty should be zero
239        assert_eq!(penalty, 0.0);
240
241        // Gradients should be unchanged
242        assert_eq!(gradients, original_gradients);
243    }
244}