use crate::{
error::Error,
geometry::{helper::*, Axis, Point},
models::{variance, Model},
};
use core::iter::Sum;
use num_traits::{cast::FromPrimitive, Float};
#[derive(Debug, Clone)]
pub struct Trainer<F> {
train_x: Vec<F>,
train_y: Vec<F>,
axis: Axis,
}
impl<F> Default for Trainer<F> {
fn default() -> Self {
Self {
train_x: Vec::<F>::new(),
train_y: Vec::<F>::new(),
axis: Axis::X,
}
}
}
impl<F> Trainer<F>
where
F: Float + Sized,
{
pub fn new() -> Self {
Self::default()
}
pub fn train_x(&self) -> &Vec<F> {
&self.train_x
}
pub fn train_y(&self) -> &Vec<F> {
&self.train_y
}
pub fn axis(&self) -> &Axis {
&self.axis
}
pub fn set_train_x(&mut self, xs: Vec<F>) {
self.train_x = xs
}
pub fn set_train_y(&mut self, ys: Vec<F>) {
self.train_y = ys
}
pub fn set_axis(&mut self, axis: Axis) {
self.axis = axis
}
pub fn train<'a, M: Model<F = F> + 'a>(&self, model: &'a mut M) -> Result<(), Error> {
model.fit(&self.train_x, &self.train_y)?;
Ok(())
}
}
impl<F> Trainer<F>
where
F: Float + Sum + FromPrimitive,
{
pub fn with_data(xs: Vec<F>, ys: Vec<F>) -> Result<(Self, Vec<Point<F>>), Error> {
assert_empty!(xs);
assert_eq_len!(xs, ys);
let mut trainer = Trainer::new();
let data = trainer.preprocess(xs, ys)?;
Ok((trainer, data))
}
pub fn preprocess(&mut self, xs: Vec<F>, ys: Vec<F>) -> Result<Vec<Point<F>>, Error> {
assert_empty!(xs);
assert_eq_len!(xs, ys);
let mut ps: Vec<Point<F>> = xs
.iter()
.zip(ys.iter())
.map(|(&x, &y)| Point { x, y })
.collect();
if variance(&xs) > variance(&ys) {
sort_by_x(&mut ps);
self.set_axis(Axis::X);
self.set_train_x(extract_x(&ps));
} else {
sort_by_y(&mut ps);
self.set_axis(Axis::Y);
self.set_train_x(extract_y(&ps));
};
let train_y: Vec<F> = (0..ps.len()).map(|id| F::from_usize(id).unwrap()).collect();
self.set_train_y(train_y);
Ok(ps)
}
pub fn with_points(ps: &mut [Point<F>]) -> Result<Self, Error> {
let px: Vec<F> = extract_x(ps);
let py: Vec<F> = extract_y(ps);
assert_eq_len!(px, py);
let x_variance = variance(&px);
let y_variance = variance(&py);
let (axis, train_x) = if x_variance > y_variance {
sort_by_x(ps);
(Axis::X, extract_x(ps))
} else {
sort_by_y(ps);
(Axis::Y, extract_y(ps))
};
let train_y: Vec<F> = (0..ps.len()).map(|id| F::from_usize(id).unwrap()).collect();
Ok(Self {
train_x,
train_y,
axis,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sort_by() {
let mut data: Vec<Point<f64>> = vec![
Point { x: 1., y: 1. },
Point { x: 3., y: 1. },
Point { x: 2., y: 1. },
Point { x: 3., y: 2. },
Point { x: 5., y: 1. },
];
let data_sort_by_x: Vec<Point<f64>> = vec![
Point { x: 1., y: 1. },
Point { x: 2., y: 1. },
Point { x: 3., y: 1. },
Point { x: 3., y: 2. },
Point { x: 5., y: 1. },
];
sort_by_x(&mut data);
assert_eq!(data_sort_by_x, data);
}
#[test]
fn train() {
let mut data: Vec<Point<f64>> = vec![
Point { x: 1., y: 1. },
Point { x: 3., y: 1. },
Point { x: 2., y: 1. },
Point { x: 3., y: 2. },
Point { x: 5., y: 1. },
];
let trainer = Trainer::with_points(&mut data).unwrap();
let test_x = vec![1., 2., 3., 3., 5.];
let test_y = vec![0., 1., 2., 3., 4.];
assert_eq!(&test_x, trainer.train_x());
assert_eq!(&test_y, trainer.train_y());
}
}