use std::{
fmt,
time::{Duration, Instant},
};
use crate::{
chart::plot_loss, format_duration, loss::*, model::UMAPModel, utils::convert_vector_to_tensor,
};
use burn::{
nn::loss::MseLoss,
optim::{decay::WeightDecayConfig, AdamConfig, GradientsParams, Optimizer},
tensor::{backend::AutodiffBackend, cast::ToElement, Device, Tensor},
};
use indicatif::{ProgressBar, ProgressStyle};
#[derive(Debug)]
pub enum LossReduction {
Mean,
Sum,
}
#[derive(Debug, PartialEq)]
pub enum Metric {
Euclidean,
EuclideanKNN,
}
impl From<&str> for Metric {
fn from(s: &str) -> Self {
match s.to_lowercase().as_str() {
"euclidean" => Metric::Euclidean,
"euclideanknn" | "euclidean_knn" => Metric::EuclideanKNN,
_ => panic!("Invalid metric type: {}", s),
}
}
}
impl fmt::Display for Metric {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Metric::Euclidean => write!(f, "Euclidean"),
Metric::EuclideanKNN => write!(f, "Euclidean KNN"),
}
}
}
#[derive(Debug)]
pub struct TrainingConfig<B: AutodiffBackend> {
pub metric: Metric,
pub epochs: usize,
pub batch_size: usize,
pub learning_rate: f64,
pub device: Device<B>,
pub beta1: f64,
pub beta2: f64,
pub penalty: f64,
pub verbose: bool,
pub patience: Option<i32>,
pub loss_reduction: LossReduction,
pub k_neighbors: usize,
pub min_desired_loss: Option<f64>,
pub timeout: Option<u64>,
}
impl<B: AutodiffBackend> TrainingConfig<B> {
pub fn builder() -> TrainingConfigBuilder<B> {
TrainingConfigBuilder::default()
}
}
#[derive(Default)]
pub struct TrainingConfigBuilder<B: AutodiffBackend> {
metric: Option<Metric>,
epochs: Option<usize>,
batch_size: Option<usize>,
learning_rate: Option<f64>,
device: Option<Device<B>>,
beta1: Option<f64>,
beta2: Option<f64>,
penalty: Option<f64>,
verbose: Option<bool>,
patience: Option<i32>,
loss_reduction: Option<LossReduction>,
k_neighbors: Option<usize>,
min_desired_loss: Option<f64>,
timeout: Option<u64>,
}
impl<B: AutodiffBackend> TrainingConfigBuilder<B> {
pub fn with_metric(mut self, metric: Metric) -> Self {
self.metric = Some(metric);
self
}
pub fn with_epochs(mut self, epochs: usize) -> Self {
self.epochs = Some(epochs);
self
}
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = Some(batch_size);
self
}
pub fn with_learning_rate(mut self, learning_rate: f64) -> Self {
self.learning_rate = Some(learning_rate);
self
}
pub fn with_device(mut self, device: Device<B>) -> Self {
self.device = Some(device);
self
}
pub fn with_beta1(mut self, beta1: f64) -> Self {
self.beta1 = Some(beta1);
self
}
pub fn with_beta2(mut self, beta2: f64) -> Self {
self.beta2 = Some(beta2);
self
}
pub fn with_penalty(mut self, penalty: f64) -> Self {
self.penalty = Some(penalty);
self
}
pub fn with_verbose(mut self, verbose: bool) -> Self {
self.verbose = Some(verbose);
self
}
pub fn with_patience(mut self, patience: i32) -> Self {
self.patience = Some(patience);
self
}
pub fn with_loss_reduction(mut self, loss_reduction: LossReduction) -> Self {
self.loss_reduction = Some(loss_reduction);
self
}
pub fn with_k_neighbors(mut self, k_neighbors: usize) -> Self {
self.k_neighbors = Some(k_neighbors);
self
}
pub fn with_min_desired_loss(mut self, min_desired_loss: f64) -> Self {
self.min_desired_loss = Some(min_desired_loss);
self
}
pub fn with_timeout(mut self, timeout: u64) -> Self {
self.timeout = Some(timeout);
self
}
pub fn build(self) -> Option<TrainingConfig<B>> {
Some(TrainingConfig {
metric: self.metric.unwrap_or(Metric::Euclidean), epochs: self.epochs?, batch_size: self.batch_size?, learning_rate: self.learning_rate.unwrap_or(0.001), device: self.device?, beta1: self.beta1.unwrap_or(0.9), beta2: self.beta2.unwrap_or(0.999), penalty: self.penalty.unwrap_or(5e-5), verbose: self.verbose.unwrap_or(false), patience: self.patience, loss_reduction: self.loss_reduction.unwrap_or(LossReduction::Sum), k_neighbors: self.k_neighbors.unwrap_or(15), min_desired_loss: self.min_desired_loss, timeout: self.timeout, })
}
}
fn get_distance<B: AutodiffBackend>(
data: Tensor<B, 2>,
config: &TrainingConfig<B>,
) -> Tensor<B, 1> {
match config.metric {
Metric::Euclidean => euclidean(data),
Metric::EuclideanKNN => euclidean_knn(data, config.k_neighbors),
}
}
pub fn train<B: AutodiffBackend>(
mut model: UMAPModel<B>,
num_samples: usize, num_features: usize, data: Vec<f64>, config: &TrainingConfig<B>, ) -> (UMAPModel<B>, Vec<f64>) {
if config.metric == Metric::EuclideanKNN && config.k_neighbors > num_samples {
panic!("When using Euclidean KNN distance, k_neighbors should be smaller than number of samples!")
}
let tensor_data =
convert_vector_to_tensor::<B>(data.clone(), num_samples, num_features, &config.device);
let config_optimizer = AdamConfig::new()
.with_weight_decay(Some(WeightDecayConfig::new(config.penalty)))
.with_beta_1(config.beta1 as f32)
.with_beta_2(config.beta2 as f32);
let mut optim = config_optimizer.init();
let start_time = Instant::now();
let pb = match config.verbose {
true => {
let pb = ProgressBar::new(config.epochs as u64);
pb.set_style(
ProgressStyle::default_bar()
.template("{bar:40} {pos}/{len} Epochs | {msg}")
.unwrap()
.progress_chars("=>-"),
);
Some(pb)
}
false => None,
};
let global_distances = get_distance(tensor_data.clone(), config);
let mut epoch = 0;
let mut losses: Vec<f64> = vec![];
let mut best_loss = f64::INFINITY;
let mut epochs_without_improvement = 0;
let mse_loss = MseLoss::new();
loop {
let local = model.forward(tensor_data.clone());
let local_distances = get_distance(local, config);
let loss = mse_loss.forward(
global_distances.clone(),
local_distances,
burn::nn::loss::Reduction::Sum,
);
let grads = loss.backward();
let current_loss = loss.clone().into_scalar().to_f64();
losses.push(current_loss);
let grads = GradientsParams::from_grads(grads, &model);
model = optim.step(config.learning_rate, model, grads);
let elapsed = start_time.elapsed();
if let Some(ref pbb) = pb {
pbb.set_message(format!(
"Elapsed: {} | Loss: {:.3} | Best loss: {:.3}",
format_duration(elapsed),
current_loss,
best_loss,
));
pbb.inc(1);
}
if current_loss <= best_loss {
best_loss = current_loss;
epochs_without_improvement = 0;
} else {
epochs_without_improvement += 1;
}
if let Some(patience) = config.patience {
if epochs_without_improvement >= patience && epoch >= config.epochs {
break; }
} else if epoch >= config.epochs {
break; }
if let Some(min_desired_loss) = config.min_desired_loss {
if current_loss < min_desired_loss {
break;
}
}
if let Some(timeout) = config.timeout {
if elapsed >= Duration::new(timeout, 0) {
println!(
"Training stopped due to timeout after {:.2?} seconds.",
elapsed
);
break; }
}
epoch += 1; }
if config.verbose {
plot_loss(losses.clone(), "losses.png").unwrap();
}
if let Some(pb) = pb {
pb.finish();
}
(model, losses)
}