smartcore/linear/
ridge_regression.rs

1//! # Ridge Regression
2//!
3//! [Linear regression](../linear_regression/index.html) is the standard algorithm for predicting a quantitative response \\(y\\) on the basis of a linear combination of explanatory variables \\(X\\)
4//! that assumes that there is approximately a linear relationship between \\(X\\) and \\(y\\).
5//! Ridge regression is an extension to linear regression that adds L2 regularization term to the loss function during training.
6//! This term encourages simpler models that have smaller coefficient values.
7//!
8//! In ridge regression coefficients \\(\beta_0, \beta_0, ... \beta_n\\) are are estimated by solving
9//!
10//! \\[\hat{\beta} = (X^TX + \alpha I)^{-1}X^Ty \\]
11//!
12//! where \\(\alpha \geq 0\\) is a tuning parameter that controls strength of regularization. When \\(\alpha = 0\\) the penalty term has no effect, and ridge regression will produce the least squares estimates.
13//! However, as \\(\alpha \rightarrow \infty\\), the impact of the shrinkage penalty grows, and the ridge regression coefficient estimates will approach zero.
14//!
15//! `smartcore` uses [SVD](../../linalg/svd/index.html) and [Cholesky](../../linalg/cholesky/index.html) matrix decomposition to find estimates of \\(\hat{\beta}\\).
16//! The Cholesky decomposition is more computationally efficient and more numerically stable than calculating the normal equation directly,
17//! but does not work for all data matrices. Unlike the Cholesky decomposition, all matrices have an SVD decomposition.
18//!
19//! Example:
20//!
21//! ```
22//! use smartcore::linalg::basic::matrix::DenseMatrix;
23//! use smartcore::linear::ridge_regression::*;
24//!
25//! // Longley dataset (https://www.statsmodels.org/stable/datasets/generated/longley.html)
26//! let x = DenseMatrix::from_2d_array(&[
27//!               &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
28//!               &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
29//!               &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
30//!               &[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
31//!               &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
32//!               &[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
33//!               &[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
34//!               &[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
35//!               &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
36//!               &[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
37//!               &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
38//!               &[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
39//!               &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
40//!               &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
41//!               &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
42//!               &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
43//!          ]).unwrap();
44//!
45//! let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0,
46//!           100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
47//!
48//! let y_hat = RidgeRegression::fit(&x, &y, RidgeRegressionParameters::default().with_alpha(0.1)).
49//!                 and_then(|lr| lr.predict(&x)).unwrap();
50//! ```
51//!
52//! ## References:
53//!
54//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 6.2. Shrinkage Methods](http://faculty.marshall.usc.edu/gareth-james/ISL/)
55//! * ["Numerical Recipes: The Art of Scientific Computing",  Press W.H., Teukolsky S.A., Vetterling W.T, Flannery B.P, 3rd ed., Section 15.4 General Linear Least Squares](http://numerical.recipes/)
56//!
57//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
58//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
59use std::fmt::Debug;
60use std::marker::PhantomData;
61
62#[cfg(feature = "serde")]
63use serde::{Deserialize, Serialize};
64
65use crate::api::{Predictor, SupervisedEstimator};
66use crate::error::Failed;
67use crate::linalg::basic::arrays::{Array1, Array2};
68use crate::linalg::traits::cholesky::CholeskyDecomposable;
69use crate::linalg::traits::svd::SVDDecomposable;
70use crate::numbers::basenum::Number;
71use crate::numbers::realnum::RealNumber;
72
73#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
74#[derive(Debug, Clone, Eq, PartialEq, Default)]
75/// Approach to use for estimation of regression coefficients. Cholesky is more efficient but SVD is more stable.
76pub enum RidgeRegressionSolverName {
77    /// Cholesky decomposition, see [Cholesky](../../linalg/cholesky/index.html)
78    #[default]
79    Cholesky,
80    /// SVD decomposition, see [SVD](../../linalg/svd/index.html)
81    SVD,
82}
83
84/// Ridge Regression parameters
85#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
86#[derive(Debug, Clone)]
87pub struct RidgeRegressionParameters<T: Number + RealNumber> {
88    /// Solver to use for estimation of regression coefficients.
89    pub solver: RidgeRegressionSolverName,
90    /// Controls the strength of the penalty to the loss function.
91    pub alpha: T,
92    /// If true the regressors X will be normalized before regression
93    /// by subtracting the mean and dividing by the standard deviation.
94    pub normalize: bool,
95}
96
97/// Ridge Regression grid search parameters
98#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
99#[derive(Debug, Clone)]
100pub struct RidgeRegressionSearchParameters<T: Number + RealNumber> {
101    #[cfg_attr(feature = "serde", serde(default))]
102    /// Solver to use for estimation of regression coefficients.
103    pub solver: Vec<RidgeRegressionSolverName>,
104    #[cfg_attr(feature = "serde", serde(default))]
105    /// Regularization parameter.
106    pub alpha: Vec<T>,
107    #[cfg_attr(feature = "serde", serde(default))]
108    /// If true the regressors X will be normalized before regression
109    /// by subtracting the mean and dividing by the standard deviation.
110    pub normalize: Vec<bool>,
111}
112
113/// Ridge Regression grid search iterator
114pub struct RidgeRegressionSearchParametersIterator<T: Number + RealNumber> {
115    ridge_regression_search_parameters: RidgeRegressionSearchParameters<T>,
116    current_solver: usize,
117    current_alpha: usize,
118    current_normalize: usize,
119}
120
121impl<T: Number + RealNumber> IntoIterator for RidgeRegressionSearchParameters<T> {
122    type Item = RidgeRegressionParameters<T>;
123    type IntoIter = RidgeRegressionSearchParametersIterator<T>;
124
125    fn into_iter(self) -> Self::IntoIter {
126        RidgeRegressionSearchParametersIterator {
127            ridge_regression_search_parameters: self,
128            current_solver: 0,
129            current_alpha: 0,
130            current_normalize: 0,
131        }
132    }
133}
134
135impl<T: Number + RealNumber> Iterator for RidgeRegressionSearchParametersIterator<T> {
136    type Item = RidgeRegressionParameters<T>;
137
138    fn next(&mut self) -> Option<Self::Item> {
139        if self.current_alpha == self.ridge_regression_search_parameters.alpha.len()
140            && self.current_solver == self.ridge_regression_search_parameters.solver.len()
141        {
142            return None;
143        }
144
145        let next = RidgeRegressionParameters {
146            solver: self.ridge_regression_search_parameters.solver[self.current_solver].clone(),
147            alpha: self.ridge_regression_search_parameters.alpha[self.current_alpha],
148            normalize: self.ridge_regression_search_parameters.normalize[self.current_normalize],
149        };
150
151        if self.current_alpha + 1 < self.ridge_regression_search_parameters.alpha.len() {
152            self.current_alpha += 1;
153        } else if self.current_solver + 1 < self.ridge_regression_search_parameters.solver.len() {
154            self.current_alpha = 0;
155            self.current_solver += 1;
156        } else if self.current_normalize + 1
157            < self.ridge_regression_search_parameters.normalize.len()
158        {
159            self.current_alpha = 0;
160            self.current_solver = 0;
161            self.current_normalize += 1;
162        } else {
163            self.current_alpha += 1;
164            self.current_solver += 1;
165            self.current_normalize += 1;
166        }
167
168        Some(next)
169    }
170}
171
172impl<T: Number + RealNumber> Default for RidgeRegressionSearchParameters<T> {
173    fn default() -> Self {
174        let default_params = RidgeRegressionParameters::default();
175
176        RidgeRegressionSearchParameters {
177            solver: vec![default_params.solver],
178            alpha: vec![default_params.alpha],
179            normalize: vec![default_params.normalize],
180        }
181    }
182}
183
184/// Ridge regression
185#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
186#[derive(Debug)]
187pub struct RidgeRegression<
188    TX: Number + RealNumber,
189    TY: Number,
190    X: Array2<TX> + CholeskyDecomposable<TX> + SVDDecomposable<TX>,
191    Y: Array1<TY>,
192> {
193    coefficients: Option<X>,
194    intercept: Option<TX>,
195    _phantom_ty: PhantomData<TY>,
196    _phantom_y: PhantomData<Y>,
197}
198
199impl<T: Number + RealNumber> RidgeRegressionParameters<T> {
200    /// Regularization parameter.
201    pub fn with_alpha(mut self, alpha: T) -> Self {
202        self.alpha = alpha;
203        self
204    }
205    /// Solver to use for estimation of regression coefficients.
206    pub fn with_solver(mut self, solver: RidgeRegressionSolverName) -> Self {
207        self.solver = solver;
208        self
209    }
210    /// If True, the regressors X will be normalized before regression by subtracting the mean and dividing by the standard deviation.
211    pub fn with_normalize(mut self, normalize: bool) -> Self {
212        self.normalize = normalize;
213        self
214    }
215}
216
217impl<T: Number + RealNumber> Default for RidgeRegressionParameters<T> {
218    fn default() -> Self {
219        RidgeRegressionParameters {
220            solver: RidgeRegressionSolverName::default(),
221            alpha: T::from_f64(1.0).unwrap(),
222            normalize: true,
223        }
224    }
225}
226
227impl<
228        TX: Number + RealNumber,
229        TY: Number,
230        X: Array2<TX> + CholeskyDecomposable<TX> + SVDDecomposable<TX>,
231        Y: Array1<TY>,
232    > PartialEq for RidgeRegression<TX, TY, X, Y>
233{
234    fn eq(&self, other: &Self) -> bool {
235        self.intercept() == other.intercept()
236            && self.coefficients().shape() == other.coefficients().shape()
237            && self
238                .coefficients()
239                .iterator(0)
240                .zip(other.coefficients().iterator(0))
241                .all(|(&a, &b)| (a - b).abs() <= TX::epsilon())
242    }
243}
244
245impl<
246        TX: Number + RealNumber,
247        TY: Number,
248        X: Array2<TX> + CholeskyDecomposable<TX> + SVDDecomposable<TX>,
249        Y: Array1<TY>,
250    > SupervisedEstimator<X, Y, RidgeRegressionParameters<TX>> for RidgeRegression<TX, TY, X, Y>
251{
252    fn new() -> Self {
253        Self {
254            coefficients: Option::None,
255            intercept: Option::None,
256            _phantom_ty: PhantomData,
257            _phantom_y: PhantomData,
258        }
259    }
260
261    fn fit(x: &X, y: &Y, parameters: RidgeRegressionParameters<TX>) -> Result<Self, Failed> {
262        RidgeRegression::fit(x, y, parameters)
263    }
264}
265
266impl<
267        TX: Number + RealNumber,
268        TY: Number,
269        X: Array2<TX> + CholeskyDecomposable<TX> + SVDDecomposable<TX>,
270        Y: Array1<TY>,
271    > Predictor<X, Y> for RidgeRegression<TX, TY, X, Y>
272{
273    fn predict(&self, x: &X) -> Result<Y, Failed> {
274        self.predict(x)
275    }
276}
277
278impl<
279        TX: Number + RealNumber,
280        TY: Number,
281        X: Array2<TX> + CholeskyDecomposable<TX> + SVDDecomposable<TX>,
282        Y: Array1<TY>,
283    > RidgeRegression<TX, TY, X, Y>
284{
285    /// Fits ridge regression to your data.
286    /// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
287    /// * `y` - target values
288    /// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
289    pub fn fit(
290        x: &X,
291        y: &Y,
292        parameters: RidgeRegressionParameters<TX>,
293    ) -> Result<RidgeRegression<TX, TY, X, Y>, Failed> {
294        //w = inv(X^t X + alpha*Id) * X.T y
295
296        let (n, p) = x.shape();
297
298        if n <= p {
299            return Err(Failed::fit(
300                "Number of rows in X should be >= number of columns in X",
301            ));
302        }
303
304        if y.shape() != n {
305            return Err(Failed::fit("Number of rows in X should = len(y)"));
306        }
307
308        let y_column = X::from_iterator(
309            y.iterator(0).map(|&v| TX::from(v).unwrap()),
310            y.shape(),
311            1,
312            0,
313        );
314
315        let (w, b) = if parameters.normalize {
316            let (scaled_x, col_mean, col_std) = Self::rescale_x(x)?;
317            let x_t = scaled_x.transpose();
318            let x_t_y = x_t.matmul(&y_column);
319            let mut x_t_x = x_t.matmul(&scaled_x);
320
321            for i in 0..p {
322                x_t_x.add_element_mut((i, i), parameters.alpha);
323            }
324
325            let mut w = match parameters.solver {
326                RidgeRegressionSolverName::Cholesky => x_t_x.cholesky_solve_mut(x_t_y)?,
327                RidgeRegressionSolverName::SVD => x_t_x.svd_solve_mut(x_t_y)?,
328            };
329
330            for (i, col_std_i) in col_std.iter().enumerate().take(p) {
331                w.set((i, 0), *w.get((i, 0)) / *col_std_i);
332            }
333
334            let mut b = TX::zero();
335
336            for (i, col_mean_i) in col_mean.iter().enumerate().take(p) {
337                b += *w.get((i, 0)) * *col_mean_i;
338            }
339
340            let b = TX::from_f64(y.mean_by()).unwrap() - b;
341
342            (w, b)
343        } else {
344            let x_t = x.transpose();
345            let x_t_y = x_t.matmul(&y_column);
346            let mut x_t_x = x_t.matmul(x);
347
348            for i in 0..p {
349                x_t_x.add_element_mut((i, i), parameters.alpha);
350            }
351
352            let w = match parameters.solver {
353                RidgeRegressionSolverName::Cholesky => x_t_x.cholesky_solve_mut(x_t_y)?,
354                RidgeRegressionSolverName::SVD => x_t_x.svd_solve_mut(x_t_y)?,
355            };
356
357            (w, TX::zero())
358        };
359
360        Ok(RidgeRegression {
361            intercept: Some(b),
362            coefficients: Some(w),
363            _phantom_ty: PhantomData,
364            _phantom_y: PhantomData,
365        })
366    }
367
368    fn rescale_x(x: &X) -> Result<(X, Vec<TX>, Vec<TX>), Failed> {
369        let col_mean: Vec<TX> = x
370            .mean_by(0)
371            .iter()
372            .map(|&v| TX::from_f64(v).unwrap())
373            .collect();
374        let col_std: Vec<TX> = x
375            .std_dev(0)
376            .iter()
377            .map(|&v| TX::from_f64(v).unwrap())
378            .collect();
379
380        for (i, col_std_i) in col_std.iter().enumerate() {
381            if (*col_std_i - TX::zero()).abs() < TX::epsilon() {
382                return Err(Failed::fit(&format!("Cannot rescale constant column {i}")));
383            }
384        }
385
386        let mut scaled_x = x.clone();
387        scaled_x.scale_mut(&col_mean, &col_std, 0);
388        Ok((scaled_x, col_mean, col_std))
389    }
390
391    /// Predict target values from `x`
392    /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
393    pub fn predict(&self, x: &X) -> Result<Y, Failed> {
394        let (nrows, _) = x.shape();
395        let mut y_hat = x.matmul(self.coefficients());
396        y_hat.add_mut(&X::fill(nrows, 1, self.intercept.unwrap()));
397        Ok(Y::from_iterator(
398            y_hat.iterator(0).map(|&v| TY::from(v).unwrap()),
399            nrows,
400        ))
401    }
402
403    /// Get estimates regression coefficients
404    pub fn coefficients(&self) -> &X {
405        self.coefficients.as_ref().unwrap()
406    }
407
408    /// Get estimate of intercept
409    pub fn intercept(&self) -> &TX {
410        self.intercept.as_ref().unwrap()
411    }
412}
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417    use crate::linalg::basic::matrix::DenseMatrix;
418    use crate::metrics::mean_absolute_error;
419
420    #[test]
421    fn search_parameters() {
422        let parameters = RidgeRegressionSearchParameters {
423            alpha: vec![0., 1.],
424            ..Default::default()
425        };
426        let mut iter = parameters.into_iter();
427        assert_eq!(iter.next().unwrap().alpha, 0.);
428        assert_eq!(
429            iter.next().unwrap().solver,
430            RidgeRegressionSolverName::Cholesky
431        );
432        assert!(iter.next().is_none());
433    }
434
435    #[cfg_attr(
436        all(target_arch = "wasm32", not(target_os = "wasi")),
437        wasm_bindgen_test::wasm_bindgen_test
438    )]
439    #[test]
440    fn ridge_fit_predict() {
441        let x = DenseMatrix::from_2d_array(&[
442            &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
443            &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
444            &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
445            &[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
446            &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
447            &[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
448            &[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
449            &[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
450            &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
451            &[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
452            &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
453            &[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
454            &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
455            &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
456            &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
457            &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
458        ])
459        .unwrap();
460
461        let y: Vec<f64> = vec![
462            83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
463            114.2, 115.7, 116.9,
464        ];
465
466        let y_hat_cholesky = RidgeRegression::fit(
467            &x,
468            &y,
469            RidgeRegressionParameters {
470                solver: RidgeRegressionSolverName::Cholesky,
471                alpha: 0.1,
472                normalize: true,
473            },
474        )
475        .and_then(|lr| lr.predict(&x))
476        .unwrap();
477
478        assert!(mean_absolute_error(&y_hat_cholesky, &y) < 2.0);
479
480        let y_hat_svd = RidgeRegression::fit(
481            &x,
482            &y,
483            RidgeRegressionParameters {
484                solver: RidgeRegressionSolverName::SVD,
485                alpha: 0.1,
486                normalize: false,
487            },
488        )
489        .and_then(|lr| lr.predict(&x))
490        .unwrap();
491
492        assert!(mean_absolute_error(&y_hat_svd, &y) < 2.0);
493    }
494
495    // TODO: implement serialization for new DenseMatrix
496    // #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
497    // #[test]
498    // #[cfg(feature = "serde")]
499    // fn serde() {
500    //     let x = DenseMatrix::from_2d_array(&[
501    //         &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
502    //         &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
503    //         &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
504    //         &[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
505    //         &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
506    //         &[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
507    //         &[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
508    //         &[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
509    //         &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
510    //         &[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
511    //         &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
512    //         &[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
513    //         &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
514    //         &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
515    //         &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
516    //         &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
517    //     ]).unwrap();
518
519    //     let y = vec![
520    //         83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
521    //         114.2, 115.7, 116.9,
522    //     ];
523
524    //     let lr = RidgeRegression::fit(&x, &y, Default::default()).unwrap();
525
526    //     let deserialized_lr: RidgeRegression<f64, f64, DenseMatrix<f64>, Vec<f64>> =
527    //         serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
528
529    //     assert_eq!(lr, deserialized_lr);
530    // }
531}