optirs_core/regularizers/
elastic_net.rs

1// ElasticNet regularization (L1 + L2)
2
3use scirs2_core::ndarray::{Array, Dimension, ScalarOperand, Zip};
4use scirs2_core::numeric::Float;
5use std::fmt::Debug;
6
7use crate::error::Result;
8use crate::regularizers::{Regularizer, L1, L2};
9
10/// ElasticNet regularization
11///
12/// Combines L1 and L2 regularization to get the benefits of both:
13/// - L1 encourages sparsity (many parameters are exactly 0)
14/// - L2 discourages large weights for more stable solutions
15///
16/// Penalty: l1_ratio * L1_penalty + (1 - l1_ratio) * L2_penalty
17///
18/// # Examples
19///
20/// ```
21/// use scirs2_core::ndarray::Array1;
22/// use optirs_core::regularizers::{ElasticNet, Regularizer};
23///
24/// // Create an ElasticNet regularizer with strength 0.01 and l1_ratio 0.5
25/// // (equal weight to L1 and L2)
26/// let regularizer = ElasticNet::new(0.01, 0.5);
27///
28/// // Parameters and gradients
29/// let params = Array1::from_vec(vec![0.5, -0.3, 0.0, 0.2]);
30/// let mut gradients = Array1::from_vec(vec![0.1, 0.2, -0.3, 0.0]);
31///
32/// // Apply regularization
33/// let penalty = regularizer.apply(&params, &mut gradients).unwrap();
34///
35/// // Gradients will be modified to include both L1 and L2 penalty gradients
36/// ```
37#[derive(Debug, Clone, Copy)]
38pub struct ElasticNet<A: Float + Debug> {
39    /// Total regularization strength
40    alpha: A,
41    /// Mixing parameter, with 0 <= l1_ratio <= 1
42    /// l1_ratio = 1 means only L1 penalty, l1_ratio = 0 means only L2 penalty
43    l1_ratio: A,
44    /// L1 regularizer
45    l1: L1<A>,
46    /// L2 regularizer
47    l2: L2<A>,
48}
49
50impl<A: Float + Debug + Send + Sync> ElasticNet<A> {
51    /// Create a new ElasticNet regularizer
52    ///
53    /// # Arguments
54    ///
55    /// * `alpha` - Total regularization strength
56    /// * `l1_ratio` - Mixing parameter (0 <= l1_ratio <= 1)
57    ///   - l1_ratio = 1: only L1 penalty
58    ///   - l1_ratio = 0: only L2 penalty
59    pub fn new(alpha: A, l1ratio: A) -> Self {
60        // Ensure l1_ratio is between 0 and 1
61        let l1_ratio = l1ratio.max(A::zero()).min(A::one());
62
63        // Compute individual strengths for L1 and L2
64        let l1_alpha = alpha * l1_ratio;
65        let l2_alpha = alpha * (A::one() - l1_ratio);
66
67        Self {
68            alpha,
69            l1_ratio,
70            l1: L1::new(l1_alpha),
71            l2: L2::new(l2_alpha),
72        }
73    }
74
75    /// Get the total regularization strength
76    pub fn alpha(&self) -> A {
77        self.alpha
78    }
79
80    /// Get the L1 ratio
81    pub fn l1_ratio(&self) -> A {
82        self.l1_ratio
83    }
84
85    /// Set the regularization parameters
86    ///
87    /// # Arguments
88    ///
89    /// * `alpha` - Total regularization strength
90    /// * `l1_ratio` - Mixing parameter (0 <= l1_ratio <= 1)
91    pub fn set_params(&mut self, alpha: A, l1ratio: A) -> &mut Self {
92        // Ensure l1_ratio is between 0 and 1
93        let l1_ratio = l1ratio.max(A::zero()).min(A::one());
94
95        self.alpha = alpha;
96        self.l1_ratio = l1_ratio;
97
98        // Update individual strengths
99        let l1_alpha = alpha * l1_ratio;
100        let l2_alpha = alpha * (A::one() - l1_ratio);
101
102        self.l1.set_alpha(l1_alpha);
103        self.l2.set_alpha(l2_alpha);
104
105        self
106    }
107}
108
109impl<A, D> Regularizer<A, D> for ElasticNet<A>
110where
111    A: Float + ScalarOperand + Debug,
112    D: Dimension,
113{
114    fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
115        // Create a copy of gradients for L2 regularization
116        let mut l2_grads = gradients.clone();
117
118        // Apply L1 regularization
119        let l1_penalty = self.l1.apply(params, gradients)?;
120
121        // Apply L2 regularization to the copy
122        let l2_penalty = self.l2.apply(params, &mut l2_grads)?;
123
124        // Combine the gradients according to l1_ratio
125        Zip::from(gradients)
126            .and(&l2_grads)
127            .for_each(|grad, &l2_grad| {
128                *grad = self.l1_ratio * *grad + (A::one() - self.l1_ratio) * l2_grad;
129            });
130
131        // Return combined penalty
132        Ok(l1_penalty + l2_penalty)
133    }
134
135    fn penalty(&self, params: &Array<A, D>) -> Result<A> {
136        // Compute L1 and L2 penalties
137        let l1_penalty = self.l1.penalty(params)?;
138        let l2_penalty = self.l2.penalty(params)?;
139
140        // Return combined penalty
141        Ok(l1_penalty + l2_penalty)
142    }
143}