#![allow(non_snake_case)]
use ndarray::{Array1, Array2, ArrayBase, Axis, Data, Ix1, Ix2, ScalarOperand};
use ndarray_linalg::{Lapack, Scalar, Solve};
use ndarray_stats::SummaryStatisticsExt;
use num_traits::float::Float;
pub struct LinearRegression {
options: Options,
}
#[derive(Clone, Copy, PartialEq, Eq)]
enum Options {
None,
WithIntercept,
WithInterceptAndNormalize,
}
impl Options {
fn should_use_intercept(&self) -> bool {
*self == Options::WithIntercept || *self == Options::WithInterceptAndNormalize
}
fn should_normalize(&self) -> bool {
*self == Options::WithInterceptAndNormalize
}
}
pub struct FittedLinearRegression<A> {
intercept: A,
params: Array1<A>,
}
impl Default for LinearRegression {
fn default() -> Self {
LinearRegression::new()
}
}
impl LinearRegression {
pub fn new() -> LinearRegression {
LinearRegression {
options: Options::WithIntercept,
}
}
pub fn with_intercept(mut self, with_intercept: bool) -> Self {
if with_intercept {
self.options = Options::WithIntercept;
} else {
self.options = Options::None;
}
self
}
pub fn with_intercept_and_normalize(mut self) -> Self {
self.options = Options::WithInterceptAndNormalize;
self
}
pub fn fit<A, B, C>(
&self,
X: &ArrayBase<B, Ix2>,
y: &ArrayBase<C, Ix1>,
) -> Result<FittedLinearRegression<A>, String>
where
A: Lapack + Scalar + ScalarOperand + Float,
B: Data<Elem = A>,
C: Data<Elem = A>,
{
let (n_samples, _) = X.dim();
assert_eq!(y.dim(), n_samples);
if self.options.should_use_intercept() {
let X_offset: Array1<A> = X
.mean_axis(Axis(0))
.ok_or_else(|| String::from("cannot compute mean of X"))?;
let X_centered: Array2<A> = X - &X_offset;
let y_offset: A = y
.mean()
.ok_or_else(|| String::from("cannot compute mean of y"))?;
let y_centered: Array1<A> = y - y_offset;
let params: Array1<A> =
compute_params(&X_centered, &y_centered, self.options.should_normalize())?;
let intercept: A = y_offset - X_offset.dot(¶ms);
Ok(FittedLinearRegression { intercept, params })
} else {
Ok(FittedLinearRegression {
intercept: A::from(0).unwrap(),
params: solve_normal_equation(X, y)?,
})
}
}
}
fn compute_params<A, B, C>(
X: &ArrayBase<B, Ix2>,
y: &ArrayBase<C, Ix1>,
normalize: bool,
) -> Result<Array1<A>, String>
where
A: Scalar + Lapack + Float,
B: Data<Elem = A>,
C: Data<Elem = A>,
{
if normalize {
let scale: Array1<A> = X.map_axis(Axis(0), |column| column.central_moment(2).unwrap());
let X: Array2<A> = X / &scale;
let mut params: Array1<A> = solve_normal_equation(&X, y)?;
params /= &scale;
Ok(params)
} else {
solve_normal_equation(X, y)
}
}
fn solve_normal_equation<A, B, C>(
X: &ArrayBase<B, Ix2>,
y: &ArrayBase<C, Ix1>,
) -> Result<Array1<A>, String>
where
A: Lapack + Scalar,
B: Data<Elem = A>,
C: Data<Elem = A>,
{
let rhs = X.t().dot(y);
let linear_operator = X.t().dot(X);
linear_operator
.solve_into(rhs)
.map_err(|err| format! {"{}", err})
}
impl<A: Scalar + ScalarOperand> FittedLinearRegression<A> {
pub fn predict(&self, X: &Array2<A>) -> Array1<A> {
X.dot(&self.params) + self.intercept
}
pub fn params(&self) -> &Array1<A> {
&self.params
}
pub fn intercept(&self) -> A {
self.intercept
}
}
#[cfg(test)]
mod tests {
extern crate openblas_src;
use super::*;
use approx::abs_diff_eq;
use ndarray::{array, s, Array1, Array2};
#[test]
fn fits_a_line_through_two_dots() {
let lin_reg = LinearRegression::new();
let A: Array2<f64> = array![[0.], [1.]];
let b: Array1<f64> = array![1., 2.];
let model = lin_reg.fit(&A, &b).unwrap();
let result = model.predict(&A);
abs_diff_eq!(result, &array![1., 2.], epsilon = 1e-12);
}
#[test]
fn without_intercept_fits_line_through_origin() {
let lin_reg = LinearRegression::new().with_intercept(false);
let A: Array2<f64> = array![[1.]];
let b: Array1<f64> = array![1.];
let model = lin_reg.fit(&A, &b).unwrap();
let result = model.predict(&array![[0.], [1.]]);
abs_diff_eq!(result, &array![0., 1.], epsilon = 1e-12);
}
#[test]
fn fits_least_squares_line_through_two_dots() {
let lin_reg = LinearRegression::new().with_intercept(false);
let A: Array2<f64> = array![[-1.], [1.]];
let b: Array1<f64> = array![1., 1.];
let model = lin_reg.fit(&A, &b).unwrap();
let result = model.predict(&A);
abs_diff_eq!(result, &array![0., 0.], epsilon = 1e-12);
}
#[test]
fn fits_least_squares_line_through_three_dots() {
let lin_reg = LinearRegression::new();
let A: Array2<f64> = array![[0.], [1.], [2.]];
let b: Array1<f64> = array![0., 0., 2.];
let model = lin_reg.fit(&A, &b).unwrap();
let actual = model.predict(&A);
abs_diff_eq!(actual, array![-1. / 3., 2. / 3., 5. / 3.], epsilon = 1e-12);
}
#[test]
fn fits_three_parameters_through_three_dots() {
let lin_reg = LinearRegression::new();
let A: Array2<f64> = array![[0., 0.], [1., 1.], [2., 4.]];
let b: Array1<f64> = array![1., 4., 9.];
let model = lin_reg.fit(&A, &b).unwrap();
abs_diff_eq!(model.params(), &array![2., 1.], epsilon = 1e-12);
abs_diff_eq!(model.intercept(), &1., epsilon = 1e-12);
}
#[test]
fn fits_four_parameters_through_four_dots() {
let lin_reg = LinearRegression::new();
let A: Array2<f64> = array![[0., 0., 0.], [1., 1., 1.], [2., 4., 8.], [3., 9., 27.]];
let b: Array1<f64> = array![1., 8., 27., 64.];
let model = lin_reg.fit(&A, &b).unwrap();
abs_diff_eq!(model.params(), &array![3., 3., 1.], epsilon = 1e-12);
abs_diff_eq!(model.intercept(), &1., epsilon = 1e-12);
}
#[test]
fn fits_three_parameters_through_three_dots_f32() {
let lin_reg = LinearRegression::new();
let A: Array2<f32> = array![[0., 0.], [1., 1.], [2., 4.]];
let b: Array1<f32> = array![1., 4., 9.];
let model = lin_reg.fit(&A, &b).unwrap();
abs_diff_eq!(model.params(), &array![2., 1.], epsilon = 1e-4);
abs_diff_eq!(model.intercept(), &1., epsilon = 1e-6);
}
#[test]
fn fits_four_parameters_through_four_dots_with_normalization() {
let lin_reg = LinearRegression::new().with_intercept_and_normalize();
let A: Array2<f64> = array![[0., 0., 0.], [1., 1., 1.], [2., 4., 8.], [3., 9., 27.]];
let b: Array1<f64> = array![1., 8., 27., 64.];
let model = lin_reg.fit(&A, &b).unwrap();
abs_diff_eq!(model.params(), &array![3., 3., 1.], epsilon = 1e-12);
abs_diff_eq!(model.intercept(), 1., epsilon = 1e-12);
}
#[test]
fn works_with_viewed_and_owned_representations() {
let lin_reg = LinearRegression::new().with_intercept_and_normalize();
let A: Array2<f64> = array![[0., 0., 0.], [1., 1., 1.], [2., 4., 8.], [3., 9., 27.]];
let b: Array1<f64> = array![1., 8., 27., 64.];
let A_view = A.slice(s![.., ..]);
let b_view = b.slice(s![..]);
let model1 = lin_reg.fit(&A, &b).expect("can't fit owned arrays");
let model2 = lin_reg
.fit(&A_view, &b)
.expect("can't fit feature view with owned target");
let model3 = lin_reg
.fit(&A, &b_view)
.expect("can't fit owned features with target view");
let model4 = lin_reg
.fit(&A_view, &b_view)
.expect("can't fit viewed arrays");
assert_eq!(model1.params(), model2.params());
assert_eq!(model2.params(), model3.params());
assert_eq!(model3.params(), model4.params());
abs_diff_eq!(model1.intercept(), model2.intercept());
abs_diff_eq!(model2.intercept(), model3.intercept());
abs_diff_eq!(model3.intercept(), model4.intercept());
}
}