Skip to main content

ferrolearn_linear/
elastic_net.rs

1//! ElasticNet regression (combined L1 and L2 regularization).
2//!
3//! This module provides [`ElasticNet`], which fits a linear model with a
4//! blended L1/L2 regularization penalty using coordinate descent with
5//! soft-thresholding:
6//!
7//! ```text
8//! minimize (1/(2n)) * ||X @ w - y||^2
9//!        + alpha * l1_ratio * ||w||_1
10//!        + (alpha/2) * (1 - l1_ratio) * ||w||_2^2
11//! ```
12//!
13//! When `l1_ratio = 1`, ElasticNet is equivalent to Lasso. When
14//! `l1_ratio = 0`, it is equivalent to Ridge. Intermediate values produce
15//! solutions that are both sparse (L1) and small in magnitude (L2).
16//!
17//! # Examples
18//!
19//! ```
20//! use ferrolearn_linear::ElasticNet;
21//! use ferrolearn_core::{Fit, Predict};
22//! use ndarray::{array, Array1, Array2};
23//!
24//! let model = ElasticNet::<f64>::new()
25//!     .with_alpha(0.1)
26//!     .with_l1_ratio(0.5);
27//! let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
28//! let y = array![2.0, 4.0, 6.0, 8.0];
29//!
30//! let fitted = model.fit(&x, &y).unwrap();
31//! let preds = fitted.predict(&x).unwrap();
32//! ```
33
34use ferrolearn_core::error::FerroError;
35use ferrolearn_core::introspection::HasCoefficients;
36use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
37use ferrolearn_core::traits::{Fit, Predict};
38use ndarray::{Array1, Array2, Axis, ScalarOperand};
39use num_traits::{Float, FromPrimitive};
40
41/// ElasticNet regression (L1 + L2 regularized least squares).
42///
43/// Minimizes a combination of L1 and L2 penalties controlled by
44/// `alpha` and `l1_ratio`. Uses coordinate descent with soft-thresholding
45/// to handle the non-smooth L1 component.
46///
47/// # Type Parameters
48///
49/// - `F`: The floating-point type (`f32` or `f64`).
50#[derive(Debug, Clone)]
51pub struct ElasticNet<F> {
52    /// Overall regularization strength. Larger values enforce stronger
53    /// regularization.
54    pub alpha: F,
55    /// Mix between L1 and L2 regularization.
56    /// - `l1_ratio = 1.0` → pure Lasso (L1 only)
57    /// - `l1_ratio = 0.0` → pure Ridge (L2 only)
58    /// - `0.0 < l1_ratio < 1.0` → ElasticNet blend
59    pub l1_ratio: F,
60    /// Maximum number of coordinate descent iterations.
61    pub max_iter: usize,
62    /// Convergence tolerance on the maximum coefficient change per pass.
63    pub tol: F,
64    /// Whether to fit an intercept (bias) term.
65    pub fit_intercept: bool,
66}
67
68impl<F: Float + FromPrimitive> ElasticNet<F> {
69    /// Create a new `ElasticNet` with default settings.
70    ///
71    /// Defaults: `alpha = 1.0`, `l1_ratio = 0.5`, `max_iter = 1000`,
72    /// `tol = 1e-4`, `fit_intercept = true`.
73    #[must_use]
74    pub fn new() -> Self {
75        Self {
76            alpha: F::one(),
77            l1_ratio: F::from(0.5).unwrap(),
78            max_iter: 1000,
79            tol: F::from(1e-4).unwrap(),
80            fit_intercept: true,
81        }
82    }
83
84    /// Set the overall regularization strength.
85    #[must_use]
86    pub fn with_alpha(mut self, alpha: F) -> Self {
87        self.alpha = alpha;
88        self
89    }
90
91    /// Set the L1/L2 mixing ratio.
92    ///
93    /// Must be in `[0.0, 1.0]`. Values outside this range will be rejected
94    /// at fit time.
95    #[must_use]
96    pub fn with_l1_ratio(mut self, l1_ratio: F) -> Self {
97        self.l1_ratio = l1_ratio;
98        self
99    }
100
101    /// Set the maximum number of coordinate descent iterations.
102    #[must_use]
103    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
104        self.max_iter = max_iter;
105        self
106    }
107
108    /// Set the convergence tolerance on maximum coefficient change.
109    #[must_use]
110    pub fn with_tol(mut self, tol: F) -> Self {
111        self.tol = tol;
112        self
113    }
114
115    /// Set whether to fit an intercept term.
116    #[must_use]
117    pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
118        self.fit_intercept = fit_intercept;
119        self
120    }
121}
122
123impl<F: Float + FromPrimitive> Default for ElasticNet<F> {
124    fn default() -> Self {
125        Self::new()
126    }
127}
128
129/// Fitted ElasticNet regression model.
130///
131/// Stores the learned (potentially sparse) coefficients and intercept.
132/// Implements [`Predict`] and [`HasCoefficients`].
133#[derive(Debug, Clone)]
134pub struct FittedElasticNet<F> {
135    /// Learned coefficient vector (some may be exactly zero when L1 > 0).
136    coefficients: Array1<F>,
137    /// Learned intercept (bias) term.
138    intercept: F,
139}
140
141impl<F: Float> FittedElasticNet<F> {
142    /// Returns the intercept (bias) term learned during fitting.
143    pub fn intercept(&self) -> F {
144        self.intercept
145    }
146}
147
148/// Soft-thresholding operator used in coordinate descent for L1 penalty.
149///
150/// Returns `sign(x) * max(|x| - threshold, 0)`.
151#[inline]
152fn soft_threshold<F: Float>(x: F, threshold: F) -> F {
153    if x > threshold {
154        x - threshold
155    } else if x < -threshold {
156        x + threshold
157    } else {
158        F::zero()
159    }
160}
161
162impl<F: Float + Send + Sync + ScalarOperand + FromPrimitive + 'static> Fit<Array2<F>, Array1<F>>
163    for ElasticNet<F>
164{
165    type Fitted = FittedElasticNet<F>;
166    type Error = FerroError;
167
168    /// Fit the ElasticNet model using coordinate descent.
169    ///
170    /// Centers the data if `fit_intercept` is `true`, then alternates
171    /// coordinate updates using the soft-threshold rule with L2 scaling.
172    ///
173    /// # Errors
174    ///
175    /// - [`FerroError::ShapeMismatch`] if `x` and `y` have different numbers
176    ///   of samples.
177    /// - [`FerroError::InvalidParameter`] if `alpha` is negative, `l1_ratio`
178    ///   is outside `[0, 1]`, or `tol` is non-positive.
179    /// - [`FerroError::InsufficientSamples`] if `n_samples == 0`.
180    fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedElasticNet<F>, FerroError> {
181        let (n_samples, n_features) = x.dim();
182
183        if n_samples != y.len() {
184            return Err(FerroError::ShapeMismatch {
185                expected: vec![n_samples],
186                actual: vec![y.len()],
187                context: "y length must match number of samples in X".into(),
188            });
189        }
190
191        if self.alpha < F::zero() {
192            return Err(FerroError::InvalidParameter {
193                name: "alpha".into(),
194                reason: "must be non-negative".into(),
195            });
196        }
197
198        if self.l1_ratio < F::zero() || self.l1_ratio > F::one() {
199            return Err(FerroError::InvalidParameter {
200                name: "l1_ratio".into(),
201                reason: "must be in [0, 1]".into(),
202            });
203        }
204
205        if n_samples == 0 {
206            return Err(FerroError::InsufficientSamples {
207                required: 1,
208                actual: 0,
209                context: "ElasticNet requires at least one sample".into(),
210            });
211        }
212
213        let n_f = F::from(n_samples).unwrap();
214
215        // Center data when fitting intercept.
216        let (x_work, y_work, x_mean, y_mean) = if self.fit_intercept {
217            let x_mean = x
218                .mean_axis(Axis(0))
219                .ok_or_else(|| FerroError::NumericalInstability {
220                    message: "failed to compute column means".into(),
221                })?;
222            let y_mean = y.mean().ok_or_else(|| FerroError::NumericalInstability {
223                message: "failed to compute target mean".into(),
224            })?;
225
226            let x_c = x - &x_mean;
227            let y_c = y - y_mean;
228            (x_c, y_c, Some(x_mean), Some(y_mean))
229        } else {
230            (x.clone(), y.clone(), None, None)
231        };
232
233        // Precompute per-column X_j^T X_j / n (used as denominator).
234        let col_norms: Vec<F> = (0..n_features)
235            .map(|j| {
236                let col = x_work.column(j);
237                col.dot(&col) / n_f
238            })
239            .collect();
240
241        // L1 and L2 penalty strengths split from alpha/l1_ratio.
242        let alpha_l1 = self.alpha * self.l1_ratio;
243        let alpha_l2 = self.alpha * (F::one() - self.l1_ratio);
244
245        // Effective denominator per column: (X_j^T X_j / n) + alpha_l2.
246        let denominators: Vec<F> = col_norms.iter().map(|&cn| cn + alpha_l2).collect();
247
248        let mut w = Array1::<F>::zeros(n_features);
249        let mut residual = y_work.clone();
250
251        for _iter in 0..self.max_iter {
252            let mut max_change = F::zero();
253
254            for j in 0..n_features {
255                let col_j = x_work.column(j);
256                let w_old = w[j];
257
258                // Add back contribution of current coefficient j to residual.
259                if w_old != F::zero() {
260                    for i in 0..n_samples {
261                        residual[i] = residual[i] + col_j[i] * w_old;
262                    }
263                }
264
265                // Unpenalized correlation: X_j^T r / n.
266                let rho_j = col_j.dot(&residual) / n_f;
267
268                // Apply soft-threshold for L1, then divide by (col_norm + alpha_l2).
269                let w_new = if denominators[j] > F::zero() {
270                    soft_threshold(rho_j, alpha_l1) / denominators[j]
271                } else {
272                    F::zero()
273                };
274
275                // Update residual with new coefficient.
276                if w_new != F::zero() {
277                    for i in 0..n_samples {
278                        residual[i] = residual[i] - col_j[i] * w_new;
279                    }
280                }
281
282                let change = (w_new - w_old).abs();
283                if change > max_change {
284                    max_change = change;
285                }
286
287                w[j] = w_new;
288            }
289
290            if max_change < self.tol {
291                let intercept = compute_intercept(&x_mean, &y_mean, &w);
292                return Ok(FittedElasticNet {
293                    coefficients: w,
294                    intercept,
295                });
296            }
297        }
298
299        // Return best solution found even without full convergence.
300        let intercept = compute_intercept(&x_mean, &y_mean, &w);
301        Ok(FittedElasticNet {
302            coefficients: w,
303            intercept,
304        })
305    }
306}
307
308/// Compute intercept from the centered means and fitted coefficients.
309fn compute_intercept<F: Float + 'static>(
310    x_mean: &Option<Array1<F>>,
311    y_mean: &Option<F>,
312    w: &Array1<F>,
313) -> F {
314    if let (Some(xm), Some(ym)) = (x_mean, y_mean) {
315        *ym - xm.dot(w)
316    } else {
317        F::zero()
318    }
319}
320
321impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>> for FittedElasticNet<F> {
322    type Output = Array1<F>;
323    type Error = FerroError;
324
325    /// Predict target values for the given feature matrix.
326    ///
327    /// Computes `X @ coefficients + intercept`.
328    ///
329    /// # Errors
330    ///
331    /// Returns [`FerroError::ShapeMismatch`] if the number of features
332    /// does not match the fitted model.
333    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
334        let n_features = x.ncols();
335        if n_features != self.coefficients.len() {
336            return Err(FerroError::ShapeMismatch {
337                expected: vec![self.coefficients.len()],
338                actual: vec![n_features],
339                context: "number of features must match fitted model".into(),
340            });
341        }
342
343        let preds = x.dot(&self.coefficients) + self.intercept;
344        Ok(preds)
345    }
346}
347
348impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F> for FittedElasticNet<F> {
349    /// Returns the learned coefficient vector.
350    fn coefficients(&self) -> &Array1<F> {
351        &self.coefficients
352    }
353
354    /// Returns the learned intercept term.
355    fn intercept(&self) -> F {
356        self.intercept
357    }
358}
359
360// Pipeline integration for f64.
361impl PipelineEstimator for ElasticNet<f64> {
362    /// Fit the model and return it as a boxed pipeline estimator.
363    ///
364    /// # Errors
365    ///
366    /// Propagates any [`FerroError`] from `fit`.
367    fn fit_pipeline(
368        &self,
369        x: &Array2<f64>,
370        y: &Array1<f64>,
371    ) -> Result<Box<dyn FittedPipelineEstimator>, FerroError> {
372        let fitted = self.fit(x, y)?;
373        Ok(Box::new(fitted))
374    }
375}
376
377impl FittedPipelineEstimator for FittedElasticNet<f64> {
378    /// Generate predictions via the pipeline interface.
379    ///
380    /// # Errors
381    ///
382    /// Propagates any [`FerroError`] from `predict`.
383    fn predict_pipeline(&self, x: &Array2<f64>) -> Result<Array1<f64>, FerroError> {
384        self.predict(x)
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391    use approx::assert_relative_eq;
392    use ndarray::array;
393
394    // ---- soft_threshold helpers ----
395
396    #[test]
397    fn test_soft_threshold_positive() {
398        assert_relative_eq!(soft_threshold(5.0_f64, 1.0), 4.0);
399    }
400
401    #[test]
402    fn test_soft_threshold_negative() {
403        assert_relative_eq!(soft_threshold(-5.0_f64, 1.0), -4.0);
404    }
405
406    #[test]
407    fn test_soft_threshold_within_band() {
408        assert_relative_eq!(soft_threshold(0.5_f64, 1.0), 0.0);
409        assert_relative_eq!(soft_threshold(-0.5_f64, 1.0), 0.0);
410        assert_relative_eq!(soft_threshold(0.0_f64, 1.0), 0.0);
411    }
412
413    // ---- Builder ----
414
415    #[test]
416    fn test_default_builder() {
417        let m = ElasticNet::<f64>::new();
418        assert_relative_eq!(m.alpha, 1.0);
419        assert_relative_eq!(m.l1_ratio, 0.5);
420        assert_eq!(m.max_iter, 1000);
421        assert!(m.fit_intercept);
422    }
423
424    #[test]
425    fn test_builder_setters() {
426        let m = ElasticNet::<f64>::new()
427            .with_alpha(0.5)
428            .with_l1_ratio(0.2)
429            .with_max_iter(500)
430            .with_tol(1e-6)
431            .with_fit_intercept(false);
432        assert_relative_eq!(m.alpha, 0.5);
433        assert_relative_eq!(m.l1_ratio, 0.2);
434        assert_eq!(m.max_iter, 500);
435        assert!(!m.fit_intercept);
436    }
437
438    // ---- Validation errors ----
439
440    #[test]
441    fn test_negative_alpha_error() {
442        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
443        let y = array![1.0, 2.0, 3.0];
444        let result = ElasticNet::<f64>::new().with_alpha(-1.0).fit(&x, &y);
445        assert!(result.is_err());
446    }
447
448    #[test]
449    fn test_l1_ratio_out_of_range_error() {
450        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
451        let y = array![1.0, 2.0, 3.0];
452        let result = ElasticNet::<f64>::new().with_l1_ratio(1.5).fit(&x, &y);
453        assert!(result.is_err());
454    }
455
456    #[test]
457    fn test_shape_mismatch_error() {
458        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
459        let y = array![1.0, 2.0];
460        let result = ElasticNet::<f64>::new().fit(&x, &y);
461        assert!(result.is_err());
462    }
463
464    // ---- Correctness ----
465
466    #[test]
467    fn test_lasso_limit_l1_ratio_one() {
468        // With l1_ratio=1, ElasticNet should behave like Lasso.
469        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
470        let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
471
472        let model = ElasticNet::<f64>::new().with_alpha(0.0).with_l1_ratio(1.0);
473        let fitted = model.fit(&x, &y).unwrap();
474
475        assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 1e-4);
476        assert_relative_eq!(fitted.intercept(), 1.0, epsilon = 1e-4);
477    }
478
479    #[test]
480    fn test_ridge_limit_l1_ratio_zero() {
481        // With l1_ratio=0 and alpha=0, should recover OLS.
482        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
483        let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
484
485        let model = ElasticNet::<f64>::new().with_alpha(0.0).with_l1_ratio(0.0);
486        let fitted = model.fit(&x, &y).unwrap();
487
488        assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 1e-4);
489        assert_relative_eq!(fitted.intercept(), 1.0, epsilon = 1e-4);
490    }
491
492    #[test]
493    fn test_sparsity_with_high_l1_ratio() {
494        // High alpha with l1_ratio=1 should zero out irrelevant features.
495        let x = Array2::from_shape_vec(
496            (10, 3),
497            vec![
498                1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0, 0.0, 4.0, 0.0, 0.0, 5.0, 0.0, 0.0, 6.0,
499                0.0, 0.0, 7.0, 0.0, 0.0, 8.0, 0.0, 0.0, 9.0, 0.0, 0.0, 10.0, 0.0, 0.0,
500            ],
501        )
502        .unwrap();
503        let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0];
504
505        let model = ElasticNet::<f64>::new().with_alpha(5.0).with_l1_ratio(1.0);
506        let fitted = model.fit(&x, &y).unwrap();
507
508        assert_relative_eq!(fitted.coefficients()[1], 0.0, epsilon = 1e-10);
509        assert_relative_eq!(fitted.coefficients()[2], 0.0, epsilon = 1e-10);
510    }
511
512    #[test]
513    fn test_higher_alpha_shrinks_more() {
514        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
515        let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
516
517        let low = ElasticNet::<f64>::new()
518            .with_alpha(0.01)
519            .with_l1_ratio(0.5)
520            .fit(&x, &y)
521            .unwrap();
522        let high = ElasticNet::<f64>::new()
523            .with_alpha(2.0)
524            .with_l1_ratio(0.5)
525            .fit(&x, &y)
526            .unwrap();
527
528        assert!(high.coefficients()[0].abs() <= low.coefficients()[0].abs());
529    }
530
531    #[test]
532    fn test_no_intercept() {
533        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
534        let y = array![2.0, 4.0, 6.0, 8.0];
535
536        let fitted = ElasticNet::<f64>::new()
537            .with_alpha(0.0)
538            .with_l1_ratio(0.5)
539            .with_fit_intercept(false)
540            .fit(&x, &y)
541            .unwrap();
542
543        assert_relative_eq!(fitted.intercept(), 0.0, epsilon = 1e-10);
544    }
545
546    #[test]
547    fn test_predict_correct_length() {
548        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
549        let y = array![2.0, 4.0, 6.0, 8.0];
550
551        let fitted = ElasticNet::<f64>::new()
552            .with_alpha(0.01)
553            .fit(&x, &y)
554            .unwrap();
555        let preds = fitted.predict(&x).unwrap();
556        assert_eq!(preds.len(), 4);
557    }
558
559    #[test]
560    fn test_predict_feature_mismatch() {
561        let x_train = Array2::from_shape_vec((3, 2), vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0]).unwrap();
562        let y = array![1.0, 2.0, 3.0];
563        let fitted = ElasticNet::<f64>::new()
564            .with_alpha(0.01)
565            .fit(&x_train, &y)
566            .unwrap();
567
568        let x_bad = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
569        let result = fitted.predict(&x_bad);
570        assert!(result.is_err());
571    }
572
573    #[test]
574    fn test_has_coefficients_length() {
575        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
576        let y = array![1.0, 2.0, 3.0];
577        let fitted = ElasticNet::<f64>::new()
578            .with_alpha(0.1)
579            .fit(&x, &y)
580            .unwrap();
581
582        assert_eq!(fitted.coefficients().len(), 2);
583    }
584
585    #[test]
586    fn test_pipeline_integration() {
587        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
588        let y = array![3.0, 5.0, 7.0, 9.0];
589
590        let model = ElasticNet::<f64>::new().with_alpha(0.01);
591        let fitted_pipe = model.fit_pipeline(&x, &y).unwrap();
592        let preds = fitted_pipe.predict_pipeline(&x).unwrap();
593        assert_eq!(preds.len(), 4);
594    }
595}