linreg_core/regularized/lasso.rs
1//! Lasso regression (L1-regularized linear regression).
2//!
3//! This module provides lasso regression implementation using cyclical coordinate
4//! descent with soft-thresholding, matching glmnet's approach.
5//!
6//! # Lasso Regression Objective
7//!
8//! Lasso regression solves:
9//!
10//! ```text
11//! minimize over (β₀, β):
12//!
13//! (1/(2n)) * Σᵢ (yᵢ - β₀ - xᵢᵀβ)² + λ * ||β||₁
14//! ```
15//!
16//! The intercept `β₀` is **not penalized**.
17//!
18//! # Solution Method
19//!
20//! Uses cyclical coordinate descent with soft-thresholding:
21//!
22//! 1. For standardized X, each coordinate update has a closed form
23//! 2. Soft-thresholding operator: S(z, γ) = sign(z) * max(|z| - γ, 0)
24//! 3. Warm starts along lambda path for efficiency
25
26use crate::error::{Error, Result};
27use crate::linalg::Matrix;
28use crate::regularized::preprocess::{
29 predict, standardize_xy, unstandardize_coefficients, StandardizeOptions,
30};
31
32#[cfg(feature = "wasm")]
33use serde::Serialize;
34
35/// Soft-thresholding operator: S(z, γ) = sign(z) * max(|z| - γ, 0).
36///
37/// # Arguments
38///
39/// * `z` - Input value
40/// * `gamma` - Threshold value (must be >= 0)
41///
42/// # Returns
43///
44/// The soft-thresholded value.
45///
46/// # Formula
47///
48/// ```text
49/// S(z, γ) = {
50/// z - γ if z > 0 and |z| > γ
51/// z + γ if z < 0 and |z| > γ
52/// 0 if |z| <= γ
53/// }
54/// ```
55pub fn soft_threshold(z: f64, gamma: f64) -> f64 {
56 if gamma < 0.0 {
57 panic!("Soft threshold gamma must be non-negative");
58 }
59 if z > gamma {
60 z - gamma
61 } else if z < -gamma {
62 z + gamma
63 } else {
64 0.0
65 }
66}
67
68/// Options for lasso regression fitting.
69///
70/// # Fields
71///
72/// * `lambda` - Regularization strength (single value)
73/// * `intercept` - Whether to include an intercept term (default: true)
74/// * `standardize` - Whether to standardize predictors (default: true)
75/// * `max_iter` - Maximum iterations per lambda (default: 1000)
76/// * `tol` - Convergence tolerance (default: 1e-7)
77/// * `penalty_factor` - Optional per-feature penalty factors
78#[derive(Clone, Debug)]
79pub struct LassoFitOptions {
80 /// Regularization strength (must be >= 0)
81 pub lambda: f64,
82 /// Whether to include an intercept
83 pub intercept: bool,
84 /// Whether to standardize predictors
85 pub standardize: bool,
86 /// Maximum coordinate descent iterations
87 pub max_iter: usize,
88 /// Convergence tolerance on coefficient changes
89 pub tol: f64,
90 /// Per-feature penalty factors (optional)
91 pub penalty_factor: Option<Vec<f64>>,
92}
93
94impl Default for LassoFitOptions {
95 fn default() -> Self {
96 LassoFitOptions {
97 lambda: 1.0,
98 intercept: true,
99 standardize: true,
100 max_iter: 1000,
101 tol: 1e-7,
102 penalty_factor: None,
103 }
104 }
105}
106
107/// Result of a lasso regression fit.
108///
109/// # Fields
110///
111/// * `lambda` - The lambda value used for fitting
112/// * `intercept` - Intercept coefficient (on original scale)
113/// * `coefficients` - Slope coefficients (on original scale, may contain zeros)
114/// * `fitted_values` - In-sample predictions
115/// * `residuals` - Residuals (y - fitted_values)
116/// * `n_nonzero` - Number of non-zero coefficients (excluding intercept)
117/// * `iterations` - Number of coordinate descent iterations
118/// * `converged` - Whether the algorithm converged
119/// * `r_squared` - R² (coefficient of determination)
120/// * `adj_r_squared` - Adjusted R² (using effective df based on n_nonzero)
121/// * `mse` - Mean squared error
122/// * `rmse` - Root mean squared error
123/// * `mae` - Mean absolute error
124#[derive(Clone, Debug)]
125#[cfg_attr(feature = "wasm", derive(Serialize))]
126pub struct LassoFit {
127 /// Lambda value used for fitting
128 pub lambda: f64,
129 /// Intercept on original scale
130 pub intercept: f64,
131 /// Slope coefficients on original scale
132 pub coefficients: Vec<f64>,
133 /// Fitted values
134 pub fitted_values: Vec<f64>,
135 /// Residuals
136 pub residuals: Vec<f64>,
137 /// Number of non-zero coefficients
138 pub n_nonzero: usize,
139 /// Number of iterations performed
140 pub iterations: usize,
141 /// Whether convergence was achieved
142 pub converged: bool,
143 /// R² (coefficient of determination)
144 pub r_squared: f64,
145 /// Adjusted R² (penalized for effective number of parameters)
146 pub adj_r_squared: f64,
147 /// Mean squared error
148 pub mse: f64,
149 /// Root mean squared error
150 pub rmse: f64,
151 /// Mean absolute error
152 pub mae: f64,
153}
154
155/// Fits lasso regression for a single lambda value.
156///
157/// # Arguments
158///
159/// * `x` - Design matrix (n × p). Should include intercept column if `intercept=true`.
160/// * `y` - Response vector (n elements)
161/// * `options` - Lasso fitting options
162///
163/// # Returns
164///
165/// A [`LassoFit`] containing the fit results.
166///
167/// # Errors
168///
169/// Returns an error if:
170/// - `lambda < 0`
171/// - Dimensions don't match
172/// - Maximum iterations reached without convergence
173///
174/// # Algorithm
175///
176/// Uses cyclical coordinate descent:
177/// 1. Standardize X and center y (if requested)
178/// 2. Initialize coefficients (zeros or warm start)
179/// 3. For each feature j:
180/// - Compute partial residual: r = y - X_{-j} * beta_{-j}
181/// - Compute correlation: rho_j = X_j^T * r / n
182/// - Apply soft-thresholding: beta_j = S(rho_j, lambda) / (1 + 0)
183/// - (For lasso with standardized X, denominator is 1)
184/// 4. Check for convergence
185/// 5. Unstandardize coefficients
186///
187/// # Example
188///
189/// ```rust,no_run
190/// use linreg_core::linalg::Matrix;
191/// use linreg_core::regularized::lasso::{lasso_fit, LassoFitOptions};
192///
193/// let x = Matrix::new(3, 2, vec![
194/// 1.0, 2.0,
195/// 1.0, 3.0,
196/// 1.0, 4.0,
197/// ]);
198/// let y = vec![3.0, 5.0, 7.0];
199///
200/// let options = LassoFitOptions {
201/// lambda: 1.0,
202/// intercept: true,
203/// standardize: true,
204/// ..Default::default()
205/// };
206///
207/// let fit = lasso_fit(&x, &y, &options).unwrap();
208/// println!("Non-zero coefficients: {}", fit.n_nonzero);
209/// ```
210pub fn lasso_fit(x: &Matrix, y: &[f64], options: &LassoFitOptions) -> Result<LassoFit> {
211 if options.lambda < 0.0 {
212 return Err(Error::InvalidInput(
213 "Lambda must be non-negative for lasso regression".to_string(),
214 ));
215 }
216
217 let n = x.rows;
218 let p = x.cols;
219
220 if y.len() != n {
221 return Err(Error::DimensionMismatch(format!(
222 "Length of y ({}) must match number of rows in X ({})",
223 y.len(),
224 n
225 )));
226 }
227
228 // Handle zero lambda: just do OLS
229 if options.lambda == 0.0 {
230 return lasso_ols_fit(x, y, options);
231 }
232
233 // Standardize X and center y
234 let std_options = StandardizeOptions {
235 intercept: options.intercept,
236 standardize_x: options.standardize,
237 standardize_y: false,
238 };
239
240 let (x_std, y_centered, std_info) = standardize_xy(x, y, &std_options);
241
242 // Initialize coefficients to zero
243 let mut beta_std = vec![0.0; p];
244
245 // Determine which columns are penalized
246 let start_col = if options.intercept { 1 } else { 0 };
247
248 // Run coordinate descent
249 let (iterations, converged) = coordinate_descent(
250 &x_std,
251 &y_centered,
252 &mut beta_std,
253 options.lambda,
254 start_col,
255 options.max_iter,
256 options.tol,
257 options.penalty_factor.as_deref(),
258 )?;
259
260 // Unstandardize coefficients (beta_orig now contains only slope coefficients)
261 let (intercept, beta_orig) = unstandardize_coefficients(&beta_std, &std_info);
262
263 // Count non-zero coefficients (beta_orig already excludes intercept col coefficient)
264 let n_nonzero = beta_orig.iter().filter(|&&b| b.abs() > 0.0).count();
265
266 // Compute fitted values and residuals
267 let fitted = predict(x, intercept, &beta_orig);
268 let residuals: Vec<f64> = y
269 .iter()
270 .zip(fitted.iter())
271 .map(|(yi, yh)| yi - yh)
272 .collect();
273
274 // Compute model fit statistics
275 let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
276 let ss_tot: f64 = y.iter().map(|yi| (yi - y_mean).powi(2)).sum();
277 let ss_res: f64 = residuals.iter().map(|r| r.powi(2)).sum();
278 let r_squared = if ss_tot > 1e-10 {
279 1.0 - ss_res / ss_tot
280 } else {
281 1.0
282 };
283
284 // For lasso, effective df = (intercept) + n_nonzero
285 // Adjusted R² uses effective degrees of freedom
286 let eff_df = 1.0 + n_nonzero as f64; // intercept + non-zero coefficients
287 let adj_r_squared = if ss_tot > 1e-10 && n > eff_df as usize {
288 1.0 - (1.0 - r_squared) * ((n - 1) as f64 / (n as f64 - eff_df))
289 } else {
290 r_squared
291 };
292
293 let mse = ss_res / (n as f64 - eff_df).max(1.0);
294 let rmse = mse.sqrt();
295 let mae: f64 = residuals.iter().map(|r| r.abs()).sum::<f64>() / n as f64;
296
297 Ok(LassoFit {
298 lambda: options.lambda,
299 intercept,
300 coefficients: beta_orig,
301 fitted_values: fitted,
302 residuals,
303 n_nonzero,
304 iterations,
305 converged,
306 r_squared,
307 adj_r_squared,
308 mse,
309 rmse,
310 mae,
311 })
312}
313
314/// Coordinate descent for lasso.
315///
316/// # Arguments
317///
318/// * `x` - Standardized design matrix
319/// * `y` - Centered response
320/// * `beta` - Coefficient vector (modified in place)
321/// * `lambda` - Regularization strength
322/// * `start_col` - First penalized column index
323/// * `max_iter` - Maximum iterations
324/// * `tol` - Convergence tolerance
325/// * `penalty_factor` - Optional per-feature penalties
326///
327/// # Returns
328///
329/// A tuple `(iterations, converged)` indicating the number of iterations
330/// and whether convergence was achieved.
331#[allow(clippy::too_many_arguments)]
332#[allow(clippy::needless_range_loop)]
333fn coordinate_descent(
334 x: &Matrix,
335 y: &[f64],
336 beta: &mut [f64],
337 lambda: f64,
338 start_col: usize,
339 max_iter: usize,
340 tol: f64,
341 penalty_factor: Option<&[f64]>,
342) -> Result<(usize, bool)> {
343 let n = x.rows;
344 let p = x.cols;
345
346 let mut residuals: Vec<f64> = y.to_vec();
347 let mut converged = false;
348
349 // Initialize with current beta values
350 for iter in 0..max_iter {
351 let _beta_old = beta.to_vec();
352 let mut max_change: f64 = 0.0;
353
354 // Update each coordinate
355 for j in start_col..p {
356 // Skip if penalty factor is infinite (always excluded)
357 if let Some(pf) = penalty_factor {
358 if j < pf.len() && pf[j] == f64::INFINITY {
359 beta[j] = 0.0;
360 continue;
361 }
362 }
363
364 // Compute rho_j = x_j^T * r / n (where r includes x_j * beta_j)
365 // Actually: r = y - X*beta, and we want x_j^T * (r + x_j * beta_j) / n
366 // This equals x_j^T * (y - X_{-j} * beta_{-j}) / n
367
368 // First, remove the contribution of feature j from residuals
369 let old_beta_j = beta[j];
370 for i in 0..n {
371 residuals[i] += x.get(i, j) * old_beta_j;
372 }
373
374 // Compute rho_j = x_j^T * residuals / n
375 let mut rho_j = 0.0;
376 for i in 0..n {
377 rho_j += x.get(i, j) * residuals[i];
378 }
379 rho_j /= n as f64;
380
381 // Get penalty factor for this feature
382 let pf = penalty_factor
383 .and_then(|pf| pf.get(j))
384 .copied()
385 .unwrap_or(1.0);
386
387 // Apply soft-thresholding
388 // For standardized X, denominator is 1
389 let threshold = lambda * pf;
390 let new_beta_j = soft_threshold(rho_j, threshold);
391
392 // Update residuals with new coefficient
393 for i in 0..n {
394 residuals[i] -= x.get(i, j) * new_beta_j;
395 }
396
397 beta[j] = new_beta_j;
398
399 // Track maximum change
400 let change = (new_beta_j - old_beta_j).abs();
401 max_change = max_change.max(change);
402 }
403
404 // Check convergence
405 if max_change < tol {
406 converged = true;
407 return Ok((iter + 1, converged));
408 }
409 }
410
411 Ok((max_iter, converged))
412}
413
414/// OLS fit for lambda = 0 (special case of lasso).
415#[allow(clippy::needless_range_loop)]
416fn lasso_ols_fit(x: &Matrix, y: &[f64], options: &LassoFitOptions) -> Result<LassoFit> {
417 // Use QR decomposition for OLS on original (non-standardized) data
418 let (q, r) = x.qr();
419
420 // Solve R * beta = Q^T * y
421 let n = x.rows;
422 let p = x.cols;
423 let mut qty = vec![0.0; p];
424
425 for i in 0..p {
426 for k in 0..n {
427 qty[i] += q.get(k, i) * y[k];
428 }
429 }
430
431 let mut beta = vec![0.0; p];
432 for i in (0..p).rev() {
433 let mut sum = qty[i];
434 for j in (i + 1)..p {
435 sum -= r.get(i, j) * beta[j];
436 }
437 beta[i] = sum / r.get(i, i);
438 }
439
440 // Extract intercept and slope coefficients directly (no unstandardization needed)
441 // OLS on original data gives coefficients on original scale
442 let (intercept, beta_orig) = if options.intercept {
443 // beta[0] is intercept, beta[1..] are slopes
444 let slopes: Vec<f64> = beta[1..].to_vec();
445 (beta[0], slopes)
446 } else {
447 // No intercept, all coefficients are slopes
448 (0.0, beta)
449 };
450
451 // Compute fitted values and residuals
452 let fitted = predict(x, intercept, &beta_orig);
453 let residuals: Vec<f64> = y
454 .iter()
455 .zip(fitted.iter())
456 .map(|(yi, yh)| yi - yh)
457 .collect();
458
459 // Count non-zero coefficients (beta_orig already excludes intercept col coefficient)
460 let n_nonzero = beta_orig.iter().filter(|&&b| b.abs() > 0.0).count();
461
462 // Compute model fit statistics
463 let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
464 let ss_tot: f64 = y.iter().map(|yi| (yi - y_mean).powi(2)).sum();
465 let ss_res: f64 = residuals.iter().map(|r| r.powi(2)).sum();
466 let r_squared = if ss_tot > 1e-10 {
467 1.0 - ss_res / ss_tot
468 } else {
469 1.0
470 };
471
472 // Adjusted R²
473 let eff_df = n_nonzero as f64;
474 let adj_r_squared = if ss_tot > 1e-10 && n > eff_df as usize {
475 1.0 - (1.0 - r_squared) * ((n - 1) as f64 / (n as f64 - eff_df))
476 } else {
477 r_squared
478 };
479
480 let mse = ss_res / (n as f64 - p as f64);
481 let rmse = mse.sqrt();
482 let mae: f64 = residuals.iter().map(|r| r.abs()).sum::<f64>() / n as f64;
483
484 Ok(LassoFit {
485 lambda: 0.0,
486 intercept,
487 coefficients: beta_orig,
488 fitted_values: fitted,
489 residuals,
490 n_nonzero,
491 iterations: 1,
492 converged: true,
493 r_squared,
494 adj_r_squared,
495 mse,
496 rmse,
497 mae,
498 })
499}
500
501/// Makes predictions using a lasso regression fit.
502///
503/// # Arguments
504///
505/// * `fit` - The lasso regression fit result
506/// * `x_new` - New data matrix (n_new × p)
507///
508/// # Returns
509///
510/// Predictions for each row in x_new.
511pub fn predict_lasso(fit: &LassoFit, x_new: &Matrix) -> Vec<f64> {
512 predict(x_new, fit.intercept, &fit.coefficients)
513}
514
515#[cfg(test)]
516mod tests {
517 use super::*;
518
519 #[test]
520 fn test_soft_threshold() {
521 assert_eq!(soft_threshold(5.0, 2.0), 3.0);
522 assert_eq!(soft_threshold(-5.0, 2.0), -3.0);
523 assert_eq!(soft_threshold(1.0, 2.0), 0.0);
524 assert_eq!(soft_threshold(-1.0, 2.0), 0.0);
525 assert_eq!(soft_threshold(2.0, 2.0), 0.0);
526 assert_eq!(soft_threshold(-2.0, 2.0), 0.0);
527 assert_eq!(soft_threshold(0.0, 0.0), 0.0);
528 }
529
530 #[test]
531 fn test_lasso_fit_simple() {
532 // Simple test: y = 2*x with perfect linear relationship
533 let x_data = vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0, 1.0, 4.0];
534 let x = Matrix::new(4, 2, x_data);
535 let y = vec![2.0, 4.0, 6.0, 8.0];
536
537 let options = LassoFitOptions {
538 lambda: 0.01, // Very small lambda for near-OLS solution
539 intercept: true,
540 standardize: true, // Standardize for better convergence
541 ..Default::default()
542 };
543
544 let fit = lasso_fit(&x, &y, &options).unwrap();
545
546 // With small lambda, should get a good fit
547 assert!(fit.converged);
548 assert!(fit.n_nonzero > 0);
549
550 // Predictions should be close to actual values
551 for i in 0..4 {
552 assert!((fit.fitted_values[i] - y[i]).abs() < 0.5);
553 }
554 }
555
556 #[test]
557 fn test_lasso_with_large_lambda() {
558 let x_data = vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0];
559 let x = Matrix::new(3, 2, x_data);
560 let y = vec![2.0, 4.0, 6.0];
561
562 let options = LassoFitOptions {
563 lambda: 100.0,
564 intercept: true,
565 standardize: false,
566 ..Default::default()
567 };
568
569 let fit = lasso_fit(&x, &y, &options).unwrap();
570
571 // With large lambda, all coefficients should be zero
572 // Only intercept should be non-zero (equal to mean of y)
573 assert_eq!(fit.n_nonzero, 0);
574 // coefficients[0] is the first (and only) slope coefficient
575 assert!((fit.coefficients[0]).abs() < 1e-10);
576 }
577
578 #[test]
579 fn test_lasso_zero_lambda_is_ols() {
580 let x_data = vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0];
581 let x = Matrix::new(3, 2, x_data);
582 let y = vec![2.0, 4.0, 6.0];
583
584 let options = LassoFitOptions {
585 lambda: 0.0,
586 intercept: true,
587 standardize: false,
588 ..Default::default()
589 };
590
591 let fit = lasso_fit(&x, &y, &options).unwrap();
592
593 // Should be close to perfect fit
594 assert!((fit.fitted_values[0] - 2.0).abs() < 1e-6);
595 assert!((fit.fitted_values[1] - 4.0).abs() < 1e-6);
596 assert!((fit.fitted_values[2] - 6.0).abs() < 1e-6);
597 }
598
599 #[test]
600 fn test_predict_lasso() {
601 let x_data = vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0];
602 let x = Matrix::new(3, 2, x_data);
603 let y = vec![2.0, 4.0, 6.0];
604
605 let options = LassoFitOptions {
606 lambda: 0.1,
607 intercept: true,
608 standardize: false,
609 ..Default::default()
610 };
611
612 let fit = lasso_fit(&x, &y, &options).unwrap();
613 let preds = predict_lasso(&fit, &x);
614
615 // Predictions on training data should equal fitted values
616 for i in 0..3 {
617 assert!((preds[i] - fit.fitted_values[i]).abs() < 1e-10);
618 }
619 }
620}