use linalg::{Matrix, BaseMatrix};
use linalg::Vector;
use learning::{LearningResult, SupModel};
use learning::toolkit::cost_fn::CostFunc;
use learning::toolkit::cost_fn::MeanSqError;
use learning::optim::grad_desc::GradientDesc;
use learning::optim::{OptimAlgorithm, Optimizable};
use learning::error::Error;
#[derive(Debug)]
pub struct LinRegressor {
parameters: Option<Vector<f64>>,
}
impl Default for LinRegressor {
fn default() -> LinRegressor {
LinRegressor { parameters: None }
}
}
impl LinRegressor {
pub fn parameters(&self) -> Option<&Vector<f64>> {
self.parameters.as_ref()
}
}
impl SupModel<Matrix<f64>, Vector<f64>> for LinRegressor {
fn train(&mut self, inputs: &Matrix<f64>, targets: &Vector<f64>) -> LearningResult<()> {
let ones = Matrix::<f64>::ones(inputs.rows(), 1);
let full_inputs = ones.hcat(inputs);
let xt = full_inputs.transpose();
self.parameters = Some((&xt * full_inputs).solve(&xt * targets)
.expect("Unable to solve linear equation."));
Ok(())
}
fn predict(&self, inputs: &Matrix<f64>) -> LearningResult<Vector<f64>> {
if let Some(ref v) = self.parameters {
let ones = Matrix::<f64>::ones(inputs.rows(), 1);
let full_inputs = ones.hcat(inputs);
Ok(full_inputs * v)
} else {
Err(Error::new_untrained())
}
}
}
impl Optimizable for LinRegressor {
type Inputs = Matrix<f64>;
type Targets = Vector<f64>;
fn compute_grad(&self,
params: &[f64],
inputs: &Matrix<f64>,
targets: &Vector<f64>)
-> (f64, Vec<f64>) {
let beta_vec = Vector::new(params.to_vec());
let outputs = inputs * beta_vec;
let cost = MeanSqError::cost(&outputs, targets);
let grad = (inputs.transpose() * (outputs - targets)) / (inputs.rows() as f64);
(cost, grad.into_vec())
}
}
impl LinRegressor {
pub fn train_with_optimization(&mut self, inputs: &Matrix<f64>, targets: &Vector<f64>) {
let ones = Matrix::<f64>::ones(inputs.rows(), 1);
let full_inputs = ones.hcat(inputs);
let initial_params = vec![0.; full_inputs.cols()];
let gd = GradientDesc::default();
let optimal_w = gd.optimize(self, &initial_params[..], &full_inputs, targets);
self.parameters = Some(Vector::new(optimal_w));
}
}