Skip to main content

ferrolearn_linear/
elastic_net_cv.rs

1//! ElasticNet regression with built-in cross-validation for alpha and
2//! l1_ratio selection.
3//!
4//! This module provides [`ElasticNetCV`], which automatically selects the
5//! best `(alpha, l1_ratio)` pair using k-fold cross-validation. For each
6//! candidate `l1_ratio`, an alpha grid is generated (or supplied), and the
7//! combination that minimises mean squared error is selected.
8//!
9//! # Examples
10//!
11//! ```
12//! use ferrolearn_linear::ElasticNetCV;
13//! use ferrolearn_core::{Fit, Predict};
14//! use ndarray::{Array1, Array2};
15//!
16//! let model = ElasticNetCV::<f64>::new();
17//! let x = Array2::from_shape_vec((10, 1), (1..=10).map(|i| i as f64).collect()).unwrap();
18//! let y = Array1::from_iter((1..=10).map(|i| 2.0 * i as f64 + 1.0));
19//!
20//! let fitted = model.fit(&x, &y).unwrap();
21//! let preds = fitted.predict(&x).unwrap();
22//! assert_eq!(preds.len(), 10);
23//! ```
24
25use ferrolearn_core::error::FerroError;
26use ferrolearn_core::introspection::HasCoefficients;
27use ferrolearn_core::traits::{Fit, Predict};
28use ndarray::{Array1, Array2, Axis, ScalarOperand};
29use num_traits::{Float, FromPrimitive};
30
31use crate::ElasticNet;
32
33/// ElasticNet regression with built-in cross-validation for joint
34/// `(alpha, l1_ratio)` selection.
35///
36/// For each candidate `l1_ratio`, the module generates a log-spaced alpha
37/// grid (from `alpha_max` down to `alpha_max * 1e-3`) or uses the
38/// user-supplied grid, runs k-fold CV, and selects the combination that
39/// minimises mean squared error.
40///
41/// # Type Parameters
42///
43/// - `F`: The floating-point type (`f32` or `f64`).
44#[derive(Debug, Clone)]
45pub struct ElasticNetCV<F> {
46    /// Candidate L1/L2 mixing ratios.
47    l1_ratios: Vec<F>,
48    /// Number of alphas to generate per l1_ratio when no explicit grid
49    /// is supplied.
50    n_alphas: usize,
51    /// Number of cross-validation folds.
52    cv: usize,
53    /// Maximum coordinate descent iterations per ElasticNet fit.
54    max_iter: usize,
55    /// Convergence tolerance for coordinate descent.
56    tol: F,
57    /// Whether to fit an intercept (bias) term.
58    fit_intercept: bool,
59}
60
61impl<F: Float + FromPrimitive> ElasticNetCV<F> {
62    /// Create a new `ElasticNetCV` with default settings.
63    ///
64    /// Defaults:
65    /// - `l1_ratios = [0.1, 0.5, 0.7, 0.9, 0.95, 0.99, 1.0]`
66    /// - `n_alphas = 100`
67    /// - `cv = 5`
68    /// - `max_iter = 1000`
69    /// - `tol = 1e-4`
70    /// - `fit_intercept = true`
71    #[must_use]
72    pub fn new() -> Self {
73        Self {
74            l1_ratios: vec![
75                F::from(0.1).unwrap(),
76                F::from(0.5).unwrap(),
77                F::from(0.7).unwrap(),
78                F::from(0.9).unwrap(),
79                F::from(0.95).unwrap(),
80                F::from(0.99).unwrap(),
81                F::one(),
82            ],
83            n_alphas: 100,
84            cv: 5,
85            max_iter: 1000,
86            tol: F::from(1e-4).unwrap(),
87            fit_intercept: true,
88        }
89    }
90
91    /// Set the candidate L1/L2 mixing ratios.
92    ///
93    /// Each value must be in `[0.0, 1.0]`.
94    #[must_use]
95    pub fn with_l1_ratios(mut self, l1_ratios: Vec<F>) -> Self {
96        self.l1_ratios = l1_ratios;
97        self
98    }
99
100    /// Set the number of alphas generated per `l1_ratio`.
101    #[must_use]
102    pub fn with_n_alphas(mut self, n_alphas: usize) -> Self {
103        self.n_alphas = n_alphas;
104        self
105    }
106
107    /// Set the number of cross-validation folds.
108    ///
109    /// Must be at least 2.
110    #[must_use]
111    pub fn with_cv(mut self, cv: usize) -> Self {
112        self.cv = cv;
113        self
114    }
115
116    /// Set the maximum number of coordinate descent iterations.
117    #[must_use]
118    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
119        self.max_iter = max_iter;
120        self
121    }
122
123    /// Set the convergence tolerance.
124    #[must_use]
125    pub fn with_tol(mut self, tol: F) -> Self {
126        self.tol = tol;
127        self
128    }
129
130    /// Set whether to fit an intercept term.
131    #[must_use]
132    pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
133        self.fit_intercept = fit_intercept;
134        self
135    }
136}
137
138impl<F: Float + FromPrimitive> Default for ElasticNetCV<F> {
139    fn default() -> Self {
140        Self::new()
141    }
142}
143
144/// Fitted ElasticNet model with cross-validated `(alpha, l1_ratio)`.
145///
146/// Stores the selected hyperparameters, learned coefficients, and
147/// intercept.
148#[derive(Debug, Clone)]
149pub struct FittedElasticNetCV<F> {
150    /// The alpha that achieved the lowest CV error.
151    best_alpha: F,
152    /// The l1_ratio that achieved the lowest CV error.
153    best_l1_ratio: F,
154    /// Learned coefficient vector (some may be exactly zero).
155    coefficients: Array1<F>,
156    /// Learned intercept (bias) term.
157    intercept: F,
158}
159
160impl<F: Float> FittedElasticNetCV<F> {
161    /// Returns the alpha value selected by cross-validation.
162    #[must_use]
163    pub fn best_alpha(&self) -> F {
164        self.best_alpha
165    }
166
167    /// Returns the l1_ratio selected by cross-validation.
168    #[must_use]
169    pub fn best_l1_ratio(&self) -> F {
170        self.best_l1_ratio
171    }
172}
173
174/// Split sample indices into `k` roughly equal folds.
175fn kfold_indices(n_samples: usize, k: usize) -> Vec<Vec<usize>> {
176    let mut folds: Vec<Vec<usize>> = (0..k).map(|_| Vec::new()).collect();
177    for i in 0..n_samples {
178        folds[i % k].push(i);
179    }
180    folds
181}
182
183/// Compute mean squared error between two arrays.
184fn mse<F: Float + FromPrimitive + 'static>(y_true: &Array1<F>, y_pred: &Array1<F>) -> F {
185    let n = F::from(y_true.len()).unwrap();
186    let diff = y_true - y_pred;
187    diff.dot(&diff) / n
188}
189
190/// Gather rows from a 2-D array by index.
191fn select_rows<F: Float>(x: &Array2<F>, indices: &[usize]) -> Array2<F> {
192    let ncols = x.ncols();
193    let mut out = Array2::<F>::zeros((indices.len(), ncols));
194    for (out_row, &idx) in indices.iter().enumerate() {
195        out.row_mut(out_row).assign(&x.row(idx));
196    }
197    out
198}
199
200/// Gather elements from a 1-D array by index.
201fn select_elements<F: Float>(y: &Array1<F>, indices: &[usize]) -> Array1<F> {
202    Array1::from_iter(indices.iter().map(|&i| y[i]))
203}
204
205/// Compute `alpha_max` for ElasticNet given a specific `l1_ratio`.
206///
207/// `alpha_max = max(|X^T y_centered|) / (n_samples * l1_ratio)`.
208/// When `l1_ratio == 0`, falls back to a large default.
209fn compute_alpha_max_enet<F: Float + FromPrimitive + ScalarOperand>(
210    x: &Array2<F>,
211    y: &Array1<F>,
212    l1_ratio: F,
213    fit_intercept: bool,
214) -> F {
215    let n = F::from(x.nrows()).unwrap();
216
217    let y_work = if fit_intercept {
218        let y_mean = y.mean().unwrap_or_else(F::zero);
219        y - y_mean
220    } else {
221        y.clone()
222    };
223
224    let x_work = if fit_intercept {
225        let x_mean = x.mean_axis(Axis(0)).unwrap();
226        x - &x_mean
227    } else {
228        x.clone()
229    };
230
231    let xty = x_work.t().dot(&y_work);
232    let mut max_abs = F::zero();
233    for &v in &xty {
234        let abs_v = v.abs();
235        if abs_v > max_abs {
236            max_abs = abs_v;
237        }
238    }
239
240    if l1_ratio > F::zero() {
241        max_abs / (n * l1_ratio)
242    } else {
243        // Pure Ridge case — use a reasonable default.
244        max_abs / n
245    }
246}
247
248/// Generate `n` log-spaced values from `high` down to `high * eps_ratio`.
249fn logspace<F: Float + FromPrimitive>(high: F, eps_ratio: F, n: usize) -> Vec<F> {
250    if n == 0 {
251        return Vec::new();
252    }
253    if n == 1 {
254        return vec![high];
255    }
256
257    let log_high = high.ln();
258    let log_low = (high * eps_ratio).ln();
259    let step = (log_low - log_high) / F::from(n - 1).unwrap();
260
261    (0..n)
262        .map(|i| (log_high + step * F::from(i).unwrap()).exp())
263        .collect()
264}
265
266impl<F: Float + Send + Sync + ScalarOperand + FromPrimitive + 'static> Fit<Array2<F>, Array1<F>>
267    for ElasticNetCV<F>
268{
269    type Fitted = FittedElasticNetCV<F>;
270    type Error = FerroError;
271
272    /// Fit the `ElasticNetCV` model.
273    ///
274    /// For each candidate `l1_ratio`, generates an alpha grid, runs k-fold
275    /// CV for every `(alpha, l1_ratio)` pair, then refits on the full data
276    /// using the best combination.
277    ///
278    /// # Errors
279    ///
280    /// - [`FerroError::ShapeMismatch`] if `x` and `y` sizes differ.
281    /// - [`FerroError::InvalidParameter`] if `l1_ratios` is empty, any ratio
282    ///   is outside `[0, 1]`, `cv < 2`, or `n_alphas == 0`.
283    /// - [`FerroError::InsufficientSamples`] if `n_samples < cv`.
284    fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedElasticNetCV<F>, FerroError> {
285        let (n_samples, _n_features) = x.dim();
286
287        if n_samples != y.len() {
288            return Err(FerroError::ShapeMismatch {
289                expected: vec![n_samples],
290                actual: vec![y.len()],
291                context: "y length must match number of samples in X".into(),
292            });
293        }
294
295        if self.l1_ratios.is_empty() {
296            return Err(FerroError::InvalidParameter {
297                name: "l1_ratios".into(),
298                reason: "must contain at least one candidate".into(),
299            });
300        }
301
302        for &r in &self.l1_ratios {
303            if r < F::zero() || r > F::one() {
304                return Err(FerroError::InvalidParameter {
305                    name: "l1_ratios".into(),
306                    reason: "all l1_ratio values must be in [0, 1]".into(),
307                });
308            }
309        }
310
311        if self.cv < 2 {
312            return Err(FerroError::InvalidParameter {
313                name: "cv".into(),
314                reason: "number of folds must be at least 2".into(),
315            });
316        }
317
318        if n_samples < self.cv {
319            return Err(FerroError::InsufficientSamples {
320                required: self.cv,
321                actual: n_samples,
322                context: "ElasticNetCV requires at least as many samples as folds".into(),
323            });
324        }
325
326        if self.n_alphas == 0 {
327            return Err(FerroError::InvalidParameter {
328                name: "n_alphas".into(),
329                reason: "must be at least 1".into(),
330            });
331        }
332
333        let folds = kfold_indices(n_samples, self.cv);
334
335        let mut best_alpha = F::one();
336        let mut best_l1_ratio = self.l1_ratios[0];
337        let mut best_mse = F::infinity();
338
339        for &l1_ratio in &self.l1_ratios {
340            // Generate alpha grid for this l1_ratio.
341            let alpha_max = compute_alpha_max_enet(x, y, l1_ratio, self.fit_intercept);
342            let alpha_grid = if alpha_max <= F::zero() {
343                vec![F::from(1e-6).unwrap(); self.n_alphas]
344            } else {
345                logspace(alpha_max, F::from(1e-3).unwrap(), self.n_alphas)
346            };
347
348            for &alpha in &alpha_grid {
349                let mut total_mse = F::zero();
350
351                for fold_idx in 0..self.cv {
352                    let test_indices = &folds[fold_idx];
353                    let train_indices: Vec<usize> = folds
354                        .iter()
355                        .enumerate()
356                        .filter(|&(i, _)| i != fold_idx)
357                        .flat_map(|(_, v)| v.iter().copied())
358                        .collect();
359
360                    let x_train = select_rows(x, &train_indices);
361                    let y_train = select_elements(y, &train_indices);
362                    let x_test = select_rows(x, test_indices);
363                    let y_test = select_elements(y, test_indices);
364
365                    let model = ElasticNet::<F>::new()
366                        .with_alpha(alpha)
367                        .with_l1_ratio(l1_ratio)
368                        .with_max_iter(self.max_iter)
369                        .with_tol(self.tol)
370                        .with_fit_intercept(self.fit_intercept);
371
372                    let fitted = model.fit(&x_train, &y_train)?;
373                    let preds = fitted.predict(&x_test)?;
374                    total_mse = total_mse + mse(&y_test, &preds);
375                }
376
377                let avg_mse = total_mse / F::from(self.cv).unwrap();
378
379                if avg_mse < best_mse {
380                    best_mse = avg_mse;
381                    best_alpha = alpha;
382                    best_l1_ratio = l1_ratio;
383                }
384            }
385        }
386
387        // Refit on full data with the best hyperparameters.
388        let final_model = ElasticNet::<F>::new()
389            .with_alpha(best_alpha)
390            .with_l1_ratio(best_l1_ratio)
391            .with_max_iter(self.max_iter)
392            .with_tol(self.tol)
393            .with_fit_intercept(self.fit_intercept);
394        let final_fitted = final_model.fit(x, y)?;
395
396        Ok(FittedElasticNetCV {
397            best_alpha,
398            best_l1_ratio,
399            coefficients: final_fitted.coefficients().clone(),
400            intercept: final_fitted.intercept(),
401        })
402    }
403}
404
405impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>>
406    for FittedElasticNetCV<F>
407{
408    type Output = Array1<F>;
409    type Error = FerroError;
410
411    /// Predict target values for the given feature matrix.
412    ///
413    /// Computes `X @ coefficients + intercept`.
414    ///
415    /// # Errors
416    ///
417    /// Returns [`FerroError::ShapeMismatch`] if the number of features
418    /// does not match the fitted model.
419    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
420        let n_features = x.ncols();
421        if n_features != self.coefficients.len() {
422            return Err(FerroError::ShapeMismatch {
423                expected: vec![self.coefficients.len()],
424                actual: vec![n_features],
425                context: "number of features must match fitted model".into(),
426            });
427        }
428
429        let preds = x.dot(&self.coefficients) + self.intercept;
430        Ok(preds)
431    }
432}
433
434impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F>
435    for FittedElasticNetCV<F>
436{
437    fn coefficients(&self) -> &Array1<F> {
438        &self.coefficients
439    }
440
441    fn intercept(&self) -> F {
442        self.intercept
443    }
444}
445
446#[cfg(test)]
447mod tests {
448    use super::*;
449    use approx::assert_relative_eq;
450    use ndarray::array;
451
452    #[test]
453    fn test_elastic_net_cv_default_builder() {
454        let m = ElasticNetCV::<f64>::new();
455        assert_eq!(m.l1_ratios.len(), 7);
456        assert_eq!(m.n_alphas, 100);
457        assert_eq!(m.cv, 5);
458        assert_eq!(m.max_iter, 1000);
459        assert!(m.fit_intercept);
460    }
461
462    #[test]
463    fn test_elastic_net_cv_builder_setters() {
464        let m = ElasticNetCV::<f64>::new()
465            .with_l1_ratios(vec![0.5, 0.9])
466            .with_n_alphas(20)
467            .with_cv(3)
468            .with_max_iter(500)
469            .with_tol(1e-6)
470            .with_fit_intercept(false);
471        assert_eq!(m.l1_ratios.len(), 2);
472        assert_eq!(m.n_alphas, 20);
473        assert_eq!(m.cv, 3);
474        assert_eq!(m.max_iter, 500);
475        assert!(!m.fit_intercept);
476    }
477
478    #[test]
479    fn test_elastic_net_cv_fit_selects_params() {
480        let x = Array2::from_shape_vec((20, 1), (1..=20).map(f64::from).collect()).unwrap();
481        let y = Array1::from_iter((1..=20).map(|i| 2.0 * f64::from(i) + 1.0));
482
483        let model = ElasticNetCV::<f64>::new()
484            .with_l1_ratios(vec![0.5, 0.9, 1.0])
485            .with_n_alphas(10)
486            .with_cv(3);
487
488        let fitted = model.fit(&x, &y).unwrap();
489
490        assert!(fitted.best_alpha() > 0.0);
491        assert!(fitted.best_l1_ratio() >= 0.0);
492        assert!(fitted.best_l1_ratio() <= 1.0);
493    }
494
495    #[test]
496    fn test_elastic_net_cv_predict() {
497        let x = Array2::from_shape_vec((10, 1), (1..=10).map(f64::from).collect()).unwrap();
498        let y = Array1::from_iter((1..=10).map(|i| 2.0 * f64::from(i) + 1.0));
499
500        let model = ElasticNetCV::<f64>::new()
501            .with_l1_ratios(vec![0.5, 0.9])
502            .with_n_alphas(10)
503            .with_cv(3);
504        let fitted = model.fit(&x, &y).unwrap();
505
506        let preds = fitted.predict(&x).unwrap();
507        assert_eq!(preds.len(), 10);
508
509        for i in 0..10 {
510            assert_relative_eq!(preds[i], y[i], epsilon = 2.0);
511        }
512    }
513
514    #[test]
515    fn test_elastic_net_cv_has_coefficients() {
516        let x = Array2::from_shape_vec((10, 2), (0..20).map(f64::from).collect()).unwrap();
517        let y = Array1::from_iter((0..10).map(f64::from));
518
519        let model = ElasticNetCV::<f64>::new()
520            .with_l1_ratios(vec![0.5])
521            .with_n_alphas(5)
522            .with_cv(3);
523        let fitted = model.fit(&x, &y).unwrap();
524
525        assert_eq!(fitted.coefficients().len(), 2);
526    }
527
528    #[test]
529    fn test_elastic_net_cv_empty_l1_ratios_error() {
530        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
531        let y = array![1.0, 2.0, 3.0, 4.0, 5.0];
532
533        let model = ElasticNetCV::<f64>::new().with_l1_ratios(vec![]);
534        let result = model.fit(&x, &y);
535        assert!(result.is_err());
536    }
537
538    #[test]
539    fn test_elastic_net_cv_invalid_l1_ratio_error() {
540        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
541        let y = array![1.0, 2.0, 3.0, 4.0, 5.0];
542
543        let model = ElasticNetCV::<f64>::new().with_l1_ratios(vec![0.5, 1.5]);
544        let result = model.fit(&x, &y);
545        assert!(result.is_err());
546    }
547
548    #[test]
549    fn test_elastic_net_cv_shape_mismatch() {
550        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
551        let y = array![1.0, 2.0];
552
553        let model = ElasticNetCV::<f64>::new();
554        let result = model.fit(&x, &y);
555        assert!(result.is_err());
556    }
557
558    #[test]
559    fn test_elastic_net_cv_insufficient_samples() {
560        let x = Array2::from_shape_vec((2, 1), vec![1.0, 2.0]).unwrap();
561        let y = array![1.0, 2.0];
562
563        let model = ElasticNetCV::<f64>::new().with_cv(5);
564        let result = model.fit(&x, &y);
565        assert!(result.is_err());
566    }
567
568    #[test]
569    fn test_elastic_net_cv_cv_too_small() {
570        let x = Array2::from_shape_vec((10, 1), (1..=10).map(f64::from).collect()).unwrap();
571        let y = Array1::from_iter((1..=10).map(f64::from));
572
573        let model = ElasticNetCV::<f64>::new().with_cv(1);
574        let result = model.fit(&x, &y);
575        assert!(result.is_err());
576    }
577
578    #[test]
579    fn test_elastic_net_cv_predict_feature_mismatch() {
580        let x_train = Array2::from_shape_vec((10, 2), (0..20).map(f64::from).collect()).unwrap();
581        let y = Array1::from_iter((0..10).map(f64::from));
582
583        let fitted = ElasticNetCV::<f64>::new()
584            .with_l1_ratios(vec![0.5])
585            .with_n_alphas(5)
586            .with_cv(3)
587            .fit(&x_train, &y)
588            .unwrap();
589
590        let x_bad = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
591        let result = fitted.predict(&x_bad);
592        assert!(result.is_err());
593    }
594
595    #[test]
596    fn test_elastic_net_cv_no_intercept() {
597        let x = Array2::from_shape_vec((10, 1), (1..=10).map(f64::from).collect()).unwrap();
598        let y = Array1::from_iter((1..=10).map(|i| 2.0 * f64::from(i)));
599
600        let model = ElasticNetCV::<f64>::new()
601            .with_l1_ratios(vec![0.5])
602            .with_n_alphas(5)
603            .with_cv(3)
604            .with_fit_intercept(false);
605        let fitted = model.fit(&x, &y).unwrap();
606
607        let preds = fitted.predict(&x).unwrap();
608        assert_eq!(preds.len(), 10);
609    }
610
611    #[test]
612    fn test_elastic_net_cv_pure_ridge_l1_ratio_zero() {
613        // l1_ratio=0 should work (pure Ridge-like behavior).
614        let x = Array2::from_shape_vec((10, 1), (1..=10).map(f64::from).collect()).unwrap();
615        let y = Array1::from_iter((1..=10).map(|i| 2.0 * f64::from(i) + 1.0));
616
617        let model = ElasticNetCV::<f64>::new()
618            .with_l1_ratios(vec![0.0, 0.5, 1.0])
619            .with_n_alphas(5)
620            .with_cv(3);
621        let fitted = model.fit(&x, &y).unwrap();
622
623        let preds = fitted.predict(&x).unwrap();
624        assert_eq!(preds.len(), 10);
625    }
626}