use std::{
cmp::Ordering::Equal,
fmt::{Display, Formatter},
};
use crate::model::{
comparison::ComparisonEntry,
error::{ModelError, ModelResult},
preprocessing::Preprocessor,
};
use crate::settings::{
ClassificationSettings, FinalAlgorithm, Metric, RegressionSettings, SupervisedSettings,
};
use comfy_table::{
Attribute, Cell, Table, modifiers::UTF8_SOLID_INNER_BORDERS, presets::UTF8_FULL,
};
use humantime::format_duration;
use smartcore::error::Failed;
use smartcore::linalg::{
basic::arrays::{Array, Array1, Array2, MutArrayView1},
traits::{
cholesky::CholeskyDecomposable, evd::EVDDecomposable, qr::QRDecomposable,
svd::SVDDecomposable,
},
};
use smartcore::numbers::{basenum::Number, floatnum::FloatNumber, realnum::RealNumber};
pub trait Algorithm<ASettings>: Sized {
type Input: RealNumber + FloatNumber;
type Output: Number;
type InputArray: Clone
+ Array<Self::Input, (usize, usize)>
+ Array2<Self::Input>
+ EVDDecomposable<Self::Input>
+ SVDDecomposable<Self::Input>
+ CholeskyDecomposable<Self::Input>
+ QRDecomposable<Self::Input>;
type OutputArray: Clone + MutArrayView1<Self::Output> + Array1<Self::Output>;
fn predict(&self, x: &Self::InputArray) -> Result<Self::OutputArray, Failed>;
fn cross_validate_model(
self,
x: &Self::InputArray,
y: &Self::OutputArray,
settings: &ASettings,
) -> Result<ComparisonEntry<Self>, Failed>;
fn all_algorithms(settings: &ASettings) -> Vec<Self>;
}
pub trait SupervisedLearningSettings {
fn supervised(&self) -> &SupervisedSettings;
}
impl SupervisedLearningSettings for ClassificationSettings {
fn supervised(&self) -> &SupervisedSettings {
&self.supervised
}
}
impl<INPUT, OUTPUT, InputArray, OutputArray> SupervisedLearningSettings
for RegressionSettings<INPUT, OUTPUT, InputArray, OutputArray>
where
INPUT: FloatNumber + RealNumber,
OUTPUT: FloatNumber,
InputArray: CholeskyDecomposable<INPUT>
+ SVDDecomposable<INPUT>
+ EVDDecomposable<INPUT>
+ QRDecomposable<INPUT>,
OutputArray: Array1<OUTPUT>,
{
fn supervised(&self) -> &SupervisedSettings {
&self.supervised
}
}
pub struct SupervisedModel<A, S, InputArray, OutputArray>
where
A: Algorithm<S, InputArray = InputArray, OutputArray = OutputArray>,
S: SupervisedLearningSettings,
InputArray: Clone
+ Array<A::Input, (usize, usize)>
+ Array2<A::Input>
+ EVDDecomposable<A::Input>
+ SVDDecomposable<A::Input>
+ CholeskyDecomposable<A::Input>
+ QRDecomposable<A::Input>,
OutputArray: Clone + MutArrayView1<A::Output> + Array1<A::Output>,
{
pub settings: S,
x_train: InputArray,
y_train: OutputArray,
comparison: Vec<ComparisonEntry<A>>,
preprocessor: Preprocessor<A::Input, InputArray>,
}
impl<A, S, InputArray, OutputArray> SupervisedModel<A, S, InputArray, OutputArray>
where
A: Algorithm<S, InputArray = InputArray, OutputArray = OutputArray>,
S: SupervisedLearningSettings,
InputArray: Clone
+ Array<A::Input, (usize, usize)>
+ Array2<A::Input>
+ EVDDecomposable<A::Input>
+ SVDDecomposable<A::Input>
+ CholeskyDecomposable<A::Input>
+ QRDecomposable<A::Input>,
OutputArray: Clone + MutArrayView1<A::Output> + Array1<A::Output>,
{
pub fn new(x: InputArray, y: OutputArray, settings: S) -> Self {
Self {
settings,
x_train: x,
y_train: y,
comparison: Vec::new(),
preprocessor: Preprocessor::new(),
}
}
pub fn train(&mut self) -> Result<(), Failed> {
let sup = self.settings.supervised();
self.preprocessor
.train(&self.x_train.clone(), &sup.preprocessing);
for alg in <A>::all_algorithms(&self.settings) {
let trained = alg.cross_validate_model(&self.x_train, &self.y_train, &self.settings)?;
self.record_trained_model(trained);
}
Ok(())
}
pub fn predict(&self, x: InputArray) -> ModelResult<OutputArray> {
let sup = self.settings.supervised();
let x = self
.preprocessor
.preprocess(x, &sup.preprocessing)
.map_err(|e| ModelError::Inference(e.to_string()))?;
match sup.final_model_approach {
FinalAlgorithm::None => Err(ModelError::NotTrained),
FinalAlgorithm::Best => {
let entry = self.comparison.first().ok_or(ModelError::NotTrained)?;
entry
.algorithm
.predict(&x)
.map_err(|e| ModelError::Inference(e.to_string()))
}
}
}
fn record_trained_model(&mut self, trained_model: ComparisonEntry<A>) {
self.comparison.push(trained_model);
self.sort();
}
fn sort(&mut self) {
let sort_by = &self.settings.supervised().sort_by;
self.comparison.sort_by(|a, b| {
a.result
.mean_test_score()
.partial_cmp(&b.result.mean_test_score())
.unwrap_or(Equal)
});
if matches!(sort_by, Metric::RSquared | Metric::Accuracy) {
self.comparison.reverse();
}
}
}
impl<A, S, InputArray, OutputArray> Display for SupervisedModel<A, S, InputArray, OutputArray>
where
A: Algorithm<S, InputArray = InputArray, OutputArray = OutputArray> + Display,
S: SupervisedLearningSettings,
InputArray: Clone
+ Array<A::Input, (usize, usize)>
+ Array2<A::Input>
+ EVDDecomposable<A::Input>
+ SVDDecomposable<A::Input>
+ CholeskyDecomposable<A::Input>
+ QRDecomposable<A::Input>,
OutputArray: Clone + MutArrayView1<A::Output> + Array1<A::Output>,
{
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let mut table = Table::new();
table.load_preset(UTF8_FULL);
table.apply_modifier(UTF8_SOLID_INNER_BORDERS);
table.set_header(vec![
Cell::new("Model").add_attribute(Attribute::Bold),
Cell::new("Time").add_attribute(Attribute::Bold),
Cell::new(format!("Training {}", self.settings.supervised().sort_by))
.add_attribute(Attribute::Bold),
Cell::new(format!("Testing {}", self.settings.supervised().sort_by))
.add_attribute(Attribute::Bold),
]);
for entry in &self.comparison {
let mut row = Vec::new();
row.push(entry.algorithm.to_string());
row.push(format_duration(entry.duration).to_string());
let decider = f64::midpoint(
entry.result.mean_train_score(),
entry.result.mean_test_score(),
)
.abs();
if (0.01..1000.0).contains(&decider) {
row.push(format!("{:.2}", entry.result.mean_train_score()));
row.push(format!("{:.2}", entry.result.mean_test_score()));
} else {
row.push(format!("{:.3e}", entry.result.mean_train_score()));
row.push(format!("{:.3e}", entry.result.mean_test_score()));
}
table.add_row(row);
}
write!(f, "{table}")
}
}