Skip to main content

ferrolearn_linear/
lars.rs

1//! Least Angle Regression (LARS) and Lasso-LARS.
2//!
3//! This module provides two estimators:
4//!
5//! - **[`Lars`]** — Least Angle Regression, a forward stepwise method that
6//!   builds sparse linear models by iteratively adding the feature most
7//!   correlated with the current residual.
8//! - **[`LassoLars`]** — A variant that enforces the Lasso (L1) constraint
9//!   by removing features from the active set when their coefficients cross
10//!   zero.
11//!
12//! Both estimators use a simplified forward stagewise approach:
13//!
14//! 1. Find the feature most correlated with the residual.
15//! 2. Add it to the active set.
16//! 3. Solve OLS on the active features.
17//! 4. Update the residual.
18//! 5. Repeat until the desired number of non-zero coefficients is reached
19//!    (LARS) or convergence (LassoLars).
20//!
21//! # Examples
22//!
23//! ```
24//! use ferrolearn_linear::Lars;
25//! use ferrolearn_core::{Fit, Predict};
26//! use ndarray::{array, Array1, Array2};
27//!
28//! let x = Array2::from_shape_vec((5, 2), vec![
29//!     1.0, 0.0, 2.0, 0.1, 3.0, 0.2, 4.0, 0.3, 5.0, 0.4,
30//! ]).unwrap();
31//! let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
32//!
33//! let model = Lars::<f64>::new().with_n_nonzero_coefs(1);
34//! let fitted = model.fit(&x, &y).unwrap();
35//! let preds = fitted.predict(&x).unwrap();
36//! assert_eq!(preds.len(), 5);
37//! ```
38
39use ferrolearn_core::error::FerroError;
40use ferrolearn_core::introspection::HasCoefficients;
41use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
42use ferrolearn_core::traits::{Fit, Predict};
43use ndarray::{Array1, Array2, Axis, ScalarOperand};
44use num_traits::{Float, FromPrimitive};
45
46// ---------------------------------------------------------------------------
47// LARS
48// ---------------------------------------------------------------------------
49
50/// Least Angle Regression (LARS).
51///
52/// Builds a sparse linear model by iteratively adding the feature most
53/// correlated with the residual. At each step, OLS is re-solved on the
54/// current active set.
55///
56/// # Type Parameters
57///
58/// - `F`: The floating-point type (`f32` or `f64`).
59#[derive(Debug, Clone)]
60pub struct Lars<F> {
61    /// Maximum number of non-zero coefficients. Defaults to `None`,
62    /// meaning use all features.
63    pub n_nonzero_coefs: Option<usize>,
64    /// Whether to fit an intercept (bias) term.
65    pub fit_intercept: bool,
66    _marker: core::marker::PhantomData<F>,
67}
68
69impl<F: Float> Lars<F> {
70    /// Create a new `Lars` with default settings.
71    ///
72    /// Defaults: `n_nonzero_coefs = None`, `fit_intercept = true`.
73    #[must_use]
74    pub fn new() -> Self {
75        Self {
76            n_nonzero_coefs: None,
77            fit_intercept: true,
78            _marker: core::marker::PhantomData,
79        }
80    }
81
82    /// Set the maximum number of non-zero coefficients.
83    #[must_use]
84    pub fn with_n_nonzero_coefs(mut self, n: usize) -> Self {
85        self.n_nonzero_coefs = Some(n);
86        self
87    }
88
89    /// Set whether to fit an intercept term.
90    #[must_use]
91    pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
92        self.fit_intercept = fit_intercept;
93        self
94    }
95}
96
97impl<F: Float> Default for Lars<F> {
98    fn default() -> Self {
99        Self::new()
100    }
101}
102
103/// Fitted LARS model.
104///
105/// Stores the learned (sparse) coefficients and intercept. Implements
106/// [`Predict`] and [`HasCoefficients`].
107#[derive(Debug, Clone)]
108pub struct FittedLars<F> {
109    /// Learned coefficient vector (many entries may be zero).
110    coefficients: Array1<F>,
111    /// Learned intercept (bias) term.
112    intercept: F,
113}
114
115// ---------------------------------------------------------------------------
116// LassoLars
117// ---------------------------------------------------------------------------
118
119/// Lasso-LARS: LARS with the Lasso constraint.
120///
121/// Like [`Lars`], but features are removed from the active set when their
122/// coefficient crosses zero during the OLS update, enforcing an L1 penalty
123/// controlled by `alpha`.
124///
125/// # Type Parameters
126///
127/// - `F`: The floating-point type (`f32` or `f64`).
128#[derive(Debug, Clone)]
129pub struct LassoLars<F> {
130    /// L1 regularization strength. Larger values produce sparser models.
131    pub alpha: F,
132    /// Maximum number of forward steps.
133    pub max_iter: usize,
134    /// Whether to fit an intercept (bias) term.
135    pub fit_intercept: bool,
136}
137
138impl<F: Float> LassoLars<F> {
139    /// Create a new `LassoLars` with default settings.
140    ///
141    /// Defaults: `alpha = 1.0`, `max_iter = 500`, `fit_intercept = true`.
142    #[must_use]
143    pub fn new() -> Self {
144        Self {
145            alpha: F::one(),
146            max_iter: 500,
147            fit_intercept: true,
148        }
149    }
150
151    /// Set the L1 regularization strength.
152    #[must_use]
153    pub fn with_alpha(mut self, alpha: F) -> Self {
154        self.alpha = alpha;
155        self
156    }
157
158    /// Set the maximum number of forward steps.
159    #[must_use]
160    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
161        self.max_iter = max_iter;
162        self
163    }
164
165    /// Set whether to fit an intercept term.
166    #[must_use]
167    pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
168        self.fit_intercept = fit_intercept;
169        self
170    }
171}
172
173impl<F: Float> Default for LassoLars<F> {
174    fn default() -> Self {
175        Self::new()
176    }
177}
178
179/// Fitted Lasso-LARS model.
180///
181/// Stores the learned (sparse) coefficients and intercept.
182#[derive(Debug, Clone)]
183pub struct FittedLassoLars<F> {
184    /// Learned coefficient vector.
185    coefficients: Array1<F>,
186    /// Learned intercept (bias) term.
187    intercept: F,
188}
189
190// ---------------------------------------------------------------------------
191// Internal helpers
192// ---------------------------------------------------------------------------
193
194/// Solve the OLS sub-problem on the active columns of `x` for target `y`.
195///
196/// Returns the full-length coefficient vector (inactive entries = 0).
197fn ols_active<F: Float + FromPrimitive + 'static>(
198    x: &Array2<F>,
199    y: &Array1<F>,
200    active: &[usize],
201    n_features: usize,
202) -> Result<Array1<F>, FerroError> {
203    let n_samples = x.nrows();
204    let k = active.len();
205
206    // Build X_active  (n_samples x k).
207    let mut xa = Array2::<F>::zeros((n_samples, k));
208    for (col_idx, &j) in active.iter().enumerate() {
209        for i in 0..n_samples {
210            xa[[i, col_idx]] = x[[i, j]];
211        }
212    }
213
214    // Solve (Xa^T Xa) w_active = Xa^T y  via Cholesky / Gauss fallback.
215    let xat = xa.t();
216    let xtx = xat.dot(&xa);
217    let xty = xat.dot(y);
218
219    let w_active = cholesky_solve(&xtx, &xty)
220        .or_else(|_| gaussian_solve(k, &xtx, &xty))?;
221
222    // Scatter into full-length vector.
223    let mut w = Array1::<F>::zeros(n_features);
224    for (col_idx, &j) in active.iter().enumerate() {
225        w[j] = w_active[col_idx];
226    }
227    Ok(w)
228}
229
230/// Cholesky solve for `A x = b`.
231fn cholesky_solve<F: Float>(a: &Array2<F>, b: &Array1<F>) -> Result<Array1<F>, FerroError> {
232    let n = a.nrows();
233    let mut l = Array2::<F>::zeros((n, n));
234
235    for i in 0..n {
236        for j in 0..=i {
237            let mut s = a[[i, j]];
238            for k in 0..j {
239                s = s - l[[i, k]] * l[[j, k]];
240            }
241            if i == j {
242                if s <= F::zero() {
243                    return Err(FerroError::NumericalInstability {
244                        message: "Cholesky: matrix not positive definite".into(),
245                    });
246                }
247                l[[i, j]] = s.sqrt();
248            } else {
249                l[[i, j]] = s / l[[j, j]];
250            }
251        }
252    }
253
254    let mut z = Array1::<F>::zeros(n);
255    for i in 0..n {
256        let mut s = b[i];
257        for k in 0..i {
258            s = s - l[[i, k]] * z[k];
259        }
260        z[i] = s / l[[i, i]];
261    }
262
263    let mut x_sol = Array1::<F>::zeros(n);
264    for i in (0..n).rev() {
265        let mut s = z[i];
266        for k in (i + 1)..n {
267            s = s - l[[k, i]] * x_sol[k];
268        }
269        x_sol[i] = s / l[[i, i]];
270    }
271
272    Ok(x_sol)
273}
274
275/// Gaussian elimination with partial pivoting.
276fn gaussian_solve<F: Float>(
277    n: usize,
278    a: &Array2<F>,
279    b: &Array1<F>,
280) -> Result<Array1<F>, FerroError> {
281    let mut aug = Array2::<F>::zeros((n, n + 1));
282    for i in 0..n {
283        for j in 0..n {
284            aug[[i, j]] = a[[i, j]];
285        }
286        aug[[i, n]] = b[i];
287    }
288
289    for col in 0..n {
290        let mut max_val = aug[[col, col]].abs();
291        let mut max_row = col;
292        for row in (col + 1)..n {
293            let v = aug[[row, col]].abs();
294            if v > max_val {
295                max_val = v;
296                max_row = row;
297            }
298        }
299
300        if max_val < F::from(1e-12).unwrap_or_else(F::epsilon) {
301            return Err(FerroError::NumericalInstability {
302                message: "singular matrix in Gaussian elimination".into(),
303            });
304        }
305
306        if max_row != col {
307            for j in 0..=n {
308                let tmp = aug[[col, j]];
309                aug[[col, j]] = aug[[max_row, j]];
310                aug[[max_row, j]] = tmp;
311            }
312        }
313
314        let pivot = aug[[col, col]];
315        for row in (col + 1)..n {
316            let factor = aug[[row, col]] / pivot;
317            for j in col..=n {
318                let above = aug[[col, j]];
319                aug[[row, j]] = aug[[row, j]] - factor * above;
320            }
321        }
322    }
323
324    let mut x_sol = Array1::<F>::zeros(n);
325    for i in (0..n).rev() {
326        let mut s = aug[[i, n]];
327        for j in (i + 1)..n {
328            s = s - aug[[i, j]] * x_sol[j];
329        }
330        if aug[[i, i]].abs() < F::from(1e-12).unwrap_or_else(F::epsilon) {
331            return Err(FerroError::NumericalInstability {
332                message: "near-zero pivot in back substitution".into(),
333            });
334        }
335        x_sol[i] = s / aug[[i, i]];
336    }
337
338    Ok(x_sol)
339}
340
341/// Centred data: `(x_centred, y_centred, x_mean, y_mean)`.
342type CentredData<F> = (Array2<F>, Array1<F>, Option<Array1<F>>, Option<F>);
343
344/// Center `x` and `y` for intercept fitting, returning centred arrays and means.
345fn center_data<F: Float + FromPrimitive + ScalarOperand + 'static>(
346    x: &Array2<F>,
347    y: &Array1<F>,
348    fit_intercept: bool,
349) -> Result<CentredData<F>, FerroError> {
350    if fit_intercept {
351        let x_mean = x
352            .mean_axis(Axis(0))
353            .ok_or_else(|| FerroError::NumericalInstability {
354                message: "failed to compute column means".into(),
355            })?;
356        let y_mean = y.mean().ok_or_else(|| FerroError::NumericalInstability {
357            message: "failed to compute target mean".into(),
358        })?;
359        let x_c = x - &x_mean;
360        let y_c = y - y_mean;
361        Ok((x_c, y_c, Some(x_mean), Some(y_mean)))
362    } else {
363        Ok((x.clone(), y.clone(), None, None))
364    }
365}
366
367/// Compute the intercept from centred means and coefficients.
368fn compute_intercept<F: Float + 'static>(
369    x_mean: &Option<Array1<F>>,
370    y_mean: &Option<F>,
371    w: &Array1<F>,
372) -> F {
373    if let (Some(xm), Some(ym)) = (x_mean, y_mean) {
374        *ym - xm.dot(w)
375    } else {
376        F::zero()
377    }
378}
379
380/// Common input validation for LARS / LassoLars.
381fn validate_input<F: Float>(
382    x: &Array2<F>,
383    y: &Array1<F>,
384    name: &str,
385) -> Result<(usize, usize), FerroError> {
386    let (n_samples, n_features) = x.dim();
387
388    if n_samples != y.len() {
389        return Err(FerroError::ShapeMismatch {
390            expected: vec![n_samples],
391            actual: vec![y.len()],
392            context: "y length must match number of samples in X".into(),
393        });
394    }
395
396    if n_samples == 0 {
397        return Err(FerroError::InsufficientSamples {
398            required: 1,
399            actual: 0,
400            context: format!("{name} requires at least one sample"),
401        });
402    }
403
404    Ok((n_samples, n_features))
405}
406
407// ---------------------------------------------------------------------------
408// Fit — Lars
409// ---------------------------------------------------------------------------
410
411impl<F: Float + Send + Sync + ScalarOperand + FromPrimitive + 'static> Fit<Array2<F>, Array1<F>>
412    for Lars<F>
413{
414    type Fitted = FittedLars<F>;
415    type Error = FerroError;
416
417    /// Fit the LARS model.
418    ///
419    /// Iteratively adds the feature most correlated with the residual to the
420    /// active set and solves OLS on that subset.
421    ///
422    /// # Errors
423    ///
424    /// - [`FerroError::ShapeMismatch`] — sample count mismatch.
425    /// - [`FerroError::InsufficientSamples`] — zero samples.
426    /// - [`FerroError::InvalidParameter`] — `n_nonzero_coefs` exceeds feature count.
427    fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedLars<F>, FerroError> {
428        let (_n_samples, n_features) = validate_input(x, y, "Lars")?;
429
430        let max_active = self.n_nonzero_coefs.unwrap_or(n_features);
431        if max_active > n_features {
432            return Err(FerroError::InvalidParameter {
433                name: "n_nonzero_coefs".into(),
434                reason: format!(
435                    "cannot exceed number of features ({n_features})"
436                ),
437            });
438        }
439
440        let (x_work, y_work, x_mean, y_mean) =
441            center_data(x, y, self.fit_intercept)?;
442
443        let mut active: Vec<usize> = Vec::with_capacity(max_active);
444        let mut in_active = vec![false; n_features];
445        let mut w = Array1::<F>::zeros(n_features);
446        let mut residual = y_work.clone();
447
448        for _step in 0..max_active {
449            // Find feature most correlated with residual (not already active).
450            let mut best_j = None;
451            let mut best_corr = F::zero();
452            for (j, &is_active) in in_active.iter().enumerate() {
453                if is_active {
454                    continue;
455                }
456                let corr = x_work.column(j).dot(&residual).abs();
457                if corr > best_corr {
458                    best_corr = corr;
459                    best_j = Some(j);
460                }
461            }
462
463            let j = match best_j {
464                Some(j) => j,
465                None => break, // all features active
466            };
467
468            active.push(j);
469            in_active[j] = true;
470
471            // OLS on active set.
472            w = ols_active(&x_work, &y_work, &active, n_features)?;
473
474            // Update residual.
475            residual = &y_work - x_work.dot(&w);
476        }
477
478        let intercept = compute_intercept(&x_mean, &y_mean, &w);
479
480        Ok(FittedLars {
481            coefficients: w,
482            intercept,
483        })
484    }
485}
486
487// ---------------------------------------------------------------------------
488// Fit — LassoLars
489// ---------------------------------------------------------------------------
490
491impl<F: Float + Send + Sync + ScalarOperand + FromPrimitive + 'static> Fit<Array2<F>, Array1<F>>
492    for LassoLars<F>
493{
494    type Fitted = FittedLassoLars<F>;
495    type Error = FerroError;
496
497    /// Fit the Lasso-LARS model.
498    ///
499    /// Like LARS, but features whose coefficients cross zero during the OLS
500    /// step are removed from the active set, enforcing an implicit L1
501    /// penalty. The iteration stops when the maximum absolute correlation
502    /// with the residual drops below `alpha`.
503    ///
504    /// # Errors
505    ///
506    /// - [`FerroError::ShapeMismatch`] — sample count mismatch.
507    /// - [`FerroError::InsufficientSamples`] — zero samples.
508    /// - [`FerroError::InvalidParameter`] — `alpha` is negative.
509    fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedLassoLars<F>, FerroError> {
510        let (n_samples, n_features) = validate_input(x, y, "LassoLars")?;
511
512        if self.alpha < F::zero() {
513            return Err(FerroError::InvalidParameter {
514                name: "alpha".into(),
515                reason: "must be non-negative".into(),
516            });
517        }
518
519        let n_f = F::from(n_samples).unwrap();
520        let (x_work, y_work, x_mean, y_mean) =
521            center_data(x, y, self.fit_intercept)?;
522
523        let mut active: Vec<usize> = Vec::new();
524        let mut in_active = vec![false; n_features];
525        let mut w = Array1::<F>::zeros(n_features);
526        let mut residual = y_work.clone();
527
528        for _step in 0..self.max_iter {
529            // Check stopping criterion: max |X^T r| / n <= alpha.
530            let mut best_j = None;
531            let mut best_corr = F::zero();
532            for (j, &is_active) in in_active.iter().enumerate() {
533                if is_active {
534                    continue;
535                }
536                let corr = x_work.column(j).dot(&residual).abs() / n_f;
537                if corr > best_corr {
538                    best_corr = corr;
539                    best_j = Some(j);
540                }
541            }
542
543            // If maximum correlation is below alpha, stop.
544            if best_corr <= self.alpha && !active.is_empty() {
545                break;
546            }
547
548            // Add best feature (if any remain).
549            if let Some(j) = best_j {
550                active.push(j);
551                in_active[j] = true;
552            } else {
553                break;
554            }
555
556            // OLS on active set.
557            let w_new = ols_active(&x_work, &y_work, &active, n_features)?;
558
559            // Drop features that crossed zero (Lasso modification).
560            let mut dropped = false;
561            for idx in (0..active.len()).rev() {
562                let feat = active[idx];
563                // A sign change (or zero) means it crossed zero.
564                if w[feat] != F::zero()
565                    && w_new[feat].signum() != w[feat].signum()
566                {
567                    active.remove(idx);
568                    in_active[feat] = false;
569                    dropped = true;
570                }
571            }
572
573            if dropped && !active.is_empty() {
574                // Re-solve OLS without the dropped features.
575                w = ols_active(&x_work, &y_work, &active, n_features)?;
576            } else {
577                w = w_new;
578            }
579
580            // Update residual.
581            residual = &y_work - x_work.dot(&w);
582        }
583
584        let intercept = compute_intercept(&x_mean, &y_mean, &w);
585
586        Ok(FittedLassoLars {
587            coefficients: w,
588            intercept,
589        })
590    }
591}
592
593// ---------------------------------------------------------------------------
594// Predict / HasCoefficients / Pipeline — FittedLars
595// ---------------------------------------------------------------------------
596
597impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>> for FittedLars<F> {
598    type Output = Array1<F>;
599    type Error = FerroError;
600
601    /// Predict target values for the given feature matrix.
602    ///
603    /// Computes `X @ coefficients + intercept`.
604    ///
605    /// # Errors
606    ///
607    /// Returns [`FerroError::ShapeMismatch`] if the number of features
608    /// does not match the fitted model.
609    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
610        if x.ncols() != self.coefficients.len() {
611            return Err(FerroError::ShapeMismatch {
612                expected: vec![self.coefficients.len()],
613                actual: vec![x.ncols()],
614                context: "number of features must match fitted model".into(),
615            });
616        }
617        Ok(x.dot(&self.coefficients) + self.intercept)
618    }
619}
620
621impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F> for FittedLars<F> {
622    fn coefficients(&self) -> &Array1<F> {
623        &self.coefficients
624    }
625
626    fn intercept(&self) -> F {
627        self.intercept
628    }
629}
630
631impl<F> PipelineEstimator<F> for Lars<F>
632where
633    F: Float + FromPrimitive + ScalarOperand + Send + Sync + 'static,
634{
635    fn fit_pipeline(
636        &self,
637        x: &Array2<F>,
638        y: &Array1<F>,
639    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
640        let fitted = self.fit(x, y)?;
641        Ok(Box::new(fitted))
642    }
643}
644
645impl<F> FittedPipelineEstimator<F> for FittedLars<F>
646where
647    F: Float + ScalarOperand + Send + Sync + 'static,
648{
649    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
650        self.predict(x)
651    }
652}
653
654// ---------------------------------------------------------------------------
655// Predict / HasCoefficients / Pipeline — FittedLassoLars
656// ---------------------------------------------------------------------------
657
658impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>> for FittedLassoLars<F> {
659    type Output = Array1<F>;
660    type Error = FerroError;
661
662    /// Predict target values for the given feature matrix.
663    ///
664    /// Computes `X @ coefficients + intercept`.
665    ///
666    /// # Errors
667    ///
668    /// Returns [`FerroError::ShapeMismatch`] if the number of features
669    /// does not match the fitted model.
670    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
671        if x.ncols() != self.coefficients.len() {
672            return Err(FerroError::ShapeMismatch {
673                expected: vec![self.coefficients.len()],
674                actual: vec![x.ncols()],
675                context: "number of features must match fitted model".into(),
676            });
677        }
678        Ok(x.dot(&self.coefficients) + self.intercept)
679    }
680}
681
682impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F> for FittedLassoLars<F> {
683    fn coefficients(&self) -> &Array1<F> {
684        &self.coefficients
685    }
686
687    fn intercept(&self) -> F {
688        self.intercept
689    }
690}
691
692impl<F> PipelineEstimator<F> for LassoLars<F>
693where
694    F: Float + FromPrimitive + ScalarOperand + Send + Sync + 'static,
695{
696    fn fit_pipeline(
697        &self,
698        x: &Array2<F>,
699        y: &Array1<F>,
700    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
701        let fitted = self.fit(x, y)?;
702        Ok(Box::new(fitted))
703    }
704}
705
706impl<F> FittedPipelineEstimator<F> for FittedLassoLars<F>
707where
708    F: Float + ScalarOperand + Send + Sync + 'static,
709{
710    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
711        self.predict(x)
712    }
713}
714
715// ---------------------------------------------------------------------------
716// Tests
717// ---------------------------------------------------------------------------
718
719#[cfg(test)]
720mod tests {
721    use super::*;
722    use approx::assert_relative_eq;
723    use ndarray::array;
724
725    // ---- Lars ----
726
727    #[test]
728    fn test_lars_defaults() {
729        let m = Lars::<f64>::new();
730        assert!(m.n_nonzero_coefs.is_none());
731        assert!(m.fit_intercept);
732    }
733
734    #[test]
735    fn test_lars_builder() {
736        let m = Lars::<f64>::new()
737            .with_n_nonzero_coefs(3)
738            .with_fit_intercept(false);
739        assert_eq!(m.n_nonzero_coefs, Some(3));
740        assert!(!m.fit_intercept);
741    }
742
743    #[test]
744    fn test_lars_simple_linear() {
745        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
746        let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
747
748        let fitted = Lars::<f64>::new().fit(&x, &y).unwrap();
749        assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 1e-6);
750        assert_relative_eq!(fitted.intercept(), 1.0, epsilon = 1e-6);
751    }
752
753    #[test]
754    fn test_lars_sparsity() {
755        // With n_nonzero_coefs=1, only one coefficient should be non-zero.
756        let x = Array2::from_shape_vec(
757            (10, 3),
758            vec![
759                1.0, 0.1, 0.01, 2.0, 0.2, 0.02, 3.0, 0.3, 0.03, 4.0, 0.4, 0.04,
760                5.0, 0.5, 0.05, 6.0, 0.6, 0.06, 7.0, 0.7, 0.07, 8.0, 0.8, 0.08,
761                9.0, 0.9, 0.09, 10.0, 1.0, 0.10,
762            ],
763        )
764        .unwrap();
765        let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0];
766
767        let fitted = Lars::<f64>::new().with_n_nonzero_coefs(1).fit(&x, &y).unwrap();
768        let nonzero = fitted
769            .coefficients()
770            .iter()
771            .filter(|&&c| c.abs() > 1e-10)
772            .count();
773        assert_eq!(nonzero, 1);
774    }
775
776    #[test]
777    fn test_lars_predict() {
778        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
779        let y = array![2.0, 4.0, 6.0, 8.0];
780
781        let fitted = Lars::<f64>::new().fit(&x, &y).unwrap();
782        let preds = fitted.predict(&x).unwrap();
783        assert_eq!(preds.len(), 4);
784    }
785
786    #[test]
787    fn test_lars_shape_mismatch() {
788        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
789        let y = array![1.0, 2.0];
790        assert!(Lars::<f64>::new().fit(&x, &y).is_err());
791    }
792
793    #[test]
794    fn test_lars_predict_feature_mismatch() {
795        let x = Array2::from_shape_vec((3, 2), vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0]).unwrap();
796        let y = array![1.0, 2.0, 3.0];
797        let fitted = Lars::<f64>::new().fit(&x, &y).unwrap();
798        let x_bad = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
799        assert!(fitted.predict(&x_bad).is_err());
800    }
801
802    #[test]
803    fn test_lars_n_nonzero_exceeds_features() {
804        let x = Array2::from_shape_vec((3, 2), vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0]).unwrap();
805        let y = array![1.0, 2.0, 3.0];
806        assert!(Lars::<f64>::new().with_n_nonzero_coefs(5).fit(&x, &y).is_err());
807    }
808
809    #[test]
810    fn test_lars_no_intercept() {
811        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
812        let y = array![2.0, 4.0, 6.0, 8.0];
813
814        let fitted = Lars::<f64>::new()
815            .with_fit_intercept(false)
816            .fit(&x, &y)
817            .unwrap();
818        assert_relative_eq!(fitted.intercept(), 0.0, epsilon = 1e-10);
819    }
820
821    #[test]
822    fn test_lars_has_coefficients() {
823        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
824        let y = array![1.0, 2.0, 3.0];
825        let fitted = Lars::<f64>::new().fit(&x, &y).unwrap();
826        assert_eq!(fitted.coefficients().len(), 2);
827    }
828
829    #[test]
830    fn test_lars_pipeline() {
831        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
832        let y = array![3.0, 5.0, 7.0, 9.0];
833        let model = Lars::<f64>::new();
834        let fitted = model.fit_pipeline(&x, &y).unwrap();
835        let preds = fitted.predict_pipeline(&x).unwrap();
836        assert_eq!(preds.len(), 4);
837    }
838
839    // ---- LassoLars ----
840
841    #[test]
842    fn test_lasso_lars_defaults() {
843        let m = LassoLars::<f64>::new();
844        assert_relative_eq!(m.alpha, 1.0);
845        assert_eq!(m.max_iter, 500);
846        assert!(m.fit_intercept);
847    }
848
849    #[test]
850    fn test_lasso_lars_builder() {
851        let m = LassoLars::<f64>::new()
852            .with_alpha(0.5)
853            .with_max_iter(100)
854            .with_fit_intercept(false);
855        assert_relative_eq!(m.alpha, 0.5);
856        assert_eq!(m.max_iter, 100);
857        assert!(!m.fit_intercept);
858    }
859
860    #[test]
861    fn test_lasso_lars_simple() {
862        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
863        let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
864
865        let fitted = LassoLars::<f64>::new()
866            .with_alpha(0.0)
867            .fit(&x, &y)
868            .unwrap();
869        assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 0.1);
870    }
871
872    #[test]
873    fn test_lasso_lars_sparsity() {
874        // With high alpha, most coefficients should be zero.
875        let x = Array2::from_shape_vec(
876            (10, 3),
877            vec![
878                1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0, 0.0, 4.0, 0.0, 0.0,
879                5.0, 0.0, 0.0, 6.0, 0.0, 0.0, 7.0, 0.0, 0.0, 8.0, 0.0, 0.0,
880                9.0, 0.0, 0.0, 10.0, 0.0, 0.0,
881            ],
882        )
883        .unwrap();
884        let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0];
885
886        let fitted = LassoLars::<f64>::new()
887            .with_alpha(5.0)
888            .fit(&x, &y)
889            .unwrap();
890        // Irrelevant features (all-zero) should not enter.
891        assert_relative_eq!(fitted.coefficients()[1], 0.0, epsilon = 1e-10);
892        assert_relative_eq!(fitted.coefficients()[2], 0.0, epsilon = 1e-10);
893    }
894
895    #[test]
896    fn test_lasso_lars_negative_alpha() {
897        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
898        let y = array![1.0, 2.0, 3.0];
899        assert!(LassoLars::<f64>::new().with_alpha(-1.0).fit(&x, &y).is_err());
900    }
901
902    #[test]
903    fn test_lasso_lars_shape_mismatch() {
904        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
905        let y = array![1.0, 2.0];
906        assert!(LassoLars::<f64>::new().fit(&x, &y).is_err());
907    }
908
909    #[test]
910    fn test_lasso_lars_predict() {
911        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
912        let y = array![2.0, 4.0, 6.0, 8.0];
913        let fitted = LassoLars::<f64>::new().with_alpha(0.01).fit(&x, &y).unwrap();
914        let preds = fitted.predict(&x).unwrap();
915        assert_eq!(preds.len(), 4);
916    }
917
918    #[test]
919    fn test_lasso_lars_has_coefficients() {
920        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
921        let y = array![1.0, 2.0, 3.0];
922        let fitted = LassoLars::<f64>::new().with_alpha(0.01).fit(&x, &y).unwrap();
923        assert_eq!(fitted.coefficients().len(), 2);
924    }
925
926    #[test]
927    fn test_lasso_lars_pipeline() {
928        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
929        let y = array![3.0, 5.0, 7.0, 9.0];
930        let model = LassoLars::<f64>::new().with_alpha(0.01);
931        let fitted = model.fit_pipeline(&x, &y).unwrap();
932        let preds = fitted.predict_pipeline(&x).unwrap();
933        assert_eq!(preds.len(), 4);
934    }
935}