use crate::{
benchmarking::{EpochEvaluation, ModelEvaluation, TrainingEvaluation},
linalg::Scalar,
model::Model,
monitor::TM,
network::params::NetworkParams,
vec_utils::r2_score_vector2,
};
#[cfg(feature = "data")]
use crate::datatable::DataTable;
#[cfg(not(feature = "data"))]
use rand::thread_rng;
#[cfg(not(feature = "data"))]
use rand::seq::SliceRandom;
pub type ReporterClosure = dyn FnMut(usize, EpochEvaluation) -> ();
pub struct SplitTraining {
pub ratio: Scalar,
pub real_time_reporter: Option<Box<ReporterClosure>>,
pub model: Option<NetworkParams>,
pub all_epochs_validation: bool,
pub all_epochs_r2: bool,
}
impl SplitTraining {
pub fn new(ratio: Scalar) -> Self {
Self {
ratio,
real_time_reporter: None,
all_epochs_validation: false,
all_epochs_r2: false,
model: None,
}
}
pub fn take_model(&mut self) -> NetworkParams {
self.model.take().unwrap()
}
pub fn all_epochs_r2(&mut self) -> &mut Self {
self.all_epochs_r2 = true;
self
}
pub fn all_epochs_validation(&mut self) -> &mut Self {
self.all_epochs_validation = true;
self
}
pub fn attach_real_time_reporter<F>(&mut self, reporter: F) -> &mut Self
where
F: FnMut(usize, EpochEvaluation) -> () + 'static,
{
self.real_time_reporter = Some(Box::new(reporter));
self
}
#[cfg(feature = "data")]
pub fn run(&mut self, model: &Model, data: &DataTable) -> (DataTable, ModelEvaluation) {
assert!(!self.all_epochs_r2 || self.all_epochs_validation);
TM::start("split");
TM::start("init");
let mut preds_and_ids = DataTable::new_empty();
let mut model_eval = ModelEvaluation::new_empty();
let predicted_features = model.dataset_config.predicted_features_names();
let id_column = model
.dataset_config
.get_id_column()
.expect("One feature must be configurationified as an id in the dataset dataset_config.");
let mut network = model.to_network();
let (train_table, validation) = data.split_ratio(self.ratio);
let (validation_x_table, validation_y_table) =
validation.random_order_in_out(&predicted_features);
let validation_x = validation_x_table.drop_column(id_column).to_vectors();
let validation_y = validation_y_table.to_vectors();
TM::end_with_message(format!(
"Initialized training with {} samples\nInitialized validation with {} samples",
train_table.num_rows(),
validation_x_table.num_rows()
));
TM::start("epochs");
let mut eval = TrainingEvaluation::new_empty();
let epochs = model.epochs;
for e in 0..epochs {
TM::start(&format!("{}/{}", e + 1, epochs));
let train_loss = model.train_epoch(e, &mut network, &train_table, id_column);
let loss_fn = model.loss.to_loss();
let (preds, loss_avg, loss_std) = if e == model.epochs - 1 || self.all_epochs_validation
{
let vloss = network.predict_evaluate_many(
&validation_x,
&validation_y,
&loss_fn,
model.batch_size.unwrap_or(validation_x.len()),
);
vloss
} else {
(vec![], -1.0, -1.0)
};
let r2 = if e == model.epochs - 1 || self.all_epochs_r2 {
TM::start("r2");
let r2 = r2_score_vector2(&validation_y, &preds);
TM::end_with_message(format!("R2: {}", r2));
r2
} else {
-1.0
};
let epoch_eval = EpochEvaluation::new(train_loss, loss_avg, loss_std, r2);
if let Some(reporter) = self.real_time_reporter.as_mut() {
reporter(e, epoch_eval.clone());
}
if e == model.epochs - 1 {
preds_and_ids = preds_and_ids.apppend(
&DataTable::from_vectors(&predicted_features, &preds)
.add_column_from(&validation_x_table, id_column),
);
};
TM::end_with_message(format!("Training Loss: {}\n ", train_loss));
eval.add_epoch(epoch_eval);
}
TM::end_with_message(format!("Final performance: {:#?}", eval.get_final_epoch()));
model_eval.add_fold(eval);
self.model = Some(network.get_params());
TM::end();
(preds_and_ids, model_eval)
}
#[cfg(not(feature = "data"))]
pub fn run(&mut self, model: &Model, data_x: &Vec<Vec<Scalar>>, data_y: &Vec<Vec<Scalar>>) -> (Vec<Vec<Scalar>>, ModelEvaluation) {
assert!(!self.all_epochs_r2 || self.all_epochs_validation);
assert!(data_x.len() == data_y.len());
TM::start("split");
TM::start("init");
let mut model_eval = ModelEvaluation::new_empty();
let mut network = model.to_network(data_x[0].len());
let split_at = (self.ratio * data_x.len() as Scalar) as usize;
let mut ids = (0..data_x.len()).map(|x| x as Scalar).collect::<Vec<_>>();
ids.shuffle(&mut thread_rng());
let data_x = ids.iter().map(|&i| data_x[i as usize].clone()).collect::<Vec<_>>();
let data_y = ids.iter().map(|&i| data_y[i as usize].clone()).collect::<Vec<_>>();
let (train_x, validation_x) = data_x.split_at(split_at);
let (train_y, validation_y) = data_y.split_at(split_at);
let train_x = train_x.to_vec();
let train_y = train_y.to_vec();
let validation_x = validation_x.to_vec();
let validation_y = validation_y.to_vec();
TM::end_with_message(format!(
"Initialized training with {} samples\nInitialized validation with {} samples",
train_x.len(),
validation_x.len()
));
TM::start("epochs");
let mut final_predictions = vec![];
let mut eval = TrainingEvaluation::new_empty();
let epochs = model.epochs;
for e in 0..epochs {
TM::start(&format!("{}/{}", e + 1, epochs));
let train_loss = model.train_epoch(e, &mut network, &train_x, &train_y);
let loss_fn = model.loss.to_loss();
let (preds, loss_avg, loss_std) = if e == model.epochs - 1 || self.all_epochs_validation
{
let vloss = network.predict_evaluate_many(
&validation_x,
&validation_y,
&loss_fn,
model.batch_size.unwrap_or(validation_x.len()),
);
vloss
} else {
(vec![], -1.0, -1.0)
};
let r2 = if e == model.epochs - 1 || self.all_epochs_r2 {
TM::start("r2");
let r2 = r2_score_vector2(&validation_y, &preds);
TM::end_with_message(format!("R2: {}", r2));
r2
} else {
-1.0
};
let epoch_eval = EpochEvaluation::new(train_loss, loss_avg, loss_std, r2);
if let Some(reporter) = self.real_time_reporter.as_mut() {
reporter(e, epoch_eval.clone());
}
if e == model.epochs - 1 {
final_predictions = preds.clone();
};
TM::end_with_message(format!("Training Loss: {}\n ", train_loss));
eval.add_epoch(epoch_eval);
}
TM::end_with_message(format!("Final performance: {:#?}", eval.get_final_epoch()));
model_eval.add_fold(eval);
self.model = Some(network.get_params());
let mut reordered_predictions = Vec::with_capacity(data_x.len());
for i in 0..data_x.len() {
reordered_predictions[ids[i] as usize] = final_predictions[i].clone();
}
TM::end();
(reordered_predictions, model_eval)
}
}