use crate::{
benchmarking::{EpochEvaluation, ModelEvaluation, TrainingEvaluation},
datatable::DataTable,
linalg::Scalar,
model::Model,
monitor::TM,
network::params::NetworkParams,
vec_utils::r2_score_matrix,
};
pub type ReporterClosure = dyn FnMut(usize, EpochEvaluation) -> ();
pub struct SplitTraining {
pub ratio: Scalar,
pub real_time_reporter: Option<Box<ReporterClosure>>,
pub model: Option<NetworkParams>,
pub all_epochs_validation: bool,
pub all_epochs_r2: bool,
}
impl SplitTraining {
pub fn new(ratio: Scalar) -> Self {
Self {
ratio,
real_time_reporter: None,
all_epochs_validation: false,
all_epochs_r2: false,
model: None,
}
}
pub fn take_model(&mut self) -> NetworkParams {
self.model.take().unwrap()
}
pub fn all_epochs_r2(&mut self) -> &mut Self {
self.all_epochs_r2 = true;
self
}
pub fn all_epochs_validation(&mut self) -> &mut Self {
self.all_epochs_validation = true;
self
}
pub fn attach_real_time_reporter<F>(&mut self, reporter: F) -> &mut Self
where
F: FnMut(usize, EpochEvaluation) -> () + 'static,
{
self.real_time_reporter = Some(Box::new(reporter));
self
}
pub fn run(&mut self, model: &Model, data: &DataTable) -> (DataTable, ModelEvaluation) {
assert!(!self.all_epochs_r2 || self.all_epochs_validation);
TM::start("split");
TM::start("init");
let mut preds_and_ids = DataTable::new_empty();
let mut model_eval = ModelEvaluation::new_empty();
let predicted_features = model.dataset.predicted_features_names();
let id_column = model
.dataset
.get_id_column()
.expect("One feature must be specified as an id in the dataset specification.");
let mut network = model.to_network();
let (train_table, validation) = data.split_ratio(self.ratio);
let (validation_x_table, validation_y_table) =
validation.random_order_in_out(&predicted_features);
let validation_x = validation_x_table.drop_column(id_column).to_vectors();
let validation_y = validation_y_table.to_vectors();
TM::end_with_message(format!(
"Initialized training with {} samples\nInitialized validation with {} samples",
train_table.num_rows(),
validation_x_table.num_rows()
));
TM::start("epochs");
let mut eval = TrainingEvaluation::new_empty();
let epochs = model.epochs;
for e in 0..epochs {
TM::start(&format!("{}/{}", e + 1, epochs));
let train_loss = model.train_epoch(e, &mut network, &train_table, id_column);
let loss_fn = model.loss.to_loss();
let (preds, loss_avg, loss_std) = if e == model.epochs - 1 || self.all_epochs_validation
{
let vloss = network.predict_evaluate_many(
&validation_x,
&validation_y,
&loss_fn,
model.batch_size.unwrap_or(validation_x.len()),
);
vloss
} else {
(vec![], -1.0, -1.0)
};
let r2 = if e == model.epochs - 1 || self.all_epochs_r2 {
TM::start("r2");
let r2 = r2_score_matrix(&validation_y, &preds);
TM::end_with_message(format!("R2: {}", r2));
r2
} else {
-1.0
};
let epoch_eval = EpochEvaluation::new(train_loss, loss_avg, loss_std, r2);
if let Some(reporter) = self.real_time_reporter.as_mut() {
reporter(e, epoch_eval.clone());
}
if e == model.epochs - 1 {
preds_and_ids = preds_and_ids.apppend(
&DataTable::from_vectors(&predicted_features, &preds)
.add_column_from(&validation_x_table, id_column),
);
};
TM::end_with_message(format!("Training Loss: {}\n ", train_loss));
eval.add_epoch(epoch_eval);
}
TM::end_with_message(format!("Final performance: {:#?}", eval.get_final_epoch()));
model_eval.add_fold(eval);
self.model = Some(network.get_params());
TM::end();
(preds_and_ids, model_eval)
}
}