rs_ml/regression/
linear.rs

1//! Linear regression models.
2
3use ndarray::{Array1, Array2, Axis};
4use ndarray_linalg::Inverse;
5
6use crate::Estimator;
7
8use super::Regressor;
9
10/// Estimator which fits an [`OrdinaryLeastSquaresRegressor`].
11///
12/// ```
13/// # use ndarray::{arr1, arr2};
14/// # use rs_ml::regression::linear::OrdinaryLeastSquaresEstimator;
15/// # use rs_ml::Estimator;
16/// # use rs_ml::regression::Regressor;
17/// # fn test() -> Option<()> {
18/// let x = arr2(&[[0.], [1.], [2.], [3.]]);
19/// let y = arr1(&[0.98, 3.06, 4.89, 7.1]); // y ~ 2x + 1
20/// let future_x = arr2(&[[4.], [5.], [6.], [7.]]);
21///
22/// let model = OrdinaryLeastSquaresEstimator.fit(&(&x, &y))?;
23/// let predictions = model.predict(&future_x)?;
24/// # Some(())
25/// # }
26/// # fn main() {
27/// #   test();
28/// # }
29/// ```
30#[derive(Debug, Clone, Copy)]
31pub struct OrdinaryLeastSquaresEstimator;
32
33/// Ordinary least squares regression model fitted by [`OrdinaryLeastSquaresEstimator`].
34#[derive(Debug, Clone)]
35pub struct OrdinaryLeastSquaresRegressor {
36    beta: Array2<f64>,
37}
38
39impl Estimator<(&Array2<f64>, &Array1<f64>)> for OrdinaryLeastSquaresEstimator {
40    type Estimator = OrdinaryLeastSquaresRegressor;
41
42    fn fit(&self, input: &(&Array2<f64>, &Array1<f64>)) -> Option<Self::Estimator> {
43        let (x, y) = input;
44
45        let nrows = x.nrows();
46        let mut x_added_one = x.to_owned().clone();
47        x_added_one.push_column(Array1::ones(nrows).view()).ok()?;
48
49        let binding = y.view().insert_axis(Axis(0));
50        let transformed_y = binding.t();
51        let inv_gram_matrix: Array2<f64> = x_added_one.t().dot(&x_added_one).inv().ok()?;
52
53        let beta = inv_gram_matrix.dot(&x_added_one.t().dot(&transformed_y));
54
55        Some(OrdinaryLeastSquaresRegressor { beta })
56    }
57}
58
59impl Regressor<Array2<f64>, Array1<f64>> for OrdinaryLeastSquaresRegressor {
60    fn predict(&self, input: &Array2<f64>) -> Option<Array1<f64>> {
61        let nrows = input.nrows();
62        let mut x_added_one = input.to_owned().clone();
63        x_added_one.push_column(Array1::ones(nrows).view()).ok()?;
64
65        let y = x_added_one.dot(&self.beta);
66
67        let binding = y.t().remove_axis(Axis(0));
68        Some(binding.to_owned())
69    }
70}