use learning::error::{Error, ErrorKind};
use linalg::{Matrix, Vector, Axes, BaseMatrix, BaseMatrixMut};
use super::{Invertible, Transformer};
use rulinalg::utils;
use libnum::{Float, FromPrimitive};
#[derive(Debug)]
pub struct Standardizer<T: Float> {
means: Option<Vector<T>>,
variances: Option<Vector<T>>,
scaled_mean: T,
scaled_stdev: T,
}
impl<T: Float> Default for Standardizer<T> {
fn default() -> Standardizer<T> {
Standardizer {
means: None,
variances: None,
scaled_mean: T::zero(),
scaled_stdev: T::one(),
}
}
}
impl<T: Float> Standardizer<T> {
pub fn new(mean: T, stdev: T) -> Standardizer<T> {
Standardizer {
means: None,
variances: None,
scaled_mean: mean,
scaled_stdev: stdev,
}
}
}
impl<T: Float + FromPrimitive> Transformer<Matrix<T>> for Standardizer<T> {
fn fit(&mut self, inputs: &Matrix<T>) -> Result<(), Error> {
if inputs.rows() <= 1 {
Err(Error::new(ErrorKind::InvalidData,
"Cannot standardize data with only one row."))
} else {
let mean = inputs.mean(Axes::Row);
let variance = try!(inputs.variance(Axes::Row).map_err(|_| {
Error::new(ErrorKind::InvalidData, "Cannot compute variance of data.")
}));
if mean.data().iter().any(|x| !x.is_finite()) {
return Err(Error::new(ErrorKind::InvalidData, "Some data point is non-finite."));
}
self.means = Some(mean);
self.variances = Some(variance);
Ok(())
}
}
fn transform(&mut self, mut inputs: Matrix<T>) -> Result<Matrix<T>, Error> {
if let (&None, &None) = (&self.means, &self.variances) {
try!(self.fit(&inputs));
}
if let (&Some(ref means), &Some(ref variances)) = (&self.means, &self.variances) {
if means.size() != inputs.cols() {
Err(Error::new(ErrorKind::InvalidData,
"Input data has different number of columns from fitted data."))
} else {
for row in inputs.iter_rows_mut() {
utils::in_place_vec_bin_op(row, means.data(), |x, &y| *x = *x - y);
utils::in_place_vec_bin_op(row, variances.data(), |x, &y| {
*x = (*x * self.scaled_stdev / y.sqrt()) + self.scaled_mean
});
}
Ok(inputs)
}
} else {
Err(Error::new(ErrorKind::InvalidState, "Transformer has not been fitted."))
}
}
}
impl<T: Float + FromPrimitive> Invertible<Matrix<T>> for Standardizer<T> {
fn inv_transform(&self, mut inputs: Matrix<T>) -> Result<Matrix<T>, Error> {
if let (&Some(ref means), &Some(ref variances)) = (&self.means, &self.variances) {
let features = means.size();
if inputs.cols() != features {
return Err(Error::new(ErrorKind::InvalidData,
"Inputs have different feature count than transformer."));
}
for row in inputs.iter_rows_mut() {
utils::in_place_vec_bin_op(row, &variances.data(), |x, &y| {
*x = (*x - self.scaled_mean) * y.sqrt() / self.scaled_stdev
});
utils::in_place_vec_bin_op(row, &means.data(), |x, &y| *x = *x + y);
}
Ok(inputs)
} else {
Err(Error::new(ErrorKind::InvalidState, "Transformer has not been fitted."))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use super::super::{Transformer, Invertible};
use linalg::{Axes, Matrix};
use std::f64;
#[test]
fn single_row_test() {
let inputs = Matrix::new(1, 2, vec![1.0, 2.0]);
let mut standardizer = Standardizer::default();
let res = standardizer.transform(inputs);
assert!(res.is_err());
}
#[test]
fn nan_data_test() {
let inputs = Matrix::new(2, 2, vec![f64::NAN; 4]);
let mut standardizer = Standardizer::default();
let res = standardizer.transform(inputs);
assert!(res.is_err());
}
#[test]
fn inf_data_test() {
let inputs = Matrix::new(2, 2, vec![f64::INFINITY; 4]);
let mut standardizer = Standardizer::default();
let res = standardizer.transform(inputs);
assert!(res.is_err());
}
#[test]
fn basic_standardize_test() {
let inputs = Matrix::new(2, 2, vec![-1.0f32, 2.0, 0.0, 3.0]);
let mut standardizer = Standardizer::default();
let transformed = standardizer.transform(inputs).unwrap();
let new_mean = transformed.mean(Axes::Row);
let new_var = transformed.variance(Axes::Row).unwrap();
assert!(new_mean.data().iter().all(|x| x.abs() < 1e-5));
assert!(new_var.data().iter().all(|x| (x.abs() - 1.0) < 1e-5));
}
#[test]
fn custom_standardize_test() {
let inputs = Matrix::new(2, 2, vec![-1.0f32, 2.0, 0.0, 3.0]);
let mut standardizer = Standardizer::new(1.0, 2.0);
let transformed = standardizer.transform(inputs).unwrap();
let new_mean = transformed.mean(Axes::Row);
let new_var = transformed.variance(Axes::Row).unwrap();
assert!(new_mean.data().iter().all(|x| (x.abs() - 1.0) < 1e-5));
assert!(new_var.data().iter().all(|x| (x.abs() - 4.0) < 1e-5));
}
#[test]
fn inv_transform_identity_test() {
let inputs = Matrix::new(2, 2, vec![-1.0f32, 2.0, 0.0, 3.0]);
let mut standardizer = Standardizer::new(1.0, 3.0);
let transformed = standardizer.transform(inputs.clone()).unwrap();
let original = standardizer.inv_transform(transformed).unwrap();
assert!((inputs - original).data().iter().all(|x| x.abs() < 1e-5));
}
}