Skip to main content

ferrolearn_linear/
lasso.rs

1//! Lasso regression (L1-regularized linear regression).
2//!
3//! This module provides [`Lasso`], which fits a linear model with L1
4//! regularization using coordinate descent with soft-thresholding:
5//!
6//! ```text
7//! minimize (1 / (2 * n_samples)) * ||X @ w - y||^2 + alpha * ||w||_1
8//! ```
9//!
10//! The L1 penalty encourages sparse solutions where some coefficients
11//! are exactly zero, making Lasso useful for feature selection.
12//!
13//! # Examples
14//!
15//! ```
16//! use ferrolearn_linear::Lasso;
17//! use ferrolearn_core::{Fit, Predict};
18//! use ndarray::{array, Array1, Array2};
19//!
20//! let model = Lasso::<f64>::new().with_alpha(0.1);
21//! let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
22//! let y = array![2.0, 4.0, 6.0, 8.0];
23//!
24//! let fitted = model.fit(&x, &y).unwrap();
25//! let preds = fitted.predict(&x).unwrap();
26//! ```
27
28use ferrolearn_core::error::FerroError;
29use ferrolearn_core::introspection::HasCoefficients;
30use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
31use ferrolearn_core::traits::{Fit, Predict};
32use ndarray::{Array1, Array2, Axis, ScalarOperand};
33use num_traits::{Float, FromPrimitive};
34
35/// Lasso regression (L1-regularized least squares).
36///
37/// Uses coordinate descent with soft-thresholding to solve the L1-penalized
38/// regression problem. The `alpha` parameter controls the strength of the
39/// L1 penalty.
40///
41/// # Type Parameters
42///
43/// - `F`: The floating-point type (`f32` or `f64`).
44#[derive(Debug, Clone)]
45pub struct Lasso<F> {
46    /// Regularization strength. Larger values specify stronger
47    /// regularization and sparser solutions.
48    pub alpha: F,
49    /// Maximum number of coordinate descent iterations.
50    pub max_iter: usize,
51    /// Convergence tolerance on the maximum coefficient change.
52    pub tol: F,
53    /// Whether to fit an intercept (bias) term.
54    pub fit_intercept: bool,
55}
56
57impl<F: Float> Lasso<F> {
58    /// Create a new `Lasso` with default settings.
59    ///
60    /// Defaults: `alpha = 1.0`, `max_iter = 1000`, `tol = 1e-4`,
61    /// `fit_intercept = true`.
62    #[must_use]
63    pub fn new() -> Self {
64        Self {
65            alpha: F::one(),
66            max_iter: 1000,
67            tol: F::from(1e-4).unwrap(),
68            fit_intercept: true,
69        }
70    }
71
72    /// Set the regularization strength.
73    #[must_use]
74    pub fn with_alpha(mut self, alpha: F) -> Self {
75        self.alpha = alpha;
76        self
77    }
78
79    /// Set the maximum number of iterations.
80    #[must_use]
81    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
82        self.max_iter = max_iter;
83        self
84    }
85
86    /// Set the convergence tolerance.
87    #[must_use]
88    pub fn with_tol(mut self, tol: F) -> Self {
89        self.tol = tol;
90        self
91    }
92
93    /// Set whether to fit an intercept term.
94    #[must_use]
95    pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
96        self.fit_intercept = fit_intercept;
97        self
98    }
99}
100
101impl<F: Float> Default for Lasso<F> {
102    fn default() -> Self {
103        Self::new()
104    }
105}
106
107/// Fitted Lasso regression model.
108///
109/// Stores the learned (potentially sparse) coefficients and intercept.
110/// Implements [`Predict`] and [`HasCoefficients`].
111#[derive(Debug, Clone)]
112pub struct FittedLasso<F> {
113    /// Learned coefficient vector (some may be exactly zero).
114    coefficients: Array1<F>,
115    /// Learned intercept (bias) term.
116    intercept: F,
117}
118
119/// Soft-thresholding operator for L1 penalty.
120///
121/// Returns `sign(x) * max(|x| - threshold, 0)`.
122fn soft_threshold<F: Float>(x: F, threshold: F) -> F {
123    if x > threshold {
124        x - threshold
125    } else if x < -threshold {
126        x + threshold
127    } else {
128        F::zero()
129    }
130}
131
132impl<F: Float + Send + Sync + ScalarOperand + FromPrimitive + 'static> Fit<Array2<F>, Array1<F>>
133    for Lasso<F>
134{
135    type Fitted = FittedLasso<F>;
136    type Error = FerroError;
137
138    /// Fit the Lasso model using coordinate descent.
139    ///
140    /// # Errors
141    ///
142    /// Returns [`FerroError::ShapeMismatch`] if the number of samples in
143    /// `x` and `y` differ.
144    /// Returns [`FerroError::InvalidParameter`] if `alpha` is negative.
145    /// Returns [`FerroError::ConvergenceFailure`] if the algorithm does
146    /// not converge within `max_iter` iterations.
147    fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedLasso<F>, FerroError> {
148        let (n_samples, n_features) = x.dim();
149
150        if n_samples != y.len() {
151            return Err(FerroError::ShapeMismatch {
152                expected: vec![n_samples],
153                actual: vec![y.len()],
154                context: "y length must match number of samples in X".into(),
155            });
156        }
157
158        if self.alpha < F::zero() {
159            return Err(FerroError::InvalidParameter {
160                name: "alpha".into(),
161                reason: "must be non-negative".into(),
162            });
163        }
164
165        if n_samples == 0 {
166            return Err(FerroError::InsufficientSamples {
167                required: 1,
168                actual: 0,
169                context: "Lasso requires at least one sample".into(),
170            });
171        }
172
173        let n_f = F::from(n_samples).unwrap();
174
175        // Center data if fitting intercept.
176        let (x_work, y_work, x_mean, y_mean) = if self.fit_intercept {
177            let x_mean = x
178                .mean_axis(Axis(0))
179                .ok_or_else(|| FerroError::NumericalInstability {
180                    message: "failed to compute column means".into(),
181                })?;
182            let y_mean = y.mean().ok_or_else(|| FerroError::NumericalInstability {
183                message: "failed to compute target mean".into(),
184            })?;
185
186            let x_c = x - &x_mean;
187            let y_c = y - y_mean;
188            (x_c, y_c, Some(x_mean), Some(y_mean))
189        } else {
190            (x.clone(), y.clone(), None, None)
191        };
192
193        // Precompute column norms (X_j^T X_j / n).
194        let col_norms: Vec<F> = (0..n_features)
195            .map(|j| {
196                let col = x_work.column(j);
197                col.dot(&col) / n_f
198            })
199            .collect();
200
201        // Initialize coefficients to zero.
202        let mut w = Array1::<F>::zeros(n_features);
203        let mut residual = y_work;
204
205        for _iter in 0..self.max_iter {
206            let mut max_change = F::zero();
207
208            for j in 0..n_features {
209                let col_j = x_work.column(j);
210
211                // Compute partial residual: r + X_j * w_j
212                let w_old = w[j];
213                if w_old != F::zero() {
214                    for i in 0..n_samples {
215                        residual[i] = residual[i] + col_j[i] * w_old;
216                    }
217                }
218
219                // Compute the unpenalized update: X_j^T r / n.
220                let rho = col_j.dot(&residual) / n_f;
221
222                // Apply soft-thresholding.
223                let w_new = if col_norms[j] > F::zero() {
224                    soft_threshold(rho, self.alpha) / col_norms[j]
225                } else {
226                    F::zero()
227                };
228
229                // Update residual: r = r - X_j * w_new.
230                if w_new != F::zero() {
231                    for i in 0..n_samples {
232                        residual[i] = residual[i] - col_j[i] * w_new;
233                    }
234                }
235
236                let change = (w_new - w_old).abs();
237                if change > max_change {
238                    max_change = change;
239                }
240
241                w[j] = w_new;
242            }
243
244            // Check convergence.
245            if max_change < self.tol {
246                let intercept = if let (Some(xm), Some(ym)) = (&x_mean, &y_mean) {
247                    *ym - xm.dot(&w)
248                } else {
249                    F::zero()
250                };
251
252                return Ok(FittedLasso {
253                    coefficients: w,
254                    intercept,
255                });
256            }
257        }
258
259        // Did not converge, but still return the current solution.
260        let intercept = if let (Some(xm), Some(ym)) = (&x_mean, &y_mean) {
261            *ym - xm.dot(&w)
262        } else {
263            F::zero()
264        };
265
266        Ok(FittedLasso {
267            coefficients: w,
268            intercept,
269        })
270    }
271}
272
273impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>> for FittedLasso<F> {
274    type Output = Array1<F>;
275    type Error = FerroError;
276
277    /// Predict target values for the given feature matrix.
278    ///
279    /// Computes `X @ coefficients + intercept`.
280    ///
281    /// # Errors
282    ///
283    /// Returns [`FerroError::ShapeMismatch`] if the number of features
284    /// does not match the fitted model.
285    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
286        let n_features = x.ncols();
287        if n_features != self.coefficients.len() {
288            return Err(FerroError::ShapeMismatch {
289                expected: vec![self.coefficients.len()],
290                actual: vec![n_features],
291                context: "number of features must match fitted model".into(),
292            });
293        }
294
295        let preds = x.dot(&self.coefficients) + self.intercept;
296        Ok(preds)
297    }
298}
299
300impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F> for FittedLasso<F> {
301    fn coefficients(&self) -> &Array1<F> {
302        &self.coefficients
303    }
304
305    fn intercept(&self) -> F {
306        self.intercept
307    }
308}
309
310// Pipeline integration.
311impl<F> PipelineEstimator<F> for Lasso<F>
312where
313    F: Float + FromPrimitive + ScalarOperand + Send + Sync + 'static,
314{
315    fn fit_pipeline(
316        &self,
317        x: &Array2<F>,
318        y: &Array1<F>,
319    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
320        let fitted = self.fit(x, y)?;
321        Ok(Box::new(fitted))
322    }
323}
324
325impl<F> FittedPipelineEstimator<F> for FittedLasso<F>
326where
327    F: Float + ScalarOperand + Send + Sync + 'static,
328{
329    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
330        self.predict(x)
331    }
332}
333
334#[cfg(test)]
335mod tests {
336    use super::*;
337    use approx::assert_relative_eq;
338    use ndarray::array;
339
340    #[test]
341    fn test_soft_threshold() {
342        assert_relative_eq!(soft_threshold(5.0_f64, 1.0), 4.0);
343        assert_relative_eq!(soft_threshold(-5.0_f64, 1.0), -4.0);
344        assert_relative_eq!(soft_threshold(0.5_f64, 1.0), 0.0);
345        assert_relative_eq!(soft_threshold(-0.5_f64, 1.0), 0.0);
346        assert_relative_eq!(soft_threshold(0.0_f64, 1.0), 0.0);
347    }
348
349    #[test]
350    fn test_lasso_zero_alpha() {
351        // With alpha=0, Lasso should behave like OLS.
352        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
353        let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
354
355        let model = Lasso::<f64>::new().with_alpha(0.0);
356        let fitted = model.fit(&x, &y).unwrap();
357
358        assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 1e-4);
359        assert_relative_eq!(fitted.intercept(), 1.0, epsilon = 1e-4);
360    }
361
362    #[test]
363    fn test_lasso_sparsity() {
364        // With high alpha, most coefficients should be zero.
365        let x = Array2::from_shape_vec(
366            (10, 3),
367            vec![
368                1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0, 0.0, 4.0, 0.0, 0.0, 5.0, 0.0, 0.0, 6.0,
369                0.0, 0.0, 7.0, 0.0, 0.0, 8.0, 0.0, 0.0, 9.0, 0.0, 0.0, 10.0, 0.0, 0.0,
370            ],
371        )
372        .unwrap();
373        let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0];
374
375        let model = Lasso::<f64>::new().with_alpha(5.0);
376        let fitted = model.fit(&x, &y).unwrap();
377
378        // Irrelevant features should have zero coefficients.
379        assert_relative_eq!(fitted.coefficients()[1], 0.0, epsilon = 1e-10);
380        assert_relative_eq!(fitted.coefficients()[2], 0.0, epsilon = 1e-10);
381    }
382
383    #[test]
384    fn test_lasso_shrinks_coefficients() {
385        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
386        let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
387
388        let model_low = Lasso::<f64>::new().with_alpha(0.01);
389        let model_high = Lasso::<f64>::new().with_alpha(1.0);
390
391        let fitted_low = model_low.fit(&x, &y).unwrap();
392        let fitted_high = model_high.fit(&x, &y).unwrap();
393
394        assert!(fitted_high.coefficients()[0].abs() <= fitted_low.coefficients()[0].abs());
395    }
396
397    #[test]
398    fn test_lasso_no_intercept() {
399        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
400        let y = array![2.0, 4.0, 6.0, 8.0];
401
402        let model = Lasso::<f64>::new()
403            .with_alpha(0.0)
404            .with_fit_intercept(false);
405        let fitted = model.fit(&x, &y).unwrap();
406
407        assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 1e-4);
408        assert_relative_eq!(fitted.intercept(), 0.0, epsilon = 1e-10);
409    }
410
411    #[test]
412    fn test_lasso_negative_alpha() {
413        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
414        let y = array![1.0, 2.0, 3.0];
415
416        let model = Lasso::<f64>::new().with_alpha(-1.0);
417        let result = model.fit(&x, &y);
418        assert!(result.is_err());
419    }
420
421    #[test]
422    fn test_lasso_shape_mismatch() {
423        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
424        let y = array![1.0, 2.0];
425
426        let model = Lasso::<f64>::new();
427        let result = model.fit(&x, &y);
428        assert!(result.is_err());
429    }
430
431    #[test]
432    fn test_lasso_predict() {
433        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
434        let y = array![2.0, 4.0, 6.0, 8.0];
435
436        let model = Lasso::<f64>::new().with_alpha(0.01);
437        let fitted = model.fit(&x, &y).unwrap();
438        let preds = fitted.predict(&x).unwrap();
439        assert_eq!(preds.len(), 4);
440    }
441
442    #[test]
443    fn test_lasso_pipeline_integration() {
444        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
445        let y = array![3.0, 5.0, 7.0, 9.0];
446
447        let model = Lasso::<f64>::new().with_alpha(0.01);
448        let fitted = model.fit_pipeline(&x, &y).unwrap();
449        let preds = fitted.predict_pipeline(&x).unwrap();
450        assert_eq!(preds.len(), 4);
451    }
452
453    #[test]
454    fn test_lasso_has_coefficients() {
455        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
456        let y = array![1.0, 2.0, 3.0];
457
458        let model = Lasso::<f64>::new().with_alpha(0.1);
459        let fitted = model.fit(&x, &y).unwrap();
460
461        assert_eq!(fitted.coefficients().len(), 2);
462    }
463}