linreg_core/regularized/lasso.rs
1//! Lasso regression (L1-regularized linear regression).
2//!
3//! This module provides a wrapper around the elastic net implementation with `alpha=1.0`.
4
5use crate::error::Result;
6use crate::linalg::Matrix;
7use crate::regularized::elastic_net::{elastic_net_fit, ElasticNetOptions};
8use crate::regularized::preprocess::predict;
9use crate::serialization::types::ModelType;
10use crate::impl_serialization;
11use serde::{Deserialize, Serialize};
12
13pub use crate::regularized::elastic_net::soft_threshold;
14
15/// Options for lasso regression fitting.
16///
17/// Configuration options for lasso regression (L1-regularized linear regression).
18///
19/// # Fields
20///
21/// - `lambda` - Regularization strength (≥ 0, higher = more sparsity)
22/// - `intercept` - Whether to include an intercept term
23/// - `standardize` - Whether to standardize predictors to unit variance
24/// - `max_iter` - Maximum coordinate descent iterations
25/// - `tol` - Convergence tolerance on coefficient changes
26/// - `penalty_factor` - Optional per-feature penalty multipliers
27/// - `warm_start` - Optional initial coefficient values for warm starts
28/// - `weights` - Optional observation weights
29///
30/// # Example
31///
32/// ```
33/// # use linreg_core::regularized::lasso::LassoFitOptions;
34/// let options = LassoFitOptions {
35/// lambda: 0.1,
36/// intercept: true,
37/// standardize: true,
38/// ..Default::default()
39/// };
40/// ```
41#[derive(Clone, Debug)]
42pub struct LassoFitOptions {
43 pub lambda: f64,
44 pub intercept: bool,
45 pub standardize: bool,
46 pub max_iter: usize,
47 pub tol: f64,
48 pub penalty_factor: Option<Vec<f64>>,
49 pub warm_start: Option<Vec<f64>>,
50 pub weights: Option<Vec<f64>>, // Observation weights
51}
52
53impl Default for LassoFitOptions {
54 fn default() -> Self {
55 LassoFitOptions {
56 lambda: 1.0,
57 intercept: true,
58 standardize: true,
59 max_iter: 100000,
60 tol: 1e-7, // Match ElasticNetOptions default
61 penalty_factor: None,
62 warm_start: None,
63 weights: None,
64 }
65 }
66}
67
68/// Result of a lasso regression fit.
69///
70/// Contains the fitted model coefficients, convergence information, and diagnostic metrics.
71///
72/// # Fields
73///
74/// - `lambda` - The regularization strength used
75/// - `intercept` - Intercept coefficient (never penalized)
76/// - `coefficients` - Slope coefficients (some may be exactly zero due to L1 penalty)
77/// - `fitted_values` - Predicted values on training data
78/// - `residuals` - Residuals (y - fitted_values)
79/// - `n_nonzero` - Number of non-zero coefficients (excluding intercept)
80/// - `iterations` - Number of coordinate descent iterations performed
81/// - `converged` - Whether the algorithm converged
82/// - `r_squared` - Coefficient of determination
83/// - `adj_r_squared` - Adjusted R²
84/// - `mse` - Mean squared error
85/// - `rmse` - Root mean squared error
86/// - `mae` - Mean absolute error
87/// - `log_likelihood` - Log-likelihood of the model (for model comparison)
88/// - `aic` - Akaike Information Criterion (lower = better)
89/// - `bic` - Bayesian Information Criterion (lower = better)
90///
91/// # Example
92///
93/// ```
94/// # use linreg_core::regularized::lasso::{lasso_fit, LassoFitOptions};
95/// # use linreg_core::linalg::Matrix;
96/// # let y = vec![2.0, 4.0, 6.0, 8.0];
97/// # let x = Matrix::new(4, 2, vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0]);
98/// # let options = LassoFitOptions { lambda: 0.01, intercept: true, standardize: true, ..Default::default() };
99/// let fit = lasso_fit(&x, &y, &options).unwrap();
100///
101/// // Check convergence and sparsity
102/// println!("Converged: {}", fit.converged);
103/// println!("Non-zero coefficients: {}", fit.n_nonzero);
104/// println!("Iterations: {}", fit.iterations);
105///
106/// // Access model coefficients
107/// println!("Intercept: {}", fit.intercept);
108/// println!("Slopes: {:?}", fit.coefficients);
109/// println!("AIC: {}", fit.aic);
110/// # Ok::<(), linreg_core::Error>(())
111/// ```
112#[derive(Clone, Debug, Serialize, Deserialize)]
113pub struct LassoFit {
114 pub lambda: f64,
115 pub intercept: f64,
116 pub coefficients: Vec<f64>,
117 pub fitted_values: Vec<f64>,
118 pub residuals: Vec<f64>,
119 pub n_nonzero: usize,
120 pub iterations: usize,
121 pub converged: bool,
122 pub r_squared: f64,
123 pub adj_r_squared: f64,
124 pub mse: f64,
125 pub rmse: f64,
126 pub mae: f64,
127 pub log_likelihood: f64,
128 pub aic: f64,
129 pub bic: f64,
130}
131
132/// Fits lasso regression for a single lambda value.
133///
134/// Lasso regression adds an L1 penalty to the coefficients, which performs
135/// automatic variable selection by shrinking some coefficients to exactly zero.
136/// The intercept is never penalized.
137///
138/// # Arguments
139///
140/// * `x` - Design matrix (n rows × p columns including intercept)
141/// * `y` - Response variable (n observations)
142/// * `options` - Configuration options for lasso regression
143///
144/// # Returns
145///
146/// A `LassoFit` containing coefficients, convergence info, and metrics.
147///
148/// # Example
149///
150/// ```
151/// # use linreg_core::regularized::lasso::{lasso_fit, LassoFitOptions};
152/// # use linreg_core::linalg::Matrix;
153/// let y = vec![2.0, 4.0, 6.0, 8.0];
154/// let x = Matrix::new(4, 2, vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0]);
155///
156/// let options = LassoFitOptions {
157/// lambda: 0.01,
158/// intercept: true,
159/// standardize: true,
160/// ..Default::default()
161/// };
162///
163/// let fit = lasso_fit(&x, &y, &options).unwrap();
164/// assert!(fit.converged);
165/// assert!(fit.n_nonzero <= 1); // At most 1 non-zero coefficient
166/// # Ok::<(), linreg_core::Error>(())
167/// ```
168///
169/// # Errors
170///
171/// Returns `Error::InsufficientData` if `x.rows() <= x.cols()`.
172/// Returns `Error::SingularMatrix` if the design matrix is singular.
173/// Returns `Error::InvalidInput` if `lambda` is negative.
174///
175/// # Panics
176///
177/// Panics if `x.cols()` is 0 (no predictors including intercept).
178pub fn lasso_fit(x: &Matrix, y: &[f64], options: &LassoFitOptions) -> Result<LassoFit> {
179 let en_options = ElasticNetOptions {
180 lambda: options.lambda,
181 alpha: 1.0, // Lasso
182 intercept: options.intercept,
183 standardize: options.standardize,
184 max_iter: options.max_iter,
185 tol: options.tol,
186 penalty_factor: options.penalty_factor.clone(),
187 warm_start: options.warm_start.clone(),
188 weights: options.weights.clone(),
189 coefficient_bounds: None,
190 };
191
192 let fit = elastic_net_fit(x, y, &en_options)?;
193
194 Ok(LassoFit {
195 lambda: fit.lambda,
196 intercept: fit.intercept,
197 coefficients: fit.coefficients,
198 fitted_values: fit.fitted_values,
199 residuals: fit.residuals,
200 n_nonzero: fit.n_nonzero,
201 iterations: fit.iterations,
202 converged: fit.converged,
203 r_squared: fit.r_squared,
204 adj_r_squared: fit.adj_r_squared,
205 mse: fit.mse,
206 rmse: fit.rmse,
207 mae: fit.mae,
208 log_likelihood: fit.log_likelihood,
209 aic: fit.aic,
210 bic: fit.bic,
211 })
212}
213
214/// Makes predictions using a lasso regression fit.
215///
216/// Computes predictions for new observations using the fitted lasso regression model.
217///
218/// # Arguments
219///
220/// * `fit` - Fitted lasso regression model
221/// * `x_new` - New design matrix (same number of columns as training data)
222///
223/// # Returns
224///
225/// Vector of predicted values.
226///
227/// # Example
228///
229/// ```
230/// # use linreg_core::regularized::lasso::{lasso_fit, predict_lasso, LassoFitOptions};
231/// # use linreg_core::linalg::Matrix;
232/// // Training data
233/// let y = vec![2.0, 4.0, 6.0, 8.0];
234/// let x = Matrix::new(4, 2, vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0]);
235///
236/// let options = LassoFitOptions {
237/// lambda: 0.01,
238/// intercept: true,
239/// standardize: true,
240/// ..Default::default()
241/// };
242/// let fit = lasso_fit(&x, &y, &options).unwrap();
243///
244/// // Predict on new data
245/// let x_new = Matrix::new(2, 2, vec![1.0, 5.0, 1.0, 6.0]);
246/// let predictions = predict_lasso(&fit, &x_new);
247///
248/// assert_eq!(predictions.len(), 2);
249/// // Predictions should be close to [10.0, 12.0] for the linear relationship y = 2*x
250/// # Ok::<(), linreg_core::Error>(())
251/// ```
252pub fn predict_lasso(fit: &LassoFit, x_new: &Matrix) -> Vec<f64> {
253 predict(x_new, fit.intercept, &fit.coefficients)
254}
255
256// ============================================================================
257// Model Serialization Traits
258// ============================================================================
259
260// Generate ModelSave and ModelLoad implementations using macro
261impl_serialization!(LassoFit, ModelType::Lasso, "Lasso");
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266
267 #[test]
268 fn test_soft_threshold() {
269 assert_eq!(soft_threshold(5.0, 2.0), 3.0);
270 assert_eq!(soft_threshold(-5.0, 2.0), -3.0);
271 assert_eq!(soft_threshold(1.0, 2.0), 0.0);
272 }
273
274 #[test]
275 fn test_lasso_fit_simple() {
276 // Simple test: y = 2*x with perfect linear relationship
277 let x_data = vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0];
278 let x = Matrix::new(4, 2, x_data);
279 let y = vec![2.0, 4.0, 6.0, 8.0];
280
281 let options = LassoFitOptions {
282 lambda: 0.01,
283 intercept: true,
284 standardize: true,
285 ..Default::default()
286 };
287
288 let fit = lasso_fit(&x, &y, &options).unwrap();
289
290 assert!(fit.converged);
291 // Predictions should be close to actual values
292 for i in 0..4 {
293 assert!((fit.fitted_values[i] - y[i]).abs() < 0.5);
294 }
295 }
296}