1use ferrolearn_core::error::FerroError;
35use ferrolearn_core::introspection::HasCoefficients;
36use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
37use ferrolearn_core::traits::{Fit, Predict};
38use ndarray::{Array1, Array2, Axis, ScalarOperand};
39use num_traits::{Float, FromPrimitive};
40
41#[derive(Debug, Clone)]
51pub struct ElasticNet<F> {
52 pub alpha: F,
55 pub l1_ratio: F,
60 pub max_iter: usize,
62 pub tol: F,
64 pub fit_intercept: bool,
66}
67
68impl<F: Float + FromPrimitive> ElasticNet<F> {
69 #[must_use]
74 pub fn new() -> Self {
75 Self {
76 alpha: F::one(),
77 l1_ratio: F::from(0.5).unwrap(),
78 max_iter: 1000,
79 tol: F::from(1e-4).unwrap(),
80 fit_intercept: true,
81 }
82 }
83
84 #[must_use]
86 pub fn with_alpha(mut self, alpha: F) -> Self {
87 self.alpha = alpha;
88 self
89 }
90
91 #[must_use]
96 pub fn with_l1_ratio(mut self, l1_ratio: F) -> Self {
97 self.l1_ratio = l1_ratio;
98 self
99 }
100
101 #[must_use]
103 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
104 self.max_iter = max_iter;
105 self
106 }
107
108 #[must_use]
110 pub fn with_tol(mut self, tol: F) -> Self {
111 self.tol = tol;
112 self
113 }
114
115 #[must_use]
117 pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
118 self.fit_intercept = fit_intercept;
119 self
120 }
121}
122
123impl<F: Float + FromPrimitive> Default for ElasticNet<F> {
124 fn default() -> Self {
125 Self::new()
126 }
127}
128
129#[derive(Debug, Clone)]
134pub struct FittedElasticNet<F> {
135 coefficients: Array1<F>,
137 intercept: F,
139}
140
141impl<F: Float> FittedElasticNet<F> {
142 pub fn intercept(&self) -> F {
144 self.intercept
145 }
146}
147
148#[inline]
152fn soft_threshold<F: Float>(x: F, threshold: F) -> F {
153 if x > threshold {
154 x - threshold
155 } else if x < -threshold {
156 x + threshold
157 } else {
158 F::zero()
159 }
160}
161
162impl<F: Float + Send + Sync + ScalarOperand + FromPrimitive + 'static> Fit<Array2<F>, Array1<F>>
163 for ElasticNet<F>
164{
165 type Fitted = FittedElasticNet<F>;
166 type Error = FerroError;
167
168 fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedElasticNet<F>, FerroError> {
181 let (n_samples, n_features) = x.dim();
182
183 if n_samples != y.len() {
184 return Err(FerroError::ShapeMismatch {
185 expected: vec![n_samples],
186 actual: vec![y.len()],
187 context: "y length must match number of samples in X".into(),
188 });
189 }
190
191 if self.alpha < F::zero() {
192 return Err(FerroError::InvalidParameter {
193 name: "alpha".into(),
194 reason: "must be non-negative".into(),
195 });
196 }
197
198 if self.l1_ratio < F::zero() || self.l1_ratio > F::one() {
199 return Err(FerroError::InvalidParameter {
200 name: "l1_ratio".into(),
201 reason: "must be in [0, 1]".into(),
202 });
203 }
204
205 if n_samples == 0 {
206 return Err(FerroError::InsufficientSamples {
207 required: 1,
208 actual: 0,
209 context: "ElasticNet requires at least one sample".into(),
210 });
211 }
212
213 let n_f = F::from(n_samples).unwrap();
214
215 let (x_work, y_work, x_mean, y_mean) = if self.fit_intercept {
217 let x_mean = x
218 .mean_axis(Axis(0))
219 .ok_or_else(|| FerroError::NumericalInstability {
220 message: "failed to compute column means".into(),
221 })?;
222 let y_mean = y.mean().ok_or_else(|| FerroError::NumericalInstability {
223 message: "failed to compute target mean".into(),
224 })?;
225
226 let x_c = x - &x_mean;
227 let y_c = y - y_mean;
228 (x_c, y_c, Some(x_mean), Some(y_mean))
229 } else {
230 (x.clone(), y.clone(), None, None)
231 };
232
233 let col_norms: Vec<F> = (0..n_features)
235 .map(|j| {
236 let col = x_work.column(j);
237 col.dot(&col) / n_f
238 })
239 .collect();
240
241 let alpha_l1 = self.alpha * self.l1_ratio;
243 let alpha_l2 = self.alpha * (F::one() - self.l1_ratio);
244
245 let denominators: Vec<F> = col_norms.iter().map(|&cn| cn + alpha_l2).collect();
247
248 let mut w = Array1::<F>::zeros(n_features);
249 let mut residual = y_work.clone();
250
251 for _iter in 0..self.max_iter {
252 let mut max_change = F::zero();
253
254 for j in 0..n_features {
255 let col_j = x_work.column(j);
256 let w_old = w[j];
257
258 if w_old != F::zero() {
260 for i in 0..n_samples {
261 residual[i] = residual[i] + col_j[i] * w_old;
262 }
263 }
264
265 let rho_j = col_j.dot(&residual) / n_f;
267
268 let w_new = if denominators[j] > F::zero() {
270 soft_threshold(rho_j, alpha_l1) / denominators[j]
271 } else {
272 F::zero()
273 };
274
275 if w_new != F::zero() {
277 for i in 0..n_samples {
278 residual[i] = residual[i] - col_j[i] * w_new;
279 }
280 }
281
282 let change = (w_new - w_old).abs();
283 if change > max_change {
284 max_change = change;
285 }
286
287 w[j] = w_new;
288 }
289
290 if max_change < self.tol {
291 let intercept = compute_intercept(&x_mean, &y_mean, &w);
292 return Ok(FittedElasticNet {
293 coefficients: w,
294 intercept,
295 });
296 }
297 }
298
299 let intercept = compute_intercept(&x_mean, &y_mean, &w);
301 Ok(FittedElasticNet {
302 coefficients: w,
303 intercept,
304 })
305 }
306}
307
308fn compute_intercept<F: Float + 'static>(
310 x_mean: &Option<Array1<F>>,
311 y_mean: &Option<F>,
312 w: &Array1<F>,
313) -> F {
314 if let (Some(xm), Some(ym)) = (x_mean, y_mean) {
315 *ym - xm.dot(w)
316 } else {
317 F::zero()
318 }
319}
320
321impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>> for FittedElasticNet<F> {
322 type Output = Array1<F>;
323 type Error = FerroError;
324
325 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
334 let n_features = x.ncols();
335 if n_features != self.coefficients.len() {
336 return Err(FerroError::ShapeMismatch {
337 expected: vec![self.coefficients.len()],
338 actual: vec![n_features],
339 context: "number of features must match fitted model".into(),
340 });
341 }
342
343 let preds = x.dot(&self.coefficients) + self.intercept;
344 Ok(preds)
345 }
346}
347
348impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F> for FittedElasticNet<F> {
349 fn coefficients(&self) -> &Array1<F> {
351 &self.coefficients
352 }
353
354 fn intercept(&self) -> F {
356 self.intercept
357 }
358}
359
360impl<F> PipelineEstimator<F> for ElasticNet<F>
362where
363 F: Float + FromPrimitive + ScalarOperand + Send + Sync + 'static,
364{
365 fn fit_pipeline(
371 &self,
372 x: &Array2<F>,
373 y: &Array1<F>,
374 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
375 let fitted = self.fit(x, y)?;
376 Ok(Box::new(fitted))
377 }
378}
379
380impl<F> FittedPipelineEstimator<F> for FittedElasticNet<F>
381where
382 F: Float + ScalarOperand + Send + Sync + 'static,
383{
384 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
390 self.predict(x)
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397 use approx::assert_relative_eq;
398 use ndarray::array;
399
400 #[test]
403 fn test_soft_threshold_positive() {
404 assert_relative_eq!(soft_threshold(5.0_f64, 1.0), 4.0);
405 }
406
407 #[test]
408 fn test_soft_threshold_negative() {
409 assert_relative_eq!(soft_threshold(-5.0_f64, 1.0), -4.0);
410 }
411
412 #[test]
413 fn test_soft_threshold_within_band() {
414 assert_relative_eq!(soft_threshold(0.5_f64, 1.0), 0.0);
415 assert_relative_eq!(soft_threshold(-0.5_f64, 1.0), 0.0);
416 assert_relative_eq!(soft_threshold(0.0_f64, 1.0), 0.0);
417 }
418
419 #[test]
422 fn test_default_builder() {
423 let m = ElasticNet::<f64>::new();
424 assert_relative_eq!(m.alpha, 1.0);
425 assert_relative_eq!(m.l1_ratio, 0.5);
426 assert_eq!(m.max_iter, 1000);
427 assert!(m.fit_intercept);
428 }
429
430 #[test]
431 fn test_builder_setters() {
432 let m = ElasticNet::<f64>::new()
433 .with_alpha(0.5)
434 .with_l1_ratio(0.2)
435 .with_max_iter(500)
436 .with_tol(1e-6)
437 .with_fit_intercept(false);
438 assert_relative_eq!(m.alpha, 0.5);
439 assert_relative_eq!(m.l1_ratio, 0.2);
440 assert_eq!(m.max_iter, 500);
441 assert!(!m.fit_intercept);
442 }
443
444 #[test]
447 fn test_negative_alpha_error() {
448 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
449 let y = array![1.0, 2.0, 3.0];
450 let result = ElasticNet::<f64>::new().with_alpha(-1.0).fit(&x, &y);
451 assert!(result.is_err());
452 }
453
454 #[test]
455 fn test_l1_ratio_out_of_range_error() {
456 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
457 let y = array![1.0, 2.0, 3.0];
458 let result = ElasticNet::<f64>::new().with_l1_ratio(1.5).fit(&x, &y);
459 assert!(result.is_err());
460 }
461
462 #[test]
463 fn test_shape_mismatch_error() {
464 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
465 let y = array![1.0, 2.0];
466 let result = ElasticNet::<f64>::new().fit(&x, &y);
467 assert!(result.is_err());
468 }
469
470 #[test]
473 fn test_lasso_limit_l1_ratio_one() {
474 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
476 let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
477
478 let model = ElasticNet::<f64>::new().with_alpha(0.0).with_l1_ratio(1.0);
479 let fitted = model.fit(&x, &y).unwrap();
480
481 assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 1e-4);
482 assert_relative_eq!(fitted.intercept(), 1.0, epsilon = 1e-4);
483 }
484
485 #[test]
486 fn test_ridge_limit_l1_ratio_zero() {
487 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
489 let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
490
491 let model = ElasticNet::<f64>::new().with_alpha(0.0).with_l1_ratio(0.0);
492 let fitted = model.fit(&x, &y).unwrap();
493
494 assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 1e-4);
495 assert_relative_eq!(fitted.intercept(), 1.0, epsilon = 1e-4);
496 }
497
498 #[test]
499 fn test_sparsity_with_high_l1_ratio() {
500 let x = Array2::from_shape_vec(
502 (10, 3),
503 vec![
504 1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0, 0.0, 4.0, 0.0, 0.0, 5.0, 0.0, 0.0, 6.0,
505 0.0, 0.0, 7.0, 0.0, 0.0, 8.0, 0.0, 0.0, 9.0, 0.0, 0.0, 10.0, 0.0, 0.0,
506 ],
507 )
508 .unwrap();
509 let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0];
510
511 let model = ElasticNet::<f64>::new().with_alpha(5.0).with_l1_ratio(1.0);
512 let fitted = model.fit(&x, &y).unwrap();
513
514 assert_relative_eq!(fitted.coefficients()[1], 0.0, epsilon = 1e-10);
515 assert_relative_eq!(fitted.coefficients()[2], 0.0, epsilon = 1e-10);
516 }
517
518 #[test]
519 fn test_higher_alpha_shrinks_more() {
520 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
521 let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
522
523 let low = ElasticNet::<f64>::new()
524 .with_alpha(0.01)
525 .with_l1_ratio(0.5)
526 .fit(&x, &y)
527 .unwrap();
528 let high = ElasticNet::<f64>::new()
529 .with_alpha(2.0)
530 .with_l1_ratio(0.5)
531 .fit(&x, &y)
532 .unwrap();
533
534 assert!(high.coefficients()[0].abs() <= low.coefficients()[0].abs());
535 }
536
537 #[test]
538 fn test_no_intercept() {
539 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
540 let y = array![2.0, 4.0, 6.0, 8.0];
541
542 let fitted = ElasticNet::<f64>::new()
543 .with_alpha(0.0)
544 .with_l1_ratio(0.5)
545 .with_fit_intercept(false)
546 .fit(&x, &y)
547 .unwrap();
548
549 assert_relative_eq!(fitted.intercept(), 0.0, epsilon = 1e-10);
550 }
551
552 #[test]
553 fn test_predict_correct_length() {
554 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
555 let y = array![2.0, 4.0, 6.0, 8.0];
556
557 let fitted = ElasticNet::<f64>::new()
558 .with_alpha(0.01)
559 .fit(&x, &y)
560 .unwrap();
561 let preds = fitted.predict(&x).unwrap();
562 assert_eq!(preds.len(), 4);
563 }
564
565 #[test]
566 fn test_predict_feature_mismatch() {
567 let x_train = Array2::from_shape_vec((3, 2), vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0]).unwrap();
568 let y = array![1.0, 2.0, 3.0];
569 let fitted = ElasticNet::<f64>::new()
570 .with_alpha(0.01)
571 .fit(&x_train, &y)
572 .unwrap();
573
574 let x_bad = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
575 let result = fitted.predict(&x_bad);
576 assert!(result.is_err());
577 }
578
579 #[test]
580 fn test_has_coefficients_length() {
581 let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
582 let y = array![1.0, 2.0, 3.0];
583 let fitted = ElasticNet::<f64>::new()
584 .with_alpha(0.1)
585 .fit(&x, &y)
586 .unwrap();
587
588 assert_eq!(fitted.coefficients().len(), 2);
589 }
590
591 #[test]
592 fn test_pipeline_integration() {
593 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
594 let y = array![3.0, 5.0, 7.0, 9.0];
595
596 let model = ElasticNet::<f64>::new().with_alpha(0.01);
597 let fitted_pipe = model.fit_pipeline(&x, &y).unwrap();
598 let preds = fitted_pipe.predict_pipeline(&x).unwrap();
599 assert_eq!(preds.len(), 4);
600 }
601}