#![allow(non_snake_case)]
use crate::error::{LinearError, Result};
#[cfg(feature = "blas")]
use linfa::dataset::{WithLapack, WithoutLapack};
use linfa::Float;
#[cfg(not(feature = "blas"))]
use linfa_linalg::qr::LeastSquaresQrInto;
use ndarray::{concatenate, s, Array, Array1, Array2, ArrayBase, Axis, Data, Ix1, Ix2};
#[cfg(feature = "blas")]
use ndarray_linalg::LeastSquaresSvdInto;
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
use linfa::dataset::{AsSingleTargets, DatasetBase};
use linfa::traits::{Fit, PredictInplace};
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub struct LinearRegression {
fit_intercept: bool,
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub struct FittedLinearRegression<F> {
intercept: F,
params: Array1<F>,
}
impl Default for LinearRegression {
fn default() -> Self {
LinearRegression::new()
}
}
impl LinearRegression {
pub fn new() -> LinearRegression {
LinearRegression {
fit_intercept: true,
}
}
pub fn with_intercept(mut self, intercept: bool) -> Self {
self.fit_intercept = intercept;
self
}
}
impl<F: Float, D: Data<Elem = F>, T: AsSingleTargets<Elem = F>>
Fit<ArrayBase<D, Ix2>, T, LinearError<F>> for LinearRegression
{
type Object = FittedLinearRegression<F>;
fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object, F> {
let X = dataset.records();
let y = dataset.as_single_targets();
let (n_samples, _) = X.dim();
assert_eq!(y.dim(), n_samples);
if self.fit_intercept {
let X = concatenate(Axis(1), &[X.view(), Array2::ones((X.nrows(), 1)).view()]).unwrap();
let params: Array1<F> = solve_least_squares(X, y.to_owned())?;
let intercept = *params.last().unwrap();
let params = params.slice(s![..params.len() - 1]).to_owned();
Ok(FittedLinearRegression { intercept, params })
} else {
let (X, y) = (X.to_owned(), y.to_owned());
Ok(FittedLinearRegression {
intercept: F::cast(0),
params: solve_least_squares(X, y)?,
})
}
}
}
fn solve_least_squares<F>(mut X: Array<F, Ix2>, mut y: Array<F, Ix1>) -> Result<Array1<F>, F>
where
F: Float,
{
let (X, y) = (X.view_mut(), y.view_mut());
#[cfg(not(feature = "blas"))]
let out = X
.least_squares_into(y.insert_axis(Axis(1)))?
.remove_axis(Axis(1));
#[cfg(feature = "blas")]
let out = X
.with_lapack()
.least_squares_into(y.with_lapack())
.map(|x| x.solution)?
.without_lapack();
Ok(out)
}
impl<F: Float> FittedLinearRegression<F> {
pub fn params(&self) -> &Array1<F> {
&self.params
}
pub fn intercept(&self) -> F {
self.intercept
}
}
impl<F: Float, D: Data<Elem = F>> PredictInplace<ArrayBase<D, Ix2>, Array1<F>>
for FittedLinearRegression<F>
{
fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<F>) {
assert_eq!(
x.nrows(),
y.len(),
"The number of data points must match the number of output targets."
);
*y = x.dot(&self.params) + self.intercept;
}
fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<F> {
Array1::zeros(x.nrows())
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use linfa::{traits::Predict, Dataset};
use ndarray::array;
#[test]
fn autotraits() {
fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
has_autotraits::<FittedLinearRegression<f64>>();
has_autotraits::<LinearRegression>();
has_autotraits::<LinearError<f64>>();
}
#[test]
fn fits_a_line_through_two_dots() {
let lin_reg = LinearRegression::new();
let dataset = Dataset::new(array![[0f64], [1.]], array![1., 2.]);
let model = lin_reg.fit(&dataset).unwrap();
let result = model.predict(dataset.records());
assert_abs_diff_eq!(result, &array![1., 2.], epsilon = 1e-12);
}
#[test]
fn without_intercept_fits_line_through_origin() {
let lin_reg = LinearRegression::new().with_intercept(false);
let dataset = Dataset::new(array![[1.]], array![1.]);
let model = lin_reg.fit(&dataset).unwrap();
let result = model.predict(&array![[0.], [1.]]);
assert_abs_diff_eq!(result, &array![0., 1.], epsilon = 1e-12);
}
#[test]
fn fits_least_squares_line_through_two_dots() {
let lin_reg = LinearRegression::new().with_intercept(false);
let dataset = Dataset::new(array![[-1.], [1.]], array![1., 1.]);
let model = lin_reg.fit(&dataset).unwrap();
let result = model.predict(dataset.records());
assert_abs_diff_eq!(result, &array![0., 0.], epsilon = 1e-12);
}
#[test]
fn fits_least_squares_line_through_three_dots() {
let lin_reg = LinearRegression::new();
let dataset = Dataset::new(array![[0.], [1.], [2.]], array![0., 0., 2.]);
let model = lin_reg.fit(&dataset).unwrap();
let actual = model.predict(dataset.records());
assert_abs_diff_eq!(actual, array![-1. / 3., 2. / 3., 5. / 3.], epsilon = 1e-12);
}
#[test]
fn fits_three_parameters_through_three_dots() {
let lin_reg = LinearRegression::new();
let dataset = Dataset::new(array![[0f64, 0.], [1., 1.], [2., 4.]], array![1., 4., 9.]);
let model = lin_reg.fit(&dataset).unwrap();
assert_abs_diff_eq!(model.params(), &array![2., 1.], epsilon = 1e-12);
assert_abs_diff_eq!(model.intercept(), &1., epsilon = 1e-12);
}
#[test]
fn fits_four_parameters_through_four_dots() {
let lin_reg = LinearRegression::new();
let dataset = Dataset::new(
array![[0f64, 0., 0.], [1., 1., 1.], [2., 4., 8.], [3., 9., 27.]],
array![1., 8., 27., 64.],
);
let model = lin_reg.fit(&dataset).unwrap();
assert_abs_diff_eq!(model.params(), &array![3., 3., 1.], epsilon = 1e-12);
assert_abs_diff_eq!(model.intercept(), &1., epsilon = 1e-12);
}
#[test]
fn fits_three_parameters_through_three_dots_f32() {
let lin_reg = LinearRegression::new();
let dataset = Dataset::new(array![[0f64, 0.], [1., 1.], [2., 4.]], array![1., 4., 9.]);
let model = lin_reg.fit(&dataset).unwrap();
assert_abs_diff_eq!(model.params(), &array![2., 1.], epsilon = 1e-4);
assert_abs_diff_eq!(model.intercept(), &1., epsilon = 1e-6);
}
}