optirs_core/regularizers/elastic_net.rs
1// ElasticNet regularization (L1 + L2)
2
3use scirs2_core::ndarray::{Array, Dimension, ScalarOperand, Zip};
4use scirs2_core::numeric::Float;
5use std::fmt::Debug;
6
7use crate::error::Result;
8use crate::regularizers::{Regularizer, L1, L2};
9
10/// ElasticNet regularization
11///
12/// Combines L1 and L2 regularization to get the benefits of both:
13/// - L1 encourages sparsity (many parameters are exactly 0)
14/// - L2 discourages large weights for more stable solutions
15///
16/// Penalty: l1_ratio * L1_penalty + (1 - l1_ratio) * L2_penalty
17///
18/// # Examples
19///
20/// ```
21/// use scirs2_core::ndarray::Array1;
22/// use optirs_core::regularizers::{ElasticNet, Regularizer};
23///
24/// // Create an ElasticNet regularizer with strength 0.01 and l1_ratio 0.5
25/// // (equal weight to L1 and L2)
26/// let regularizer = ElasticNet::new(0.01, 0.5);
27///
28/// // Parameters and gradients
29/// let params = Array1::from_vec(vec![0.5, -0.3, 0.0, 0.2]);
30/// let mut gradients = Array1::from_vec(vec![0.1, 0.2, -0.3, 0.0]);
31///
32/// // Apply regularization
33/// let penalty = regularizer.apply(¶ms, &mut gradients).unwrap();
34///
35/// // Gradients will be modified to include both L1 and L2 penalty gradients
36/// ```
37#[derive(Debug, Clone, Copy)]
38pub struct ElasticNet<A: Float + Debug> {
39 /// Total regularization strength
40 alpha: A,
41 /// Mixing parameter, with 0 <= l1_ratio <= 1
42 /// l1_ratio = 1 means only L1 penalty, l1_ratio = 0 means only L2 penalty
43 l1_ratio: A,
44 /// L1 regularizer
45 l1: L1<A>,
46 /// L2 regularizer
47 l2: L2<A>,
48}
49
50impl<A: Float + Debug + Send + Sync> ElasticNet<A> {
51 /// Create a new ElasticNet regularizer
52 ///
53 /// # Arguments
54 ///
55 /// * `alpha` - Total regularization strength
56 /// * `l1_ratio` - Mixing parameter (0 <= l1_ratio <= 1)
57 /// - l1_ratio = 1: only L1 penalty
58 /// - l1_ratio = 0: only L2 penalty
59 pub fn new(alpha: A, l1ratio: A) -> Self {
60 // Ensure l1_ratio is between 0 and 1
61 let l1_ratio = l1ratio.max(A::zero()).min(A::one());
62
63 // Compute individual strengths for L1 and L2
64 let l1_alpha = alpha * l1_ratio;
65 let l2_alpha = alpha * (A::one() - l1_ratio);
66
67 Self {
68 alpha,
69 l1_ratio,
70 l1: L1::new(l1_alpha),
71 l2: L2::new(l2_alpha),
72 }
73 }
74
75 /// Get the total regularization strength
76 pub fn alpha(&self) -> A {
77 self.alpha
78 }
79
80 /// Get the L1 ratio
81 pub fn l1_ratio(&self) -> A {
82 self.l1_ratio
83 }
84
85 /// Set the regularization parameters
86 ///
87 /// # Arguments
88 ///
89 /// * `alpha` - Total regularization strength
90 /// * `l1_ratio` - Mixing parameter (0 <= l1_ratio <= 1)
91 pub fn set_params(&mut self, alpha: A, l1ratio: A) -> &mut Self {
92 // Ensure l1_ratio is between 0 and 1
93 let l1_ratio = l1ratio.max(A::zero()).min(A::one());
94
95 self.alpha = alpha;
96 self.l1_ratio = l1_ratio;
97
98 // Update individual strengths
99 let l1_alpha = alpha * l1_ratio;
100 let l2_alpha = alpha * (A::one() - l1_ratio);
101
102 self.l1.set_alpha(l1_alpha);
103 self.l2.set_alpha(l2_alpha);
104
105 self
106 }
107}
108
109impl<A, D> Regularizer<A, D> for ElasticNet<A>
110where
111 A: Float + ScalarOperand + Debug,
112 D: Dimension,
113{
114 fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
115 // Create a copy of gradients for L2 regularization
116 let mut l2_grads = gradients.clone();
117
118 // Apply L1 regularization
119 let l1_penalty = self.l1.apply(params, gradients)?;
120
121 // Apply L2 regularization to the copy
122 let l2_penalty = self.l2.apply(params, &mut l2_grads)?;
123
124 // Combine the gradients according to l1_ratio
125 Zip::from(gradients)
126 .and(&l2_grads)
127 .for_each(|grad, &l2_grad| {
128 *grad = self.l1_ratio * *grad + (A::one() - self.l1_ratio) * l2_grad;
129 });
130
131 // Return combined penalty
132 Ok(l1_penalty + l2_penalty)
133 }
134
135 fn penalty(&self, params: &Array<A, D>) -> Result<A> {
136 // Compute L1 and L2 penalties
137 let l1_penalty = self.l1.penalty(params)?;
138 let l2_penalty = self.l2.penalty(params)?;
139
140 // Return combined penalty
141 Ok(l1_penalty + l2_penalty)
142 }
143}