Skip to main content

linreg_core/regularized/
lasso.rs

1//! Lasso regression (L1-regularized linear regression).
2//!
3//! This module provides a wrapper around the elastic net implementation with `alpha=1.0`.
4
5use crate::error::Result;
6use crate::linalg::Matrix;
7use crate::regularized::elastic_net::{elastic_net_fit, ElasticNetOptions};
8use crate::regularized::preprocess::predict;
9
10#[cfg(feature = "wasm")]
11use serde::Serialize;
12
13pub use crate::regularized::elastic_net::soft_threshold;
14
15/// Options for lasso regression fitting.
16///
17/// Configuration options for lasso regression (L1-regularized linear regression).
18///
19/// # Fields
20///
21/// - `lambda` - Regularization strength (≥ 0, higher = more sparsity)
22/// - `intercept` - Whether to include an intercept term
23/// - `standardize` - Whether to standardize predictors to unit variance
24/// - `max_iter` - Maximum coordinate descent iterations
25/// - `tol` - Convergence tolerance on coefficient changes
26/// - `penalty_factor` - Optional per-feature penalty multipliers
27/// - `warm_start` - Optional initial coefficient values for warm starts
28/// - `weights` - Optional observation weights
29///
30/// # Example
31///
32/// ```
33/// # use linreg_core::regularized::lasso::LassoFitOptions;
34/// let options = LassoFitOptions {
35///     lambda: 0.1,
36///     intercept: true,
37///     standardize: true,
38///     ..Default::default()
39/// };
40/// ```
41#[derive(Clone, Debug)]
42pub struct LassoFitOptions {
43    pub lambda: f64,
44    pub intercept: bool,
45    pub standardize: bool,
46    pub max_iter: usize,
47    pub tol: f64,
48    pub penalty_factor: Option<Vec<f64>>,
49    pub warm_start: Option<Vec<f64>>,
50    pub weights: Option<Vec<f64>>, // Observation weights
51}
52
53impl Default for LassoFitOptions {
54    fn default() -> Self {
55        LassoFitOptions {
56            lambda: 1.0,
57            intercept: true,
58            standardize: true,
59            max_iter: 100000,
60            tol: 1e-7, // Match ElasticNetOptions default
61            penalty_factor: None,
62            warm_start: None,
63            weights: None,
64        }
65    }
66}
67
68/// Result of a lasso regression fit.
69///
70/// Contains the fitted model coefficients, convergence information, and diagnostic metrics.
71///
72/// # Fields
73///
74/// - `lambda` - The regularization strength used
75/// - `intercept` - Intercept coefficient (never penalized)
76/// - `coefficients` - Slope coefficients (some may be exactly zero due to L1 penalty)
77/// - `fitted_values` - Predicted values on training data
78/// - `residuals` - Residuals (y - fitted_values)
79/// - `n_nonzero` - Number of non-zero coefficients (excluding intercept)
80/// - `iterations` - Number of coordinate descent iterations performed
81/// - `converged` - Whether the algorithm converged
82/// - `r_squared` - Coefficient of determination
83/// - `adj_r_squared` - Adjusted R²
84/// - `mse` - Mean squared error
85/// - `rmse` - Root mean squared error
86/// - `mae` - Mean absolute error
87///
88/// # Example
89///
90/// ```
91/// # use linreg_core::regularized::lasso::{lasso_fit, LassoFitOptions};
92/// # use linreg_core::linalg::Matrix;
93/// # let y = vec![2.0, 4.0, 6.0, 8.0];
94/// # let x = Matrix::new(4, 2, vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0]);
95/// # let options = LassoFitOptions { lambda: 0.01, intercept: true, standardize: true, ..Default::default() };
96/// let fit = lasso_fit(&x, &y, &options).unwrap();
97///
98/// // Check convergence and sparsity
99/// println!("Converged: {}", fit.converged);
100/// println!("Non-zero coefficients: {}", fit.n_nonzero);
101/// println!("Iterations: {}", fit.iterations);
102///
103/// // Access model coefficients
104/// println!("Intercept: {}", fit.intercept);
105/// println!("Slopes: {:?}", fit.coefficients);
106/// # Ok::<(), linreg_core::Error>(())
107/// ```
108#[derive(Clone, Debug)]
109#[cfg_attr(feature = "wasm", derive(Serialize))]
110pub struct LassoFit {
111    pub lambda: f64,
112    pub intercept: f64,
113    pub coefficients: Vec<f64>,
114    pub fitted_values: Vec<f64>,
115    pub residuals: Vec<f64>,
116    pub n_nonzero: usize,
117    pub iterations: usize,
118    pub converged: bool,
119    pub r_squared: f64,
120    pub adj_r_squared: f64,
121    pub mse: f64,
122    pub rmse: f64,
123    pub mae: f64,
124}
125
126/// Fits lasso regression for a single lambda value.
127///
128/// Lasso regression adds an L1 penalty to the coefficients, which performs
129/// automatic variable selection by shrinking some coefficients to exactly zero.
130/// The intercept is never penalized.
131///
132/// # Arguments
133///
134/// * `x` - Design matrix (n rows × p columns including intercept)
135/// * `y` - Response variable (n observations)
136/// * `options` - Configuration options for lasso regression
137///
138/// # Returns
139///
140/// A `LassoFit` containing coefficients, convergence info, and metrics.
141///
142/// # Example
143///
144/// ```
145/// # use linreg_core::regularized::lasso::{lasso_fit, LassoFitOptions};
146/// # use linreg_core::linalg::Matrix;
147/// let y = vec![2.0, 4.0, 6.0, 8.0];
148/// let x = Matrix::new(4, 2, vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0]);
149///
150/// let options = LassoFitOptions {
151///     lambda: 0.01,
152///     intercept: true,
153///     standardize: true,
154///     ..Default::default()
155/// };
156///
157/// let fit = lasso_fit(&x, &y, &options).unwrap();
158/// assert!(fit.converged);
159/// assert!(fit.n_nonzero <= 1); // At most 1 non-zero coefficient
160/// # Ok::<(), linreg_core::Error>(())
161/// ```
162pub fn lasso_fit(x: &Matrix, y: &[f64], options: &LassoFitOptions) -> Result<LassoFit> {
163    let en_options = ElasticNetOptions {
164        lambda: options.lambda,
165        alpha: 1.0, // Lasso
166        intercept: options.intercept,
167        standardize: options.standardize,
168        max_iter: options.max_iter,
169        tol: options.tol,
170        penalty_factor: options.penalty_factor.clone(),
171        warm_start: options.warm_start.clone(),
172        weights: options.weights.clone(),
173        coefficient_bounds: None,
174    };
175
176    let fit = elastic_net_fit(x, y, &en_options)?;
177
178    Ok(LassoFit {
179        lambda: fit.lambda,
180        intercept: fit.intercept,
181        coefficients: fit.coefficients,
182        fitted_values: fit.fitted_values,
183        residuals: fit.residuals,
184        n_nonzero: fit.n_nonzero,
185        iterations: fit.iterations,
186        converged: fit.converged,
187        r_squared: fit.r_squared,
188        adj_r_squared: fit.adj_r_squared,
189        mse: fit.mse,
190        rmse: fit.rmse,
191        mae: fit.mae,
192    })
193}
194
195/// Makes predictions using a lasso regression fit.
196///
197/// Computes predictions for new observations using the fitted lasso regression model.
198///
199/// # Arguments
200///
201/// * `fit` - Fitted lasso regression model
202/// * `x_new` - New design matrix (same number of columns as training data)
203///
204/// # Returns
205///
206/// Vector of predicted values.
207///
208/// # Example
209///
210/// ```
211/// # use linreg_core::regularized::lasso::{lasso_fit, predict_lasso, LassoFitOptions};
212/// # use linreg_core::linalg::Matrix;
213/// // Training data
214/// let y = vec![2.0, 4.0, 6.0, 8.0];
215/// let x = Matrix::new(4, 2, vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0]);
216///
217/// let options = LassoFitOptions {
218///     lambda: 0.01,
219///     intercept: true,
220///     standardize: true,
221///     ..Default::default()
222/// };
223/// let fit = lasso_fit(&x, &y, &options).unwrap();
224///
225/// // Predict on new data
226/// let x_new = Matrix::new(2, 2, vec![1.0, 5.0, 1.0, 6.0]);
227/// let predictions = predict_lasso(&fit, &x_new);
228///
229/// assert_eq!(predictions.len(), 2);
230/// // Predictions should be close to [10.0, 12.0] for the linear relationship y = 2*x
231/// # Ok::<(), linreg_core::Error>(())
232/// ```
233pub fn predict_lasso(fit: &LassoFit, x_new: &Matrix) -> Vec<f64> {
234    predict(x_new, fit.intercept, &fit.coefficients)
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    #[test]
242    fn test_soft_threshold() {
243        assert_eq!(soft_threshold(5.0, 2.0), 3.0);
244        assert_eq!(soft_threshold(-5.0, 2.0), -3.0);
245        assert_eq!(soft_threshold(1.0, 2.0), 0.0);
246    }
247
248    #[test]
249    fn test_lasso_fit_simple() {
250        // Simple test: y = 2*x with perfect linear relationship
251        let x_data = vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0];
252        let x = Matrix::new(4, 2, x_data);
253        let y = vec![2.0, 4.0, 6.0, 8.0];
254
255        let options = LassoFitOptions {
256            lambda: 0.01,
257            intercept: true,
258            standardize: true,
259            ..Default::default()
260        };
261
262        let fit = lasso_fit(&x, &y, &options).unwrap();
263
264        assert!(fit.converged);
265        // Predictions should be close to actual values
266        for i in 0..4 {
267            assert!((fit.fitted_values[i] - y[i]).abs() < 0.5);
268        }
269    }
270}