Skip to main content

optirs_core/regularizers/
mod.rs

1// Regularization techniques for machine learning
2//
3// This module provides various regularization techniques commonly used in
4// machine learning to prevent overfitting, such as L1 (Lasso), L2 (Ridge),
5// ElasticNet, and Dropout.
6
7use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
8use scirs2_core::numeric::Float;
9use std::fmt::Debug;
10
11use crate::error::Result;
12
13/// Trait for regularizers that can be applied to parameters and gradients
14pub trait Regularizer<A, D>
15where
16    A: Float + ScalarOperand + Debug,
17    D: Dimension,
18{
19    /// Apply regularization to parameters and gradients
20    ///
21    /// # Arguments
22    ///
23    /// * `params` - The parameters to regularize
24    /// * `gradients` - The gradients to modify
25    ///
26    /// # Returns
27    ///
28    /// The regularization penalty value
29    fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A>;
30
31    /// Compute the regularization penalty value
32    ///
33    /// # Arguments
34    ///
35    /// * `params` - The parameters to compute the penalty for
36    ///
37    /// # Returns
38    ///
39    /// The regularization penalty value
40    fn penalty(&self, params: &Array<A, D>) -> Result<A>;
41}
42
43mod activity;
44mod dropconnect;
45mod dropout;
46mod elastic_net;
47mod entropy;
48mod group_lasso;
49mod l1;
50mod l2;
51mod label_smoothing;
52mod manifold;
53mod mixup;
54mod orthogonal;
55mod shakedrop;
56mod spatial_dropout;
57mod spectral_norm;
58mod stochastic_depth;
59mod weight_standardization;
60
61// Re-export regularizers
62pub use activity::{ActivityNorm, ActivityRegularization};
63pub use dropconnect::DropConnect;
64pub use dropout::Dropout;
65pub use elastic_net::ElasticNet;
66pub use entropy::{EntropyRegularization, EntropyRegularizerType};
67pub use group_lasso::{GroupLasso, SparsityPattern, StructuredSparsity};
68pub use l1::L1;
69pub use l2::L2;
70pub use label_smoothing::LabelSmoothing;
71pub use manifold::ManifoldRegularization;
72pub use mixup::{CutMix, MixUp};
73pub use orthogonal::OrthogonalRegularization;
74pub use shakedrop::ShakeDrop;
75pub use spatial_dropout::{FeatureDropout, SpatialDropout};
76pub use spectral_norm::SpectralNorm;
77pub use stochastic_depth::StochasticDepth;
78pub use weight_standardization::WeightStandardization;