linreg_core/regularized/
lasso.rs

1//! Lasso regression (L1-regularized linear regression).
2//!
3//! This module provides lasso regression implementation using cyclical coordinate
4//! descent with soft-thresholding, matching glmnet's approach.
5//!
6//! # Lasso Regression Objective
7//!
8//! Lasso regression solves:
9//!
10//! ```text
11//! minimize over (β₀, β):
12//!
13//!     (1/(2n)) * Σᵢ (yᵢ - β₀ - xᵢᵀβ)² + λ * ||β||₁
14//! ```
15//!
16//! The intercept `β₀` is **not penalized**.
17//!
18//! # Solution Method
19//!
20//! Uses cyclical coordinate descent with soft-thresholding:
21//!
22//! 1. For standardized X, each coordinate update has a closed form
23//! 2. Soft-thresholding operator: S(z, γ) = sign(z) * max(|z| - γ, 0)
24//! 3. Warm starts along lambda path for efficiency
25
26use crate::error::{Error, Result};
27use crate::linalg::Matrix;
28use crate::regularized::preprocess::{
29    predict, standardize_xy, unstandardize_coefficients, StandardizeOptions,
30};
31
32#[cfg(feature = "wasm")]
33use serde::Serialize;
34
35/// Soft-thresholding operator: S(z, γ) = sign(z) * max(|z| - γ, 0).
36///
37/// # Arguments
38///
39/// * `z` - Input value
40/// * `gamma` - Threshold value (must be >= 0)
41///
42/// # Returns
43///
44/// The soft-thresholded value.
45///
46/// # Formula
47///
48/// ```text
49/// S(z, γ) = {
50///     z - γ    if z > 0 and |z| > γ
51///     z + γ    if z < 0 and |z| > γ
52///     0        if |z| <= γ
53/// }
54/// ```
55pub fn soft_threshold(z: f64, gamma: f64) -> f64 {
56    if gamma < 0.0 {
57        panic!("Soft threshold gamma must be non-negative");
58    }
59    if z > gamma {
60        z - gamma
61    } else if z < -gamma {
62        z + gamma
63    } else {
64        0.0
65    }
66}
67
68/// Options for lasso regression fitting.
69///
70/// # Fields
71///
72/// * `lambda` - Regularization strength (single value)
73/// * `intercept` - Whether to include an intercept term (default: true)
74/// * `standardize` - Whether to standardize predictors (default: true)
75/// * `max_iter` - Maximum iterations per lambda (default: 1000)
76/// * `tol` - Convergence tolerance (default: 1e-7)
77/// * `penalty_factor` - Optional per-feature penalty factors
78#[derive(Clone, Debug)]
79pub struct LassoFitOptions {
80    /// Regularization strength (must be >= 0)
81    pub lambda: f64,
82    /// Whether to include an intercept
83    pub intercept: bool,
84    /// Whether to standardize predictors
85    pub standardize: bool,
86    /// Maximum coordinate descent iterations
87    pub max_iter: usize,
88    /// Convergence tolerance on coefficient changes
89    pub tol: f64,
90    /// Per-feature penalty factors (optional)
91    pub penalty_factor: Option<Vec<f64>>,
92}
93
94impl Default for LassoFitOptions {
95    fn default() -> Self {
96        LassoFitOptions {
97            lambda: 1.0,
98            intercept: true,
99            standardize: true,
100            max_iter: 1000,
101            tol: 1e-7,
102            penalty_factor: None,
103        }
104    }
105}
106
107/// Result of a lasso regression fit.
108///
109/// # Fields
110///
111/// * `lambda` - The lambda value used for fitting
112/// * `intercept` - Intercept coefficient (on original scale)
113/// * `coefficients` - Slope coefficients (on original scale, may contain zeros)
114/// * `fitted_values` - In-sample predictions
115/// * `residuals` - Residuals (y - fitted_values)
116/// * `n_nonzero` - Number of non-zero coefficients (excluding intercept)
117/// * `iterations` - Number of coordinate descent iterations
118/// * `converged` - Whether the algorithm converged
119/// * `r_squared` - R² (coefficient of determination)
120/// * `adj_r_squared` - Adjusted R² (using effective df based on n_nonzero)
121/// * `mse` - Mean squared error
122/// * `rmse` - Root mean squared error
123/// * `mae` - Mean absolute error
124#[derive(Clone, Debug)]
125#[cfg_attr(feature = "wasm", derive(Serialize))]
126pub struct LassoFit {
127    /// Lambda value used for fitting
128    pub lambda: f64,
129    /// Intercept on original scale
130    pub intercept: f64,
131    /// Slope coefficients on original scale
132    pub coefficients: Vec<f64>,
133    /// Fitted values
134    pub fitted_values: Vec<f64>,
135    /// Residuals
136    pub residuals: Vec<f64>,
137    /// Number of non-zero coefficients
138    pub n_nonzero: usize,
139    /// Number of iterations performed
140    pub iterations: usize,
141    /// Whether convergence was achieved
142    pub converged: bool,
143    /// R² (coefficient of determination)
144    pub r_squared: f64,
145    /// Adjusted R² (penalized for effective number of parameters)
146    pub adj_r_squared: f64,
147    /// Mean squared error
148    pub mse: f64,
149    /// Root mean squared error
150    pub rmse: f64,
151    /// Mean absolute error
152    pub mae: f64,
153}
154
155/// Fits lasso regression for a single lambda value.
156///
157/// # Arguments
158///
159/// * `x` - Design matrix (n × p). Should include intercept column if `intercept=true`.
160/// * `y` - Response vector (n elements)
161/// * `options` - Lasso fitting options
162///
163/// # Returns
164///
165/// A [`LassoFit`] containing the fit results.
166///
167/// # Errors
168///
169/// Returns an error if:
170/// - `lambda < 0`
171/// - Dimensions don't match
172/// - Maximum iterations reached without convergence
173///
174/// # Algorithm
175///
176/// Uses cyclical coordinate descent:
177/// 1. Standardize X and center y (if requested)
178/// 2. Initialize coefficients (zeros or warm start)
179/// 3. For each feature j:
180///    - Compute partial residual: r = y - X_{-j} * beta_{-j}
181///    - Compute correlation: rho_j = X_j^T * r / n
182///    - Apply soft-thresholding: beta_j = S(rho_j, lambda) / (1 + 0)
183///    - (For lasso with standardized X, denominator is 1)
184/// 4. Check for convergence
185/// 5. Unstandardize coefficients
186///
187/// # Example
188///
189/// ```rust,no_run
190/// use linreg_core::linalg::Matrix;
191/// use linreg_core::regularized::lasso::{lasso_fit, LassoFitOptions};
192///
193/// let x = Matrix::new(3, 2, vec![
194///     1.0, 2.0,
195///     1.0, 3.0,
196///     1.0, 4.0,
197/// ]);
198/// let y = vec![3.0, 5.0, 7.0];
199///
200/// let options = LassoFitOptions {
201///     lambda: 1.0,
202///     intercept: true,
203///     standardize: true,
204///     ..Default::default()
205/// };
206///
207/// let fit = lasso_fit(&x, &y, &options).unwrap();
208/// println!("Non-zero coefficients: {}", fit.n_nonzero);
209/// ```
210pub fn lasso_fit(x: &Matrix, y: &[f64], options: &LassoFitOptions) -> Result<LassoFit> {
211    if options.lambda < 0.0 {
212        return Err(Error::InvalidInput(
213            "Lambda must be non-negative for lasso regression".to_string(),
214        ));
215    }
216
217    let n = x.rows;
218    let p = x.cols;
219
220    if y.len() != n {
221        return Err(Error::DimensionMismatch(
222            format!("Length of y ({}) must match number of rows in X ({})", y.len(), n)
223        ));
224    }
225
226    // Handle zero lambda: just do OLS
227    if options.lambda == 0.0 {
228        return lasso_ols_fit(x, y, options);
229    }
230
231    // Standardize X and center y
232    let std_options = StandardizeOptions {
233        intercept: options.intercept,
234        standardize_x: options.standardize,
235        standardize_y: false,
236    };
237
238    let (x_std, y_centered, std_info) = standardize_xy(x, y, &std_options);
239
240    // Initialize coefficients to zero
241    let mut beta_std = vec![0.0; p];
242
243    // Determine which columns are penalized
244    let start_col = if options.intercept { 1 } else { 0 };
245
246    // Run coordinate descent
247    let (iterations, converged) = coordinate_descent(
248        &x_std,
249        &y_centered,
250        &mut beta_std,
251        options.lambda,
252        start_col,
253        options.max_iter,
254        options.tol,
255        options.penalty_factor.as_deref(),
256    )?;
257
258    // Unstandardize coefficients
259    let (intercept, beta_orig) = unstandardize_coefficients(&beta_std, &std_info);
260
261    // Count non-zero coefficients
262    let n_nonzero = beta_orig.iter().skip(start_col).filter(|&&b| b.abs() > 0.0).count();
263
264    // Compute fitted values and residuals
265    let fitted = predict(x, intercept, &beta_orig);
266    let residuals: Vec<f64> = y.iter().zip(fitted.iter()).map(|(yi, yh)| yi - yh).collect();
267
268    // Compute model fit statistics
269    let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
270    let ss_tot: f64 = y.iter().map(|yi| (yi - y_mean).powi(2)).sum();
271    let ss_res: f64 = residuals.iter().map(|r| r.powi(2)).sum();
272    let r_squared = if ss_tot > 1e-10 {
273        1.0 - ss_res / ss_tot
274    } else {
275        1.0
276    };
277
278    // For lasso, effective df = (intercept) + n_nonzero
279    // Adjusted R² uses effective degrees of freedom
280    let eff_df = 1.0 + n_nonzero as f64; // intercept + non-zero coefficients
281    let adj_r_squared = if ss_tot > 1e-10 && n > eff_df as usize {
282        1.0 - (1.0 - r_squared) * ((n - 1) as f64 / (n as f64 - eff_df))
283    } else {
284        r_squared
285    };
286
287    let mse = ss_res / (n as f64 - eff_df).max(1.0);
288    let rmse = mse.sqrt();
289    let mae: f64 = residuals.iter().map(|r| r.abs()).sum::<f64>() / n as f64;
290
291    Ok(LassoFit {
292        lambda: options.lambda,
293        intercept,
294        coefficients: beta_orig,
295        fitted_values: fitted,
296        residuals,
297        n_nonzero,
298        iterations,
299        converged,
300        r_squared,
301        adj_r_squared,
302        mse,
303        rmse,
304        mae,
305    })
306}
307
308/// Coordinate descent for lasso.
309///
310/// # Arguments
311///
312/// * `x` - Standardized design matrix
313/// * `y` - Centered response
314/// * `beta` - Coefficient vector (modified in place)
315/// * `lambda` - Regularization strength
316/// * `start_col` - First penalized column index
317/// * `max_iter` - Maximum iterations
318/// * `tol` - Convergence tolerance
319/// * `penalty_factor` - Optional per-feature penalties
320///
321/// # Returns
322///
323/// A tuple `(iterations, converged)` indicating the number of iterations
324/// and whether convergence was achieved.
325fn coordinate_descent(
326    x: &Matrix,
327    y: &[f64],
328    beta: &mut [f64],
329    lambda: f64,
330    start_col: usize,
331    max_iter: usize,
332    tol: f64,
333    penalty_factor: Option<&[f64]>,
334) -> Result<(usize, bool)> {
335    let n = x.rows;
336    let p = x.cols;
337
338    let mut residuals: Vec<f64> = y.to_vec();
339    let mut converged = false;
340
341    // Initialize with current beta values
342    for iter in 0..max_iter {
343        let _beta_old = beta.to_vec();
344        let mut max_change: f64 = 0.0;
345
346        // Update each coordinate
347        for j in start_col..p {
348            // Skip if penalty factor is infinite (always excluded)
349            if let Some(pf) = penalty_factor {
350                if j < pf.len() && pf[j] == f64::INFINITY {
351                    beta[j] = 0.0;
352                    continue;
353                }
354            }
355
356            // Compute rho_j = x_j^T * r / n (where r includes x_j * beta_j)
357            // Actually: r = y - X*beta, and we want x_j^T * (r + x_j * beta_j) / n
358            // This equals x_j^T * (y - X_{-j} * beta_{-j}) / n
359
360            // First, remove the contribution of feature j from residuals
361            let old_beta_j = beta[j];
362            for i in 0..n {
363                residuals[i] += x.get(i, j) * old_beta_j;
364            }
365
366            // Compute rho_j = x_j^T * residuals / n
367            let mut rho_j = 0.0;
368            for i in 0..n {
369                rho_j += x.get(i, j) * residuals[i];
370            }
371            rho_j /= n as f64;
372
373            // Get penalty factor for this feature
374            let pf = penalty_factor
375                .and_then(|pf| pf.get(j))
376                .copied()
377                .unwrap_or(1.0);
378
379            // Apply soft-thresholding
380            // For standardized X, denominator is 1
381            let threshold = lambda * pf;
382            let new_beta_j = soft_threshold(rho_j, threshold);
383
384            // Update residuals with new coefficient
385            for i in 0..n {
386                residuals[i] -= x.get(i, j) * new_beta_j;
387            }
388
389            beta[j] = new_beta_j;
390
391            // Track maximum change
392            let change = (new_beta_j - old_beta_j).abs();
393            max_change = max_change.max(change);
394        }
395
396        // Check convergence
397        if max_change < tol {
398            converged = true;
399            return Ok((iter + 1, converged));
400        }
401    }
402
403    Ok((max_iter, converged))
404}
405
406/// OLS fit for lambda = 0 (special case of lasso).
407fn lasso_ols_fit(x: &Matrix, y: &[f64], options: &LassoFitOptions) -> Result<LassoFit> {
408    let std_options = StandardizeOptions {
409        intercept: options.intercept,
410        standardize_x: false,
411        standardize_y: false,
412    };
413
414    let (_, _, std_info) = standardize_xy(x, y, &std_options);
415
416    // Use QR decomposition for OLS
417    let (q, r) = x.qr();
418
419    // Solve R * beta = Q^T * y
420    let n = x.rows;
421    let p = x.cols;
422    let mut qty = vec![0.0; p];
423
424    for i in 0..p {
425        for k in 0..n {
426            qty[i] += q.get(k, i) * y[k];
427        }
428    }
429
430    let mut beta = vec![0.0; p];
431    for i in (0..p).rev() {
432        let mut sum = qty[i];
433        for j in (i + 1)..p {
434            sum -= r.get(i, j) * beta[j];
435        }
436        beta[i] = sum / r.get(i, i);
437    }
438
439    // Unstandardize
440    let (intercept, beta_orig) = unstandardize_coefficients(&beta, &std_info);
441
442    // Compute fitted values and residuals
443    let fitted = predict(x, intercept, &beta_orig);
444    let residuals: Vec<f64> = y.iter().zip(fitted.iter()).map(|(yi, yh)| yi - yh).collect();
445
446    // Count non-zero coefficients
447    let n_nonzero = beta_orig.iter().filter(|&&b| b.abs() > 0.0).count();
448
449    // Compute model fit statistics
450    let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
451    let ss_tot: f64 = y.iter().map(|yi| (yi - y_mean).powi(2)).sum();
452    let ss_res: f64 = residuals.iter().map(|r| r.powi(2)).sum();
453    let r_squared = if ss_tot > 1e-10 {
454        1.0 - ss_res / ss_tot
455    } else {
456        1.0
457    };
458
459    // Adjusted R²
460    let eff_df = n_nonzero as f64;
461    let adj_r_squared = if ss_tot > 1e-10 && n > eff_df as usize {
462        1.0 - (1.0 - r_squared) * ((n - 1) as f64 / (n as f64 - eff_df))
463    } else {
464        r_squared
465    };
466
467    let mse = ss_res / (n as f64 - p as f64);
468    let rmse = mse.sqrt();
469    let mae: f64 = residuals.iter().map(|r| r.abs()).sum::<f64>() / n as f64;
470
471    Ok(LassoFit {
472        lambda: 0.0,
473        intercept,
474        coefficients: beta_orig,
475        fitted_values: fitted,
476        residuals,
477        n_nonzero,
478        iterations: 1,
479        converged: true,
480        r_squared,
481        adj_r_squared,
482        mse,
483        rmse,
484        mae,
485    })
486}
487
488/// Makes predictions using a lasso regression fit.
489///
490/// # Arguments
491///
492/// * `fit` - The lasso regression fit result
493/// * `x_new` - New data matrix (n_new × p)
494///
495/// # Returns
496///
497/// Predictions for each row in x_new.
498pub fn predict_lasso(fit: &LassoFit, x_new: &Matrix) -> Vec<f64> {
499    predict(x_new, fit.intercept, &fit.coefficients)
500}
501
502#[cfg(test)]
503mod tests {
504    use super::*;
505
506    #[test]
507    fn test_soft_threshold() {
508        assert_eq!(soft_threshold(5.0, 2.0), 3.0);
509        assert_eq!(soft_threshold(-5.0, 2.0), -3.0);
510        assert_eq!(soft_threshold(1.0, 2.0), 0.0);
511        assert_eq!(soft_threshold(-1.0, 2.0), 0.0);
512        assert_eq!(soft_threshold(2.0, 2.0), 0.0);
513        assert_eq!(soft_threshold(-2.0, 2.0), 0.0);
514        assert_eq!(soft_threshold(0.0, 0.0), 0.0);
515    }
516
517    #[test]
518    fn test_lasso_fit_simple() {
519        // Simple test: y = 2*x with perfect linear relationship
520        let x_data = vec![
521            1.0, 1.0,
522            1.0, 2.0,
523            1.0, 3.0,
524            1.0, 4.0,
525        ];
526        let x = Matrix::new(4, 2, x_data);
527        let y = vec![2.0, 4.0, 6.0, 8.0];
528
529        let options = LassoFitOptions {
530            lambda: 0.01,  // Very small lambda for near-OLS solution
531            intercept: true,
532            standardize: true,  // Standardize for better convergence
533            ..Default::default()
534        };
535
536        let fit = lasso_fit(&x, &y, &options).unwrap();
537
538        // With small lambda, should get a good fit
539        assert!(fit.converged);
540        assert!(fit.n_nonzero > 0);
541
542        // Predictions should be close to actual values
543        for i in 0..4 {
544            assert!((fit.fitted_values[i] - y[i]).abs() < 0.5);
545        }
546    }
547
548    #[test]
549    fn test_lasso_with_large_lambda() {
550        let x_data = vec![
551            1.0, 1.0,
552            1.0, 2.0,
553            1.0, 3.0,
554        ];
555        let x = Matrix::new(3, 2, x_data);
556        let y = vec![2.0, 4.0, 6.0];
557
558        let options = LassoFitOptions {
559            lambda: 100.0,
560            intercept: true,
561            standardize: false,
562            ..Default::default()
563        };
564
565        let fit = lasso_fit(&x, &y, &options).unwrap();
566
567        // With large lambda, all coefficients should be zero
568        // Only intercept should be non-zero (equal to mean of y)
569        assert_eq!(fit.n_nonzero, 0);
570        assert!((fit.coefficients[1]).abs() < 1e-10);
571    }
572
573    #[test]
574    fn test_lasso_zero_lambda_is_ols() {
575        let x_data = vec![
576            1.0, 1.0,
577            1.0, 2.0,
578            1.0, 3.0,
579        ];
580        let x = Matrix::new(3, 2, x_data);
581        let y = vec![2.0, 4.0, 6.0];
582
583        let options = LassoFitOptions {
584            lambda: 0.0,
585            intercept: true,
586            standardize: false,
587            ..Default::default()
588        };
589
590        let fit = lasso_fit(&x, &y, &options).unwrap();
591
592        // Should be close to perfect fit
593        assert!((fit.fitted_values[0] - 2.0).abs() < 1e-6);
594        assert!((fit.fitted_values[1] - 4.0).abs() < 1e-6);
595        assert!((fit.fitted_values[2] - 6.0).abs() < 1e-6);
596    }
597
598    #[test]
599    fn test_predict_lasso() {
600        let x_data = vec![
601            1.0, 1.0,
602            1.0, 2.0,
603            1.0, 3.0,
604        ];
605        let x = Matrix::new(3, 2, x_data);
606        let y = vec![2.0, 4.0, 6.0];
607
608        let options = LassoFitOptions {
609            lambda: 0.1,
610            intercept: true,
611            standardize: false,
612            ..Default::default()
613        };
614
615        let fit = lasso_fit(&x, &y, &options).unwrap();
616        let preds = predict_lasso(&fit, &x);
617
618        // Predictions on training data should equal fitted values
619        for i in 0..3 {
620            assert!((preds[i] - fit.fitted_values[i]).abs() < 1e-10);
621        }
622    }
623}