#![allow(non_snake_case)]
use crate::error::{LinearError, Result};
use ndarray::{Array1, Array2, ArrayBase, Axis, Data, Ix1, Ix2};
use ndarray_linalg::{Lapack, Scalar, Solve};
use ndarray_stats::SummaryStatisticsExt;
use serde::{Deserialize, Serialize};
use linfa::dataset::{AsTargets, DatasetBase};
use linfa::traits::{Fit, PredictRef};
pub trait Float: linfa::Float + Lapack + Scalar {}
impl Float for f32 {}
impl Float for f64 {}
#[derive(Serialize, Deserialize)]
pub struct LinearRegression {
options: Options,
}
#[derive(Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
enum Options {
None,
WithIntercept,
WithInterceptAndNormalize,
}
impl Options {
fn should_use_intercept(&self) -> bool {
*self == Options::WithIntercept || *self == Options::WithInterceptAndNormalize
}
fn should_normalize(&self) -> bool {
*self == Options::WithInterceptAndNormalize
}
}
#[derive(Serialize, Deserialize)]
pub struct FittedLinearRegression<A> {
intercept: A,
params: Array1<A>,
}
impl Default for LinearRegression {
fn default() -> Self {
LinearRegression::new()
}
}
impl LinearRegression {
pub fn new() -> LinearRegression {
LinearRegression {
options: Options::WithIntercept,
}
}
pub fn with_intercept(mut self, with_intercept: bool) -> Self {
if with_intercept {
self.options = Options::WithIntercept;
} else {
self.options = Options::None;
}
self
}
pub fn with_intercept_and_normalize(mut self) -> Self {
self.options = Options::WithInterceptAndNormalize;
self
}
}
impl<F: Float, D: Data<Elem = F>, T: AsTargets<Elem = F>> Fit<ArrayBase<D, Ix2>, T, LinearError>
for LinearRegression
{
type Object = FittedLinearRegression<F>;
fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
let X = dataset.records();
let y = dataset.try_single_target()?;
let (n_samples, _) = X.dim();
assert_eq!(y.dim(), n_samples);
if self.options.should_use_intercept() {
let X_offset: Array1<F> = X
.mean_axis(Axis(0))
.ok_or_else(|| LinearError::NotEnoughSamples)?;
let X_centered: Array2<F> = X - &X_offset;
let y_offset: F = y.mean().ok_or_else(|| LinearError::NotEnoughTargets)?;
let y_centered: Array1<F> = &y - y_offset;
let params: Array1<F> =
compute_params(&X_centered, &y_centered, self.options.should_normalize())?;
let intercept: F = y_offset - X_offset.dot(¶ms);
Ok(FittedLinearRegression { intercept, params })
} else {
Ok(FittedLinearRegression {
intercept: F::cast(0),
params: solve_normal_equation(X, &y)?,
})
}
}
}
fn compute_params<F, B, C>(
X: &ArrayBase<B, Ix2>,
y: &ArrayBase<C, Ix1>,
normalize: bool,
) -> Result<Array1<F>>
where
F: Float,
B: Data<Elem = F>,
C: Data<Elem = F>,
{
if normalize {
let scale: Array1<F> = X.map_axis(Axis(0), |column| column.central_moment(2).unwrap());
let X: Array2<F> = X / &scale;
let mut params: Array1<F> = solve_normal_equation(&X, y)?;
params /= &scale;
Ok(params)
} else {
solve_normal_equation(X, y)
}
}
fn solve_normal_equation<F, B, C>(X: &ArrayBase<B, Ix2>, y: &ArrayBase<C, Ix1>) -> Result<Array1<F>>
where
F: Float,
B: Data<Elem = F>,
C: Data<Elem = F>,
{
let rhs = X.t().dot(y);
let linear_operator = X.t().dot(X);
linear_operator.solve_into(rhs).map_err(|err| err.into())
}
impl<F: Float> FittedLinearRegression<F> {
pub fn params(&self) -> &Array1<F> {
&self.params
}
pub fn intercept(&self) -> F {
self.intercept
}
}
impl<F: Float, D: Data<Elem = F>> PredictRef<ArrayBase<D, Ix2>, Array1<F>>
for FittedLinearRegression<F>
{
fn predict_ref<'a>(&'a self, x: &ArrayBase<D, Ix2>) -> Array1<F> {
x.dot(&self.params) + self.intercept
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use linfa::{traits::Predict, Dataset};
use ndarray::array;
#[test]
fn fits_a_line_through_two_dots() {
let lin_reg = LinearRegression::new();
let dataset = Dataset::new(array![[0f64], [1.]], array![1., 2.]);
let model = lin_reg.fit(&dataset).unwrap();
let result = model.predict(dataset.records());
assert_abs_diff_eq!(result, &array![1., 2.], epsilon = 1e-12);
}
#[test]
fn without_intercept_fits_line_through_origin() {
let lin_reg = LinearRegression::new().with_intercept(false);
let dataset = Dataset::new(array![[1.]], array![1.]);
let model = lin_reg.fit(&dataset).unwrap();
let result = model.predict(&array![[0.], [1.]]);
assert_abs_diff_eq!(result, &array![0., 1.], epsilon = 1e-12);
}
#[test]
fn fits_least_squares_line_through_two_dots() {
let lin_reg = LinearRegression::new().with_intercept(false);
let dataset = Dataset::new(array![[-1.], [1.]], array![1., 1.]);
let model = lin_reg.fit(&dataset).unwrap();
let result = model.predict(dataset.records());
assert_abs_diff_eq!(result, &array![0., 0.], epsilon = 1e-12);
}
#[test]
fn fits_least_squares_line_through_three_dots() {
let lin_reg = LinearRegression::new();
let dataset = Dataset::new(array![[0.], [1.], [2.]], array![0., 0., 2.]);
let model = lin_reg.fit(&dataset).unwrap();
let actual = model.predict(dataset.records());
assert_abs_diff_eq!(actual, array![-1. / 3., 2. / 3., 5. / 3.], epsilon = 1e-12);
}
#[test]
fn fits_three_parameters_through_three_dots() {
let lin_reg = LinearRegression::new();
let dataset = Dataset::new(array![[0f64, 0.], [1., 1.], [2., 4.]], array![1., 4., 9.]);
let model = lin_reg.fit(&dataset).unwrap();
assert_abs_diff_eq!(model.params(), &array![2., 1.], epsilon = 1e-12);
assert_abs_diff_eq!(model.intercept(), &1., epsilon = 1e-12);
}
#[test]
fn fits_four_parameters_through_four_dots() {
let lin_reg = LinearRegression::new();
let dataset = Dataset::new(
array![[0f64, 0., 0.], [1., 1., 1.], [2., 4., 8.], [3., 9., 27.]],
array![1., 8., 27., 64.],
);
let model = lin_reg.fit(&dataset).unwrap();
assert_abs_diff_eq!(model.params(), &array![3., 3., 1.], epsilon = 1e-12);
assert_abs_diff_eq!(model.intercept(), &1., epsilon = 1e-12);
}
#[test]
fn fits_three_parameters_through_three_dots_f32() {
let lin_reg = LinearRegression::new();
let dataset = Dataset::new(array![[0f64, 0.], [1., 1.], [2., 4.]], array![1., 4., 9.]);
let model = lin_reg.fit(&dataset).unwrap();
assert_abs_diff_eq!(model.params(), &array![2., 1.], epsilon = 1e-4);
assert_abs_diff_eq!(model.intercept(), &1., epsilon = 1e-6);
}
#[test]
fn fits_four_parameters_through_four_dots_with_normalization() {
let lin_reg = LinearRegression::new().with_intercept_and_normalize();
let dataset = Dataset::new(
array![[0f64, 0., 0.], [1., 1., 1.], [2., 4., 8.], [3., 9., 27.]],
array![1., 8., 27., 64.],
);
let model = lin_reg.fit(&dataset).unwrap();
assert_abs_diff_eq!(model.params(), &array![3., 3., 1.], epsilon = 1e-12);
assert_abs_diff_eq!(model.intercept(), 1., epsilon = 1e-12);
}
#[test]
fn works_with_viewed_and_owned_representations() {
let lin_reg = LinearRegression::new().with_intercept_and_normalize();
let dataset = Dataset::new(
array![[0., 0., 0.], [1., 1., 1.], [2., 4., 8.], [3., 9., 27.]],
array![1., 8., 27., 64.],
);
let dataset_view = dataset.view();
let model1 = lin_reg.fit(&dataset).expect("can't fit owned arrays");
let model2 = lin_reg
.fit(&dataset_view)
.expect("can't fit feature view with owned target");
assert_eq!(model1.params(), model2.params());
assert_abs_diff_eq!(model1.intercept(), model2.intercept());
}
}