linfa_elasticnet/
algorithm.rs

1use approx::{abs_diff_eq, abs_diff_ne};
2use linfa_linalg::norm::Norm;
3#[cfg(not(feature = "blas"))]
4use linfa_linalg::qr::QRInto;
5use ndarray::linalg::general_mat_mul;
6use ndarray::{
7    s, Array, Array1, Array2, ArrayBase, ArrayView, ArrayView1, ArrayView2, Axis, CowArray, Data,
8    Dimension, Ix2, RemoveAxis,
9};
10#[cfg(feature = "blas")]
11use ndarray_linalg::InverseHInto;
12
13use linfa::dataset::{WithLapack, WithoutLapack};
14use linfa::traits::{Fit, PredictInplace};
15use linfa::{
16    dataset::{AsMultiTargets, AsSingleTargets, AsTargets, Records},
17    DatasetBase, Float,
18};
19
20use super::{
21    hyperparams::{ElasticNetValidParams, MultiTaskElasticNetValidParams},
22    ElasticNet, ElasticNetError, MultiTaskElasticNet, Result,
23};
24
25impl<F, D, T> Fit<ArrayBase<D, Ix2>, T, ElasticNetError> for ElasticNetValidParams<F>
26where
27    F: Float,
28    D: Data<Elem = F>,
29    T: AsSingleTargets<Elem = F>,
30{
31    type Object = ElasticNet<F>;
32
33    /// Fit an elastic net model given a feature matrix `x` and a target
34    /// variable `y`.
35    ///
36    /// The feature matrix `x` must have shape `(n_samples, n_features)`
37    ///
38    /// The target variable `y` must have shape `(n_samples)`
39    ///
40    /// Returns a `FittedElasticNet` object which contains the fitted
41    /// parameters and can be used to `predict` values of the target variable
42    /// for new feature values.
43    fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
44        let target = dataset.as_single_targets();
45
46        let (intercept, y) = compute_intercept(self.with_intercept(), target);
47        let (hyperplane, duality_gap, n_steps) = coordinate_descent(
48            dataset.records().view(),
49            y.view(),
50            self.tolerance(),
51            self.max_iterations(),
52            self.l1_ratio(),
53            self.penalty(),
54        );
55        let intercept = intercept.into_scalar();
56
57        let y_est = dataset.records().dot(&hyperplane) + intercept;
58
59        // try to calculate the variance
60        let variance = variance_params(dataset, y_est.view());
61
62        Ok(ElasticNet {
63            hyperplane,
64            intercept,
65            duality_gap,
66            n_steps,
67            variance,
68        })
69    }
70}
71
72impl<F, D, T> Fit<ArrayBase<D, Ix2>, T, ElasticNetError> for MultiTaskElasticNetValidParams<F>
73where
74    F: Float,
75    T: AsMultiTargets<Elem = F>,
76    D: Data<Elem = F>,
77{
78    type Object = MultiTaskElasticNet<F>;
79
80    /// Fit a multi-task Elastic Net model given a feature matrix `x` and a target
81    /// matrix `y`.
82    ///
83    /// The feature matrix `x` must have shape `(n_samples, n_features)`
84    ///
85    /// The target variable `y` must have shape `(n_samples, n_tasks)`
86    ///
87    /// Returns a `FittedMultiTaskElasticNet` object which contains the fitted
88    /// parameters and can be used to `predict` values of the target variables
89    /// for new feature values.
90    fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
91        let targets = dataset.targets().as_multi_targets();
92        let (intercept, y) = compute_intercept(self.with_intercept(), targets);
93
94        let (hyperplane, duality_gap, n_steps) = block_coordinate_descent(
95            dataset.records().view(),
96            y.view(),
97            self.tolerance(),
98            self.max_iterations(),
99            self.l1_ratio(),
100            self.penalty(),
101        );
102
103        let y_est = dataset.records().dot(&hyperplane) + &intercept;
104
105        // try to calculate the variance
106        let variance = variance_params(dataset, y_est.view());
107
108        Ok(MultiTaskElasticNet {
109            hyperplane,
110            intercept,
111            duality_gap,
112            n_steps,
113            variance,
114        })
115    }
116}
117
118impl<F: Float, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Ix2>, Array1<F>> for ElasticNet<F> {
119    /// Given an input matrix `X`, with shape `(n_samples, n_features)`,
120    /// `predict` returns the target variable according to elastic net
121    /// learned from the training data distribution.
122    fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<F>) {
123        assert_eq!(
124            x.nrows(),
125            y.len(),
126            "The number of data points must match the number of output targets."
127        );
128
129        *y = x.dot(&self.hyperplane) + self.intercept;
130    }
131
132    fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<F> {
133        Array1::zeros(x.nrows())
134    }
135}
136
137impl<F: Float, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Ix2>, Array2<F>>
138    for MultiTaskElasticNet<F>
139{
140    /// Given an input matrix `X`, with shape `(n_samples, n_features)`,
141    /// `predict` returns the target variable according to elastic net
142    /// learned from the training data distribution.
143    fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array2<F>) {
144        assert_eq!(
145            x.nrows(),
146            y.nrows(),
147            "The number of data points must match the number of output targets."
148        );
149
150        *y = x.dot(&self.hyperplane) + &self.intercept;
151    }
152
153    fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array2<F> {
154        // TODO: fix, should be (x.nrows(), y.ncols())
155        Array2::zeros((x.nrows(), x.nrows()))
156    }
157}
158
159/// View the fitted parameters and make predictions with a fitted
160/// elastic net model
161impl<F: Float> ElasticNet<F> {
162    /// Get the fitted hyperplane
163    pub fn hyperplane(&self) -> &Array1<F> {
164        &self.hyperplane
165    }
166
167    /// Get the fitted intercept, 0. if no intercept was fitted
168    pub fn intercept(&self) -> F {
169        self.intercept
170    }
171
172    /// Get the number of steps taken in optimization algorithm
173    pub fn n_steps(&self) -> u32 {
174        self.n_steps
175    }
176
177    /// Get the duality gap at the end of the optimization algorithm
178    pub fn duality_gap(&self) -> F {
179        self.duality_gap
180    }
181
182    /// Calculate the Z score
183    pub fn z_score(&self) -> Result<Array1<F>> {
184        self.variance
185            .as_ref()
186            .map(|variance| {
187                self.hyperplane
188                    .iter()
189                    .zip(variance.iter())
190                    .map(|(a, b)| *a / b.sqrt())
191                    .collect()
192            })
193            .map_err(|err| err.clone())
194    }
195
196    /// Calculate the confidence level
197    pub fn confidence_95th(&self) -> Result<Array1<(F, F)>> {
198        // the 95th percentile of our confidence level
199        let p = F::cast(1.645);
200
201        self.variance
202            .as_ref()
203            .map(|variance| {
204                self.hyperplane
205                    .iter()
206                    .zip(variance.iter())
207                    .map(|(a, b)| (*a - p * b.sqrt(), *a + p * b.sqrt()))
208                    .collect()
209            })
210            .map_err(|err| err.clone())
211    }
212}
213
214/// View the fitted parameters and make predictions with a fitted
215/// elastic net model
216impl<F: Float> MultiTaskElasticNet<F> {
217    /// Get the fitted hyperplane
218    pub fn hyperplane(&self) -> &Array2<F> {
219        &self.hyperplane
220    }
221
222    /// Get the fitted intercept, [0., ..., 0.] if no intercept was fitted
223    /// Note that there are as many intercepts as tasks
224    pub fn intercept(&self) -> &Array1<F> {
225        &self.intercept
226    }
227
228    /// Get the number of steps taken in optimization algorithm
229    pub fn n_steps(&self) -> u32 {
230        self.n_steps
231    }
232
233    /// Get the duality gap at the end of the optimization algorithm
234    pub fn duality_gap(&self) -> F {
235        self.duality_gap
236    }
237
238    /// Calculate the Z score
239    pub fn z_score(&self) -> Result<Array2<F>> {
240        self.variance
241            .as_ref()
242            .map(|variance| {
243                ndarray::Zip::from(&self.hyperplane)
244                    .and_broadcast(variance)
245                    .map_collect(|a, b| *a / b.sqrt())
246            })
247            .map_err(|err| err.clone())
248    }
249
250    /// Calculate the confidence level
251    pub fn confidence_95th(&self) -> Result<Array2<(F, F)>> {
252        // the 95th percentile of our confidence level
253        let p = F::cast(1.645);
254
255        self.variance
256            .as_ref()
257            .map(|variance| {
258                ndarray::Zip::from(&self.hyperplane)
259                    .and_broadcast(variance)
260                    .map_collect(|a, b| (*a - p * b.sqrt(), *a + p * b.sqrt()))
261            })
262            .map_err(|err| err.clone())
263    }
264}
265
266fn coordinate_descent<'a, F: Float>(
267    x: ArrayView2<'a, F>,
268    y: ArrayView1<'a, F>,
269    tol: F,
270    max_steps: u32,
271    l1_ratio: F,
272    penalty: F,
273) -> (Array1<F>, F, u32) {
274    let n_samples = F::cast(x.nrows());
275    let n_features = x.ncols();
276    // the parameters of the model
277    let mut w = Array1::<F>::zeros(n_features);
278    // the residuals: `y - X*w` (since w=0, this is just `y` for now),
279    // the residuals are updated during the algorithm as the parameters change
280    let mut r = y.to_owned();
281    let mut n_steps = 0u32;
282    let norm_cols_x = x.map_axis(Axis(0), |col| col.dot(&col));
283    let mut gap = F::one() + tol;
284    let d_w_tol = tol;
285    let tol = tol * y.dot(&y);
286    while n_steps < max_steps {
287        let mut w_max = F::zero();
288        let mut d_w_max = F::zero();
289        for j in 0..n_features {
290            if abs_diff_eq!(norm_cols_x[j], F::zero()) {
291                continue;
292            }
293            let old_w_j = w[j];
294            let x_j: ArrayView1<F> = x.slice(s![.., j]);
295            if abs_diff_ne!(old_w_j, F::zero()) {
296                r.scaled_add(old_w_j, &x_j);
297            }
298            let tmp: F = x_j.dot(&r);
299            w[j] = tmp.signum() * F::max(tmp.abs() - n_samples * l1_ratio * penalty, F::zero())
300                / (norm_cols_x[j] + n_samples * (F::one() - l1_ratio) * penalty);
301            if abs_diff_ne!(w[j], F::zero()) {
302                r.scaled_add(-w[j], &x_j);
303            }
304            let d_w_j = (w[j] - old_w_j).abs();
305            d_w_max = F::max(d_w_max, d_w_j);
306            w_max = F::max(w_max, w[j].abs());
307        }
308        n_steps += 1;
309
310        if n_steps == max_steps - 1 || abs_diff_eq!(w_max, F::zero()) || d_w_max / w_max < d_w_tol {
311            // We've hit one potential stopping criteria
312            // check duality gap for ultimate stopping criterion
313            gap = duality_gap(x.view(), y.view(), w.view(), r.view(), l1_ratio, penalty);
314            if gap < tol {
315                break;
316            }
317        }
318    }
319    (w, gap, n_steps)
320}
321
322fn block_coordinate_descent<'a, F: Float>(
323    x: ArrayView2<'a, F>,
324    y: ArrayView2<'a, F>,
325    tol: F,
326    max_steps: u32,
327    l1_ratio: F,
328    penalty: F,
329) -> (Array2<F>, F, u32) {
330    let n_samples = F::cast(x.nrows());
331    let n_features = x.ncols();
332    let n_tasks = y.ncols();
333    // the parameters of the model
334    let mut w = Array2::<F>::zeros((n_features, n_tasks));
335    // the residuals: `Y - XW` (since W=0, this is just `Y` for now),
336    // the residuals are updated during the algorithm as the parameters change
337    let mut r = y.to_owned();
338    let mut n_steps = 0u32;
339    let norm_cols_x = x.map_axis(Axis(0), |col| col.dot(&col));
340    let mut gap = F::one() + tol;
341    let d_w_tol = tol;
342    let tol = tol * y.iter().map(|&y_ij| y_ij * y_ij).sum();
343    while n_steps < max_steps {
344        let mut w_max = F::zero();
345        let mut d_w_max = F::zero();
346        for j in 0..n_features {
347            if abs_diff_eq!(norm_cols_x[j], F::zero()) {
348                continue;
349            }
350            let mut old_w_j = w.slice_mut(s![j, ..]);
351            let x_j = x.slice(s![.., j]);
352            let norm_old_w_j = old_w_j.dot(&old_w_j).sqrt();
353            if abs_diff_ne!(norm_old_w_j, F::zero()) {
354                // r += outer(x_j, old_w_j)
355                general_mat_mul(
356                    F::one(),
357                    &x_j.view().insert_axis(Axis(1)),
358                    &old_w_j.view().insert_axis(Axis(0)),
359                    F::one(),
360                    &mut r,
361                );
362            }
363            let tmp = x_j.dot(&r);
364            old_w_j.assign(
365                &(block_soft_thresholding(tmp.view(), n_samples * l1_ratio * penalty)
366                    / (norm_cols_x[j] + n_samples * (F::one() - l1_ratio) * penalty)),
367            );
368            let norm_w_j = old_w_j.dot(&old_w_j).sqrt();
369            if abs_diff_ne!(norm_w_j, F::zero()) {
370                // r -= outer(x_j, old_w_j)
371                general_mat_mul(
372                    -F::one(),
373                    &x_j.insert_axis(Axis(1)),
374                    &old_w_j.insert_axis(Axis(0)),
375                    F::one(),
376                    &mut r,
377                );
378            }
379            let d_w_j = (norm_w_j - norm_old_w_j).abs();
380            d_w_max = F::max(d_w_max, d_w_j);
381            w_max = F::max(w_max, norm_w_j);
382        }
383        n_steps += 1;
384
385        if n_steps == max_steps - 1 || abs_diff_eq!(w_max, F::zero()) || d_w_max / w_max < d_w_tol {
386            // We've hit one potential stopping criteria
387            // check duality gap for ultimate stopping criterion
388            gap = duality_gap_mtl(x.view(), y.view(), w.view(), r.view(), l1_ratio, penalty);
389            if gap < tol {
390                break;
391            }
392        }
393    }
394
395    (w, gap, n_steps)
396}
397
398// Algorithm based off of this post: https://math.stackexchange.com/questions/2045579/deriving-block-soft-threshold-from-l-2-norm-prox-operator
399fn block_soft_thresholding<F: Float>(x: ArrayView1<F>, threshold: F) -> Array1<F> {
400    let norm_x = x.dot(&x).sqrt();
401    if norm_x < threshold {
402        return Array1::<F>::zeros(x.len());
403    }
404    let scale = F::one() - threshold / norm_x;
405    &x * scale
406}
407
408fn duality_gap<'a, F: Float>(
409    x: ArrayView2<'a, F>,
410    y: ArrayView1<'a, F>,
411    w: ArrayView1<'a, F>,
412    r: ArrayView1<'a, F>,
413    l1_ratio: F,
414    penalty: F,
415) -> F {
416    let half = F::cast(0.5);
417    let n_samples = F::cast(x.nrows());
418    let l1_reg = l1_ratio * penalty * n_samples;
419    let l2_reg = (F::one() - l1_ratio) * penalty * n_samples;
420    let xta = x.t().dot(&r) - &w * l2_reg;
421
422    let dual_norm_xta = xta.norm_max();
423    let r_norm2 = r.dot(&r);
424    let w_norm2 = w.dot(&w);
425    let (const_, mut gap) = if dual_norm_xta > l1_reg {
426        let const_ = l1_reg / dual_norm_xta;
427        let a_norm2 = r_norm2 * const_ * const_;
428        (const_, half * (r_norm2 + a_norm2))
429    } else {
430        (F::one(), r_norm2)
431    };
432    let l1_norm = w.norm_l1();
433    gap += l1_reg * l1_norm - const_ * r.dot(&y)
434        + half * l2_reg * (F::one() + const_ * const_) * w_norm2;
435    gap
436}
437
438fn duality_gap_mtl<'a, F: Float>(
439    x: ArrayView2<'a, F>,
440    y: ArrayView2<'a, F>,
441    w: ArrayView2<'a, F>,
442    r: ArrayView2<'a, F>,
443    l1_ratio: F,
444    penalty: F,
445) -> F {
446    let half = F::cast(0.5);
447    let n_samples = F::cast(x.nrows());
448    let l1_reg = l1_ratio * penalty * n_samples;
449    let l2_reg = (F::one() - l1_ratio) * penalty * n_samples;
450    let xta = x.t().dot(&r) - &w * l2_reg;
451
452    let dual_norm_xta = xta.map_axis(Axis(1), |x| x.dot(&x).sqrt()).norm_max();
453    let r_norm2 = r.iter().map(|&rij| rij * rij).sum();
454    let w_norm2 = w.iter().map(|&wij| wij * wij).sum();
455    let (const_, mut gap) = if dual_norm_xta > l1_reg {
456        let const_ = l1_reg / dual_norm_xta;
457        let a_norm2 = r_norm2 * const_ * const_;
458        (const_, half * (r_norm2 + a_norm2))
459    } else {
460        (F::one(), r_norm2)
461    };
462    let rty = r.t().dot(&y);
463    let trace_rty = rty.diag().sum();
464    let l21_norm = w.map_axis(Axis(1), |wj| (wj.dot(&wj)).sqrt()).sum();
465    gap += l1_reg * l21_norm - const_ * trace_rty
466        + half * l2_reg * (F::one() + const_ * const_) * w_norm2;
467    gap
468}
469
470fn variance_params<F: Float, T: AsTargets<Elem = F>, D: Data<Elem = F>>(
471    ds: &DatasetBase<ArrayBase<D, Ix2>, T>,
472    y_est: ArrayView<F, T::Ix>,
473) -> Result<Array1<F>> {
474    let nfeatures = ds.nfeatures();
475    let nsamples = ds.nsamples();
476
477    let target = ds.targets().as_targets();
478    let ndim = target.ndim();
479
480    let ntasks: usize = match ndim {
481        1 => 1,
482        2 => *target.shape().last().unwrap(),
483        _ => {
484            return Err(ElasticNetError::IncorrectTargetShape);
485        }
486    };
487
488    let y_est = y_est.as_targets();
489
490    // check that we have enough samples
491    if nsamples < nfeatures + 1 {
492        return Err(ElasticNetError::NotEnoughSamples);
493    }
494
495    let var_target =
496        (&target - &y_est).mapv(|x| x * x).sum() / F::cast(ntasks * (nsamples - nfeatures));
497
498    // `A.t * A` always produces a symmetric matrix
499    let ds2 = ds.records().t().dot(ds.records()).with_lapack();
500    #[cfg(feature = "blas")]
501    let inv_cov = ds2.invh_into();
502    #[cfg(not(feature = "blas"))]
503    let inv_cov = (|| ds2.qr_into()?.inverse())();
504
505    match inv_cov {
506        Ok(inv_cov) => Ok(inv_cov.without_lapack().diag().mapv(|x| var_target * x)),
507        Err(_) => Err(ElasticNetError::IllConditioned),
508    }
509}
510
511/// Compute the intercept as the mean of `y` along each column and center `y` if an intercept
512/// should be used, use 0 as intercept and leave `y` unchanged otherwise.
513/// If `y` is 2D, mean is 1D and center is 2D. If `y` is 1D, mean is a number and center is 1D.
514fn compute_intercept<F: Float, I: RemoveAxis>(
515    with_intercept: bool,
516    y: ArrayView<F, I>,
517) -> (Array<F, I::Smaller>, CowArray<F, I>)
518where
519    I::Smaller: Dimension<Larger = I>,
520{
521    if with_intercept {
522        let y_mean = y
523            // Take the mean of each column (1D array counts as 1 column)
524            .mean_axis(Axis(0))
525            .expect("Axis 0 length of 0");
526        // Subtract y_mean from each "row" of y
527        let y_centered = &y - &y_mean.view().insert_axis(Axis(0));
528        (y_mean, y_centered.into())
529    } else {
530        (Array::zeros(y.raw_dim().remove_axis(Axis(0))), y.into())
531    }
532}
533
534#[cfg(test)]
535mod tests {
536    use super::{block_coordinate_descent, coordinate_descent, ElasticNet, MultiTaskElasticNet};
537    use approx::assert_abs_diff_eq;
538    use ndarray::{array, s, Array, Array1, Array2, Axis};
539    use ndarray_rand::rand::SeedableRng;
540    use ndarray_rand::rand_distr::Uniform;
541    use ndarray_rand::RandomExt;
542    use rand_xoshiro::Xoshiro256Plus;
543
544    use crate::{ElasticNetError, ElasticNetParams, ElasticNetValidParams};
545    use linfa::{
546        metrics::SingleTargetRegression,
547        traits::{Fit, Predict},
548        Dataset,
549    };
550
551    #[test]
552    fn autotraits() {
553        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
554        has_autotraits::<ElasticNet<f64>>();
555        has_autotraits::<ElasticNetParams<f64>>();
556        has_autotraits::<ElasticNetValidParams<f64>>();
557        has_autotraits::<ElasticNetError>();
558    }
559
560    fn elastic_net_objective(
561        x: &Array2<f64>,
562        y: &Array1<f64>,
563        intercept: f64,
564        beta: &Array1<f64>,
565        alpha: f64,
566        lambda: f64,
567    ) -> f64 {
568        squared_error(x, y, intercept, beta) + lambda * elastic_net_penalty(beta, alpha)
569    }
570
571    fn elastic_net_multi_task_objective(
572        x: &Array2<f64>,
573        y: &Array2<f64>,
574        intercept: &Array1<f64>,
575        beta: &Array2<f64>,
576        alpha: f64,
577        lambda: f64,
578    ) -> f64 {
579        squared_error_mtl(x, y, intercept, beta) + lambda * elastic_net_mtl_penalty(beta, alpha)
580    }
581
582    fn squared_error(x: &Array2<f64>, y: &Array1<f64>, intercept: f64, beta: &Array1<f64>) -> f64 {
583        let mut resid = -x.dot(beta);
584        resid -= intercept;
585        resid += y;
586        let mut result = 0.0;
587        for r in &resid {
588            result += r * r;
589        }
590        result /= 2.0 * y.len() as f64;
591        result
592    }
593
594    fn squared_error_mtl(
595        x: &Array2<f64>,
596        y: &Array2<f64>,
597        intercept: &Array1<f64>,
598        beta: &Array2<f64>,
599    ) -> f64 {
600        let mut resid = x.dot(beta);
601        resid = &resid * -1.;
602        resid = &resid - intercept + y;
603        let mut datafit = resid.iter().map(|rij| rij * rij).sum();
604        datafit /= 2.0 * x.nrows() as f64;
605        datafit
606    }
607
608    fn elastic_net_penalty(beta: &Array1<f64>, alpha: f64) -> f64 {
609        let mut penalty = 0.0;
610        for beta_j in beta {
611            penalty += (1.0 - alpha) / 2.0 * beta_j * beta_j + alpha * beta_j.abs();
612        }
613        penalty
614    }
615
616    fn elastic_net_mtl_penalty(beta: &Array2<f64>, alpha: f64) -> f64 {
617        let frob_norm: f64 = beta.iter().map(|beta_ij| beta_ij * beta_ij).sum();
618        let l21_norm = beta
619            .map_axis(Axis(1), |beta_j| (beta_j.dot(&beta_j)).sqrt())
620            .sum();
621        (1.0 - alpha) / 2.0 * frob_norm + alpha * l21_norm
622    }
623
624    #[test]
625    fn elastic_net_penalty_works() {
626        let beta = array![-2.0, 1.0];
627        assert_abs_diff_eq!(
628            elastic_net_penalty(&beta, 0.8),
629            0.4 + 0.1 + 1.6 + 0.8,
630            epsilon = 1e-12
631        );
632        assert_abs_diff_eq!(elastic_net_penalty(&beta, 1.0), 3.0);
633        assert_abs_diff_eq!(elastic_net_penalty(&beta, 0.0), 2.5);
634
635        let beta2 = array![0.0, 0.0];
636        assert_abs_diff_eq!(elastic_net_penalty(&beta2, 0.8), 0.0);
637        assert_abs_diff_eq!(elastic_net_penalty(&beta2, 1.0), 0.0);
638        assert_abs_diff_eq!(elastic_net_penalty(&beta2, 0.0), 0.0);
639    }
640
641    #[test]
642    fn elastic_net_mtl_penalty_works() {
643        let beta = array![[-2.0, 1.0, 3.0], [3.0, 1.5, -1.7]];
644        assert_abs_diff_eq!(
645            elastic_net_mtl_penalty(&beta, 0.7),
646            9.472383565516601,
647            epsilon = 1e-12
648        );
649        assert_abs_diff_eq!(
650            elastic_net_mtl_penalty(&beta, 1.0),
651            7.501976522166574,
652            epsilon = 1e-12
653        );
654        assert_abs_diff_eq!(
655            elastic_net_mtl_penalty(&beta, 0.2),
656            12.756395304433315,
657            epsilon = 1e-12
658        );
659
660        let beta2 = array![[0., 0.], [0., 0.], [0., 0.]];
661        assert_abs_diff_eq!(elastic_net_mtl_penalty(&beta2, 0.8), 0.0);
662        assert_abs_diff_eq!(elastic_net_mtl_penalty(&beta2, 1.2), 0.0);
663        assert_abs_diff_eq!(elastic_net_mtl_penalty(&beta2, 0.8), 0.0);
664    }
665
666    #[test]
667    fn squared_error_works() {
668        let x = array![[2.0, 1.0], [-1.0, 2.0]];
669        let y = array![1.0, 1.0];
670        let beta = array![0.0, 1.0];
671        assert_abs_diff_eq!(squared_error(&x, &y, 0.0, &beta), 0.25);
672    }
673
674    #[test]
675    fn squared_error_mtl_works() {
676        let x = array![[1.2, 2.3], [-1.3, 0.3], [-1.3, 0.1]];
677        let y = array![
678            [0.2, 1.0, 0.0, 1.],
679            [-0.3, 0.7, 0.1, 2.],
680            [-0.3, 0.7, 2.3, 3.]
681        ];
682        let beta = array![[2.3, 4.5, 1.2, -3.4], [1.2, -3.4, 0.7, -1.2]];
683        assert_abs_diff_eq!(
684            squared_error_mtl(&x, &y, &array![0., 0., 0., 0.], &beta),
685            41.66298333333333
686        );
687        let intercept = array![1., 3., 2., 0.3];
688        assert_abs_diff_eq!(
689            squared_error_mtl(&x, &y, &intercept, &beta),
690            29.059983333333335
691        );
692    }
693
694    #[test]
695    fn coordinate_descent_lowers_objective() {
696        let x = array![[1.0, 0.0], [0.0, 1.0]];
697        let y = array![1.0, -1.0];
698        let beta = array![0.0, 0.0];
699        let intercept = 0.0;
700        let alpha = 0.8;
701        let lambda = 0.001;
702        let objective_start = elastic_net_objective(&x, &y, intercept, &beta, alpha, lambda);
703        let opt_result = coordinate_descent(x.view(), y.view(), 1e-4, 3, alpha, lambda);
704        let objective_end = elastic_net_objective(&x, &y, intercept, &opt_result.0, alpha, lambda);
705        assert!(objective_start > objective_end);
706    }
707
708    #[test]
709    fn block_coordinate_descent_lowers_objective() {
710        let x = array![[1.0, 0., -0.3, 3.2], [0.3, 1.2, -0.6, 1.2]];
711        let y = array![[0.3, -1.2, 0.7], [1.4, -3.2, 0.2]];
712        let beta = array![[0., 0., 0.], [0., 0., 0.], [0., 0., 0.], [0., 0., 0.]];
713        let intercept = array![0., 0., 0.];
714        let alpha = 0.4;
715        let lambda = 0.002;
716        let objective_start =
717            elastic_net_multi_task_objective(&x, &y, &intercept, &beta, alpha, lambda);
718        let opt_result = block_coordinate_descent(x.view(), y.view(), 1e-4, 3, alpha, lambda);
719        let objective_end =
720            elastic_net_multi_task_objective(&x, &y, &intercept, &opt_result.0, alpha, lambda);
721        assert!(objective_start > objective_end);
722    }
723
724    #[test]
725    fn lasso_zero_works() {
726        let dataset = Dataset::from((array![[0.], [0.], [0.]], array![0., 0., 0.]));
727
728        let model = ElasticNet::params()
729            .l1_ratio(1.0)
730            .penalty(0.1)
731            .fit(&dataset)
732            .unwrap();
733
734        assert_abs_diff_eq!(model.intercept(), 0.);
735        assert_abs_diff_eq!(model.hyperplane(), &array![0.]);
736    }
737
738    #[test]
739    fn mtl_lasso_zero_works() {
740        let dataset = Dataset::from((array![[0.], [0.], [0.]], array![[0.], [0.], [0.]]));
741
742        let model = MultiTaskElasticNet::params()
743            .l1_ratio(1.0)
744            .penalty(0.1)
745            .fit(&dataset)
746            .unwrap();
747
748        assert_abs_diff_eq!(model.intercept(), &array![0.]);
749        assert_abs_diff_eq!(model.hyperplane(), &array![[0.]]);
750    }
751
752    #[test]
753    fn lasso_toy_example_works() {
754        // Test Lasso on a toy example for various values of alpha.
755        // When validating this against glmnet notice that glmnet divides it
756        // against n_samples.
757        let dataset = Dataset::new(array![[-1.0], [0.0], [1.0]], array![-1.0, 0.0, 1.0]);
758
759        // input for prediction
760        let t = array![[2.0], [3.0], [4.0]];
761        let model = ElasticNet::lasso().penalty(1e-8).fit(&dataset).unwrap();
762        assert_abs_diff_eq!(model.intercept(), 0.0);
763        assert_abs_diff_eq!(model.hyperplane(), &array![1.0], epsilon = 1e-6);
764        assert_abs_diff_eq!(model.predict(&t), array![2.0, 3.0, 4.0], epsilon = 1e-6);
765        assert_abs_diff_eq!(model.duality_gap(), 0.0);
766
767        let model = ElasticNet::lasso().penalty(0.1).fit(&dataset).unwrap();
768        assert_abs_diff_eq!(model.intercept(), 0.0);
769        assert_abs_diff_eq!(model.hyperplane(), &array![0.85], epsilon = 1e-6);
770        assert_abs_diff_eq!(model.predict(&t), array![1.7, 2.55, 3.4], epsilon = 1e-6);
771        assert_abs_diff_eq!(model.duality_gap(), 0.0);
772
773        let model = ElasticNet::lasso().penalty(0.5).fit(&dataset).unwrap();
774        assert_abs_diff_eq!(model.intercept(), 0.0);
775        assert_abs_diff_eq!(model.hyperplane(), &array![0.25], epsilon = 1e-6);
776        assert_abs_diff_eq!(model.predict(&t), array![0.5, 0.75, 1.0], epsilon = 1e-6);
777        assert_abs_diff_eq!(model.duality_gap(), 0.0);
778
779        let model = ElasticNet::lasso().penalty(1.0).fit(&dataset).unwrap();
780        assert_abs_diff_eq!(model.intercept(), 0.0);
781        assert_abs_diff_eq!(model.hyperplane(), &array![0.0], epsilon = 1e-6);
782        assert_abs_diff_eq!(model.predict(&t), array![0.0, 0.0, 0.0], epsilon = 1e-6);
783        assert_abs_diff_eq!(model.duality_gap(), 0.0);
784    }
785
786    #[test]
787    fn multitask_lasso_toy_example_works() {
788        // Test MultiTaskLasso on a toy example for various values of alpha.
789        // When validating this against sklearn notice that sklearn divides it
790        // against n_samples.
791        let dataset = Dataset::new(
792            array![[-1.0], [0.0], [1.0]],
793            array![[-1.0, 1.0], [0.0, -1.5], [1.0, 1.3]],
794        );
795
796        // no intercept fitting
797        let t = array![[2.0], [3.0], [4.0]];
798        let model = MultiTaskElasticNet::lasso()
799            .with_intercept(false)
800            .penalty(0.01)
801            .fit(&dataset)
802            .unwrap();
803        assert_abs_diff_eq!(model.intercept(), &array![0., 0.]);
804        assert_abs_diff_eq!(
805            model.hyperplane(),
806            &array![[0.9851659, 0.1477748]],
807            epsilon = 1e-6
808        );
809        assert_abs_diff_eq!(
810            model.predict(&t),
811            array![
812                [1.9703319, 0.2955497],
813                [2.9554978, 0.4433246],
814                [3.9406638, 0.5910995]
815            ],
816            epsilon = 1e-6
817        );
818        assert_abs_diff_eq!(model.duality_gap(), 0.0, epsilon = 1e-9);
819
820        // input for prediction
821        let t = array![[2.0], [3.0], [4.0]];
822        let model = MultiTaskElasticNet::lasso()
823            .penalty(1e-8)
824            .fit(&dataset)
825            .unwrap();
826        assert_abs_diff_eq!(model.intercept(), &array![0., 0.2666666667], epsilon = 1e-6);
827        assert_abs_diff_eq!(model.hyperplane(), &array![[1., 0.15]], epsilon = 1e-6);
828        assert_abs_diff_eq!(
829            model.predict(&t),
830            array![
831                [1.99999997, 0.56666666],
832                [2.99999996, 0.71666666],
833                [3.99999994, 0.86666666]
834            ],
835            epsilon = 1e-6
836        );
837        assert_abs_diff_eq!(model.duality_gap(), 0.0, epsilon = 1e-9);
838
839        let model = MultiTaskElasticNet::lasso()
840            .penalty(0.1)
841            .fit(&dataset)
842            .unwrap();
843        assert_abs_diff_eq!(model.intercept(), &array![0., 0.2666666667], epsilon = 1e-6);
844        assert_abs_diff_eq!(
845            model.hyperplane(),
846            &array![[0.851659, 0.127749]],
847            epsilon = 1e-6
848        );
849        assert_abs_diff_eq!(
850            model.predict(&t),
851            &array![
852                [1.70331909, 0.52216453],
853                [2.55497864, 0.64991346],
854                [3.40663819, 0.77766239]
855            ],
856            epsilon = 1e-6
857        );
858        assert_abs_diff_eq!(model.duality_gap(), 0.0, epsilon = 1e-9);
859
860        let model = MultiTaskElasticNet::lasso()
861            .penalty(0.5)
862            .fit(&dataset)
863            .unwrap();
864        assert_abs_diff_eq!(model.intercept(), &array![0., 0.2666666667], epsilon = 1e-6);
865        assert_abs_diff_eq!(
866            model.hyperplane(),
867            &array![[0.258298, 0.038744]],
868            epsilon = 1e-6
869        );
870        assert_abs_diff_eq!(
871            model.predict(&t),
872            &array![
873                [0.51659547, 0.34415599],
874                [0.77489321, 0.38290065],
875                [1.03319094, 0.42164531]
876            ],
877            epsilon = 1e-6
878        );
879        assert_abs_diff_eq!(model.duality_gap(), 0.0, epsilon = 1e-6);
880
881        let model = MultiTaskElasticNet::lasso()
882            .penalty(1.0)
883            .fit(&dataset)
884            .unwrap();
885        assert_abs_diff_eq!(model.intercept(), &array![0., 0.2666666667], epsilon = 1e-6);
886        assert_abs_diff_eq!(model.hyperplane(), &array![[0.0, 0.0]], epsilon = 1e-6);
887        assert_abs_diff_eq!(
888            model.predict(&t),
889            &array![[0., 0.2666666667], [0., 0.2666666667], [0., 0.2666666667]],
890            epsilon = 1e-6
891        );
892        assert_abs_diff_eq!(model.duality_gap(), 0.0, epsilon = 1e-6);
893    }
894
895    #[test]
896    fn elastic_net_toy_example_works() {
897        let dataset = Dataset::new(array![[-1.0], [0.0], [1.0]], array![-1.0, 0.0, 1.0]);
898
899        // for predictions
900        let t = array![[2.0], [3.0], [4.0]];
901        let model = ElasticNet::params()
902            .l1_ratio(0.3)
903            .penalty(0.5)
904            .fit(&dataset)
905            .unwrap();
906
907        assert_abs_diff_eq!(model.intercept(), 0.0);
908        assert_abs_diff_eq!(model.hyperplane(), &array![0.50819], epsilon = 1e-3);
909        assert_abs_diff_eq!(
910            model.predict(&t),
911            array![1.0163, 1.5245, 2.0327],
912            epsilon = 1e-3
913        );
914        assert_abs_diff_eq!(model.duality_gap(), 0.0);
915
916        let model = ElasticNet::params()
917            .l1_ratio(0.5)
918            .penalty(0.5)
919            .fit(&dataset)
920            .unwrap();
921
922        assert_abs_diff_eq!(model.intercept(), 0.0);
923        assert_abs_diff_eq!(model.hyperplane(), &array![0.45454], epsilon = 1e-3);
924        assert_abs_diff_eq!(
925            model.predict(&t),
926            array![0.9090, 1.3636, 1.8181],
927            epsilon = 1e-3
928        );
929        assert_abs_diff_eq!(model.duality_gap(), 0.0);
930    }
931
932    #[test]
933    fn multitask_elasticnet_toy_example_works() {
934        // Test MultiTaskElasticNet on a toy example for various values of alpha
935        // and l1_ratio. When validating this against sklearn notice that sklearn
936        // divides it against n_samples.
937        let dataset = Dataset::new(
938            array![[-1.0], [0.0], [1.0]],
939            array![[-1.0, 1.0], [0.0, -1.5], [1.0, 1.3]],
940        );
941
942        // no intercept fitting
943        let t = array![[2.0], [3.0], [4.0]];
944        let model = MultiTaskElasticNet::params()
945            .with_intercept(false)
946            .l1_ratio(0.3)
947            .penalty(0.1)
948            .fit(&dataset)
949            .unwrap();
950        assert_abs_diff_eq!(model.intercept(), &array![0., 0.]);
951        assert_abs_diff_eq!(
952            model.hyperplane(),
953            &array![[0.86470395, 0.12970559]],
954            epsilon = 1e-6
955        );
956        assert_abs_diff_eq!(
957            model.predict(&t),
958            array![
959                [1.7294079, 0.25941118],
960                [2.59411185, 0.38911678],
961                [3.4588158, 0.51882237]
962            ],
963            epsilon = 1e-6
964        );
965        assert_abs_diff_eq!(model.duality_gap(), 0.0, epsilon = 1e-12);
966
967        // input for prediction
968        let t = array![[2.0], [3.0], [4.0]];
969        let model = MultiTaskElasticNet::params()
970            .l1_ratio(0.3)
971            .penalty(0.1)
972            .fit(&dataset)
973            .unwrap();
974        assert_abs_diff_eq!(model.intercept(), &array![0., 0.26666666], epsilon = 1e-6);
975        assert_abs_diff_eq!(
976            model.hyperplane(),
977            &array![[0.86470395, 0.12970559]],
978            epsilon = 1e-6
979        );
980        assert_abs_diff_eq!(
981            model.predict(&t),
982            array![
983                [1.7294079, 0.52607785],
984                [2.59411185, 0.65578344],
985                [3.4588158, 0.78548904]
986            ],
987            epsilon = 1e-6
988        );
989        assert_abs_diff_eq!(model.duality_gap(), 0.0, epsilon = 1e-12);
990
991        let model = MultiTaskElasticNet::params()
992            .l1_ratio(0.5)
993            .penalty(0.1)
994            .fit(&dataset)
995            .unwrap();
996        assert_abs_diff_eq!(model.intercept(), &array![0., 0.2666666], epsilon = 1e-6);
997        assert_abs_diff_eq!(
998            model.hyperplane(),
999            &array![[0.861237, 0.12918555]],
1000            epsilon = 1e-6
1001        );
1002        assert_abs_diff_eq!(
1003            model.predict(&t),
1004            &array![
1005                [1.722474, 0.52503777],
1006                [2.583711, 0.65422332],
1007                [3.44494799, 0.78340887]
1008            ],
1009            epsilon = 1e-6
1010        );
1011        assert_abs_diff_eq!(model.duality_gap(), 0.0, epsilon = 1e-12);
1012    }
1013
1014    #[test]
1015    fn elastic_net_2d_toy_example_works() {
1016        let dataset = Dataset::new(array![[1.0, 0.0], [0.0, 1.0]], array![3.0, 2.0]);
1017
1018        let model = ElasticNet::params().penalty(0.0).fit(&dataset).unwrap();
1019        assert_abs_diff_eq!(model.intercept(), 2.5);
1020        assert_abs_diff_eq!(model.hyperplane(), &array![0.5, -0.5], epsilon = 0.001);
1021    }
1022
1023    #[test]
1024    #[allow(clippy::excessive_precision)]
1025    fn elastic_net_diabetes_1_works_like_sklearn() {
1026        // test that elastic net implementation gives very similar results to
1027        // sklearn implementation for the first 20 lines taken from the diabetes
1028        // dataset in linfa/datasets/diabetes_(data|target).csv.gz
1029        #[rustfmt::skip]
1030        let x = array![
1031            [3.807590643342410180e-02, 5.068011873981870252e-02, 6.169620651868849837e-02, 2.187235499495579841e-02, -4.422349842444640161e-02, -3.482076283769860309e-02, -4.340084565202689815e-02, -2.592261998182820038e-03, 1.990842087631829876e-02, -1.764612515980519894e-02],
1032            [-1.882016527791040067e-03, -4.464163650698899782e-02, -5.147406123880610140e-02, -2.632783471735180084e-02, -8.448724111216979540e-03, -1.916333974822199970e-02, 7.441156407875940126e-02, -3.949338287409189657e-02, -6.832974362442149896e-02, -9.220404962683000083e-02],
1033            [8.529890629667830071e-02, 5.068011873981870252e-02, 4.445121333659410312e-02, -5.670610554934250001e-03, -4.559945128264750180e-02, -3.419446591411950259e-02, -3.235593223976569732e-02, -2.592261998182820038e-03, 2.863770518940129874e-03, -2.593033898947460017e-02],
1034            [-8.906293935226029801e-02, -4.464163650698899782e-02, -1.159501450521270051e-02, -3.665644679856060184e-02, 1.219056876180000040e-02, 2.499059336410210108e-02, -3.603757004385269719e-02, 3.430885887772629900e-02, 2.269202256674450122e-02, -9.361911330135799444e-03],
1035            [5.383060374248070309e-03, -4.464163650698899782e-02, -3.638469220447349689e-02, 2.187235499495579841e-02, 3.934851612593179802e-03, 1.559613951041610019e-02, 8.142083605192099172e-03, -2.592261998182820038e-03, -3.199144494135589684e-02, -4.664087356364819692e-02],
1036            [-9.269547780327989928e-02, -4.464163650698899782e-02, -4.069594049999709917e-02, -1.944209332987930153e-02, -6.899064987206669775e-02, -7.928784441181220555e-02, 4.127682384197570165e-02, -7.639450375000099436e-02, -4.118038518800790082e-02, -9.634615654166470144e-02],
1037            [-4.547247794002570037e-02, 5.068011873981870252e-02, -4.716281294328249912e-02, -1.599922263614299983e-02, -4.009563984984299695e-02, -2.480001206043359885e-02, 7.788079970179680352e-04, -3.949338287409189657e-02, -6.291294991625119570e-02, -3.835665973397880263e-02],
1038            [6.350367559056099842e-02, 5.068011873981870252e-02, -1.894705840284650021e-03, 6.662967401352719310e-02, 9.061988167926439408e-02, 1.089143811236970016e-01, 2.286863482154040048e-02, 1.770335448356720118e-02, -3.581672810154919867e-02, 3.064409414368320182e-03],
1039            [4.170844488444359899e-02, 5.068011873981870252e-02, 6.169620651868849837e-02, -4.009931749229690007e-02, -1.395253554402150001e-02, 6.201685656730160021e-03, -2.867429443567860031e-02, -2.592261998182820038e-03, -1.495647502491130078e-02, 1.134862324403770016e-02],
1040            [-7.090024709716259699e-02, -4.464163650698899782e-02, 3.906215296718960200e-02, -3.321357610482440076e-02, -1.257658268582039982e-02, -3.450761437590899733e-02, -2.499265663159149983e-02, -2.592261998182820038e-03, 6.773632611028609918e-02, -1.350401824497050006e-02],
1041            [-9.632801625429950054e-02, -4.464163650698899782e-02, -8.380842345523309422e-02, 8.100872220010799790e-03, -1.033894713270950005e-01, -9.056118903623530669e-02, -1.394774321933030074e-02, -7.639450375000099436e-02, -6.291294991625119570e-02, -3.421455281914410201e-02],
1042            [2.717829108036539862e-02, 5.068011873981870252e-02, 1.750591148957160101e-02, -3.321357610482440076e-02, -7.072771253015849857e-03, 4.597154030400080194e-02, -6.549067247654929980e-02, 7.120997975363539678e-02, -9.643322289178400675e-02, -5.906719430815229877e-02],
1043            [1.628067572730669890e-02, -4.464163650698899782e-02, -2.884000768730720157e-02, -9.113481248670509197e-03, -4.320865536613589623e-03, -9.768885894535990141e-03, 4.495846164606279866e-02, -3.949338287409189657e-02, -3.075120986455629965e-02, -4.249876664881350324e-02],
1044            [5.383060374248070309e-03, 5.068011873981870252e-02, -1.894705840284650021e-03, 8.100872220010799790e-03, -4.320865536613589623e-03, -1.571870666853709964e-02, -2.902829807069099918e-03, -2.592261998182820038e-03, 3.839324821169769891e-02, -1.350401824497050006e-02],
1045            [4.534098333546320025e-02, -4.464163650698899782e-02, -2.560657146566450160e-02, -1.255635194240680048e-02, 1.769438019460449832e-02, -6.128357906048329537e-05, 8.177483968693349814e-02, -3.949338287409189657e-02, -3.199144494135589684e-02, -7.563562196749110123e-02],
1046            [-5.273755484206479882e-02, 5.068011873981870252e-02, -1.806188694849819934e-02, 8.040115678847230274e-02, 8.924392882106320368e-02, 1.076617872765389949e-01, -3.971920784793980114e-02, 1.081111006295440019e-01, 3.605579008983190309e-02, -4.249876664881350324e-02],
1047            [-5.514554978810590376e-03, -4.464163650698899782e-02, 4.229558918883229851e-02, 4.941532054484590319e-02, 2.457414448561009990e-02, -2.386056667506489953e-02, 7.441156407875940126e-02, -3.949338287409189657e-02, 5.227999979678119719e-02, 2.791705090337660150e-02],
1048            [7.076875249260000666e-02, 5.068011873981870252e-02, 1.211685112016709989e-02, 5.630106193231849965e-02, 3.420581449301800248e-02, 4.941617338368559792e-02, -3.971920784793980114e-02, 3.430885887772629900e-02, 2.736770754260900093e-02, -1.077697500466389974e-03],
1049            [-3.820740103798660192e-02, -4.464163650698899782e-02, -1.051720243133190055e-02, -3.665644679856060184e-02, -3.734373413344069942e-02, -1.947648821001150138e-02, -2.867429443567860031e-02, -2.592261998182820038e-03, -1.811826730789670159e-02, -1.764612515980519894e-02],
1050            [-2.730978568492789874e-02, -4.464163650698899782e-02, -1.806188694849819934e-02, -4.009931749229690007e-02, -2.944912678412469915e-03, -1.133462820348369975e-02, 3.759518603788870178e-02, -3.949338287409189657e-02, -8.944018957797799166e-03, -5.492508739331759815e-02]
1051        ];
1052        #[rustfmt::skip]
1053        let y = array![1.51e+02, 7.5e+01, 1.41e+02, 2.06e+02, 1.35e+02, 9.7e+01, 1.38e+02, 6.3e+01, 1.1e+02, 3.1e+02, 1.01e+02, 6.9e+01, 1.79e+02, 1.85e+02, 1.18e+02, 1.71e+02, 1.66e+02, 1.44e+02, 9.7e+01, 1.68e+02];
1054        let model = ElasticNet::params()
1055            .l1_ratio(0.2)
1056            .penalty(0.5)
1057            .fit(&Dataset::new(x, y))
1058            .unwrap();
1059
1060        assert_abs_diff_eq!(
1061            model.hyperplane(),
1062            &array![
1063                -2.00558969,
1064                -0.92208413,
1065                1.27586213,
1066                -0.06617076,
1067                0.26484338,
1068                -0.48702845,
1069                -0.60274235,
1070                0.3975141,
1071                4.33229135,
1072                1.11981207
1073            ],
1074            epsilon = 0.01
1075        );
1076        assert_abs_diff_eq!(model.intercept(), 141.283952, epsilon = 1e-1);
1077        assert!(
1078            f64::abs(model.duality_gap()) < 1e-4,
1079            "Duality gap too large"
1080        );
1081    }
1082
1083    #[test]
1084    #[allow(clippy::excessive_precision)]
1085    fn elastic_net_diabetes_2_works_like_sklearn() {
1086        // test that elastic net implementation gives very similar results to
1087        // sklearn implementation for the last 20 lines taken from the diabetes
1088        // dataset in linfa/datasets/diabetes_(data|target).csv.gz
1089        #[rustfmt::skip]
1090        let x = array![
1091            [-7.816532399920170238e-02,5.068011873981870252e-02,7.786338762690199478e-02,5.285819123858220142e-02,7.823630595545419397e-02,6.444729954958319795e-02,2.655027262562750096e-02,-2.592261998182820038e-03,4.067226371449769728e-02,-9.361911330135799444e-03],
1092            [9.015598825267629943e-03,5.068011873981870252e-02,-3.961812842611620034e-02,2.875809638242839833e-02,3.833367306762140020e-02,7.352860494147960002e-02,-7.285394808472339667e-02,1.081111006295440019e-01,1.556684454070180086e-02,-4.664087356364819692e-02],
1093            [1.750521923228520000e-03,5.068011873981870252e-02,1.103903904628619932e-02,-1.944209332987930153e-02,-1.670444126042380101e-02,-3.819065120534880214e-03,-4.708248345611389801e-02,3.430885887772629900e-02,2.405258322689299982e-02,2.377494398854190089e-02],
1094            [-7.816532399920170238e-02,-4.464163650698899782e-02,-4.069594049999709917e-02,-8.141376581713200000e-02,-1.006375656106929944e-01,-1.127947298232920004e-01,2.286863482154040048e-02,-7.639450375000099436e-02,-2.028874775162960165e-02,-5.078298047848289754e-02],
1095            [3.081082953138499989e-02,5.068011873981870252e-02,-3.422906805671169922e-02,4.367720260718979675e-02,5.759701308243719842e-02,6.883137801463659611e-02,-3.235593223976569732e-02,5.755656502954899917e-02,3.546193866076970125e-02,8.590654771106250032e-02],
1096            [-3.457486258696700065e-02,5.068011873981870252e-02,5.649978676881649634e-03,-5.670610554934250001e-03,-7.311850844667000526e-02,-6.269097593696699999e-02,-6.584467611156170040e-03,-3.949338287409189657e-02,-4.542095777704099890e-02,3.205915781821130212e-02],
1097            [4.897352178648269744e-02,5.068011873981870252e-02,8.864150836571099701e-02,8.728689817594480205e-02,3.558176735121919981e-02,2.154596028441720101e-02,-2.499265663159149983e-02,3.430885887772629900e-02,6.604820616309839409e-02,1.314697237742440128e-01],
1098            [-4.183993948900609910e-02,-4.464163650698899782e-02,-3.315125598283080038e-02,-2.288496402361559975e-02,4.658939021682820258e-02,4.158746183894729970e-02,5.600337505832399948e-02,-2.473293452372829840e-02,-2.595242443518940012e-02,-3.835665973397880263e-02],
1099            [-9.147093429830140468e-03,-4.464163650698899782e-02,-5.686312160821060252e-02,-5.042792957350569760e-02,2.182223876920789951e-02,4.534524338042170144e-02,-2.867429443567860031e-02,3.430885887772629900e-02,-9.918957363154769225e-03,-1.764612515980519894e-02],
1100            [7.076875249260000666e-02,5.068011873981870252e-02,-3.099563183506899924e-02,2.187235499495579841e-02,-3.734373413344069942e-02,-4.703355284749029946e-02,3.391354823380159783e-02,-3.949338287409189657e-02,-1.495647502491130078e-02,-1.077697500466389974e-03],
1101            [9.015598825267629943e-03,-4.464163650698899782e-02,5.522933407540309841e-02,-5.670610554934250001e-03,5.759701308243719842e-02,4.471894645684260094e-02,-2.902829807069099918e-03,2.323852261495349888e-02,5.568354770267369691e-02,1.066170822852360034e-01],
1102            [-2.730978568492789874e-02,-4.464163650698899782e-02,-6.009655782985329903e-02,-2.977070541108809906e-02,4.658939021682820258e-02,1.998021797546959896e-02,1.222728555318910032e-01,-3.949338287409189657e-02,-5.140053526058249722e-02,-9.361911330135799444e-03],
1103            [1.628067572730669890e-02,-4.464163650698899782e-02,1.338730381358059929e-03,8.100872220010799790e-03,5.310804470794310353e-03,1.089891258357309975e-02,3.023191042971450082e-02,-3.949338287409189657e-02,-4.542095777704099890e-02,3.205915781821130212e-02],
1104            [-1.277963188084970010e-02,-4.464163650698899782e-02,-2.345094731790270046e-02,-4.009931749229690007e-02,-1.670444126042380101e-02,4.635943347782499856e-03,-1.762938102341739949e-02,-2.592261998182820038e-03,-3.845911230135379971e-02,-3.835665973397880263e-02],
1105            [-5.637009329308430294e-02,-4.464163650698899782e-02,-7.410811479030500470e-02,-5.042792957350569760e-02,-2.496015840963049931e-02,-4.703355284749029946e-02,9.281975309919469896e-02,-7.639450375000099436e-02,-6.117659509433449883e-02,-4.664087356364819692e-02],
1106            [4.170844488444359899e-02,5.068011873981870252e-02,1.966153563733339868e-02,5.974393262605470073e-02,-5.696818394814720174e-03,-2.566471273376759888e-03,-2.867429443567860031e-02,-2.592261998182820038e-03,3.119299070280229930e-02,7.206516329203029904e-03],
1107            [-5.514554978810590376e-03,5.068011873981870252e-02,-1.590626280073640167e-02,-6.764228304218700139e-02,4.934129593323050011e-02,7.916527725369119917e-02,-2.867429443567860031e-02,3.430885887772629900e-02,-1.811826730789670159e-02,4.448547856271539702e-02],
1108            [4.170844488444359899e-02,5.068011873981870252e-02,-1.590626280073640167e-02,1.728186074811709910e-02,-3.734373413344069942e-02,-1.383981589779990050e-02,-2.499265663159149983e-02,-1.107951979964190078e-02,-4.687948284421659950e-02,1.549073015887240078e-02],
1109            [-4.547247794002570037e-02,-4.464163650698899782e-02,3.906215296718960200e-02,1.215130832538269907e-03,1.631842733640340160e-02,1.528299104862660025e-02,-2.867429443567860031e-02,2.655962349378539894e-02,4.452837402140529671e-02,-2.593033898947460017e-02],
1110            [-4.547247794002570037e-02,-4.464163650698899782e-02,-7.303030271642410587e-02,-8.141376581713200000e-02,8.374011738825870577e-02,2.780892952020790065e-02,1.738157847891100005e-01,-3.949338287409189657e-02,-4.219859706946029777e-03,3.064409414368320182e-03]
1111        ];
1112        #[rustfmt::skip]
1113        let y = array![2.33e+02, 9.1e+01, 1.11e+02, 1.52e+02, 1.2e+02, 6.70e+01, 3.1e+02, 9.4e+01, 1.83e+02, 6.6e+01, 1.73e+02, 7.2e+01, 4.9e+01, 6.4e+01, 4.8e+01, 1.78e+02, 1.04e+02, 1.32e+02, 2.20e+02, 5.7e+01];
1114        let model = ElasticNet::params()
1115            .l1_ratio(0.2)
1116            .penalty(0.5)
1117            .fit(&Dataset::new(x, y))
1118            .unwrap();
1119
1120        assert_abs_diff_eq!(
1121            model.hyperplane(),
1122            &array![
1123                0.19879313,
1124                1.46970138,
1125                5.58097318,
1126                3.80089794,
1127                1.46466565,
1128                1.42327857,
1129                -3.86944632,
1130                2.60836423,
1131                4.79584768,
1132                3.03232988
1133            ],
1134            epsilon = 0.01
1135        );
1136        assert_abs_diff_eq!(model.intercept(), 126.279, epsilon = 1e-1);
1137        assert_abs_diff_eq!(model.duality_gap(), 0.00011079, epsilon = 1e-4);
1138    }
1139
1140    #[test]
1141    fn select_subset() {
1142        let mut rng = Xoshiro256Plus::seed_from_u64(42);
1143
1144        // check that we are selecting the subsect of informative features
1145        let mut w = Array::random_using(50, Uniform::new(1., 2.), &mut rng);
1146        w.slice_mut(s![10..]).fill(0.0);
1147
1148        let x = Array::random_using((100, 50), Uniform::new(-1., 1.), &mut rng);
1149        let y = x.dot(&w);
1150        let train = Dataset::new(x, y);
1151
1152        let model = ElasticNet::lasso()
1153            .penalty(0.1)
1154            .max_iterations(1000)
1155            .tolerance(1e-10)
1156            .fit(&train)
1157            .unwrap();
1158
1159        // check that we set the last 40 parameters to zero
1160        let num_zeros = model
1161            .hyperplane()
1162            .into_iter()
1163            .filter(|x| **x < 1e-5)
1164            .count();
1165
1166        assert_eq!(num_zeros, 40);
1167
1168        // predict a small testing dataset
1169        let x = Array::random_using((100, 50), Uniform::new(-1., 1.), &mut rng);
1170        let y = x.dot(&w);
1171
1172        let predicted = model.predict(&x);
1173        let rms = y.mean_squared_error(&predicted);
1174        assert!(rms.unwrap() < 0.67);
1175    }
1176
1177    #[test]
1178    fn diabetes_z_score() {
1179        let dataset = linfa_datasets::diabetes();
1180        let model = ElasticNet::params().penalty(0.0).fit(&dataset).unwrap();
1181
1182        // BMI and BP (blood pressure) should be relevant
1183        let z_score = model.z_score().unwrap();
1184        assert!(z_score[2] > 2.0);
1185        assert!(z_score[3] > 2.0);
1186
1187        // confidence level
1188        let confidence_level = model.confidence_95th().unwrap();
1189        assert!(confidence_level[2].0 < 416.);
1190        assert!(confidence_level[3].0 < 220.);
1191    }
1192}