use crate::{
decision_tree::{self, RegressorModel},
trainer_builders::*,
FloatTarget, Trainset,
};
use serde::{Deserialize, Serialize};
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Regressor {
regressor: RegressorModel,
}
#[derive(Debug, Clone, PartialEq, Default)]
pub struct Trainer {
pub config: decision_tree::TrainConfig,
}
impl TrainConfigProvider for Trainer {
fn train_config(&mut self) -> &mut decision_tree::TrainConfig {
&mut self.config
}
}
impl CommonTrainerBuilder for Trainer {}
impl Trainer {
pub fn train(&self, data: &[f32], targets: &[FloatTarget]) -> Regressor {
let trainset = Trainset::with_transposed(data, &targets);
Regressor {
regressor: RegressorModel::train(&trainset, &self.config),
}
}
}
impl Regressor {
pub fn predict(&self, dataset: &[f32]) -> Vec<FloatTarget> {
self.regressor.predict(dataset)
}
pub fn predict_one(&self, sample: &[f32]) -> FloatTarget {
self.regressor.predict_one(sample)
}
pub fn trainer() -> Trainer {
Trainer::default()
}
pub fn num_features(&self) -> usize {
self.regressor.num_features()
}
}