Skip to main content

ferrolearn_linear/
elastic_net.rs

1//! ElasticNet regression (combined L1 and L2 regularization).
2//!
3//! This module provides [`ElasticNet`], which fits a linear model with a
4//! blended L1/L2 regularization penalty using coordinate descent with
5//! soft-thresholding:
6//!
7//! ```text
8//! minimize (1/(2n)) * ||X @ w - y||^2
9//!        + alpha * l1_ratio * ||w||_1
10//!        + (alpha/2) * (1 - l1_ratio) * ||w||_2^2
11//! ```
12//!
13//! When `l1_ratio = 1`, ElasticNet is equivalent to Lasso. When
14//! `l1_ratio = 0`, it is equivalent to Ridge. Intermediate values produce
15//! solutions that are both sparse (L1) and small in magnitude (L2).
16//!
17//! # Examples
18//!
19//! ```
20//! use ferrolearn_linear::ElasticNet;
21//! use ferrolearn_core::{Fit, Predict};
22//! use ndarray::{array, Array1, Array2};
23//!
24//! let model = ElasticNet::<f64>::new()
25//!     .with_alpha(0.1)
26//!     .with_l1_ratio(0.5);
27//! let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
28//! let y = array![2.0, 4.0, 6.0, 8.0];
29//!
30//! let fitted = model.fit(&x, &y).unwrap();
31//! let preds = fitted.predict(&x).unwrap();
32//! ```
33
34use ferrolearn_core::error::FerroError;
35use ferrolearn_core::introspection::HasCoefficients;
36use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
37use ferrolearn_core::traits::{Fit, Predict};
38use ndarray::{Array1, Array2, Axis, ScalarOperand};
39use num_traits::{Float, FromPrimitive};
40
41/// ElasticNet regression (L1 + L2 regularized least squares).
42///
43/// Minimizes a combination of L1 and L2 penalties controlled by
44/// `alpha` and `l1_ratio`. Uses coordinate descent with soft-thresholding
45/// to handle the non-smooth L1 component.
46///
47/// # Type Parameters
48///
49/// - `F`: The floating-point type (`f32` or `f64`).
50#[derive(Debug, Clone)]
51pub struct ElasticNet<F> {
52    /// Overall regularization strength. Larger values enforce stronger
53    /// regularization.
54    pub alpha: F,
55    /// Mix between L1 and L2 regularization.
56    /// - `l1_ratio = 1.0` → pure Lasso (L1 only)
57    /// - `l1_ratio = 0.0` → pure Ridge (L2 only)
58    /// - `0.0 < l1_ratio < 1.0` → ElasticNet blend
59    pub l1_ratio: F,
60    /// Maximum number of coordinate descent iterations.
61    pub max_iter: usize,
62    /// Convergence tolerance on the maximum coefficient change per pass.
63    pub tol: F,
64    /// Whether to fit an intercept (bias) term.
65    pub fit_intercept: bool,
66}
67
68impl<F: Float + FromPrimitive> ElasticNet<F> {
69    /// Create a new `ElasticNet` with default settings.
70    ///
71    /// Defaults: `alpha = 1.0`, `l1_ratio = 0.5`, `max_iter = 1000`,
72    /// `tol = 1e-4`, `fit_intercept = true`.
73    #[must_use]
74    pub fn new() -> Self {
75        Self {
76            alpha: F::one(),
77            l1_ratio: F::from(0.5).unwrap(),
78            max_iter: 1000,
79            tol: F::from(1e-4).unwrap(),
80            fit_intercept: true,
81        }
82    }
83
84    /// Set the overall regularization strength.
85    #[must_use]
86    pub fn with_alpha(mut self, alpha: F) -> Self {
87        self.alpha = alpha;
88        self
89    }
90
91    /// Set the L1/L2 mixing ratio.
92    ///
93    /// Must be in `[0.0, 1.0]`. Values outside this range will be rejected
94    /// at fit time.
95    #[must_use]
96    pub fn with_l1_ratio(mut self, l1_ratio: F) -> Self {
97        self.l1_ratio = l1_ratio;
98        self
99    }
100
101    /// Set the maximum number of coordinate descent iterations.
102    #[must_use]
103    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
104        self.max_iter = max_iter;
105        self
106    }
107
108    /// Set the convergence tolerance on maximum coefficient change.
109    #[must_use]
110    pub fn with_tol(mut self, tol: F) -> Self {
111        self.tol = tol;
112        self
113    }
114
115    /// Set whether to fit an intercept term.
116    #[must_use]
117    pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
118        self.fit_intercept = fit_intercept;
119        self
120    }
121}
122
123impl<F: Float + FromPrimitive> Default for ElasticNet<F> {
124    fn default() -> Self {
125        Self::new()
126    }
127}
128
129/// Fitted ElasticNet regression model.
130///
131/// Stores the learned (potentially sparse) coefficients and intercept.
132/// Implements [`Predict`] and [`HasCoefficients`].
133#[derive(Debug, Clone)]
134pub struct FittedElasticNet<F> {
135    /// Learned coefficient vector (some may be exactly zero when L1 > 0).
136    coefficients: Array1<F>,
137    /// Learned intercept (bias) term.
138    intercept: F,
139}
140
141impl<F: Float> FittedElasticNet<F> {
142    /// Returns the intercept (bias) term learned during fitting.
143    pub fn intercept(&self) -> F {
144        self.intercept
145    }
146}
147
148/// Soft-thresholding operator used in coordinate descent for L1 penalty.
149///
150/// Returns `sign(x) * max(|x| - threshold, 0)`.
151#[inline]
152fn soft_threshold<F: Float>(x: F, threshold: F) -> F {
153    if x > threshold {
154        x - threshold
155    } else if x < -threshold {
156        x + threshold
157    } else {
158        F::zero()
159    }
160}
161
162impl<F: Float + Send + Sync + ScalarOperand + FromPrimitive + 'static> Fit<Array2<F>, Array1<F>>
163    for ElasticNet<F>
164{
165    type Fitted = FittedElasticNet<F>;
166    type Error = FerroError;
167
168    /// Fit the ElasticNet model using coordinate descent.
169    ///
170    /// Centers the data if `fit_intercept` is `true`, then alternates
171    /// coordinate updates using the soft-threshold rule with L2 scaling.
172    ///
173    /// # Errors
174    ///
175    /// - [`FerroError::ShapeMismatch`] if `x` and `y` have different numbers
176    ///   of samples.
177    /// - [`FerroError::InvalidParameter`] if `alpha` is negative, `l1_ratio`
178    ///   is outside `[0, 1]`, or `tol` is non-positive.
179    /// - [`FerroError::InsufficientSamples`] if `n_samples == 0`.
180    fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedElasticNet<F>, FerroError> {
181        let (n_samples, n_features) = x.dim();
182
183        if n_samples != y.len() {
184            return Err(FerroError::ShapeMismatch {
185                expected: vec![n_samples],
186                actual: vec![y.len()],
187                context: "y length must match number of samples in X".into(),
188            });
189        }
190
191        if self.alpha < F::zero() {
192            return Err(FerroError::InvalidParameter {
193                name: "alpha".into(),
194                reason: "must be non-negative".into(),
195            });
196        }
197
198        if self.l1_ratio < F::zero() || self.l1_ratio > F::one() {
199            return Err(FerroError::InvalidParameter {
200                name: "l1_ratio".into(),
201                reason: "must be in [0, 1]".into(),
202            });
203        }
204
205        if n_samples == 0 {
206            return Err(FerroError::InsufficientSamples {
207                required: 1,
208                actual: 0,
209                context: "ElasticNet requires at least one sample".into(),
210            });
211        }
212
213        let n_f = F::from(n_samples).unwrap();
214
215        // Center data when fitting intercept.
216        let (x_work, y_work, x_mean, y_mean) = if self.fit_intercept {
217            let x_mean = x
218                .mean_axis(Axis(0))
219                .ok_or_else(|| FerroError::NumericalInstability {
220                    message: "failed to compute column means".into(),
221                })?;
222            let y_mean = y.mean().ok_or_else(|| FerroError::NumericalInstability {
223                message: "failed to compute target mean".into(),
224            })?;
225
226            let x_c = x - &x_mean;
227            let y_c = y - y_mean;
228            (x_c, y_c, Some(x_mean), Some(y_mean))
229        } else {
230            (x.clone(), y.clone(), None, None)
231        };
232
233        // Precompute per-column X_j^T X_j / n (used as denominator).
234        let col_norms: Vec<F> = (0..n_features)
235            .map(|j| {
236                let col = x_work.column(j);
237                col.dot(&col) / n_f
238            })
239            .collect();
240
241        // L1 and L2 penalty strengths split from alpha/l1_ratio.
242        let alpha_l1 = self.alpha * self.l1_ratio;
243        let alpha_l2 = self.alpha * (F::one() - self.l1_ratio);
244
245        // Effective denominator per column: (X_j^T X_j / n) + alpha_l2.
246        let denominators: Vec<F> = col_norms.iter().map(|&cn| cn + alpha_l2).collect();
247
248        let mut w = Array1::<F>::zeros(n_features);
249        let mut residual = y_work.clone();
250
251        for _iter in 0..self.max_iter {
252            let mut max_change = F::zero();
253
254            for j in 0..n_features {
255                let col_j = x_work.column(j);
256                let w_old = w[j];
257
258                // Add back contribution of current coefficient j to residual.
259                if w_old != F::zero() {
260                    for i in 0..n_samples {
261                        residual[i] = residual[i] + col_j[i] * w_old;
262                    }
263                }
264
265                // Unpenalized correlation: X_j^T r / n.
266                let rho_j = col_j.dot(&residual) / n_f;
267
268                // Apply soft-threshold for L1, then divide by (col_norm + alpha_l2).
269                let w_new = if denominators[j] > F::zero() {
270                    soft_threshold(rho_j, alpha_l1) / denominators[j]
271                } else {
272                    F::zero()
273                };
274
275                // Update residual with new coefficient.
276                if w_new != F::zero() {
277                    for i in 0..n_samples {
278                        residual[i] = residual[i] - col_j[i] * w_new;
279                    }
280                }
281
282                let change = (w_new - w_old).abs();
283                if change > max_change {
284                    max_change = change;
285                }
286
287                w[j] = w_new;
288            }
289
290            if max_change < self.tol {
291                let intercept = compute_intercept(&x_mean, &y_mean, &w);
292                return Ok(FittedElasticNet {
293                    coefficients: w,
294                    intercept,
295                });
296            }
297        }
298
299        // Return best solution found even without full convergence.
300        let intercept = compute_intercept(&x_mean, &y_mean, &w);
301        Ok(FittedElasticNet {
302            coefficients: w,
303            intercept,
304        })
305    }
306}
307
308/// Compute intercept from the centered means and fitted coefficients.
309fn compute_intercept<F: Float + 'static>(
310    x_mean: &Option<Array1<F>>,
311    y_mean: &Option<F>,
312    w: &Array1<F>,
313) -> F {
314    if let (Some(xm), Some(ym)) = (x_mean, y_mean) {
315        *ym - xm.dot(w)
316    } else {
317        F::zero()
318    }
319}
320
321impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>> for FittedElasticNet<F> {
322    type Output = Array1<F>;
323    type Error = FerroError;
324
325    /// Predict target values for the given feature matrix.
326    ///
327    /// Computes `X @ coefficients + intercept`.
328    ///
329    /// # Errors
330    ///
331    /// Returns [`FerroError::ShapeMismatch`] if the number of features
332    /// does not match the fitted model.
333    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
334        let n_features = x.ncols();
335        if n_features != self.coefficients.len() {
336            return Err(FerroError::ShapeMismatch {
337                expected: vec![self.coefficients.len()],
338                actual: vec![n_features],
339                context: "number of features must match fitted model".into(),
340            });
341        }
342
343        let preds = x.dot(&self.coefficients) + self.intercept;
344        Ok(preds)
345    }
346}
347
348impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F> for FittedElasticNet<F> {
349    /// Returns the learned coefficient vector.
350    fn coefficients(&self) -> &Array1<F> {
351        &self.coefficients
352    }
353
354    /// Returns the learned intercept term.
355    fn intercept(&self) -> F {
356        self.intercept
357    }
358}
359
360// Pipeline integration.
361impl<F> PipelineEstimator<F> for ElasticNet<F>
362where
363    F: Float + FromPrimitive + ScalarOperand + Send + Sync + 'static,
364{
365    /// Fit the model and return it as a boxed pipeline estimator.
366    ///
367    /// # Errors
368    ///
369    /// Propagates any [`FerroError`] from `fit`.
370    fn fit_pipeline(
371        &self,
372        x: &Array2<F>,
373        y: &Array1<F>,
374    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
375        let fitted = self.fit(x, y)?;
376        Ok(Box::new(fitted))
377    }
378}
379
380impl<F> FittedPipelineEstimator<F> for FittedElasticNet<F>
381where
382    F: Float + ScalarOperand + Send + Sync + 'static,
383{
384    /// Generate predictions via the pipeline interface.
385    ///
386    /// # Errors
387    ///
388    /// Propagates any [`FerroError`] from `predict`.
389    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
390        self.predict(x)
391    }
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397    use approx::assert_relative_eq;
398    use ndarray::array;
399
400    // ---- soft_threshold helpers ----
401
402    #[test]
403    fn test_soft_threshold_positive() {
404        assert_relative_eq!(soft_threshold(5.0_f64, 1.0), 4.0);
405    }
406
407    #[test]
408    fn test_soft_threshold_negative() {
409        assert_relative_eq!(soft_threshold(-5.0_f64, 1.0), -4.0);
410    }
411
412    #[test]
413    fn test_soft_threshold_within_band() {
414        assert_relative_eq!(soft_threshold(0.5_f64, 1.0), 0.0);
415        assert_relative_eq!(soft_threshold(-0.5_f64, 1.0), 0.0);
416        assert_relative_eq!(soft_threshold(0.0_f64, 1.0), 0.0);
417    }
418
419    // ---- Builder ----
420
421    #[test]
422    fn test_default_builder() {
423        let m = ElasticNet::<f64>::new();
424        assert_relative_eq!(m.alpha, 1.0);
425        assert_relative_eq!(m.l1_ratio, 0.5);
426        assert_eq!(m.max_iter, 1000);
427        assert!(m.fit_intercept);
428    }
429
430    #[test]
431    fn test_builder_setters() {
432        let m = ElasticNet::<f64>::new()
433            .with_alpha(0.5)
434            .with_l1_ratio(0.2)
435            .with_max_iter(500)
436            .with_tol(1e-6)
437            .with_fit_intercept(false);
438        assert_relative_eq!(m.alpha, 0.5);
439        assert_relative_eq!(m.l1_ratio, 0.2);
440        assert_eq!(m.max_iter, 500);
441        assert!(!m.fit_intercept);
442    }
443
444    // ---- Validation errors ----
445
446    #[test]
447    fn test_negative_alpha_error() {
448        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
449        let y = array![1.0, 2.0, 3.0];
450        let result = ElasticNet::<f64>::new().with_alpha(-1.0).fit(&x, &y);
451        assert!(result.is_err());
452    }
453
454    #[test]
455    fn test_l1_ratio_out_of_range_error() {
456        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
457        let y = array![1.0, 2.0, 3.0];
458        let result = ElasticNet::<f64>::new().with_l1_ratio(1.5).fit(&x, &y);
459        assert!(result.is_err());
460    }
461
462    #[test]
463    fn test_shape_mismatch_error() {
464        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
465        let y = array![1.0, 2.0];
466        let result = ElasticNet::<f64>::new().fit(&x, &y);
467        assert!(result.is_err());
468    }
469
470    // ---- Correctness ----
471
472    #[test]
473    fn test_lasso_limit_l1_ratio_one() {
474        // With l1_ratio=1, ElasticNet should behave like Lasso.
475        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
476        let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
477
478        let model = ElasticNet::<f64>::new().with_alpha(0.0).with_l1_ratio(1.0);
479        let fitted = model.fit(&x, &y).unwrap();
480
481        assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 1e-4);
482        assert_relative_eq!(fitted.intercept(), 1.0, epsilon = 1e-4);
483    }
484
485    #[test]
486    fn test_ridge_limit_l1_ratio_zero() {
487        // With l1_ratio=0 and alpha=0, should recover OLS.
488        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
489        let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
490
491        let model = ElasticNet::<f64>::new().with_alpha(0.0).with_l1_ratio(0.0);
492        let fitted = model.fit(&x, &y).unwrap();
493
494        assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 1e-4);
495        assert_relative_eq!(fitted.intercept(), 1.0, epsilon = 1e-4);
496    }
497
498    #[test]
499    fn test_sparsity_with_high_l1_ratio() {
500        // High alpha with l1_ratio=1 should zero out irrelevant features.
501        let x = Array2::from_shape_vec(
502            (10, 3),
503            vec![
504                1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0, 0.0, 4.0, 0.0, 0.0, 5.0, 0.0, 0.0, 6.0,
505                0.0, 0.0, 7.0, 0.0, 0.0, 8.0, 0.0, 0.0, 9.0, 0.0, 0.0, 10.0, 0.0, 0.0,
506            ],
507        )
508        .unwrap();
509        let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0];
510
511        let model = ElasticNet::<f64>::new().with_alpha(5.0).with_l1_ratio(1.0);
512        let fitted = model.fit(&x, &y).unwrap();
513
514        assert_relative_eq!(fitted.coefficients()[1], 0.0, epsilon = 1e-10);
515        assert_relative_eq!(fitted.coefficients()[2], 0.0, epsilon = 1e-10);
516    }
517
518    #[test]
519    fn test_higher_alpha_shrinks_more() {
520        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
521        let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
522
523        let low = ElasticNet::<f64>::new()
524            .with_alpha(0.01)
525            .with_l1_ratio(0.5)
526            .fit(&x, &y)
527            .unwrap();
528        let high = ElasticNet::<f64>::new()
529            .with_alpha(2.0)
530            .with_l1_ratio(0.5)
531            .fit(&x, &y)
532            .unwrap();
533
534        assert!(high.coefficients()[0].abs() <= low.coefficients()[0].abs());
535    }
536
537    #[test]
538    fn test_no_intercept() {
539        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
540        let y = array![2.0, 4.0, 6.0, 8.0];
541
542        let fitted = ElasticNet::<f64>::new()
543            .with_alpha(0.0)
544            .with_l1_ratio(0.5)
545            .with_fit_intercept(false)
546            .fit(&x, &y)
547            .unwrap();
548
549        assert_relative_eq!(fitted.intercept(), 0.0, epsilon = 1e-10);
550    }
551
552    #[test]
553    fn test_predict_correct_length() {
554        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
555        let y = array![2.0, 4.0, 6.0, 8.0];
556
557        let fitted = ElasticNet::<f64>::new()
558            .with_alpha(0.01)
559            .fit(&x, &y)
560            .unwrap();
561        let preds = fitted.predict(&x).unwrap();
562        assert_eq!(preds.len(), 4);
563    }
564
565    #[test]
566    fn test_predict_feature_mismatch() {
567        let x_train = Array2::from_shape_vec((3, 2), vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0]).unwrap();
568        let y = array![1.0, 2.0, 3.0];
569        let fitted = ElasticNet::<f64>::new()
570            .with_alpha(0.01)
571            .fit(&x_train, &y)
572            .unwrap();
573
574        let x_bad = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
575        let result = fitted.predict(&x_bad);
576        assert!(result.is_err());
577    }
578
579    #[test]
580    fn test_has_coefficients_length() {
581        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
582        let y = array![1.0, 2.0, 3.0];
583        let fitted = ElasticNet::<f64>::new()
584            .with_alpha(0.1)
585            .fit(&x, &y)
586            .unwrap();
587
588        assert_eq!(fitted.coefficients().len(), 2);
589    }
590
591    #[test]
592    fn test_pipeline_integration() {
593        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
594        let y = array![3.0, 5.0, 7.0, 9.0];
595
596        let model = ElasticNet::<f64>::new().with_alpha(0.01);
597        let fitted_pipe = model.fit_pipeline(&x, &y).unwrap();
598        let preds = fitted_pipe.predict_pipeline(&x).unwrap();
599        assert_eq!(preds.len(), 4);
600    }
601}