use nalgebra::{DMatrix, DVector};
use RustQuant_error::RustQuantError;
#[allow(clippy::module_name_repetitions)]
#[derive(Clone, Debug)]
pub struct LinearRegressionInput<T> {
pub x: DMatrix<T>,
pub y: DVector<T>,
}
#[allow(clippy::module_name_repetitions)]
#[derive(Clone, Debug)]
pub struct LinearRegressionOutput<T> {
pub intercept: T,
pub coefficients: DVector<T>,
}
#[derive(Copy, Clone)]
pub enum Decomposition {
None,
QR,
SVD,
}
impl LinearRegressionInput<f64> {
#[must_use]
pub fn new(x: DMatrix<f64>, y: DVector<f64>) -> Self {
Self { x, y }
}
pub fn fit(
&self,
method: Decomposition,
) -> Result<LinearRegressionOutput<f64>, RustQuantError> {
let x = self.x.clone().insert_column(0, 1.);
let y = self.y.clone();
match method {
Decomposition::None => {
let x_t = x.transpose();
let x_t_x = x_t.clone() * x;
let x_t_x_inv = x_t_x
.try_inverse()
.ok_or(RustQuantError::MatrixInversionFailed)?;
let x_t_y = x_t * y;
let coefficients = x_t_x_inv * x_t_y;
let intercept = coefficients[0];
Ok(LinearRegressionOutput {
intercept,
coefficients,
})
}
Decomposition::QR => {
let qr = x.qr();
let q = qr.q();
let r = qr.r();
let coefficients = r
.try_inverse()
.ok_or(RustQuantError::MatrixInversionFailed)?
* q.transpose()
* y;
let intercept = coefficients[0];
Ok(LinearRegressionOutput {
intercept,
coefficients,
})
}
Decomposition::SVD => {
let svd = x.svd(true, true);
let v = svd
.v_t
.ok_or(RustQuantError::SvdDecompositionFailed)?
.transpose();
let s_inv = DMatrix::from_diagonal(&svd.singular_values.map(|x| 1.0 / x));
let u = svd.u.ok_or(RustQuantError::SvdDecompositionFailedOnU)?;
let pseudo_inverse = v * s_inv * u.transpose();
let coefficients = &pseudo_inverse * y;
let intercept = coefficients[0];
Ok(LinearRegressionOutput {
intercept,
coefficients,
})
}
}
}
}
impl LinearRegressionOutput<f64> {
pub fn predict(&self, input: DMatrix<f64>) -> Result<DVector<f64>, RustQuantError> {
let intercept = DVector::from_element(input.nrows(), self.intercept);
let coefficients = self.coefficients.clone().remove_row(0);
let predictions = input * coefficients + intercept;
Ok(predictions)
}
}
#[cfg(test)]
mod tests_linear_regression {
use super::*;
use std::f64::EPSILON as EPS;
use std::time::Instant;
use RustQuant_utils::assert_approx_equal;
#[test]
fn test_linear_regression() -> Result<(), RustQuantError> {
#[rustfmt::skip]
let x_train = DMatrix::from_row_slice(
4, 3, &[-0.083_784_355, -0.633_485_70, -0.399_266_60,
-0.982_943_745, 1.090_797_46, -0.468_123_05,
-1.875_067_321, -0.913_727_27, 0.326_962_08,
-0.186_144_661, 1.001_639_71, -0.412_746_90],
);
#[rustfmt::skip]
let x_test = DMatrix::from_row_slice(
4, 3, &[0.562_036_47, 0.595_846_45, -0.411_653_01,
0.663_358_26, 0.452_091_83, -0.294_327_15,
-0.602_897_28, 0.896_743_96, 1.218_573_96,
0.698_377_69, 0.572_216_51, 0.244_111_43],
);
let response =
DVector::from_row_slice(&[-0.445_151_96, -1.847_803_64, -0.628_825_31, -0.861_080_69]);
let input = LinearRegressionInput {
x: x_train,
y: response,
};
let start_none = Instant::now();
let output = input.fit(Decomposition::None)?;
let elapsed_none = start_none.elapsed();
let coefficients = output.coefficients.clone();
let start_qr = Instant::now();
let output_qr = input.fit(Decomposition::QR)?;
let coefficients_qr = output_qr.coefficients.clone();
let elapsed_qr = start_qr.elapsed();
let start_svd = Instant::now();
let output_svd = input.fit(Decomposition::SVD)?;
let coefficients_svd = output_svd.coefficients.clone();
let elapsed_svd = start_svd.elapsed();
println!("None: time {:?}, Coefs: {:?}\n", elapsed_none, coefficients);
println!("QR: time {:?}, Coefs: {:?}\n", elapsed_qr, coefficients_qr);
println!(
"SVD: time {:?}, Coefs: {:?}\n",
elapsed_svd, coefficients_svd
);
let preds = output.predict(x_test)?;
assert_approx_equal!(output.intercept, 0.453_267_356_085_818_9, EPS);
for (i, coefficient) in output.coefficients.iter().enumerate() {
assert_approx_equal!(
coefficient,
&[
0.453_267_356_085_818_9,
1.059_866_132_317_468_5,
-0.169_093_464_601_759_45,
2.296_053_332_765_449_5
][i],
EPS
);
}
for (i, pred) in preds.iter().enumerate() {
assert_approx_equal!(
pred,
&[
0.003_019_769_611_493_972,
0.404_101_701_919_158_5,
2.460_554_206_769_587_4,
1.657_189_007_522_339_4
][i],
EPS
);
}
Ok(())
}
}