use crate::{
classify, decision_tree,
decision_tree::ClassifierModel,
ensemble_predictor,
ensemble_trainer::{self, EnsembleConfig},
trainer_builders::*,
ClassDecode, ClassTarget, ClassesMapping, MaxFeaturesPolicy, Trainset,
};
use serde::{Deserialize, Serialize};
#[derive(Default, Clone, PartialEq, Debug, Serialize, Deserialize)]
pub struct Classifier {
ensemble: Vec<ClassifierModel>,
classes_map: ClassesMapping,
}
#[derive(Clone, PartialEq, Debug)]
pub struct Trainer {
pub config: EnsembleConfig,
}
impl Default for Trainer {
fn default() -> Self {
let mut config = EnsembleConfig::default();
config.tree_config_proto.max_features = MaxFeaturesPolicy::SQRT;
Self { config }
}
}
#[derive(Clone)]
struct Trainee {
tree: ClassifierModel,
num_classes: usize,
}
impl ensemble_trainer::Trainable<ClassTarget> for Trainee {
fn fit(&mut self, ts: &Trainset<ClassTarget>, config: decision_tree::TrainConfig) {
self.tree = ClassifierModel::train(ts, self.num_classes, &config);
}
}
impl ensemble_predictor::Predictor for ClassifierModel {
fn predict(&self, dataset: &[f32]) -> Vec<f32> {
self.predict(dataset)
}
}
impl Trainer {
pub fn train(&self, data: &[f32], labels: &[i64]) -> Classifier {
let (classes_map, labels_enc) = ClassesMapping::with_encode(labels);
let proto = Trainee {
tree: ClassifierModel::default(),
num_classes: classes_map.num_classes(),
};
let trainset = Trainset::with_transposed(data, &labels_enc);
let ens = ensemble_trainer::fit(proto, &trainset, &self.config);
Classifier {
ensemble: ens.into_iter().map(|t| t.tree).collect(),
classes_map,
}
}
}
impl Classifier {
pub fn predict(&self, dataset: &[f32], num_threads: usize) -> Vec<i64> {
classify(&self.proba(dataset, num_threads), &self.classes_map)
}
pub fn predict_one(&self, sample: &[f32]) -> i64 {
classify(&self.proba(sample, 1), &self.classes_map)[0]
}
pub fn proba(&self, dataset: &[f32], num_threads: usize) -> Vec<f32> {
ensemble_predictor::predict(&self.ensemble, dataset, num_threads)
}
pub fn num_features(&self) -> usize {
self.ensemble[0].num_features()
}
pub fn trainer() -> Trainer {
Trainer::default()
}
}
impl ClassDecode for Classifier {
fn get_decode_table(&self) -> &[i64] {
self.classes_map.get_decode_table()
}
}
impl TrainConfigProvider for Trainer {
fn train_config(&mut self) -> &mut decision_tree::TrainConfig {
&mut self.config.tree_config_proto
}
}
impl EnsembleConfigProvider for Trainer {
fn ensemble_config(&mut self) -> &mut EnsembleConfig {
&mut self.config
}
}
impl CommonTrainerBuilder for Trainer {}
impl EnsembleTrainerBuilder for Trainer {}