mod config;
mod get_distance_by_metric;
mod train_sparse;
use crate::{
backend::AutodiffBackend,
chart::{self, plot_loss, ChartConfigBuilder},
format_duration,
model::UMAPModel,
normalize_data,
utils::convert_vector_to_tensor,
};
use burn::{
module::{AutodiffModule, Module},
optim::{decay::WeightDecayConfig, AdamConfig, GradientsParams, Optimizer},
record::{BinFileRecorder, FullPrecisionSettings},
tensor::{cast::ToElement, Device, Int, IndexingUpdateOp, Tensor, TensorData},
};
pub use config::*;
pub use train_sparse::{train_sparse, EpochProgress};
use crossbeam_channel::Receiver;
use get_distance_by_metric::*;
use indicatif::{ProgressBar, ProgressStyle};
use num::{Float, FromPrimitive};
use std::path::PathBuf;
use std::time::Duration;
use std::{thread, time::Instant};
pub fn train<B: AutodiffBackend, F: Float>(
name: &str,
mut model: UMAPModel<B>,
num_samples: usize,
num_features: usize,
mut data: Vec<F>,
config: &TrainingConfig,
device: Device<B>,
exit_rx: Receiver<()>,
labels: Option<Vec<String>>,
) -> (UMAPModel<B>, Vec<F>, F)
where
F: FromPrimitive + Send + Sync + burn::tensor::Element,
{
let figures_dir: PathBuf = config
.figures_dir
.clone()
.unwrap_or_else(std::env::temp_dir);
let can_plot = match std::fs::create_dir_all(&figures_dir) {
Ok(()) => true,
Err(e) => {
eprintln!(
"[fast-umap] Warning: could not create figures directory '{}': {}. \
Plot output will be disabled for this run. \
Set `figures_dir` in your config to a writable path.",
figures_dir.display(),
e
);
false
}
};
let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
let model_path = format!("/tmp/{name}.bin");
let batch_size = config.batch_size;
let k = config.k_neighbors;
let repulsion_strength = config.repulsion_strength;
let kernel_a = config.kernel_a;
let kernel_b = config.kernel_b;
let verbose = config.verbose;
if verbose {
println!("[fast-umap] Configuration:");
println!("[fast-umap] samples={num_samples} features={num_features} k_neighbors={k}");
println!("[fast-umap] epochs={} lr={:.0e} repulsion_strength={repulsion_strength}",
config.epochs, config.learning_rate);
println!("[fast-umap] kernel: a={kernel_a:.4} b={kernel_b:.4} (q = 1 / (1 + a·d^(2b)))");
}
if batch_size <= 1 {
panic!("batch_size must be > 1");
}
if k >= batch_size {
panic!(
"k_neighbors ({k}) must be < batch_size ({batch_size}). \
Increase batch_size or decrease k_neighbors."
);
}
normalize_data(&mut data, num_samples, num_features);
if verbose {
println!(
"[fast-umap] Computing batch-local k-NN graph (k={k}) …",
);
}
let mut batches_start: Vec<usize> = Vec::new();
let mut tensor_batches: Vec<Tensor<B, 2>> = Vec::new();
let mut batch_artefacts: Vec<(Tensor<B, 2>, usize, Tensor<B, 2>)> = Vec::new();
for batch_start in (0..num_samples).step_by(batch_size) {
let batch_end = (batch_start + batch_size).min(num_samples);
let bs = batch_end - batch_start;
let mut batch_vec: Vec<F> = Vec::with_capacity(bs * num_features);
for i in batch_start..batch_end {
let s = i * num_features;
batch_vec.extend_from_slice(&data[s..s + num_features]);
}
let tensor_batch: Tensor<B, 2> =
convert_vector_to_tensor(batch_vec, bs, num_features, &device);
let hd_pairwise = pairwise_distances(tensor_batch.clone());
let flat: Vec<f32> = hd_pairwise.to_data().to_vec::<f32>().unwrap();
let (idx_flat, _) = knn_from_pairwise_cpu(&flat, bs, k);
let mut knn_ind_flat = vec![0.0f32; bs * bs];
for local_i in 0..bs {
for j in 0..k {
let local_j = idx_flat[local_i * k + j] as usize;
knn_ind_flat[local_i * bs + local_j] = 1.0;
}
}
let knn_indicator: Tensor<B, 2> = Tensor::from_data(
TensorData::new(knn_ind_flat, [bs, bs]),
&device,
);
let within_knn_count = bs * k;
let diag_idx: Tensor<B, 1, Int> =
Tensor::arange(0i64..(bs as i64), &device);
let eye: Tensor<B, 2> =
Tensor::<B, 2>::zeros([bs, bs], &device).scatter(
1,
diag_idx.reshape([bs, 1]),
Tensor::ones([bs, 1], &device),
IndexingUpdateOp::Add,
);
let non_neighbor_mask =
(Tensor::ones([bs, bs], &device) - knn_indicator.clone() - eye)
.clamp_min(0.0f32);
batches_start.push(batch_start);
tensor_batches.push(tensor_batch);
batch_artefacts.push((knn_indicator, within_knn_count, non_neighbor_mask));
}
let num_batches = batches_start.len();
if verbose {
println!(
"[fast-umap] Precomputation done ({num_batches} batch(es)). Training started …"
);
}
let tensor_batches_all = Tensor::<B, 2>::cat(tensor_batches, 0);
let config_optimizer = AdamConfig::new()
.with_weight_decay(Some(WeightDecayConfig::new(config.penalty)))
.with_beta_1(config.beta1 as f32)
.with_beta_2(config.beta2 as f32);
let mut optim = config_optimizer.init();
let start_time = Instant::now();
let pb = if config.verbose {
let pb = ProgressBar::new(config.epochs as u64);
pb.set_style(
ProgressStyle::default_bar()
.template("{bar:40} | {msg}")
.unwrap()
.progress_chars("=>-"),
);
Some(pb)
} else {
None
};
let mut epoch = 0usize;
let mut losses: Vec<F> = vec![];
let mut best_loss = F::infinity();
let mut epochs_without_improvement = 0i32;
'main: loop {
let mut epoch_loss_sum = F::zero();
let mut num_batches_seen = 0usize;
for batch_idx in 0..num_batches {
if exit_rx.try_recv().is_ok() {
break 'main;
}
let batch_start = batches_start[batch_idx];
let batch_end = (batch_start + batch_size).min(num_samples);
let bs = batch_end - batch_start;
let tensor_batch = tensor_batches_all
.clone()
.slice([batch_start..batch_end, 0..num_features]);
let (knn_indicator, within_knn_count, non_neighbor_mask) =
batch_artefacts[batch_idx].clone();
let local = model.forward(tensor_batch);
let local_pairwise = pairwise_distances(local); let local_sq = local_pairwise.clone() * local_pairwise;
let dist_pow = local_sq.clone().clamp_min(1e-8f32).powf_scalar(kernel_b);
let a_dpow = dist_pow.clone() * kernel_a;
let q = (a_dpow.clone() + 1.0f32).recip();
let log_q = q.clone().clamp_min(1e-6f32).log();
let attraction =
(log_q.neg() * knn_indicator).sum() / within_knn_count as f32;
let one_minus_q = a_dpow.clone() / (a_dpow + 1.0f32);
let repulsion_per_pair = one_minus_q.clamp_min(1e-6f32).log().neg();
let num_non_neighbors =
(bs * (bs - 1)).saturating_sub(within_knn_count) as f32;
let repulsion = (repulsion_per_pair * non_neighbor_mask).sum()
/ num_non_neighbors.max(1.0);
let loss = attraction + repulsion_strength * repulsion;
let current_loss = F::from(loss.clone().into_scalar().to_f64()).unwrap();
if current_loss.is_nan() || current_loss.is_infinite() {
eprintln!(
"[fast-umap] WARNING: loss is {current_loss:.4} at epoch {epoch}, \
batch {batch_idx} — stopping early."
);
break 'main;
}
epoch_loss_sum = epoch_loss_sum + current_loss;
num_batches_seen += 1;
let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, &model);
model = optim.step(config.learning_rate, model, grads);
}
let epoch_loss = if num_batches_seen > 0 {
epoch_loss_sum / F::from_usize(num_batches_seen).unwrap()
} else {
F::infinity()
};
losses.push(epoch_loss);
let elapsed = start_time.elapsed();
if let Some(pb) = &pb {
pb.inc(1);
pb.set_message(format!(
"Elapsed: {} | Epoch: {epoch}/{} | Loss: {epoch_loss:.6} | Best: {best_loss:.6}",
format_duration(elapsed),
config.epochs,
));
}
if let Some(timeout) = config.timeout {
if elapsed >= Duration::from_secs(timeout) {
break;
}
}
if epoch_loss <= best_loss {
best_loss = epoch_loss;
epochs_without_improvement = 0;
model
.clone()
.save_file(model_path.clone(), &recorder)
.expect("Could not save model checkpoint");
} else {
epochs_without_improvement += 1;
}
if let Some(patience) = config.patience {
if epochs_without_improvement >= patience {
break;
}
}
if let Some(min_desired_loss) = config.min_desired_loss {
if epoch_loss < F::from(min_desired_loss).unwrap() {
break;
}
}
if epoch >= config.epochs {
break;
}
#[allow(unused_variables)]
let loss_plot_path = figures_dir
.join(format!("losses_{name}.png"))
.to_string_lossy()
.into_owned();
#[cfg(feature = "verbose")]
if can_plot {
const STEP: usize = 100;
if epoch > 0 && epoch % STEP == 0 {
let losses_snap = losses.clone();
let model_snap = model.valid();
let tensor_data = convert_vector_to_tensor(
data.clone(),
num_samples,
num_features,
&device,
);
let embeddings = model_snap.forward(tensor_data);
let lpath = loss_plot_path.clone();
let caption = format!("{name}_{epoch}");
let fig_path = figures_dir
.join(format!("{name}_{epoch}.png"))
.to_string_lossy()
.into_owned();
let snap_labels = labels.clone();
thread::spawn(move || {
let chart_config = ChartConfigBuilder::default()
.caption(&caption)
.path(&fig_path)
.build();
chart::chart_tensor(embeddings, snap_labels, Some(chart_config));
if losses_snap.len() > STEP {
plot_loss(losses_snap[STEP..].to_vec(), &lpath).unwrap();
}
});
}
}
if can_plot && config.verbose {
plot_loss(losses.clone(), &loss_plot_path).unwrap();
}
epoch += 1;
}
#[cfg(feature = "verbose")]
if can_plot {
let path = figures_dir
.join(format!("losses_{name}.png"))
.to_string_lossy()
.into_owned();
plot_loss(losses.clone(), &path).unwrap();
}
if let Some(pb) = pb {
pb.finish();
}
model = model
.load_file(model_path, &recorder, &device)
.expect("Could not load best model checkpoint");
let total_elapsed = start_time.elapsed();
if verbose {
println!(
"[fast-umap] Training complete — {epoch} epochs in {}, best loss: {best_loss:.6}",
format_duration(total_elapsed),
);
}
(model, losses, best_loss)
}