use std::time::Instant;
use crate::{
chart::plot_loss,
format_duration,
loss::{pairwise_distance, umap_loss},
model::UMAPModel,
utils::convert_vector_to_tensor,
};
use burn::{
optim::{decay::WeightDecayConfig, AdamConfig, GradientsParams, Optimizer},
tensor::{backend::AutodiffBackend, cast::ToElement, Device},
};
use indicatif::{ProgressBar, ProgressStyle};
#[derive(Debug)]
pub struct TrainingConfig<B: AutodiffBackend> {
pub epochs: usize, pub batch_size: usize, pub learning_rate: f64, pub device: Device<B>, pub beta1: f64, pub beta2: f64, pub penalty: f64, pub verbose: bool, pub patience: Option<i32>, }
impl<B: AutodiffBackend> TrainingConfig<B> {
pub fn builder() -> TrainingConfigBuilder<B> {
TrainingConfigBuilder::default()
}
}
#[derive(Default)]
pub struct TrainingConfigBuilder<B: AutodiffBackend> {
epochs: Option<usize>,
batch_size: Option<usize>,
learning_rate: Option<f64>,
device: Option<Device<B>>,
beta1: Option<f64>,
beta2: Option<f64>,
penalty: Option<f64>,
verbose: Option<bool>,
patience: Option<i32>,
}
impl<B: AutodiffBackend> TrainingConfigBuilder<B> {
pub fn with_epochs(mut self, epochs: usize) -> Self {
self.epochs = Some(epochs);
self
}
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = Some(batch_size);
self
}
pub fn with_learning_rate(mut self, learning_rate: f64) -> Self {
self.learning_rate = Some(learning_rate);
self
}
pub fn with_device(mut self, device: Device<B>) -> Self {
self.device = Some(device);
self
}
pub fn with_beta1(mut self, beta1: f64) -> Self {
self.beta1 = Some(beta1);
self
}
pub fn with_beta2(mut self, beta2: f64) -> Self {
self.beta2 = Some(beta2);
self
}
pub fn with_penalty(mut self, penalty: f64) -> Self {
self.penalty = Some(penalty);
self
}
pub fn with_verbose(mut self, verbose: bool) -> Self {
self.verbose = Some(verbose);
self
}
pub fn with_patience(mut self, patience: i32) -> Self {
self.patience = Some(patience);
self
}
pub fn build(self) -> Option<TrainingConfig<B>> {
Some(TrainingConfig {
epochs: self.epochs?,
batch_size: self.batch_size?,
learning_rate: self.learning_rate.unwrap_or(0.001), device: self.device?,
beta1: self.beta1.unwrap_or(0.9), beta2: self.beta2.unwrap_or(0.999), penalty: self.penalty.unwrap_or(5e-5), verbose: self.verbose.unwrap_or(false),
patience: self.patience,
})
}
}
pub fn train<B: AutodiffBackend>(
mut model: UMAPModel<B>,
num_samples: usize, num_features: usize, data: Vec<f64>, config: &TrainingConfig<B>, ) -> UMAPModel<B> {
let tensor_data =
convert_vector_to_tensor::<B>(data.clone(), num_samples, num_features, &config.device);
let config_optimizer = AdamConfig::new()
.with_weight_decay(Some(WeightDecayConfig::new(config.penalty)))
.with_beta_1(config.beta1 as f32)
.with_beta_2(config.beta2 as f32);
let mut optim = config_optimizer.init();
let start_time = Instant::now();
let pb = match config.verbose {
true => {
let pb = ProgressBar::new(config.epochs as u64);
pb.set_style(
ProgressStyle::default_bar()
.template("{bar:40} {pos}/{len} Epochs | {msg}")
.unwrap()
.progress_chars("=>-"),
);
Some(pb)
}
false => None,
};
let global_distances = pairwise_distance(tensor_data.clone());
let mut epoch = 0;
let mut losses: Vec<f64> = vec![];
let mut best_loss = f64::INFINITY;
let mut epochs_without_improvement = 0;
loop {
let local = model.forward(tensor_data.clone());
let loss = umap_loss(global_distances.clone(), local);
let grads = loss.backward();
let current_loss = loss.clone().into_scalar().to_f64();
losses.push(current_loss);
let grads = GradientsParams::from_grads(grads, &model);
model = optim.step(config.learning_rate, model, grads);
let elapsed = start_time.elapsed();
if let Some(ref pbb) = pb {
pbb.set_message(format!(
"Elapsed: {} | Loss: {:.3} | Best loss: {:.3}",
format_duration(elapsed),
current_loss,
best_loss,
));
pbb.inc(1);
}
if current_loss <= best_loss {
best_loss = current_loss;
epochs_without_improvement = 0;
} else {
epochs_without_improvement += 1;
}
if let Some(patience) = config.patience {
if epochs_without_improvement >= patience && epoch >= config.epochs {
break; }
} else if epoch >= config.epochs {
break; }
epoch += 1; }
if config.verbose {
plot_loss(losses, "losses.png").unwrap();
}
if let Some(pb) = pb {
pb.finish();
}
model
}