1use ferrolearn_core::error::FerroError;
35use ferrolearn_core::introspection::HasCoefficients;
36use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
37use ferrolearn_core::traits::{Fit, Predict};
38use ndarray::{Array1, Array2, Axis, ScalarOperand};
39use num_traits::{Float, FromPrimitive};
40
41#[derive(Debug, Clone)]
51pub struct ElasticNet<F> {
52 pub alpha: F,
55 pub l1_ratio: F,
60 pub max_iter: usize,
62 pub tol: F,
64 pub fit_intercept: bool,
66}
67
68impl<F: Float + FromPrimitive> ElasticNet<F> {
69 #[must_use]
74 pub fn new() -> Self {
75 Self {
76 alpha: F::one(),
77 l1_ratio: F::from(0.5).unwrap(),
78 max_iter: 1000,
79 tol: F::from(1e-4).unwrap(),
80 fit_intercept: true,
81 }
82 }
83
84 #[must_use]
86 pub fn with_alpha(mut self, alpha: F) -> Self {
87 self.alpha = alpha;
88 self
89 }
90
91 #[must_use]
96 pub fn with_l1_ratio(mut self, l1_ratio: F) -> Self {
97 self.l1_ratio = l1_ratio;
98 self
99 }
100
101 #[must_use]
103 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
104 self.max_iter = max_iter;
105 self
106 }
107
108 #[must_use]
110 pub fn with_tol(mut self, tol: F) -> Self {
111 self.tol = tol;
112 self
113 }
114
115 #[must_use]
117 pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
118 self.fit_intercept = fit_intercept;
119 self
120 }
121}
122
123impl<F: Float + FromPrimitive> Default for ElasticNet<F> {
124 fn default() -> Self {
125 Self::new()
126 }
127}
128
129#[derive(Debug, Clone)]
134pub struct FittedElasticNet<F> {
135 coefficients: Array1<F>,
137 intercept: F,
139}
140
141impl<F: Float> FittedElasticNet<F> {
142 pub fn intercept(&self) -> F {
144 self.intercept
145 }
146}
147
148#[inline]
152fn soft_threshold<F: Float>(x: F, threshold: F) -> F {
153 if x > threshold {
154 x - threshold
155 } else if x < -threshold {
156 x + threshold
157 } else {
158 F::zero()
159 }
160}
161
162impl<F: Float + Send + Sync + ScalarOperand + FromPrimitive + 'static> Fit<Array2<F>, Array1<F>>
163 for ElasticNet<F>
164{
165 type Fitted = FittedElasticNet<F>;
166 type Error = FerroError;
167
168 fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedElasticNet<F>, FerroError> {
181 let (n_samples, n_features) = x.dim();
182
183 if n_samples != y.len() {
184 return Err(FerroError::ShapeMismatch {
185 expected: vec![n_samples],
186 actual: vec![y.len()],
187 context: "y length must match number of samples in X".into(),
188 });
189 }
190
191 if self.alpha < F::zero() {
192 return Err(FerroError::InvalidParameter {
193 name: "alpha".into(),
194 reason: "must be non-negative".into(),
195 });
196 }
197
198 if self.l1_ratio < F::zero() || self.l1_ratio > F::one() {
199 return Err(FerroError::InvalidParameter {
200 name: "l1_ratio".into(),
201 reason: "must be in [0, 1]".into(),
202 });
203 }
204
205 if n_samples == 0 {
206 return Err(FerroError::InsufficientSamples {
207 required: 1,
208 actual: 0,
209 context: "ElasticNet requires at least one sample".into(),
210 });
211 }
212
213 let n_f = F::from(n_samples).unwrap();
214
215 let (x_work, y_work, x_mean, y_mean) = if self.fit_intercept {
217 let x_mean = x
218 .mean_axis(Axis(0))
219 .ok_or_else(|| FerroError::NumericalInstability {
220 message: "failed to compute column means".into(),
221 })?;
222 let y_mean = y.mean().ok_or_else(|| FerroError::NumericalInstability {
223 message: "failed to compute target mean".into(),
224 })?;
225
226 let x_c = x - &x_mean;
227 let y_c = y - y_mean;
228 (x_c, y_c, Some(x_mean), Some(y_mean))
229 } else {
230 (x.clone(), y.clone(), None, None)
231 };
232
233 let col_norms: Vec<F> = (0..n_features)
235 .map(|j| {
236 let col = x_work.column(j);
237 col.dot(&col) / n_f
238 })
239 .collect();
240
241 let alpha_l1 = self.alpha * self.l1_ratio;
243 let alpha_l2 = self.alpha * (F::one() - self.l1_ratio);
244
245 let denominators: Vec<F> = col_norms.iter().map(|&cn| cn + alpha_l2).collect();
247
248 let mut w = Array1::<F>::zeros(n_features);
249 let mut residual = y_work.clone();
250
251 for _iter in 0..self.max_iter {
252 let mut max_change = F::zero();
253
254 for j in 0..n_features {
255 let col_j = x_work.column(j);
256 let w_old = w[j];
257
258 if w_old != F::zero() {
260 for i in 0..n_samples {
261 residual[i] = residual[i] + col_j[i] * w_old;
262 }
263 }
264
265 let rho_j = col_j.dot(&residual) / n_f;
267
268 let w_new = if denominators[j] > F::zero() {
270 soft_threshold(rho_j, alpha_l1) / denominators[j]
271 } else {
272 F::zero()
273 };
274
275 if w_new != F::zero() {
277 for i in 0..n_samples {
278 residual[i] = residual[i] - col_j[i] * w_new;
279 }
280 }
281
282 let change = (w_new - w_old).abs();
283 if change > max_change {
284 max_change = change;
285 }
286
287 w[j] = w_new;
288 }
289
290 if max_change < self.tol {
291 let intercept = compute_intercept(&x_mean, &y_mean, &w);
292 return Ok(FittedElasticNet {
293 coefficients: w,
294 intercept,
295 });
296 }
297 }
298
299 let intercept = compute_intercept(&x_mean, &y_mean, &w);
301 Ok(FittedElasticNet {
302 coefficients: w,
303 intercept,
304 })
305 }
306}
307
308fn compute_intercept<F: Float + 'static>(
310 x_mean: &Option<Array1<F>>,
311 y_mean: &Option<F>,
312 w: &Array1<F>,
313) -> F {
314 if let (Some(xm), Some(ym)) = (x_mean, y_mean) {
315 *ym - xm.dot(w)
316 } else {
317 F::zero()
318 }
319}
320
321impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>> for FittedElasticNet<F> {
322 type Output = Array1<F>;
323 type Error = FerroError;
324
325 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
334 let n_features = x.ncols();
335 if n_features != self.coefficients.len() {
336 return Err(FerroError::ShapeMismatch {
337 expected: vec![self.coefficients.len()],
338 actual: vec![n_features],
339 context: "number of features must match fitted model".into(),
340 });
341 }
342
343 let preds = x.dot(&self.coefficients) + self.intercept;
344 Ok(preds)
345 }
346}
347
348impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F> for FittedElasticNet<F> {
349 fn coefficients(&self) -> &Array1<F> {
351 &self.coefficients
352 }
353
354 fn intercept(&self) -> F {
356 self.intercept
357 }
358}
359
360impl PipelineEstimator for ElasticNet<f64> {
362 fn fit_pipeline(
368 &self,
369 x: &Array2<f64>,
370 y: &Array1<f64>,
371 ) -> Result<Box<dyn FittedPipelineEstimator>, FerroError> {
372 let fitted = self.fit(x, y)?;
373 Ok(Box::new(fitted))
374 }
375}
376
377impl FittedPipelineEstimator for FittedElasticNet<f64> {
378 fn predict_pipeline(&self, x: &Array2<f64>) -> Result<Array1<f64>, FerroError> {
384 self.predict(x)
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391 use approx::assert_relative_eq;
392 use ndarray::array;
393
394 #[test]
397 fn test_soft_threshold_positive() {
398 assert_relative_eq!(soft_threshold(5.0_f64, 1.0), 4.0);
399 }
400
401 #[test]
402 fn test_soft_threshold_negative() {
403 assert_relative_eq!(soft_threshold(-5.0_f64, 1.0), -4.0);
404 }
405
406 #[test]
407 fn test_soft_threshold_within_band() {
408 assert_relative_eq!(soft_threshold(0.5_f64, 1.0), 0.0);
409 assert_relative_eq!(soft_threshold(-0.5_f64, 1.0), 0.0);
410 assert_relative_eq!(soft_threshold(0.0_f64, 1.0), 0.0);
411 }
412
413 #[test]
416 fn test_default_builder() {
417 let m = ElasticNet::<f64>::new();
418 assert_relative_eq!(m.alpha, 1.0);
419 assert_relative_eq!(m.l1_ratio, 0.5);
420 assert_eq!(m.max_iter, 1000);
421 assert!(m.fit_intercept);
422 }
423
424 #[test]
425 fn test_builder_setters() {
426 let m = ElasticNet::<f64>::new()
427 .with_alpha(0.5)
428 .with_l1_ratio(0.2)
429 .with_max_iter(500)
430 .with_tol(1e-6)
431 .with_fit_intercept(false);
432 assert_relative_eq!(m.alpha, 0.5);
433 assert_relative_eq!(m.l1_ratio, 0.2);
434 assert_eq!(m.max_iter, 500);
435 assert!(!m.fit_intercept);
436 }
437
438 #[test]
441 fn test_negative_alpha_error() {
442 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
443 let y = array![1.0, 2.0, 3.0];
444 let result = ElasticNet::<f64>::new().with_alpha(-1.0).fit(&x, &y);
445 assert!(result.is_err());
446 }
447
448 #[test]
449 fn test_l1_ratio_out_of_range_error() {
450 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
451 let y = array![1.0, 2.0, 3.0];
452 let result = ElasticNet::<f64>::new().with_l1_ratio(1.5).fit(&x, &y);
453 assert!(result.is_err());
454 }
455
456 #[test]
457 fn test_shape_mismatch_error() {
458 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
459 let y = array![1.0, 2.0];
460 let result = ElasticNet::<f64>::new().fit(&x, &y);
461 assert!(result.is_err());
462 }
463
464 #[test]
467 fn test_lasso_limit_l1_ratio_one() {
468 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
470 let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
471
472 let model = ElasticNet::<f64>::new().with_alpha(0.0).with_l1_ratio(1.0);
473 let fitted = model.fit(&x, &y).unwrap();
474
475 assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 1e-4);
476 assert_relative_eq!(fitted.intercept(), 1.0, epsilon = 1e-4);
477 }
478
479 #[test]
480 fn test_ridge_limit_l1_ratio_zero() {
481 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
483 let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
484
485 let model = ElasticNet::<f64>::new().with_alpha(0.0).with_l1_ratio(0.0);
486 let fitted = model.fit(&x, &y).unwrap();
487
488 assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 1e-4);
489 assert_relative_eq!(fitted.intercept(), 1.0, epsilon = 1e-4);
490 }
491
492 #[test]
493 fn test_sparsity_with_high_l1_ratio() {
494 let x = Array2::from_shape_vec(
496 (10, 3),
497 vec![
498 1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0, 0.0, 4.0, 0.0, 0.0, 5.0, 0.0, 0.0, 6.0,
499 0.0, 0.0, 7.0, 0.0, 0.0, 8.0, 0.0, 0.0, 9.0, 0.0, 0.0, 10.0, 0.0, 0.0,
500 ],
501 )
502 .unwrap();
503 let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0];
504
505 let model = ElasticNet::<f64>::new().with_alpha(5.0).with_l1_ratio(1.0);
506 let fitted = model.fit(&x, &y).unwrap();
507
508 assert_relative_eq!(fitted.coefficients()[1], 0.0, epsilon = 1e-10);
509 assert_relative_eq!(fitted.coefficients()[2], 0.0, epsilon = 1e-10);
510 }
511
512 #[test]
513 fn test_higher_alpha_shrinks_more() {
514 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
515 let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
516
517 let low = ElasticNet::<f64>::new()
518 .with_alpha(0.01)
519 .with_l1_ratio(0.5)
520 .fit(&x, &y)
521 .unwrap();
522 let high = ElasticNet::<f64>::new()
523 .with_alpha(2.0)
524 .with_l1_ratio(0.5)
525 .fit(&x, &y)
526 .unwrap();
527
528 assert!(high.coefficients()[0].abs() <= low.coefficients()[0].abs());
529 }
530
531 #[test]
532 fn test_no_intercept() {
533 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
534 let y = array![2.0, 4.0, 6.0, 8.0];
535
536 let fitted = ElasticNet::<f64>::new()
537 .with_alpha(0.0)
538 .with_l1_ratio(0.5)
539 .with_fit_intercept(false)
540 .fit(&x, &y)
541 .unwrap();
542
543 assert_relative_eq!(fitted.intercept(), 0.0, epsilon = 1e-10);
544 }
545
546 #[test]
547 fn test_predict_correct_length() {
548 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
549 let y = array![2.0, 4.0, 6.0, 8.0];
550
551 let fitted = ElasticNet::<f64>::new()
552 .with_alpha(0.01)
553 .fit(&x, &y)
554 .unwrap();
555 let preds = fitted.predict(&x).unwrap();
556 assert_eq!(preds.len(), 4);
557 }
558
559 #[test]
560 fn test_predict_feature_mismatch() {
561 let x_train = Array2::from_shape_vec((3, 2), vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0]).unwrap();
562 let y = array![1.0, 2.0, 3.0];
563 let fitted = ElasticNet::<f64>::new()
564 .with_alpha(0.01)
565 .fit(&x_train, &y)
566 .unwrap();
567
568 let x_bad = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
569 let result = fitted.predict(&x_bad);
570 assert!(result.is_err());
571 }
572
573 #[test]
574 fn test_has_coefficients_length() {
575 let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
576 let y = array![1.0, 2.0, 3.0];
577 let fitted = ElasticNet::<f64>::new()
578 .with_alpha(0.1)
579 .fit(&x, &y)
580 .unwrap();
581
582 assert_eq!(fitted.coefficients().len(), 2);
583 }
584
585 #[test]
586 fn test_pipeline_integration() {
587 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
588 let y = array![3.0, 5.0, 7.0, 9.0];
589
590 let model = ElasticNet::<f64>::new().with_alpha(0.01);
591 let fitted_pipe = model.fit_pipeline(&x, &y).unwrap();
592 let preds = fitted_pipe.predict_pipeline(&x).unwrap();
593 assert_eq!(preds.len(), 4);
594 }
595}