optirs-core 0.3.1

OptiRS core optimization algorithms and utilities
Documentation
// Stochastic Depth regularization
//
// Stochastic Depth is a regularization technique that randomly skips
// certain layers during training, which helps prevent overfitting and
// improves gradient flow in very deep networks.

use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::Debug;

use crate::error::Result;
use crate::regularizers::Regularizer;

/// Stochastic Depth regularization
///
/// Implements stochastic depth by randomly skipping layers during training.
/// During inference, all layers are used with a scaling factor.
///
/// # Example
///
/// ```
/// use scirs2_core::ndarray::array;
/// use optirs_core::regularizers::StochasticDepth;
///
/// let stochastic_depth = StochasticDepth::new(0.2, 10, 50);
/// let features = array![[1.0, 2.0], [3.0, 4.0]];
///
/// // Apply stochastic depth for layer 5 during training
/// let output = stochastic_depth.apply_layer(5, &features, true);
/// ```
#[derive(Debug, Clone)]
pub struct StochasticDepth<A: Float> {
    /// Probability of dropping a layer
    drop_prob: A,
    /// Current layer index
    layer_idx: usize,
    /// Total number of layers
    num_layers: usize,
    /// Random state for drop decision
    rng_state: u64,
}

impl<A: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> StochasticDepth<A> {
    /// Create a new stochastic depth regularization
    ///
    /// # Arguments
    ///
    /// * `drop_prob` - The base probability of dropping a layer
    /// * `layer_idx` - The index of the current layer
    /// * `num_layers` - The total number of layers in the network
    pub fn new(drop_prob: A, layer_idx: usize, numlayers: usize) -> Self {
        Self {
            drop_prob,
            layer_idx,
            num_layers: numlayers,
            rng_state: 0,
        }
    }

    /// Set layer index
    ///
    /// # Arguments
    ///
    /// * `layer_idx` - New layer index
    pub fn set_layer(&mut self, layeridx: usize) {
        self.layer_idx = layeridx;
    }

    /// Set the RNG state for deterministic behavior
    pub fn set_rng_state(&mut self, state: u64) {
        self.rng_state = state;
    }

    /// Get the survival probability for the current layer
    ///
    /// The survival probability typically decreases for deeper layers,
    /// following a linear decay schedule.
    fn survival_probability(&self) -> A {
        // Linear decay of survival probability with depth
        let layer_ratio = A::from_usize(self.layer_idx).expect("unwrap failed")
            / A::from_usize(self.num_layers).expect("unwrap failed");
        A::one() - (self.drop_prob * layer_ratio)
    }

    /// Decide whether to drop the current layer
    fn should_drop(&self) -> bool {
        // Use simple random hash function for reproducibility
        let hash = (self
            .rng_state
            .wrapping_mul(0x7fffffff)
            .wrapping_add(self.layer_idx as u64))
            % 10000;
        let random_val = A::from_f64(hash as f64 / 10000.0).expect("unwrap failed");

        random_val > self.survival_probability()
    }

    /// Apply stochastic depth to a layer
    ///
    /// # Arguments
    ///
    /// * `layer_idx` - Index of the layer
    /// * `features` - Input features
    /// * `training` - Whether in training mode
    ///
    /// # Returns
    ///
    /// The output features, which are either:
    /// - The identity (input) if the layer is dropped during training
    /// - The input scaled by the survival probability during inference
    /// - The input if not dropped during training
    pub fn apply_layer<D>(
        &self,
        layer_idx: usize,
        features: &Array<A, D>,
        training: bool,
    ) -> Array<A, D>
    where
        D: Dimension,
    {
        let survival_prob = self.survival_probability();

        if training {
            let mut sd = self.clone();
            sd.set_layer(layer_idx);

            if sd.should_drop() {
                // Skip this layer
                features.clone()
            } else {
                // Use this layer normally
                features.clone()
            }
        } else {
            // During inference, scale by survival probability
            features * survival_prob
        }
    }
}

// Implement Regularizer trait (although the main functionality is in apply_layer)
impl<A: Float + Debug + ScalarOperand + FromPrimitive, D: Dimension + Send + Sync> Regularizer<A, D>
    for StochasticDepth<A>
{
    fn apply(&self, _params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
        // This method is not the primary way to use stochastic depth,
        // prefer apply_layer for layer-wise applications
        Ok(A::zero())
    }

    fn penalty(&self, params: &Array<A, D>) -> Result<A> {
        // Stochastic depth doesn't add a direct penalty term
        Ok(A::zero())
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use scirs2_core::ndarray::array;

    #[test]
    fn test_stochastic_depth_creation() {
        let sd = StochasticDepth::<f64>::new(0.2, 5, 10);
        assert_eq!(sd.drop_prob, 0.2);
        assert_eq!(sd.layer_idx, 5);
        assert_eq!(sd.num_layers, 10);
    }

    #[test]
    fn test_survival_probability() {
        // For layer 0 of 10 with drop_prob 0.5, survival prob is 1.0
        let sd1 = StochasticDepth::<f64>::new(0.5, 0, 10);
        assert_eq!(sd1.survival_probability(), 1.0);

        // For layer 10 of 10 with drop_prob 0.5, survival prob is 0.5
        let sd2 = StochasticDepth::<f64>::new(0.5, 10, 10);
        assert_eq!(sd2.survival_probability(), 0.5);

        // For layer 5 of 10 with drop_prob 0.5, survival prob is 0.75
        let sd3 = StochasticDepth::<f64>::new(0.5, 5, 10);
        assert_eq!(sd3.survival_probability(), 0.75);
    }

    #[test]
    fn test_should_drop() {
        // With fixed RNG states, we can test deterministic behavior
        let mut sd = StochasticDepth::<f64>::new(0.5, 5, 10);

        // Try different RNG states
        sd.set_rng_state(12345);
        let _result1 = sd.should_drop();

        sd.set_rng_state(54321);
        let _result2 = sd.should_drop();

        // The results should be deterministic for given RNG states
        // result1 is already a boolean, no need to assert
        // result2 is already a boolean, no need to assert
    }

    #[test]
    fn test_apply_layer_training() {
        let sd = StochasticDepth::<f64>::new(0.5, 5, 10);
        let features = array![[1.0, 2.0], [3.0, 4.0]];

        // In training mode, the output is either features or modified features
        let output = sd.apply_layer(5, &features, true);

        // Output should be 2D array with same shape
        assert_eq!(output.shape(), features.shape());
    }

    #[test]
    fn test_apply_layer_inference() {
        let sd = StochasticDepth::<f64>::new(0.5, 5, 10);
        let features = array![[1.0, 2.0], [3.0, 4.0]];

        // In inference mode, output is always scaled by survival probability
        let output = sd.apply_layer(5, &features, false);
        let survival_prob = sd.survival_probability();

        // Check that each element is scaled by survival probability
        for (i, j) in output.indexed_iter() {
            assert_eq!(*j, features[i] * survival_prob);
        }
    }

    #[test]
    fn test_regularizer_trait() {
        let sd = StochasticDepth::<f64>::new(0.5, 5, 10);
        let params = array![[1.0, 2.0], [3.0, 4.0]];
        let mut gradients = array![[0.1, 0.2], [0.3, 0.4]];
        let original_gradients = gradients.clone();

        let penalty = sd.apply(&params, &mut gradients).expect("unwrap failed");

        // Penalty should be zero
        assert_eq!(penalty, 0.0);

        // Gradients should be unchanged
        assert_eq!(gradients, original_gradients);
    }
}