1use approx::{abs_diff_eq, abs_diff_ne};
2use linfa_linalg::norm::Norm;
3#[cfg(not(feature = "blas"))]
4use linfa_linalg::qr::QRInto;
5use ndarray::linalg::general_mat_mul;
6use ndarray::{
7 s, Array, Array1, Array2, ArrayBase, ArrayView, ArrayView1, ArrayView2, Axis, CowArray, Data,
8 Dimension, Ix2, RemoveAxis,
9};
10#[cfg(feature = "blas")]
11use ndarray_linalg::InverseHInto;
12
13use linfa::dataset::{WithLapack, WithoutLapack};
14use linfa::traits::{Fit, PredictInplace};
15use linfa::{
16 dataset::{AsMultiTargets, AsSingleTargets, AsTargets, Records},
17 DatasetBase, Float,
18};
19
20use super::{
21 hyperparams::{ElasticNetValidParams, MultiTaskElasticNetValidParams},
22 ElasticNet, ElasticNetError, MultiTaskElasticNet, Result,
23};
24
25impl<F, D, T> Fit<ArrayBase<D, Ix2>, T, ElasticNetError> for ElasticNetValidParams<F>
26where
27 F: Float,
28 D: Data<Elem = F>,
29 T: AsSingleTargets<Elem = F>,
30{
31 type Object = ElasticNet<F>;
32
33 fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
44 let target = dataset.as_single_targets();
45
46 let (intercept, y) = compute_intercept(self.with_intercept(), target);
47 let (hyperplane, duality_gap, n_steps) = coordinate_descent(
48 dataset.records().view(),
49 y.view(),
50 self.tolerance(),
51 self.max_iterations(),
52 self.l1_ratio(),
53 self.penalty(),
54 );
55 let intercept = intercept.into_scalar();
56
57 let y_est = dataset.records().dot(&hyperplane) + intercept;
58
59 let variance = variance_params(dataset, y_est.view());
61
62 Ok(ElasticNet {
63 hyperplane,
64 intercept,
65 duality_gap,
66 n_steps,
67 variance,
68 })
69 }
70}
71
72impl<F, D, T> Fit<ArrayBase<D, Ix2>, T, ElasticNetError> for MultiTaskElasticNetValidParams<F>
73where
74 F: Float,
75 T: AsMultiTargets<Elem = F>,
76 D: Data<Elem = F>,
77{
78 type Object = MultiTaskElasticNet<F>;
79
80 fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
91 let targets = dataset.targets().as_multi_targets();
92 let (intercept, y) = compute_intercept(self.with_intercept(), targets);
93
94 let (hyperplane, duality_gap, n_steps) = block_coordinate_descent(
95 dataset.records().view(),
96 y.view(),
97 self.tolerance(),
98 self.max_iterations(),
99 self.l1_ratio(),
100 self.penalty(),
101 );
102
103 let y_est = dataset.records().dot(&hyperplane) + &intercept;
104
105 let variance = variance_params(dataset, y_est.view());
107
108 Ok(MultiTaskElasticNet {
109 hyperplane,
110 intercept,
111 duality_gap,
112 n_steps,
113 variance,
114 })
115 }
116}
117
118impl<F: Float, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Ix2>, Array1<F>> for ElasticNet<F> {
119 fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<F>) {
123 assert_eq!(
124 x.nrows(),
125 y.len(),
126 "The number of data points must match the number of output targets."
127 );
128
129 *y = x.dot(&self.hyperplane) + self.intercept;
130 }
131
132 fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<F> {
133 Array1::zeros(x.nrows())
134 }
135}
136
137impl<F: Float, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Ix2>, Array2<F>>
138 for MultiTaskElasticNet<F>
139{
140 fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array2<F>) {
144 assert_eq!(
145 x.nrows(),
146 y.nrows(),
147 "The number of data points must match the number of output targets."
148 );
149
150 *y = x.dot(&self.hyperplane) + &self.intercept;
151 }
152
153 fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array2<F> {
154 Array2::zeros((x.nrows(), x.nrows()))
156 }
157}
158
159impl<F: Float> ElasticNet<F> {
162 pub fn hyperplane(&self) -> &Array1<F> {
164 &self.hyperplane
165 }
166
167 pub fn intercept(&self) -> F {
169 self.intercept
170 }
171
172 pub fn n_steps(&self) -> u32 {
174 self.n_steps
175 }
176
177 pub fn duality_gap(&self) -> F {
179 self.duality_gap
180 }
181
182 pub fn z_score(&self) -> Result<Array1<F>> {
184 self.variance
185 .as_ref()
186 .map(|variance| {
187 self.hyperplane
188 .iter()
189 .zip(variance.iter())
190 .map(|(a, b)| *a / b.sqrt())
191 .collect()
192 })
193 .map_err(|err| err.clone())
194 }
195
196 pub fn confidence_95th(&self) -> Result<Array1<(F, F)>> {
198 let p = F::cast(1.645);
200
201 self.variance
202 .as_ref()
203 .map(|variance| {
204 self.hyperplane
205 .iter()
206 .zip(variance.iter())
207 .map(|(a, b)| (*a - p * b.sqrt(), *a + p * b.sqrt()))
208 .collect()
209 })
210 .map_err(|err| err.clone())
211 }
212}
213
214impl<F: Float> MultiTaskElasticNet<F> {
217 pub fn hyperplane(&self) -> &Array2<F> {
219 &self.hyperplane
220 }
221
222 pub fn intercept(&self) -> &Array1<F> {
225 &self.intercept
226 }
227
228 pub fn n_steps(&self) -> u32 {
230 self.n_steps
231 }
232
233 pub fn duality_gap(&self) -> F {
235 self.duality_gap
236 }
237
238 pub fn z_score(&self) -> Result<Array2<F>> {
240 self.variance
241 .as_ref()
242 .map(|variance| {
243 ndarray::Zip::from(&self.hyperplane)
244 .and_broadcast(variance)
245 .map_collect(|a, b| *a / b.sqrt())
246 })
247 .map_err(|err| err.clone())
248 }
249
250 pub fn confidence_95th(&self) -> Result<Array2<(F, F)>> {
252 let p = F::cast(1.645);
254
255 self.variance
256 .as_ref()
257 .map(|variance| {
258 ndarray::Zip::from(&self.hyperplane)
259 .and_broadcast(variance)
260 .map_collect(|a, b| (*a - p * b.sqrt(), *a + p * b.sqrt()))
261 })
262 .map_err(|err| err.clone())
263 }
264}
265
266fn coordinate_descent<'a, F: Float>(
267 x: ArrayView2<'a, F>,
268 y: ArrayView1<'a, F>,
269 tol: F,
270 max_steps: u32,
271 l1_ratio: F,
272 penalty: F,
273) -> (Array1<F>, F, u32) {
274 let n_samples = F::cast(x.nrows());
275 let n_features = x.ncols();
276 let mut w = Array1::<F>::zeros(n_features);
278 let mut r = y.to_owned();
281 let mut n_steps = 0u32;
282 let norm_cols_x = x.map_axis(Axis(0), |col| col.dot(&col));
283 let mut gap = F::one() + tol;
284 let d_w_tol = tol;
285 let tol = tol * y.dot(&y);
286 while n_steps < max_steps {
287 let mut w_max = F::zero();
288 let mut d_w_max = F::zero();
289 for j in 0..n_features {
290 if abs_diff_eq!(norm_cols_x[j], F::zero()) {
291 continue;
292 }
293 let old_w_j = w[j];
294 let x_j: ArrayView1<F> = x.slice(s![.., j]);
295 if abs_diff_ne!(old_w_j, F::zero()) {
296 r.scaled_add(old_w_j, &x_j);
297 }
298 let tmp: F = x_j.dot(&r);
299 w[j] = tmp.signum() * F::max(tmp.abs() - n_samples * l1_ratio * penalty, F::zero())
300 / (norm_cols_x[j] + n_samples * (F::one() - l1_ratio) * penalty);
301 if abs_diff_ne!(w[j], F::zero()) {
302 r.scaled_add(-w[j], &x_j);
303 }
304 let d_w_j = (w[j] - old_w_j).abs();
305 d_w_max = F::max(d_w_max, d_w_j);
306 w_max = F::max(w_max, w[j].abs());
307 }
308 n_steps += 1;
309
310 if n_steps == max_steps - 1 || abs_diff_eq!(w_max, F::zero()) || d_w_max / w_max < d_w_tol {
311 gap = duality_gap(x.view(), y.view(), w.view(), r.view(), l1_ratio, penalty);
314 if gap < tol {
315 break;
316 }
317 }
318 }
319 (w, gap, n_steps)
320}
321
322fn block_coordinate_descent<'a, F: Float>(
323 x: ArrayView2<'a, F>,
324 y: ArrayView2<'a, F>,
325 tol: F,
326 max_steps: u32,
327 l1_ratio: F,
328 penalty: F,
329) -> (Array2<F>, F, u32) {
330 let n_samples = F::cast(x.nrows());
331 let n_features = x.ncols();
332 let n_tasks = y.ncols();
333 let mut w = Array2::<F>::zeros((n_features, n_tasks));
335 let mut r = y.to_owned();
338 let mut n_steps = 0u32;
339 let norm_cols_x = x.map_axis(Axis(0), |col| col.dot(&col));
340 let mut gap = F::one() + tol;
341 let d_w_tol = tol;
342 let tol = tol * y.iter().map(|&y_ij| y_ij * y_ij).sum();
343 while n_steps < max_steps {
344 let mut w_max = F::zero();
345 let mut d_w_max = F::zero();
346 for j in 0..n_features {
347 if abs_diff_eq!(norm_cols_x[j], F::zero()) {
348 continue;
349 }
350 let mut old_w_j = w.slice_mut(s![j, ..]);
351 let x_j = x.slice(s![.., j]);
352 let norm_old_w_j = old_w_j.dot(&old_w_j).sqrt();
353 if abs_diff_ne!(norm_old_w_j, F::zero()) {
354 general_mat_mul(
356 F::one(),
357 &x_j.view().insert_axis(Axis(1)),
358 &old_w_j.view().insert_axis(Axis(0)),
359 F::one(),
360 &mut r,
361 );
362 }
363 let tmp = x_j.dot(&r);
364 old_w_j.assign(
365 &(block_soft_thresholding(tmp.view(), n_samples * l1_ratio * penalty)
366 / (norm_cols_x[j] + n_samples * (F::one() - l1_ratio) * penalty)),
367 );
368 let norm_w_j = old_w_j.dot(&old_w_j).sqrt();
369 if abs_diff_ne!(norm_w_j, F::zero()) {
370 general_mat_mul(
372 -F::one(),
373 &x_j.insert_axis(Axis(1)),
374 &old_w_j.insert_axis(Axis(0)),
375 F::one(),
376 &mut r,
377 );
378 }
379 let d_w_j = (norm_w_j - norm_old_w_j).abs();
380 d_w_max = F::max(d_w_max, d_w_j);
381 w_max = F::max(w_max, norm_w_j);
382 }
383 n_steps += 1;
384
385 if n_steps == max_steps - 1 || abs_diff_eq!(w_max, F::zero()) || d_w_max / w_max < d_w_tol {
386 gap = duality_gap_mtl(x.view(), y.view(), w.view(), r.view(), l1_ratio, penalty);
389 if gap < tol {
390 break;
391 }
392 }
393 }
394
395 (w, gap, n_steps)
396}
397
398fn block_soft_thresholding<F: Float>(x: ArrayView1<F>, threshold: F) -> Array1<F> {
400 let norm_x = x.dot(&x).sqrt();
401 if norm_x < threshold {
402 return Array1::<F>::zeros(x.len());
403 }
404 let scale = F::one() - threshold / norm_x;
405 &x * scale
406}
407
408fn duality_gap<'a, F: Float>(
409 x: ArrayView2<'a, F>,
410 y: ArrayView1<'a, F>,
411 w: ArrayView1<'a, F>,
412 r: ArrayView1<'a, F>,
413 l1_ratio: F,
414 penalty: F,
415) -> F {
416 let half = F::cast(0.5);
417 let n_samples = F::cast(x.nrows());
418 let l1_reg = l1_ratio * penalty * n_samples;
419 let l2_reg = (F::one() - l1_ratio) * penalty * n_samples;
420 let xta = x.t().dot(&r) - &w * l2_reg;
421
422 let dual_norm_xta = xta.norm_max();
423 let r_norm2 = r.dot(&r);
424 let w_norm2 = w.dot(&w);
425 let (const_, mut gap) = if dual_norm_xta > l1_reg {
426 let const_ = l1_reg / dual_norm_xta;
427 let a_norm2 = r_norm2 * const_ * const_;
428 (const_, half * (r_norm2 + a_norm2))
429 } else {
430 (F::one(), r_norm2)
431 };
432 let l1_norm = w.norm_l1();
433 gap += l1_reg * l1_norm - const_ * r.dot(&y)
434 + half * l2_reg * (F::one() + const_ * const_) * w_norm2;
435 gap
436}
437
438fn duality_gap_mtl<'a, F: Float>(
439 x: ArrayView2<'a, F>,
440 y: ArrayView2<'a, F>,
441 w: ArrayView2<'a, F>,
442 r: ArrayView2<'a, F>,
443 l1_ratio: F,
444 penalty: F,
445) -> F {
446 let half = F::cast(0.5);
447 let n_samples = F::cast(x.nrows());
448 let l1_reg = l1_ratio * penalty * n_samples;
449 let l2_reg = (F::one() - l1_ratio) * penalty * n_samples;
450 let xta = x.t().dot(&r) - &w * l2_reg;
451
452 let dual_norm_xta = xta.map_axis(Axis(1), |x| x.dot(&x).sqrt()).norm_max();
453 let r_norm2 = r.iter().map(|&rij| rij * rij).sum();
454 let w_norm2 = w.iter().map(|&wij| wij * wij).sum();
455 let (const_, mut gap) = if dual_norm_xta > l1_reg {
456 let const_ = l1_reg / dual_norm_xta;
457 let a_norm2 = r_norm2 * const_ * const_;
458 (const_, half * (r_norm2 + a_norm2))
459 } else {
460 (F::one(), r_norm2)
461 };
462 let rty = r.t().dot(&y);
463 let trace_rty = rty.diag().sum();
464 let l21_norm = w.map_axis(Axis(1), |wj| (wj.dot(&wj)).sqrt()).sum();
465 gap += l1_reg * l21_norm - const_ * trace_rty
466 + half * l2_reg * (F::one() + const_ * const_) * w_norm2;
467 gap
468}
469
470fn variance_params<F: Float, T: AsTargets<Elem = F>, D: Data<Elem = F>>(
471 ds: &DatasetBase<ArrayBase<D, Ix2>, T>,
472 y_est: ArrayView<F, T::Ix>,
473) -> Result<Array1<F>> {
474 let nfeatures = ds.nfeatures();
475 let nsamples = ds.nsamples();
476
477 let target = ds.targets().as_targets();
478 let ndim = target.ndim();
479
480 let ntasks: usize = match ndim {
481 1 => 1,
482 2 => *target.shape().last().unwrap(),
483 _ => {
484 return Err(ElasticNetError::IncorrectTargetShape);
485 }
486 };
487
488 let y_est = y_est.as_targets();
489
490 if nsamples < nfeatures + 1 {
492 return Err(ElasticNetError::NotEnoughSamples);
493 }
494
495 let var_target =
496 (&target - &y_est).mapv(|x| x * x).sum() / F::cast(ntasks * (nsamples - nfeatures));
497
498 let ds2 = ds.records().t().dot(ds.records()).with_lapack();
500 #[cfg(feature = "blas")]
501 let inv_cov = ds2.invh_into();
502 #[cfg(not(feature = "blas"))]
503 let inv_cov = (|| ds2.qr_into()?.inverse())();
504
505 match inv_cov {
506 Ok(inv_cov) => Ok(inv_cov.without_lapack().diag().mapv(|x| var_target * x)),
507 Err(_) => Err(ElasticNetError::IllConditioned),
508 }
509}
510
511fn compute_intercept<F: Float, I: RemoveAxis>(
515 with_intercept: bool,
516 y: ArrayView<F, I>,
517) -> (Array<F, I::Smaller>, CowArray<F, I>)
518where
519 I::Smaller: Dimension<Larger = I>,
520{
521 if with_intercept {
522 let y_mean = y
523 .mean_axis(Axis(0))
525 .expect("Axis 0 length of 0");
526 let y_centered = &y - &y_mean.view().insert_axis(Axis(0));
528 (y_mean, y_centered.into())
529 } else {
530 (Array::zeros(y.raw_dim().remove_axis(Axis(0))), y.into())
531 }
532}
533
534#[cfg(test)]
535mod tests {
536 use super::{block_coordinate_descent, coordinate_descent, ElasticNet, MultiTaskElasticNet};
537 use approx::assert_abs_diff_eq;
538 use ndarray::{array, s, Array, Array1, Array2, Axis};
539 use ndarray_rand::rand::SeedableRng;
540 use ndarray_rand::rand_distr::Uniform;
541 use ndarray_rand::RandomExt;
542 use rand_xoshiro::Xoshiro256Plus;
543
544 use crate::{ElasticNetError, ElasticNetParams, ElasticNetValidParams};
545 use linfa::{
546 metrics::SingleTargetRegression,
547 traits::{Fit, Predict},
548 Dataset,
549 };
550
551 #[test]
552 fn autotraits() {
553 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
554 has_autotraits::<ElasticNet<f64>>();
555 has_autotraits::<ElasticNetParams<f64>>();
556 has_autotraits::<ElasticNetValidParams<f64>>();
557 has_autotraits::<ElasticNetError>();
558 }
559
560 fn elastic_net_objective(
561 x: &Array2<f64>,
562 y: &Array1<f64>,
563 intercept: f64,
564 beta: &Array1<f64>,
565 alpha: f64,
566 lambda: f64,
567 ) -> f64 {
568 squared_error(x, y, intercept, beta) + lambda * elastic_net_penalty(beta, alpha)
569 }
570
571 fn elastic_net_multi_task_objective(
572 x: &Array2<f64>,
573 y: &Array2<f64>,
574 intercept: &Array1<f64>,
575 beta: &Array2<f64>,
576 alpha: f64,
577 lambda: f64,
578 ) -> f64 {
579 squared_error_mtl(x, y, intercept, beta) + lambda * elastic_net_mtl_penalty(beta, alpha)
580 }
581
582 fn squared_error(x: &Array2<f64>, y: &Array1<f64>, intercept: f64, beta: &Array1<f64>) -> f64 {
583 let mut resid = -x.dot(beta);
584 resid -= intercept;
585 resid += y;
586 let mut result = 0.0;
587 for r in &resid {
588 result += r * r;
589 }
590 result /= 2.0 * y.len() as f64;
591 result
592 }
593
594 fn squared_error_mtl(
595 x: &Array2<f64>,
596 y: &Array2<f64>,
597 intercept: &Array1<f64>,
598 beta: &Array2<f64>,
599 ) -> f64 {
600 let mut resid = x.dot(beta);
601 resid = &resid * -1.;
602 resid = &resid - intercept + y;
603 let mut datafit = resid.iter().map(|rij| rij * rij).sum();
604 datafit /= 2.0 * x.nrows() as f64;
605 datafit
606 }
607
608 fn elastic_net_penalty(beta: &Array1<f64>, alpha: f64) -> f64 {
609 let mut penalty = 0.0;
610 for beta_j in beta {
611 penalty += (1.0 - alpha) / 2.0 * beta_j * beta_j + alpha * beta_j.abs();
612 }
613 penalty
614 }
615
616 fn elastic_net_mtl_penalty(beta: &Array2<f64>, alpha: f64) -> f64 {
617 let frob_norm: f64 = beta.iter().map(|beta_ij| beta_ij * beta_ij).sum();
618 let l21_norm = beta
619 .map_axis(Axis(1), |beta_j| (beta_j.dot(&beta_j)).sqrt())
620 .sum();
621 (1.0 - alpha) / 2.0 * frob_norm + alpha * l21_norm
622 }
623
624 #[test]
625 fn elastic_net_penalty_works() {
626 let beta = array![-2.0, 1.0];
627 assert_abs_diff_eq!(
628 elastic_net_penalty(&beta, 0.8),
629 0.4 + 0.1 + 1.6 + 0.8,
630 epsilon = 1e-12
631 );
632 assert_abs_diff_eq!(elastic_net_penalty(&beta, 1.0), 3.0);
633 assert_abs_diff_eq!(elastic_net_penalty(&beta, 0.0), 2.5);
634
635 let beta2 = array![0.0, 0.0];
636 assert_abs_diff_eq!(elastic_net_penalty(&beta2, 0.8), 0.0);
637 assert_abs_diff_eq!(elastic_net_penalty(&beta2, 1.0), 0.0);
638 assert_abs_diff_eq!(elastic_net_penalty(&beta2, 0.0), 0.0);
639 }
640
641 #[test]
642 fn elastic_net_mtl_penalty_works() {
643 let beta = array![[-2.0, 1.0, 3.0], [3.0, 1.5, -1.7]];
644 assert_abs_diff_eq!(
645 elastic_net_mtl_penalty(&beta, 0.7),
646 9.472383565516601,
647 epsilon = 1e-12
648 );
649 assert_abs_diff_eq!(
650 elastic_net_mtl_penalty(&beta, 1.0),
651 7.501976522166574,
652 epsilon = 1e-12
653 );
654 assert_abs_diff_eq!(
655 elastic_net_mtl_penalty(&beta, 0.2),
656 12.756395304433315,
657 epsilon = 1e-12
658 );
659
660 let beta2 = array![[0., 0.], [0., 0.], [0., 0.]];
661 assert_abs_diff_eq!(elastic_net_mtl_penalty(&beta2, 0.8), 0.0);
662 assert_abs_diff_eq!(elastic_net_mtl_penalty(&beta2, 1.2), 0.0);
663 assert_abs_diff_eq!(elastic_net_mtl_penalty(&beta2, 0.8), 0.0);
664 }
665
666 #[test]
667 fn squared_error_works() {
668 let x = array![[2.0, 1.0], [-1.0, 2.0]];
669 let y = array![1.0, 1.0];
670 let beta = array![0.0, 1.0];
671 assert_abs_diff_eq!(squared_error(&x, &y, 0.0, &beta), 0.25);
672 }
673
674 #[test]
675 fn squared_error_mtl_works() {
676 let x = array![[1.2, 2.3], [-1.3, 0.3], [-1.3, 0.1]];
677 let y = array![
678 [0.2, 1.0, 0.0, 1.],
679 [-0.3, 0.7, 0.1, 2.],
680 [-0.3, 0.7, 2.3, 3.]
681 ];
682 let beta = array![[2.3, 4.5, 1.2, -3.4], [1.2, -3.4, 0.7, -1.2]];
683 assert_abs_diff_eq!(
684 squared_error_mtl(&x, &y, &array![0., 0., 0., 0.], &beta),
685 41.66298333333333
686 );
687 let intercept = array![1., 3., 2., 0.3];
688 assert_abs_diff_eq!(
689 squared_error_mtl(&x, &y, &intercept, &beta),
690 29.059983333333335
691 );
692 }
693
694 #[test]
695 fn coordinate_descent_lowers_objective() {
696 let x = array![[1.0, 0.0], [0.0, 1.0]];
697 let y = array![1.0, -1.0];
698 let beta = array![0.0, 0.0];
699 let intercept = 0.0;
700 let alpha = 0.8;
701 let lambda = 0.001;
702 let objective_start = elastic_net_objective(&x, &y, intercept, &beta, alpha, lambda);
703 let opt_result = coordinate_descent(x.view(), y.view(), 1e-4, 3, alpha, lambda);
704 let objective_end = elastic_net_objective(&x, &y, intercept, &opt_result.0, alpha, lambda);
705 assert!(objective_start > objective_end);
706 }
707
708 #[test]
709 fn block_coordinate_descent_lowers_objective() {
710 let x = array![[1.0, 0., -0.3, 3.2], [0.3, 1.2, -0.6, 1.2]];
711 let y = array![[0.3, -1.2, 0.7], [1.4, -3.2, 0.2]];
712 let beta = array![[0., 0., 0.], [0., 0., 0.], [0., 0., 0.], [0., 0., 0.]];
713 let intercept = array![0., 0., 0.];
714 let alpha = 0.4;
715 let lambda = 0.002;
716 let objective_start =
717 elastic_net_multi_task_objective(&x, &y, &intercept, &beta, alpha, lambda);
718 let opt_result = block_coordinate_descent(x.view(), y.view(), 1e-4, 3, alpha, lambda);
719 let objective_end =
720 elastic_net_multi_task_objective(&x, &y, &intercept, &opt_result.0, alpha, lambda);
721 assert!(objective_start > objective_end);
722 }
723
724 #[test]
725 fn lasso_zero_works() {
726 let dataset = Dataset::from((array![[0.], [0.], [0.]], array![0., 0., 0.]));
727
728 let model = ElasticNet::params()
729 .l1_ratio(1.0)
730 .penalty(0.1)
731 .fit(&dataset)
732 .unwrap();
733
734 assert_abs_diff_eq!(model.intercept(), 0.);
735 assert_abs_diff_eq!(model.hyperplane(), &array![0.]);
736 }
737
738 #[test]
739 fn mtl_lasso_zero_works() {
740 let dataset = Dataset::from((array![[0.], [0.], [0.]], array![[0.], [0.], [0.]]));
741
742 let model = MultiTaskElasticNet::params()
743 .l1_ratio(1.0)
744 .penalty(0.1)
745 .fit(&dataset)
746 .unwrap();
747
748 assert_abs_diff_eq!(model.intercept(), &array![0.]);
749 assert_abs_diff_eq!(model.hyperplane(), &array![[0.]]);
750 }
751
752 #[test]
753 fn lasso_toy_example_works() {
754 let dataset = Dataset::new(array![[-1.0], [0.0], [1.0]], array![-1.0, 0.0, 1.0]);
758
759 let t = array![[2.0], [3.0], [4.0]];
761 let model = ElasticNet::lasso().penalty(1e-8).fit(&dataset).unwrap();
762 assert_abs_diff_eq!(model.intercept(), 0.0);
763 assert_abs_diff_eq!(model.hyperplane(), &array![1.0], epsilon = 1e-6);
764 assert_abs_diff_eq!(model.predict(&t), array![2.0, 3.0, 4.0], epsilon = 1e-6);
765 assert_abs_diff_eq!(model.duality_gap(), 0.0);
766
767 let model = ElasticNet::lasso().penalty(0.1).fit(&dataset).unwrap();
768 assert_abs_diff_eq!(model.intercept(), 0.0);
769 assert_abs_diff_eq!(model.hyperplane(), &array![0.85], epsilon = 1e-6);
770 assert_abs_diff_eq!(model.predict(&t), array![1.7, 2.55, 3.4], epsilon = 1e-6);
771 assert_abs_diff_eq!(model.duality_gap(), 0.0);
772
773 let model = ElasticNet::lasso().penalty(0.5).fit(&dataset).unwrap();
774 assert_abs_diff_eq!(model.intercept(), 0.0);
775 assert_abs_diff_eq!(model.hyperplane(), &array![0.25], epsilon = 1e-6);
776 assert_abs_diff_eq!(model.predict(&t), array![0.5, 0.75, 1.0], epsilon = 1e-6);
777 assert_abs_diff_eq!(model.duality_gap(), 0.0);
778
779 let model = ElasticNet::lasso().penalty(1.0).fit(&dataset).unwrap();
780 assert_abs_diff_eq!(model.intercept(), 0.0);
781 assert_abs_diff_eq!(model.hyperplane(), &array![0.0], epsilon = 1e-6);
782 assert_abs_diff_eq!(model.predict(&t), array![0.0, 0.0, 0.0], epsilon = 1e-6);
783 assert_abs_diff_eq!(model.duality_gap(), 0.0);
784 }
785
786 #[test]
787 fn multitask_lasso_toy_example_works() {
788 let dataset = Dataset::new(
792 array![[-1.0], [0.0], [1.0]],
793 array![[-1.0, 1.0], [0.0, -1.5], [1.0, 1.3]],
794 );
795
796 let t = array![[2.0], [3.0], [4.0]];
798 let model = MultiTaskElasticNet::lasso()
799 .with_intercept(false)
800 .penalty(0.01)
801 .fit(&dataset)
802 .unwrap();
803 assert_abs_diff_eq!(model.intercept(), &array![0., 0.]);
804 assert_abs_diff_eq!(
805 model.hyperplane(),
806 &array![[0.9851659, 0.1477748]],
807 epsilon = 1e-6
808 );
809 assert_abs_diff_eq!(
810 model.predict(&t),
811 array![
812 [1.9703319, 0.2955497],
813 [2.9554978, 0.4433246],
814 [3.9406638, 0.5910995]
815 ],
816 epsilon = 1e-6
817 );
818 assert_abs_diff_eq!(model.duality_gap(), 0.0, epsilon = 1e-9);
819
820 let t = array![[2.0], [3.0], [4.0]];
822 let model = MultiTaskElasticNet::lasso()
823 .penalty(1e-8)
824 .fit(&dataset)
825 .unwrap();
826 assert_abs_diff_eq!(model.intercept(), &array![0., 0.2666666667], epsilon = 1e-6);
827 assert_abs_diff_eq!(model.hyperplane(), &array![[1., 0.15]], epsilon = 1e-6);
828 assert_abs_diff_eq!(
829 model.predict(&t),
830 array![
831 [1.99999997, 0.56666666],
832 [2.99999996, 0.71666666],
833 [3.99999994, 0.86666666]
834 ],
835 epsilon = 1e-6
836 );
837 assert_abs_diff_eq!(model.duality_gap(), 0.0, epsilon = 1e-9);
838
839 let model = MultiTaskElasticNet::lasso()
840 .penalty(0.1)
841 .fit(&dataset)
842 .unwrap();
843 assert_abs_diff_eq!(model.intercept(), &array![0., 0.2666666667], epsilon = 1e-6);
844 assert_abs_diff_eq!(
845 model.hyperplane(),
846 &array![[0.851659, 0.127749]],
847 epsilon = 1e-6
848 );
849 assert_abs_diff_eq!(
850 model.predict(&t),
851 &array![
852 [1.70331909, 0.52216453],
853 [2.55497864, 0.64991346],
854 [3.40663819, 0.77766239]
855 ],
856 epsilon = 1e-6
857 );
858 assert_abs_diff_eq!(model.duality_gap(), 0.0, epsilon = 1e-9);
859
860 let model = MultiTaskElasticNet::lasso()
861 .penalty(0.5)
862 .fit(&dataset)
863 .unwrap();
864 assert_abs_diff_eq!(model.intercept(), &array![0., 0.2666666667], epsilon = 1e-6);
865 assert_abs_diff_eq!(
866 model.hyperplane(),
867 &array![[0.258298, 0.038744]],
868 epsilon = 1e-6
869 );
870 assert_abs_diff_eq!(
871 model.predict(&t),
872 &array![
873 [0.51659547, 0.34415599],
874 [0.77489321, 0.38290065],
875 [1.03319094, 0.42164531]
876 ],
877 epsilon = 1e-6
878 );
879 assert_abs_diff_eq!(model.duality_gap(), 0.0, epsilon = 1e-6);
880
881 let model = MultiTaskElasticNet::lasso()
882 .penalty(1.0)
883 .fit(&dataset)
884 .unwrap();
885 assert_abs_diff_eq!(model.intercept(), &array![0., 0.2666666667], epsilon = 1e-6);
886 assert_abs_diff_eq!(model.hyperplane(), &array![[0.0, 0.0]], epsilon = 1e-6);
887 assert_abs_diff_eq!(
888 model.predict(&t),
889 &array![[0., 0.2666666667], [0., 0.2666666667], [0., 0.2666666667]],
890 epsilon = 1e-6
891 );
892 assert_abs_diff_eq!(model.duality_gap(), 0.0, epsilon = 1e-6);
893 }
894
895 #[test]
896 fn elastic_net_toy_example_works() {
897 let dataset = Dataset::new(array![[-1.0], [0.0], [1.0]], array![-1.0, 0.0, 1.0]);
898
899 let t = array![[2.0], [3.0], [4.0]];
901 let model = ElasticNet::params()
902 .l1_ratio(0.3)
903 .penalty(0.5)
904 .fit(&dataset)
905 .unwrap();
906
907 assert_abs_diff_eq!(model.intercept(), 0.0);
908 assert_abs_diff_eq!(model.hyperplane(), &array![0.50819], epsilon = 1e-3);
909 assert_abs_diff_eq!(
910 model.predict(&t),
911 array![1.0163, 1.5245, 2.0327],
912 epsilon = 1e-3
913 );
914 assert_abs_diff_eq!(model.duality_gap(), 0.0);
915
916 let model = ElasticNet::params()
917 .l1_ratio(0.5)
918 .penalty(0.5)
919 .fit(&dataset)
920 .unwrap();
921
922 assert_abs_diff_eq!(model.intercept(), 0.0);
923 assert_abs_diff_eq!(model.hyperplane(), &array![0.45454], epsilon = 1e-3);
924 assert_abs_diff_eq!(
925 model.predict(&t),
926 array![0.9090, 1.3636, 1.8181],
927 epsilon = 1e-3
928 );
929 assert_abs_diff_eq!(model.duality_gap(), 0.0);
930 }
931
932 #[test]
933 fn multitask_elasticnet_toy_example_works() {
934 let dataset = Dataset::new(
938 array![[-1.0], [0.0], [1.0]],
939 array![[-1.0, 1.0], [0.0, -1.5], [1.0, 1.3]],
940 );
941
942 let t = array![[2.0], [3.0], [4.0]];
944 let model = MultiTaskElasticNet::params()
945 .with_intercept(false)
946 .l1_ratio(0.3)
947 .penalty(0.1)
948 .fit(&dataset)
949 .unwrap();
950 assert_abs_diff_eq!(model.intercept(), &array![0., 0.]);
951 assert_abs_diff_eq!(
952 model.hyperplane(),
953 &array![[0.86470395, 0.12970559]],
954 epsilon = 1e-6
955 );
956 assert_abs_diff_eq!(
957 model.predict(&t),
958 array![
959 [1.7294079, 0.25941118],
960 [2.59411185, 0.38911678],
961 [3.4588158, 0.51882237]
962 ],
963 epsilon = 1e-6
964 );
965 assert_abs_diff_eq!(model.duality_gap(), 0.0, epsilon = 1e-12);
966
967 let t = array![[2.0], [3.0], [4.0]];
969 let model = MultiTaskElasticNet::params()
970 .l1_ratio(0.3)
971 .penalty(0.1)
972 .fit(&dataset)
973 .unwrap();
974 assert_abs_diff_eq!(model.intercept(), &array![0., 0.26666666], epsilon = 1e-6);
975 assert_abs_diff_eq!(
976 model.hyperplane(),
977 &array![[0.86470395, 0.12970559]],
978 epsilon = 1e-6
979 );
980 assert_abs_diff_eq!(
981 model.predict(&t),
982 array![
983 [1.7294079, 0.52607785],
984 [2.59411185, 0.65578344],
985 [3.4588158, 0.78548904]
986 ],
987 epsilon = 1e-6
988 );
989 assert_abs_diff_eq!(model.duality_gap(), 0.0, epsilon = 1e-12);
990
991 let model = MultiTaskElasticNet::params()
992 .l1_ratio(0.5)
993 .penalty(0.1)
994 .fit(&dataset)
995 .unwrap();
996 assert_abs_diff_eq!(model.intercept(), &array![0., 0.2666666], epsilon = 1e-6);
997 assert_abs_diff_eq!(
998 model.hyperplane(),
999 &array![[0.861237, 0.12918555]],
1000 epsilon = 1e-6
1001 );
1002 assert_abs_diff_eq!(
1003 model.predict(&t),
1004 &array![
1005 [1.722474, 0.52503777],
1006 [2.583711, 0.65422332],
1007 [3.44494799, 0.78340887]
1008 ],
1009 epsilon = 1e-6
1010 );
1011 assert_abs_diff_eq!(model.duality_gap(), 0.0, epsilon = 1e-12);
1012 }
1013
1014 #[test]
1015 fn elastic_net_2d_toy_example_works() {
1016 let dataset = Dataset::new(array![[1.0, 0.0], [0.0, 1.0]], array![3.0, 2.0]);
1017
1018 let model = ElasticNet::params().penalty(0.0).fit(&dataset).unwrap();
1019 assert_abs_diff_eq!(model.intercept(), 2.5);
1020 assert_abs_diff_eq!(model.hyperplane(), &array![0.5, -0.5], epsilon = 0.001);
1021 }
1022
1023 #[test]
1024 #[allow(clippy::excessive_precision)]
1025 fn elastic_net_diabetes_1_works_like_sklearn() {
1026 #[rustfmt::skip]
1030 let x = array![
1031 [3.807590643342410180e-02, 5.068011873981870252e-02, 6.169620651868849837e-02, 2.187235499495579841e-02, -4.422349842444640161e-02, -3.482076283769860309e-02, -4.340084565202689815e-02, -2.592261998182820038e-03, 1.990842087631829876e-02, -1.764612515980519894e-02],
1032 [-1.882016527791040067e-03, -4.464163650698899782e-02, -5.147406123880610140e-02, -2.632783471735180084e-02, -8.448724111216979540e-03, -1.916333974822199970e-02, 7.441156407875940126e-02, -3.949338287409189657e-02, -6.832974362442149896e-02, -9.220404962683000083e-02],
1033 [8.529890629667830071e-02, 5.068011873981870252e-02, 4.445121333659410312e-02, -5.670610554934250001e-03, -4.559945128264750180e-02, -3.419446591411950259e-02, -3.235593223976569732e-02, -2.592261998182820038e-03, 2.863770518940129874e-03, -2.593033898947460017e-02],
1034 [-8.906293935226029801e-02, -4.464163650698899782e-02, -1.159501450521270051e-02, -3.665644679856060184e-02, 1.219056876180000040e-02, 2.499059336410210108e-02, -3.603757004385269719e-02, 3.430885887772629900e-02, 2.269202256674450122e-02, -9.361911330135799444e-03],
1035 [5.383060374248070309e-03, -4.464163650698899782e-02, -3.638469220447349689e-02, 2.187235499495579841e-02, 3.934851612593179802e-03, 1.559613951041610019e-02, 8.142083605192099172e-03, -2.592261998182820038e-03, -3.199144494135589684e-02, -4.664087356364819692e-02],
1036 [-9.269547780327989928e-02, -4.464163650698899782e-02, -4.069594049999709917e-02, -1.944209332987930153e-02, -6.899064987206669775e-02, -7.928784441181220555e-02, 4.127682384197570165e-02, -7.639450375000099436e-02, -4.118038518800790082e-02, -9.634615654166470144e-02],
1037 [-4.547247794002570037e-02, 5.068011873981870252e-02, -4.716281294328249912e-02, -1.599922263614299983e-02, -4.009563984984299695e-02, -2.480001206043359885e-02, 7.788079970179680352e-04, -3.949338287409189657e-02, -6.291294991625119570e-02, -3.835665973397880263e-02],
1038 [6.350367559056099842e-02, 5.068011873981870252e-02, -1.894705840284650021e-03, 6.662967401352719310e-02, 9.061988167926439408e-02, 1.089143811236970016e-01, 2.286863482154040048e-02, 1.770335448356720118e-02, -3.581672810154919867e-02, 3.064409414368320182e-03],
1039 [4.170844488444359899e-02, 5.068011873981870252e-02, 6.169620651868849837e-02, -4.009931749229690007e-02, -1.395253554402150001e-02, 6.201685656730160021e-03, -2.867429443567860031e-02, -2.592261998182820038e-03, -1.495647502491130078e-02, 1.134862324403770016e-02],
1040 [-7.090024709716259699e-02, -4.464163650698899782e-02, 3.906215296718960200e-02, -3.321357610482440076e-02, -1.257658268582039982e-02, -3.450761437590899733e-02, -2.499265663159149983e-02, -2.592261998182820038e-03, 6.773632611028609918e-02, -1.350401824497050006e-02],
1041 [-9.632801625429950054e-02, -4.464163650698899782e-02, -8.380842345523309422e-02, 8.100872220010799790e-03, -1.033894713270950005e-01, -9.056118903623530669e-02, -1.394774321933030074e-02, -7.639450375000099436e-02, -6.291294991625119570e-02, -3.421455281914410201e-02],
1042 [2.717829108036539862e-02, 5.068011873981870252e-02, 1.750591148957160101e-02, -3.321357610482440076e-02, -7.072771253015849857e-03, 4.597154030400080194e-02, -6.549067247654929980e-02, 7.120997975363539678e-02, -9.643322289178400675e-02, -5.906719430815229877e-02],
1043 [1.628067572730669890e-02, -4.464163650698899782e-02, -2.884000768730720157e-02, -9.113481248670509197e-03, -4.320865536613589623e-03, -9.768885894535990141e-03, 4.495846164606279866e-02, -3.949338287409189657e-02, -3.075120986455629965e-02, -4.249876664881350324e-02],
1044 [5.383060374248070309e-03, 5.068011873981870252e-02, -1.894705840284650021e-03, 8.100872220010799790e-03, -4.320865536613589623e-03, -1.571870666853709964e-02, -2.902829807069099918e-03, -2.592261998182820038e-03, 3.839324821169769891e-02, -1.350401824497050006e-02],
1045 [4.534098333546320025e-02, -4.464163650698899782e-02, -2.560657146566450160e-02, -1.255635194240680048e-02, 1.769438019460449832e-02, -6.128357906048329537e-05, 8.177483968693349814e-02, -3.949338287409189657e-02, -3.199144494135589684e-02, -7.563562196749110123e-02],
1046 [-5.273755484206479882e-02, 5.068011873981870252e-02, -1.806188694849819934e-02, 8.040115678847230274e-02, 8.924392882106320368e-02, 1.076617872765389949e-01, -3.971920784793980114e-02, 1.081111006295440019e-01, 3.605579008983190309e-02, -4.249876664881350324e-02],
1047 [-5.514554978810590376e-03, -4.464163650698899782e-02, 4.229558918883229851e-02, 4.941532054484590319e-02, 2.457414448561009990e-02, -2.386056667506489953e-02, 7.441156407875940126e-02, -3.949338287409189657e-02, 5.227999979678119719e-02, 2.791705090337660150e-02],
1048 [7.076875249260000666e-02, 5.068011873981870252e-02, 1.211685112016709989e-02, 5.630106193231849965e-02, 3.420581449301800248e-02, 4.941617338368559792e-02, -3.971920784793980114e-02, 3.430885887772629900e-02, 2.736770754260900093e-02, -1.077697500466389974e-03],
1049 [-3.820740103798660192e-02, -4.464163650698899782e-02, -1.051720243133190055e-02, -3.665644679856060184e-02, -3.734373413344069942e-02, -1.947648821001150138e-02, -2.867429443567860031e-02, -2.592261998182820038e-03, -1.811826730789670159e-02, -1.764612515980519894e-02],
1050 [-2.730978568492789874e-02, -4.464163650698899782e-02, -1.806188694849819934e-02, -4.009931749229690007e-02, -2.944912678412469915e-03, -1.133462820348369975e-02, 3.759518603788870178e-02, -3.949338287409189657e-02, -8.944018957797799166e-03, -5.492508739331759815e-02]
1051 ];
1052 #[rustfmt::skip]
1053 let y = array![1.51e+02, 7.5e+01, 1.41e+02, 2.06e+02, 1.35e+02, 9.7e+01, 1.38e+02, 6.3e+01, 1.1e+02, 3.1e+02, 1.01e+02, 6.9e+01, 1.79e+02, 1.85e+02, 1.18e+02, 1.71e+02, 1.66e+02, 1.44e+02, 9.7e+01, 1.68e+02];
1054 let model = ElasticNet::params()
1055 .l1_ratio(0.2)
1056 .penalty(0.5)
1057 .fit(&Dataset::new(x, y))
1058 .unwrap();
1059
1060 assert_abs_diff_eq!(
1061 model.hyperplane(),
1062 &array![
1063 -2.00558969,
1064 -0.92208413,
1065 1.27586213,
1066 -0.06617076,
1067 0.26484338,
1068 -0.48702845,
1069 -0.60274235,
1070 0.3975141,
1071 4.33229135,
1072 1.11981207
1073 ],
1074 epsilon = 0.01
1075 );
1076 assert_abs_diff_eq!(model.intercept(), 141.283952, epsilon = 1e-1);
1077 assert!(
1078 f64::abs(model.duality_gap()) < 1e-4,
1079 "Duality gap too large"
1080 );
1081 }
1082
1083 #[test]
1084 #[allow(clippy::excessive_precision)]
1085 fn elastic_net_diabetes_2_works_like_sklearn() {
1086 #[rustfmt::skip]
1090 let x = array![
1091 [-7.816532399920170238e-02,5.068011873981870252e-02,7.786338762690199478e-02,5.285819123858220142e-02,7.823630595545419397e-02,6.444729954958319795e-02,2.655027262562750096e-02,-2.592261998182820038e-03,4.067226371449769728e-02,-9.361911330135799444e-03],
1092 [9.015598825267629943e-03,5.068011873981870252e-02,-3.961812842611620034e-02,2.875809638242839833e-02,3.833367306762140020e-02,7.352860494147960002e-02,-7.285394808472339667e-02,1.081111006295440019e-01,1.556684454070180086e-02,-4.664087356364819692e-02],
1093 [1.750521923228520000e-03,5.068011873981870252e-02,1.103903904628619932e-02,-1.944209332987930153e-02,-1.670444126042380101e-02,-3.819065120534880214e-03,-4.708248345611389801e-02,3.430885887772629900e-02,2.405258322689299982e-02,2.377494398854190089e-02],
1094 [-7.816532399920170238e-02,-4.464163650698899782e-02,-4.069594049999709917e-02,-8.141376581713200000e-02,-1.006375656106929944e-01,-1.127947298232920004e-01,2.286863482154040048e-02,-7.639450375000099436e-02,-2.028874775162960165e-02,-5.078298047848289754e-02],
1095 [3.081082953138499989e-02,5.068011873981870252e-02,-3.422906805671169922e-02,4.367720260718979675e-02,5.759701308243719842e-02,6.883137801463659611e-02,-3.235593223976569732e-02,5.755656502954899917e-02,3.546193866076970125e-02,8.590654771106250032e-02],
1096 [-3.457486258696700065e-02,5.068011873981870252e-02,5.649978676881649634e-03,-5.670610554934250001e-03,-7.311850844667000526e-02,-6.269097593696699999e-02,-6.584467611156170040e-03,-3.949338287409189657e-02,-4.542095777704099890e-02,3.205915781821130212e-02],
1097 [4.897352178648269744e-02,5.068011873981870252e-02,8.864150836571099701e-02,8.728689817594480205e-02,3.558176735121919981e-02,2.154596028441720101e-02,-2.499265663159149983e-02,3.430885887772629900e-02,6.604820616309839409e-02,1.314697237742440128e-01],
1098 [-4.183993948900609910e-02,-4.464163650698899782e-02,-3.315125598283080038e-02,-2.288496402361559975e-02,4.658939021682820258e-02,4.158746183894729970e-02,5.600337505832399948e-02,-2.473293452372829840e-02,-2.595242443518940012e-02,-3.835665973397880263e-02],
1099 [-9.147093429830140468e-03,-4.464163650698899782e-02,-5.686312160821060252e-02,-5.042792957350569760e-02,2.182223876920789951e-02,4.534524338042170144e-02,-2.867429443567860031e-02,3.430885887772629900e-02,-9.918957363154769225e-03,-1.764612515980519894e-02],
1100 [7.076875249260000666e-02,5.068011873981870252e-02,-3.099563183506899924e-02,2.187235499495579841e-02,-3.734373413344069942e-02,-4.703355284749029946e-02,3.391354823380159783e-02,-3.949338287409189657e-02,-1.495647502491130078e-02,-1.077697500466389974e-03],
1101 [9.015598825267629943e-03,-4.464163650698899782e-02,5.522933407540309841e-02,-5.670610554934250001e-03,5.759701308243719842e-02,4.471894645684260094e-02,-2.902829807069099918e-03,2.323852261495349888e-02,5.568354770267369691e-02,1.066170822852360034e-01],
1102 [-2.730978568492789874e-02,-4.464163650698899782e-02,-6.009655782985329903e-02,-2.977070541108809906e-02,4.658939021682820258e-02,1.998021797546959896e-02,1.222728555318910032e-01,-3.949338287409189657e-02,-5.140053526058249722e-02,-9.361911330135799444e-03],
1103 [1.628067572730669890e-02,-4.464163650698899782e-02,1.338730381358059929e-03,8.100872220010799790e-03,5.310804470794310353e-03,1.089891258357309975e-02,3.023191042971450082e-02,-3.949338287409189657e-02,-4.542095777704099890e-02,3.205915781821130212e-02],
1104 [-1.277963188084970010e-02,-4.464163650698899782e-02,-2.345094731790270046e-02,-4.009931749229690007e-02,-1.670444126042380101e-02,4.635943347782499856e-03,-1.762938102341739949e-02,-2.592261998182820038e-03,-3.845911230135379971e-02,-3.835665973397880263e-02],
1105 [-5.637009329308430294e-02,-4.464163650698899782e-02,-7.410811479030500470e-02,-5.042792957350569760e-02,-2.496015840963049931e-02,-4.703355284749029946e-02,9.281975309919469896e-02,-7.639450375000099436e-02,-6.117659509433449883e-02,-4.664087356364819692e-02],
1106 [4.170844488444359899e-02,5.068011873981870252e-02,1.966153563733339868e-02,5.974393262605470073e-02,-5.696818394814720174e-03,-2.566471273376759888e-03,-2.867429443567860031e-02,-2.592261998182820038e-03,3.119299070280229930e-02,7.206516329203029904e-03],
1107 [-5.514554978810590376e-03,5.068011873981870252e-02,-1.590626280073640167e-02,-6.764228304218700139e-02,4.934129593323050011e-02,7.916527725369119917e-02,-2.867429443567860031e-02,3.430885887772629900e-02,-1.811826730789670159e-02,4.448547856271539702e-02],
1108 [4.170844488444359899e-02,5.068011873981870252e-02,-1.590626280073640167e-02,1.728186074811709910e-02,-3.734373413344069942e-02,-1.383981589779990050e-02,-2.499265663159149983e-02,-1.107951979964190078e-02,-4.687948284421659950e-02,1.549073015887240078e-02],
1109 [-4.547247794002570037e-02,-4.464163650698899782e-02,3.906215296718960200e-02,1.215130832538269907e-03,1.631842733640340160e-02,1.528299104862660025e-02,-2.867429443567860031e-02,2.655962349378539894e-02,4.452837402140529671e-02,-2.593033898947460017e-02],
1110 [-4.547247794002570037e-02,-4.464163650698899782e-02,-7.303030271642410587e-02,-8.141376581713200000e-02,8.374011738825870577e-02,2.780892952020790065e-02,1.738157847891100005e-01,-3.949338287409189657e-02,-4.219859706946029777e-03,3.064409414368320182e-03]
1111 ];
1112 #[rustfmt::skip]
1113 let y = array![2.33e+02, 9.1e+01, 1.11e+02, 1.52e+02, 1.2e+02, 6.70e+01, 3.1e+02, 9.4e+01, 1.83e+02, 6.6e+01, 1.73e+02, 7.2e+01, 4.9e+01, 6.4e+01, 4.8e+01, 1.78e+02, 1.04e+02, 1.32e+02, 2.20e+02, 5.7e+01];
1114 let model = ElasticNet::params()
1115 .l1_ratio(0.2)
1116 .penalty(0.5)
1117 .fit(&Dataset::new(x, y))
1118 .unwrap();
1119
1120 assert_abs_diff_eq!(
1121 model.hyperplane(),
1122 &array![
1123 0.19879313,
1124 1.46970138,
1125 5.58097318,
1126 3.80089794,
1127 1.46466565,
1128 1.42327857,
1129 -3.86944632,
1130 2.60836423,
1131 4.79584768,
1132 3.03232988
1133 ],
1134 epsilon = 0.01
1135 );
1136 assert_abs_diff_eq!(model.intercept(), 126.279, epsilon = 1e-1);
1137 assert_abs_diff_eq!(model.duality_gap(), 0.00011079, epsilon = 1e-4);
1138 }
1139
1140 #[test]
1141 fn select_subset() {
1142 let mut rng = Xoshiro256Plus::seed_from_u64(42);
1143
1144 let mut w = Array::random_using(50, Uniform::new(1., 2.), &mut rng);
1146 w.slice_mut(s![10..]).fill(0.0);
1147
1148 let x = Array::random_using((100, 50), Uniform::new(-1., 1.), &mut rng);
1149 let y = x.dot(&w);
1150 let train = Dataset::new(x, y);
1151
1152 let model = ElasticNet::lasso()
1153 .penalty(0.1)
1154 .max_iterations(1000)
1155 .tolerance(1e-10)
1156 .fit(&train)
1157 .unwrap();
1158
1159 let num_zeros = model
1161 .hyperplane()
1162 .into_iter()
1163 .filter(|x| **x < 1e-5)
1164 .count();
1165
1166 assert_eq!(num_zeros, 40);
1167
1168 let x = Array::random_using((100, 50), Uniform::new(-1., 1.), &mut rng);
1170 let y = x.dot(&w);
1171
1172 let predicted = model.predict(&x);
1173 let rms = y.mean_squared_error(&predicted);
1174 assert!(rms.unwrap() < 0.67);
1175 }
1176
1177 #[test]
1178 fn diabetes_z_score() {
1179 let dataset = linfa_datasets::diabetes();
1180 let model = ElasticNet::params().penalty(0.0).fit(&dataset).unwrap();
1181
1182 let z_score = model.z_score().unwrap();
1184 assert!(z_score[2] > 2.0);
1185 assert!(z_score[3] > 2.0);
1186
1187 let confidence_level = model.confidence_95th().unwrap();
1189 assert!(confidence_level[2].0 < 416.);
1190 assert!(confidence_level[3].0 < 220.);
1191 }
1192}