use crate::{
decision_tree::{self, RegressorModel},
ensemble_predictor,
ensemble_trainer::{self, EnsembleConfig},
trainer_builders::*,
FloatTarget, Trainset,
};
use serde::{Deserialize, Serialize};
#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct Regressor {
ensemble: Vec<RegressorModel>,
}
#[derive(Clone, Debug, PartialEq)]
pub struct Trainer {
config: EnsembleConfig,
}
impl Default for Trainer {
fn default() -> Self {
Self {
config: EnsembleConfig::default(),
}
}
}
#[derive(Clone, Default)]
struct Trainee {
tree: RegressorModel,
}
impl ensemble_trainer::Trainable<FloatTarget> for Trainee {
fn fit(&mut self, ts: &Trainset<FloatTarget>, config: decision_tree::TrainConfig) {
self.tree = RegressorModel::train(ts, &config);
}
}
impl ensemble_predictor::Predictor for RegressorModel {
fn predict(&self, dataset: &[f32]) -> Vec<f32> {
self.predict(dataset)
}
}
impl Trainer {
pub fn train(&self, data: &[f32], targets: &[FloatTarget]) -> Regressor {
let trainset = Trainset::with_transposed(data, targets);
let trainee = Trainee::default();
let ens = ensemble_trainer::fit(trainee, &trainset, &self.config);
Regressor {
ensemble: ens.into_iter().map(|t| t.tree).collect(),
}
}
}
impl Regressor {
pub fn predict(&self, dataset: &[f32], num_threads: usize) -> Vec<FloatTarget> {
ensemble_predictor::predict(&self.ensemble, dataset, num_threads)
}
pub fn predict_one(&self, sample: &[f32]) -> FloatTarget {
ensemble_predictor::predict(&self.ensemble, sample, 1)[0]
}
pub fn trainer() -> Trainer {
Trainer::default()
}
}
impl TrainConfigProvider for Trainer {
fn train_config(&mut self) -> &mut decision_tree::TrainConfig {
&mut self.config.tree_config_proto
}
}
impl EnsembleConfigProvider for Trainer {
fn ensemble_config(&mut self) -> &mut EnsembleConfig {
&mut self.config
}
}
impl CommonTrainerBuilder for Trainer {}
impl EnsembleTrainerBuilder for Trainer {}