1use ferrolearn_core::error::FerroError;
26use ferrolearn_core::introspection::HasCoefficients;
27use ferrolearn_core::traits::{Fit, Predict};
28use ndarray::{Array1, Array2, Axis, ScalarOperand};
29use num_traits::{Float, FromPrimitive};
30
31use crate::ElasticNet;
32
33#[derive(Debug, Clone)]
45pub struct ElasticNetCV<F> {
46 l1_ratios: Vec<F>,
48 n_alphas: usize,
51 cv: usize,
53 max_iter: usize,
55 tol: F,
57 fit_intercept: bool,
59}
60
61impl<F: Float + FromPrimitive> ElasticNetCV<F> {
62 #[must_use]
72 pub fn new() -> Self {
73 Self {
74 l1_ratios: vec![
75 F::from(0.1).unwrap(),
76 F::from(0.5).unwrap(),
77 F::from(0.7).unwrap(),
78 F::from(0.9).unwrap(),
79 F::from(0.95).unwrap(),
80 F::from(0.99).unwrap(),
81 F::one(),
82 ],
83 n_alphas: 100,
84 cv: 5,
85 max_iter: 1000,
86 tol: F::from(1e-4).unwrap(),
87 fit_intercept: true,
88 }
89 }
90
91 #[must_use]
95 pub fn with_l1_ratios(mut self, l1_ratios: Vec<F>) -> Self {
96 self.l1_ratios = l1_ratios;
97 self
98 }
99
100 #[must_use]
102 pub fn with_n_alphas(mut self, n_alphas: usize) -> Self {
103 self.n_alphas = n_alphas;
104 self
105 }
106
107 #[must_use]
111 pub fn with_cv(mut self, cv: usize) -> Self {
112 self.cv = cv;
113 self
114 }
115
116 #[must_use]
118 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
119 self.max_iter = max_iter;
120 self
121 }
122
123 #[must_use]
125 pub fn with_tol(mut self, tol: F) -> Self {
126 self.tol = tol;
127 self
128 }
129
130 #[must_use]
132 pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
133 self.fit_intercept = fit_intercept;
134 self
135 }
136}
137
138impl<F: Float + FromPrimitive> Default for ElasticNetCV<F> {
139 fn default() -> Self {
140 Self::new()
141 }
142}
143
144#[derive(Debug, Clone)]
149pub struct FittedElasticNetCV<F> {
150 best_alpha: F,
152 best_l1_ratio: F,
154 coefficients: Array1<F>,
156 intercept: F,
158}
159
160impl<F: Float> FittedElasticNetCV<F> {
161 #[must_use]
163 pub fn best_alpha(&self) -> F {
164 self.best_alpha
165 }
166
167 #[must_use]
169 pub fn best_l1_ratio(&self) -> F {
170 self.best_l1_ratio
171 }
172}
173
174fn kfold_indices(n_samples: usize, k: usize) -> Vec<Vec<usize>> {
176 let mut folds: Vec<Vec<usize>> = (0..k).map(|_| Vec::new()).collect();
177 for i in 0..n_samples {
178 folds[i % k].push(i);
179 }
180 folds
181}
182
183fn mse<F: Float + FromPrimitive + 'static>(y_true: &Array1<F>, y_pred: &Array1<F>) -> F {
185 let n = F::from(y_true.len()).unwrap();
186 let diff = y_true - y_pred;
187 diff.dot(&diff) / n
188}
189
190fn select_rows<F: Float>(x: &Array2<F>, indices: &[usize]) -> Array2<F> {
192 let ncols = x.ncols();
193 let mut out = Array2::<F>::zeros((indices.len(), ncols));
194 for (out_row, &idx) in indices.iter().enumerate() {
195 out.row_mut(out_row).assign(&x.row(idx));
196 }
197 out
198}
199
200fn select_elements<F: Float>(y: &Array1<F>, indices: &[usize]) -> Array1<F> {
202 Array1::from_iter(indices.iter().map(|&i| y[i]))
203}
204
205fn compute_alpha_max_enet<F: Float + FromPrimitive + ScalarOperand>(
210 x: &Array2<F>,
211 y: &Array1<F>,
212 l1_ratio: F,
213 fit_intercept: bool,
214) -> F {
215 let n = F::from(x.nrows()).unwrap();
216
217 let y_work = if fit_intercept {
218 let y_mean = y.mean().unwrap_or_else(F::zero);
219 y - y_mean
220 } else {
221 y.clone()
222 };
223
224 let x_work = if fit_intercept {
225 let x_mean = x.mean_axis(Axis(0)).unwrap();
226 x - &x_mean
227 } else {
228 x.clone()
229 };
230
231 let xty = x_work.t().dot(&y_work);
232 let mut max_abs = F::zero();
233 for &v in &xty {
234 let abs_v = v.abs();
235 if abs_v > max_abs {
236 max_abs = abs_v;
237 }
238 }
239
240 if l1_ratio > F::zero() {
241 max_abs / (n * l1_ratio)
242 } else {
243 max_abs / n
245 }
246}
247
248fn logspace<F: Float + FromPrimitive>(high: F, eps_ratio: F, n: usize) -> Vec<F> {
250 if n == 0 {
251 return Vec::new();
252 }
253 if n == 1 {
254 return vec![high];
255 }
256
257 let log_high = high.ln();
258 let log_low = (high * eps_ratio).ln();
259 let step = (log_low - log_high) / F::from(n - 1).unwrap();
260
261 (0..n)
262 .map(|i| (log_high + step * F::from(i).unwrap()).exp())
263 .collect()
264}
265
266impl<F: Float + Send + Sync + ScalarOperand + FromPrimitive + 'static> Fit<Array2<F>, Array1<F>>
267 for ElasticNetCV<F>
268{
269 type Fitted = FittedElasticNetCV<F>;
270 type Error = FerroError;
271
272 fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedElasticNetCV<F>, FerroError> {
285 let (n_samples, _n_features) = x.dim();
286
287 if n_samples != y.len() {
288 return Err(FerroError::ShapeMismatch {
289 expected: vec![n_samples],
290 actual: vec![y.len()],
291 context: "y length must match number of samples in X".into(),
292 });
293 }
294
295 if self.l1_ratios.is_empty() {
296 return Err(FerroError::InvalidParameter {
297 name: "l1_ratios".into(),
298 reason: "must contain at least one candidate".into(),
299 });
300 }
301
302 for &r in &self.l1_ratios {
303 if r < F::zero() || r > F::one() {
304 return Err(FerroError::InvalidParameter {
305 name: "l1_ratios".into(),
306 reason: "all l1_ratio values must be in [0, 1]".into(),
307 });
308 }
309 }
310
311 if self.cv < 2 {
312 return Err(FerroError::InvalidParameter {
313 name: "cv".into(),
314 reason: "number of folds must be at least 2".into(),
315 });
316 }
317
318 if n_samples < self.cv {
319 return Err(FerroError::InsufficientSamples {
320 required: self.cv,
321 actual: n_samples,
322 context: "ElasticNetCV requires at least as many samples as folds".into(),
323 });
324 }
325
326 if self.n_alphas == 0 {
327 return Err(FerroError::InvalidParameter {
328 name: "n_alphas".into(),
329 reason: "must be at least 1".into(),
330 });
331 }
332
333 let folds = kfold_indices(n_samples, self.cv);
334
335 let mut best_alpha = F::one();
336 let mut best_l1_ratio = self.l1_ratios[0];
337 let mut best_mse = F::infinity();
338
339 for &l1_ratio in &self.l1_ratios {
340 let alpha_max = compute_alpha_max_enet(x, y, l1_ratio, self.fit_intercept);
342 let alpha_grid = if alpha_max <= F::zero() {
343 vec![F::from(1e-6).unwrap(); self.n_alphas]
344 } else {
345 logspace(alpha_max, F::from(1e-3).unwrap(), self.n_alphas)
346 };
347
348 for &alpha in &alpha_grid {
349 let mut total_mse = F::zero();
350
351 for fold_idx in 0..self.cv {
352 let test_indices = &folds[fold_idx];
353 let train_indices: Vec<usize> = folds
354 .iter()
355 .enumerate()
356 .filter(|&(i, _)| i != fold_idx)
357 .flat_map(|(_, v)| v.iter().copied())
358 .collect();
359
360 let x_train = select_rows(x, &train_indices);
361 let y_train = select_elements(y, &train_indices);
362 let x_test = select_rows(x, test_indices);
363 let y_test = select_elements(y, test_indices);
364
365 let model = ElasticNet::<F>::new()
366 .with_alpha(alpha)
367 .with_l1_ratio(l1_ratio)
368 .with_max_iter(self.max_iter)
369 .with_tol(self.tol)
370 .with_fit_intercept(self.fit_intercept);
371
372 let fitted = model.fit(&x_train, &y_train)?;
373 let preds = fitted.predict(&x_test)?;
374 total_mse = total_mse + mse(&y_test, &preds);
375 }
376
377 let avg_mse = total_mse / F::from(self.cv).unwrap();
378
379 if avg_mse < best_mse {
380 best_mse = avg_mse;
381 best_alpha = alpha;
382 best_l1_ratio = l1_ratio;
383 }
384 }
385 }
386
387 let final_model = ElasticNet::<F>::new()
389 .with_alpha(best_alpha)
390 .with_l1_ratio(best_l1_ratio)
391 .with_max_iter(self.max_iter)
392 .with_tol(self.tol)
393 .with_fit_intercept(self.fit_intercept);
394 let final_fitted = final_model.fit(x, y)?;
395
396 Ok(FittedElasticNetCV {
397 best_alpha,
398 best_l1_ratio,
399 coefficients: final_fitted.coefficients().clone(),
400 intercept: final_fitted.intercept(),
401 })
402 }
403}
404
405impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>>
406 for FittedElasticNetCV<F>
407{
408 type Output = Array1<F>;
409 type Error = FerroError;
410
411 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
420 let n_features = x.ncols();
421 if n_features != self.coefficients.len() {
422 return Err(FerroError::ShapeMismatch {
423 expected: vec![self.coefficients.len()],
424 actual: vec![n_features],
425 context: "number of features must match fitted model".into(),
426 });
427 }
428
429 let preds = x.dot(&self.coefficients) + self.intercept;
430 Ok(preds)
431 }
432}
433
434impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F>
435 for FittedElasticNetCV<F>
436{
437 fn coefficients(&self) -> &Array1<F> {
438 &self.coefficients
439 }
440
441 fn intercept(&self) -> F {
442 self.intercept
443 }
444}
445
446#[cfg(test)]
447mod tests {
448 use super::*;
449 use approx::assert_relative_eq;
450 use ndarray::array;
451
452 #[test]
453 fn test_elastic_net_cv_default_builder() {
454 let m = ElasticNetCV::<f64>::new();
455 assert_eq!(m.l1_ratios.len(), 7);
456 assert_eq!(m.n_alphas, 100);
457 assert_eq!(m.cv, 5);
458 assert_eq!(m.max_iter, 1000);
459 assert!(m.fit_intercept);
460 }
461
462 #[test]
463 fn test_elastic_net_cv_builder_setters() {
464 let m = ElasticNetCV::<f64>::new()
465 .with_l1_ratios(vec![0.5, 0.9])
466 .with_n_alphas(20)
467 .with_cv(3)
468 .with_max_iter(500)
469 .with_tol(1e-6)
470 .with_fit_intercept(false);
471 assert_eq!(m.l1_ratios.len(), 2);
472 assert_eq!(m.n_alphas, 20);
473 assert_eq!(m.cv, 3);
474 assert_eq!(m.max_iter, 500);
475 assert!(!m.fit_intercept);
476 }
477
478 #[test]
479 fn test_elastic_net_cv_fit_selects_params() {
480 let x = Array2::from_shape_vec((20, 1), (1..=20).map(f64::from).collect()).unwrap();
481 let y = Array1::from_iter((1..=20).map(|i| 2.0 * f64::from(i) + 1.0));
482
483 let model = ElasticNetCV::<f64>::new()
484 .with_l1_ratios(vec![0.5, 0.9, 1.0])
485 .with_n_alphas(10)
486 .with_cv(3);
487
488 let fitted = model.fit(&x, &y).unwrap();
489
490 assert!(fitted.best_alpha() > 0.0);
491 assert!(fitted.best_l1_ratio() >= 0.0);
492 assert!(fitted.best_l1_ratio() <= 1.0);
493 }
494
495 #[test]
496 fn test_elastic_net_cv_predict() {
497 let x = Array2::from_shape_vec((10, 1), (1..=10).map(f64::from).collect()).unwrap();
498 let y = Array1::from_iter((1..=10).map(|i| 2.0 * f64::from(i) + 1.0));
499
500 let model = ElasticNetCV::<f64>::new()
501 .with_l1_ratios(vec![0.5, 0.9])
502 .with_n_alphas(10)
503 .with_cv(3);
504 let fitted = model.fit(&x, &y).unwrap();
505
506 let preds = fitted.predict(&x).unwrap();
507 assert_eq!(preds.len(), 10);
508
509 for i in 0..10 {
510 assert_relative_eq!(preds[i], y[i], epsilon = 2.0);
511 }
512 }
513
514 #[test]
515 fn test_elastic_net_cv_has_coefficients() {
516 let x = Array2::from_shape_vec((10, 2), (0..20).map(f64::from).collect()).unwrap();
517 let y = Array1::from_iter((0..10).map(f64::from));
518
519 let model = ElasticNetCV::<f64>::new()
520 .with_l1_ratios(vec![0.5])
521 .with_n_alphas(5)
522 .with_cv(3);
523 let fitted = model.fit(&x, &y).unwrap();
524
525 assert_eq!(fitted.coefficients().len(), 2);
526 }
527
528 #[test]
529 fn test_elastic_net_cv_empty_l1_ratios_error() {
530 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
531 let y = array![1.0, 2.0, 3.0, 4.0, 5.0];
532
533 let model = ElasticNetCV::<f64>::new().with_l1_ratios(vec![]);
534 let result = model.fit(&x, &y);
535 assert!(result.is_err());
536 }
537
538 #[test]
539 fn test_elastic_net_cv_invalid_l1_ratio_error() {
540 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
541 let y = array![1.0, 2.0, 3.0, 4.0, 5.0];
542
543 let model = ElasticNetCV::<f64>::new().with_l1_ratios(vec![0.5, 1.5]);
544 let result = model.fit(&x, &y);
545 assert!(result.is_err());
546 }
547
548 #[test]
549 fn test_elastic_net_cv_shape_mismatch() {
550 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
551 let y = array![1.0, 2.0];
552
553 let model = ElasticNetCV::<f64>::new();
554 let result = model.fit(&x, &y);
555 assert!(result.is_err());
556 }
557
558 #[test]
559 fn test_elastic_net_cv_insufficient_samples() {
560 let x = Array2::from_shape_vec((2, 1), vec![1.0, 2.0]).unwrap();
561 let y = array![1.0, 2.0];
562
563 let model = ElasticNetCV::<f64>::new().with_cv(5);
564 let result = model.fit(&x, &y);
565 assert!(result.is_err());
566 }
567
568 #[test]
569 fn test_elastic_net_cv_cv_too_small() {
570 let x = Array2::from_shape_vec((10, 1), (1..=10).map(f64::from).collect()).unwrap();
571 let y = Array1::from_iter((1..=10).map(f64::from));
572
573 let model = ElasticNetCV::<f64>::new().with_cv(1);
574 let result = model.fit(&x, &y);
575 assert!(result.is_err());
576 }
577
578 #[test]
579 fn test_elastic_net_cv_predict_feature_mismatch() {
580 let x_train = Array2::from_shape_vec((10, 2), (0..20).map(f64::from).collect()).unwrap();
581 let y = Array1::from_iter((0..10).map(f64::from));
582
583 let fitted = ElasticNetCV::<f64>::new()
584 .with_l1_ratios(vec![0.5])
585 .with_n_alphas(5)
586 .with_cv(3)
587 .fit(&x_train, &y)
588 .unwrap();
589
590 let x_bad = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
591 let result = fitted.predict(&x_bad);
592 assert!(result.is_err());
593 }
594
595 #[test]
596 fn test_elastic_net_cv_no_intercept() {
597 let x = Array2::from_shape_vec((10, 1), (1..=10).map(f64::from).collect()).unwrap();
598 let y = Array1::from_iter((1..=10).map(|i| 2.0 * f64::from(i)));
599
600 let model = ElasticNetCV::<f64>::new()
601 .with_l1_ratios(vec![0.5])
602 .with_n_alphas(5)
603 .with_cv(3)
604 .with_fit_intercept(false);
605 let fitted = model.fit(&x, &y).unwrap();
606
607 let preds = fitted.predict(&x).unwrap();
608 assert_eq!(preds.len(), 10);
609 }
610
611 #[test]
612 fn test_elastic_net_cv_pure_ridge_l1_ratio_zero() {
613 let x = Array2::from_shape_vec((10, 1), (1..=10).map(f64::from).collect()).unwrap();
615 let y = Array1::from_iter((1..=10).map(|i| 2.0 * f64::from(i) + 1.0));
616
617 let model = ElasticNetCV::<f64>::new()
618 .with_l1_ratios(vec![0.0, 0.5, 1.0])
619 .with_n_alphas(5)
620 .with_cv(3);
621 let fitted = model.fit(&x, &y).unwrap();
622
623 let preds = fitted.predict(&x).unwrap();
624 assert_eq!(preds.len(), 10);
625 }
626}