1use std::fmt::Debug;
60use std::marker::PhantomData;
61
62#[cfg(feature = "serde")]
63use serde::{Deserialize, Serialize};
64
65use crate::api::{Predictor, SupervisedEstimator};
66use crate::error::Failed;
67use crate::linalg::basic::arrays::{Array1, Array2};
68use crate::linalg::traits::cholesky::CholeskyDecomposable;
69use crate::linalg::traits::svd::SVDDecomposable;
70use crate::numbers::basenum::Number;
71use crate::numbers::realnum::RealNumber;
72
73#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
74#[derive(Debug, Clone, Eq, PartialEq, Default)]
75pub enum RidgeRegressionSolverName {
77 #[default]
79 Cholesky,
80 SVD,
82}
83
84#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
86#[derive(Debug, Clone)]
87pub struct RidgeRegressionParameters<T: Number + RealNumber> {
88 pub solver: RidgeRegressionSolverName,
90 pub alpha: T,
92 pub normalize: bool,
95}
96
97#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
99#[derive(Debug, Clone)]
100pub struct RidgeRegressionSearchParameters<T: Number + RealNumber> {
101 #[cfg_attr(feature = "serde", serde(default))]
102 pub solver: Vec<RidgeRegressionSolverName>,
104 #[cfg_attr(feature = "serde", serde(default))]
105 pub alpha: Vec<T>,
107 #[cfg_attr(feature = "serde", serde(default))]
108 pub normalize: Vec<bool>,
111}
112
113pub struct RidgeRegressionSearchParametersIterator<T: Number + RealNumber> {
115 ridge_regression_search_parameters: RidgeRegressionSearchParameters<T>,
116 current_solver: usize,
117 current_alpha: usize,
118 current_normalize: usize,
119}
120
121impl<T: Number + RealNumber> IntoIterator for RidgeRegressionSearchParameters<T> {
122 type Item = RidgeRegressionParameters<T>;
123 type IntoIter = RidgeRegressionSearchParametersIterator<T>;
124
125 fn into_iter(self) -> Self::IntoIter {
126 RidgeRegressionSearchParametersIterator {
127 ridge_regression_search_parameters: self,
128 current_solver: 0,
129 current_alpha: 0,
130 current_normalize: 0,
131 }
132 }
133}
134
135impl<T: Number + RealNumber> Iterator for RidgeRegressionSearchParametersIterator<T> {
136 type Item = RidgeRegressionParameters<T>;
137
138 fn next(&mut self) -> Option<Self::Item> {
139 if self.current_alpha == self.ridge_regression_search_parameters.alpha.len()
140 && self.current_solver == self.ridge_regression_search_parameters.solver.len()
141 {
142 return None;
143 }
144
145 let next = RidgeRegressionParameters {
146 solver: self.ridge_regression_search_parameters.solver[self.current_solver].clone(),
147 alpha: self.ridge_regression_search_parameters.alpha[self.current_alpha],
148 normalize: self.ridge_regression_search_parameters.normalize[self.current_normalize],
149 };
150
151 if self.current_alpha + 1 < self.ridge_regression_search_parameters.alpha.len() {
152 self.current_alpha += 1;
153 } else if self.current_solver + 1 < self.ridge_regression_search_parameters.solver.len() {
154 self.current_alpha = 0;
155 self.current_solver += 1;
156 } else if self.current_normalize + 1
157 < self.ridge_regression_search_parameters.normalize.len()
158 {
159 self.current_alpha = 0;
160 self.current_solver = 0;
161 self.current_normalize += 1;
162 } else {
163 self.current_alpha += 1;
164 self.current_solver += 1;
165 self.current_normalize += 1;
166 }
167
168 Some(next)
169 }
170}
171
172impl<T: Number + RealNumber> Default for RidgeRegressionSearchParameters<T> {
173 fn default() -> Self {
174 let default_params = RidgeRegressionParameters::default();
175
176 RidgeRegressionSearchParameters {
177 solver: vec![default_params.solver],
178 alpha: vec![default_params.alpha],
179 normalize: vec![default_params.normalize],
180 }
181 }
182}
183
184#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
186#[derive(Debug)]
187pub struct RidgeRegression<
188 TX: Number + RealNumber,
189 TY: Number,
190 X: Array2<TX> + CholeskyDecomposable<TX> + SVDDecomposable<TX>,
191 Y: Array1<TY>,
192> {
193 coefficients: Option<X>,
194 intercept: Option<TX>,
195 _phantom_ty: PhantomData<TY>,
196 _phantom_y: PhantomData<Y>,
197}
198
199impl<T: Number + RealNumber> RidgeRegressionParameters<T> {
200 pub fn with_alpha(mut self, alpha: T) -> Self {
202 self.alpha = alpha;
203 self
204 }
205 pub fn with_solver(mut self, solver: RidgeRegressionSolverName) -> Self {
207 self.solver = solver;
208 self
209 }
210 pub fn with_normalize(mut self, normalize: bool) -> Self {
212 self.normalize = normalize;
213 self
214 }
215}
216
217impl<T: Number + RealNumber> Default for RidgeRegressionParameters<T> {
218 fn default() -> Self {
219 RidgeRegressionParameters {
220 solver: RidgeRegressionSolverName::default(),
221 alpha: T::from_f64(1.0).unwrap(),
222 normalize: true,
223 }
224 }
225}
226
227impl<
228 TX: Number + RealNumber,
229 TY: Number,
230 X: Array2<TX> + CholeskyDecomposable<TX> + SVDDecomposable<TX>,
231 Y: Array1<TY>,
232 > PartialEq for RidgeRegression<TX, TY, X, Y>
233{
234 fn eq(&self, other: &Self) -> bool {
235 self.intercept() == other.intercept()
236 && self.coefficients().shape() == other.coefficients().shape()
237 && self
238 .coefficients()
239 .iterator(0)
240 .zip(other.coefficients().iterator(0))
241 .all(|(&a, &b)| (a - b).abs() <= TX::epsilon())
242 }
243}
244
245impl<
246 TX: Number + RealNumber,
247 TY: Number,
248 X: Array2<TX> + CholeskyDecomposable<TX> + SVDDecomposable<TX>,
249 Y: Array1<TY>,
250 > SupervisedEstimator<X, Y, RidgeRegressionParameters<TX>> for RidgeRegression<TX, TY, X, Y>
251{
252 fn new() -> Self {
253 Self {
254 coefficients: Option::None,
255 intercept: Option::None,
256 _phantom_ty: PhantomData,
257 _phantom_y: PhantomData,
258 }
259 }
260
261 fn fit(x: &X, y: &Y, parameters: RidgeRegressionParameters<TX>) -> Result<Self, Failed> {
262 RidgeRegression::fit(x, y, parameters)
263 }
264}
265
266impl<
267 TX: Number + RealNumber,
268 TY: Number,
269 X: Array2<TX> + CholeskyDecomposable<TX> + SVDDecomposable<TX>,
270 Y: Array1<TY>,
271 > Predictor<X, Y> for RidgeRegression<TX, TY, X, Y>
272{
273 fn predict(&self, x: &X) -> Result<Y, Failed> {
274 self.predict(x)
275 }
276}
277
278impl<
279 TX: Number + RealNumber,
280 TY: Number,
281 X: Array2<TX> + CholeskyDecomposable<TX> + SVDDecomposable<TX>,
282 Y: Array1<TY>,
283 > RidgeRegression<TX, TY, X, Y>
284{
285 pub fn fit(
290 x: &X,
291 y: &Y,
292 parameters: RidgeRegressionParameters<TX>,
293 ) -> Result<RidgeRegression<TX, TY, X, Y>, Failed> {
294 let (n, p) = x.shape();
297
298 if n <= p {
299 return Err(Failed::fit(
300 "Number of rows in X should be >= number of columns in X",
301 ));
302 }
303
304 if y.shape() != n {
305 return Err(Failed::fit("Number of rows in X should = len(y)"));
306 }
307
308 let y_column = X::from_iterator(
309 y.iterator(0).map(|&v| TX::from(v).unwrap()),
310 y.shape(),
311 1,
312 0,
313 );
314
315 let (w, b) = if parameters.normalize {
316 let (scaled_x, col_mean, col_std) = Self::rescale_x(x)?;
317 let x_t = scaled_x.transpose();
318 let x_t_y = x_t.matmul(&y_column);
319 let mut x_t_x = x_t.matmul(&scaled_x);
320
321 for i in 0..p {
322 x_t_x.add_element_mut((i, i), parameters.alpha);
323 }
324
325 let mut w = match parameters.solver {
326 RidgeRegressionSolverName::Cholesky => x_t_x.cholesky_solve_mut(x_t_y)?,
327 RidgeRegressionSolverName::SVD => x_t_x.svd_solve_mut(x_t_y)?,
328 };
329
330 for (i, col_std_i) in col_std.iter().enumerate().take(p) {
331 w.set((i, 0), *w.get((i, 0)) / *col_std_i);
332 }
333
334 let mut b = TX::zero();
335
336 for (i, col_mean_i) in col_mean.iter().enumerate().take(p) {
337 b += *w.get((i, 0)) * *col_mean_i;
338 }
339
340 let b = TX::from_f64(y.mean_by()).unwrap() - b;
341
342 (w, b)
343 } else {
344 let x_t = x.transpose();
345 let x_t_y = x_t.matmul(&y_column);
346 let mut x_t_x = x_t.matmul(x);
347
348 for i in 0..p {
349 x_t_x.add_element_mut((i, i), parameters.alpha);
350 }
351
352 let w = match parameters.solver {
353 RidgeRegressionSolverName::Cholesky => x_t_x.cholesky_solve_mut(x_t_y)?,
354 RidgeRegressionSolverName::SVD => x_t_x.svd_solve_mut(x_t_y)?,
355 };
356
357 (w, TX::zero())
358 };
359
360 Ok(RidgeRegression {
361 intercept: Some(b),
362 coefficients: Some(w),
363 _phantom_ty: PhantomData,
364 _phantom_y: PhantomData,
365 })
366 }
367
368 fn rescale_x(x: &X) -> Result<(X, Vec<TX>, Vec<TX>), Failed> {
369 let col_mean: Vec<TX> = x
370 .mean_by(0)
371 .iter()
372 .map(|&v| TX::from_f64(v).unwrap())
373 .collect();
374 let col_std: Vec<TX> = x
375 .std_dev(0)
376 .iter()
377 .map(|&v| TX::from_f64(v).unwrap())
378 .collect();
379
380 for (i, col_std_i) in col_std.iter().enumerate() {
381 if (*col_std_i - TX::zero()).abs() < TX::epsilon() {
382 return Err(Failed::fit(&format!("Cannot rescale constant column {i}")));
383 }
384 }
385
386 let mut scaled_x = x.clone();
387 scaled_x.scale_mut(&col_mean, &col_std, 0);
388 Ok((scaled_x, col_mean, col_std))
389 }
390
391 pub fn predict(&self, x: &X) -> Result<Y, Failed> {
394 let (nrows, _) = x.shape();
395 let mut y_hat = x.matmul(self.coefficients());
396 y_hat.add_mut(&X::fill(nrows, 1, self.intercept.unwrap()));
397 Ok(Y::from_iterator(
398 y_hat.iterator(0).map(|&v| TY::from(v).unwrap()),
399 nrows,
400 ))
401 }
402
403 pub fn coefficients(&self) -> &X {
405 self.coefficients.as_ref().unwrap()
406 }
407
408 pub fn intercept(&self) -> &TX {
410 self.intercept.as_ref().unwrap()
411 }
412}
413
414#[cfg(test)]
415mod tests {
416 use super::*;
417 use crate::linalg::basic::matrix::DenseMatrix;
418 use crate::metrics::mean_absolute_error;
419
420 #[test]
421 fn search_parameters() {
422 let parameters = RidgeRegressionSearchParameters {
423 alpha: vec![0., 1.],
424 ..Default::default()
425 };
426 let mut iter = parameters.into_iter();
427 assert_eq!(iter.next().unwrap().alpha, 0.);
428 assert_eq!(
429 iter.next().unwrap().solver,
430 RidgeRegressionSolverName::Cholesky
431 );
432 assert!(iter.next().is_none());
433 }
434
435 #[cfg_attr(
436 all(target_arch = "wasm32", not(target_os = "wasi")),
437 wasm_bindgen_test::wasm_bindgen_test
438 )]
439 #[test]
440 fn ridge_fit_predict() {
441 let x = DenseMatrix::from_2d_array(&[
442 &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
443 &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
444 &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
445 &[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
446 &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
447 &[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
448 &[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
449 &[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
450 &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
451 &[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
452 &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
453 &[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
454 &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
455 &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
456 &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
457 &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
458 ])
459 .unwrap();
460
461 let y: Vec<f64> = vec![
462 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
463 114.2, 115.7, 116.9,
464 ];
465
466 let y_hat_cholesky = RidgeRegression::fit(
467 &x,
468 &y,
469 RidgeRegressionParameters {
470 solver: RidgeRegressionSolverName::Cholesky,
471 alpha: 0.1,
472 normalize: true,
473 },
474 )
475 .and_then(|lr| lr.predict(&x))
476 .unwrap();
477
478 assert!(mean_absolute_error(&y_hat_cholesky, &y) < 2.0);
479
480 let y_hat_svd = RidgeRegression::fit(
481 &x,
482 &y,
483 RidgeRegressionParameters {
484 solver: RidgeRegressionSolverName::SVD,
485 alpha: 0.1,
486 normalize: false,
487 },
488 )
489 .and_then(|lr| lr.predict(&x))
490 .unwrap();
491
492 assert!(mean_absolute_error(&y_hat_svd, &y) < 2.0);
493 }
494
495 }