optirs_core/regularizers/l1.rs
1// L1 (Lasso) regularization
2
3use scirs2_core::ndarray::{Array, Dimension, ScalarOperand, Zip};
4use scirs2_core::numeric::Float;
5use std::fmt::Debug;
6
7use crate::error::Result;
8use crate::regularizers::Regularizer;
9
10/// L1 (Lasso) regularization
11///
12/// Adds a penalty equal to the sum of the absolute values of the parameters,
13/// which encourages sparsity (many parameters will be exactly 0).
14///
15/// Penalty: alpha * sum(abs(params))
16///
17/// # Examples
18///
19/// ```
20/// use scirs2_core::ndarray::Array1;
21/// use optirs_core::regularizers::{L1, Regularizer};
22///
23/// // Create an L1 regularizer with strength 0.01
24/// let regularizer = L1::new(0.01);
25///
26/// // Parameters and gradients
27/// let params = Array1::from_vec(vec![0.5, -0.3, 0.0, 0.2]);
28/// let mut gradients = Array1::from_vec(vec![0.1, 0.2, -0.3, 0.0]);
29///
30/// // Apply regularization
31/// let penalty = regularizer.apply(¶ms, &mut gradients).unwrap();
32///
33/// // Gradients will be modified to include the L1 penalty gradient
34/// // Penalty will be: 0.01 * (|0.5| + |-0.3| + |0.0| + |0.2|) = 0.01 * 1.0 = 0.01
35/// ```
36#[derive(Debug, Clone, Copy)]
37pub struct L1<A: Float + Debug> {
38 /// Regularization strength
39 alpha: A,
40}
41
42impl<A: Float + Debug + Send + Sync> L1<A> {
43 /// Create a new L1 regularizer
44 ///
45 /// # Arguments
46 ///
47 /// * `alpha` - Regularization strength
48 pub fn new(alpha: A) -> Self {
49 Self { alpha }
50 }
51
52 /// Get the regularization strength
53 pub fn alpha(&self) -> A {
54 self.alpha
55 }
56
57 /// Set the regularization strength
58 pub fn set_alpha(&mut self, alpha: A) -> &mut Self {
59 self.alpha = alpha;
60 self
61 }
62}
63
64impl<A, D> Regularizer<A, D> for L1<A>
65where
66 A: Float + ScalarOperand + Debug,
67 D: Dimension,
68{
69 fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
70 // L1 gradient: alpha * sign(params)
71 Zip::from(params).and(gradients).for_each(|¶m, grad| {
72 // Sign function: 1 for positive, -1 for negative, 0 for zero
73 let sign = if param > A::zero() {
74 A::one()
75 } else if param < A::zero() {
76 -A::one()
77 } else {
78 A::zero()
79 };
80
81 *grad = *grad + self.alpha * sign;
82 });
83
84 self.penalty(params)
85 }
86
87 fn penalty(&self, params: &Array<A, D>) -> Result<A> {
88 // L1 penalty: alpha * sum(abs(params))
89 let sum_abs = params.iter().fold(A::zero(), |acc, &x| acc + x.abs());
90 Ok(self.alpha * sum_abs)
91 }
92}