use crate::{
classify,
decision_tree::{self, ClassifierModel},
trainer_builders::*,
ClassDecode, ClassesMapping, Trainset,
};
use argminmax::ArgMinMax;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Classifier {
classifier: ClassifierModel,
classes_map: ClassesMapping,
}
#[derive(Clone, PartialEq, Debug, Default)]
pub struct Trainer {
config: decision_tree::TrainConfig,
}
impl TrainConfigProvider for Trainer {
fn train_config(&mut self) -> &mut decision_tree::TrainConfig {
&mut self.config
}
}
impl CommonTrainerBuilder for Trainer {}
impl Trainer {
pub fn train(&self, data: &[f32], labels: &[i64]) -> Classifier {
let (classes_map, encoded_labels) = ClassesMapping::with_encode(labels);
let ts = Trainset::with_transposed(data, &encoded_labels);
Classifier {
classifier: ClassifierModel::train(&ts, classes_map.num_classes(), &self.config),
classes_map,
}
}
}
impl Classifier {
pub fn predict(&self, dataset: &[f32]) -> Vec<i64> {
classify(&self.proba(dataset), &self.classes_map)
}
pub fn predict_one(&self, sample: &[f32]) -> i64 {
self.classes_map.decode(self.proba(sample).argmax())
}
pub fn proba(&self, dataset: &[f32]) -> Vec<f32> {
self.classifier.predict(dataset)
}
pub fn trainer() -> Trainer {
Trainer::default()
}
pub fn num_features(&self) -> usize {
self.classifier.num_features()
}
}
impl ClassDecode for Classifier {
fn get_decode_table(&self) -> &[i64] {
self.classes_map.get_decode_table()
}
}