use crate::{
backend::AutodiffBackend,
format_duration,
model::UMAPModel,
normalize_data,
utils::convert_vector_to_tensor,
};
#[cfg(feature = "plotters")]
use crate::chart::{self, plot_loss, ChartConfigBuilder};
use burn::{
module::{AutodiffModule, Module},
optim::{decay::WeightDecayConfig, AdamConfig, GradientsParams, Optimizer},
tensor::{cast::ToElement, Device, Int, Tensor, TensorData},
};
use super::get_distance_by_metric::*;
use super::config::*;
use crossbeam_channel::Receiver;
use indicatif::{ProgressBar, ProgressStyle};
use num::{Float, FromPrimitive, ToPrimitive};
use rand::seq::SliceRandom;
use rand::Rng;
use std::path::PathBuf;
use std::time::Duration;
use std::{thread, time::Instant};
const EDGE_BATCH_COUNT: usize = 16;
const LOSS_READBACK_INTERVAL: usize = 5;
#[cfg(feature = "plotters")]
const PLOT_INTERVAL: usize = 25;
const MAX_POS_EDGES_PER_EPOCH: usize = 50_000;
#[derive(Debug, Clone)]
pub struct EpochProgress {
pub epoch: usize,
pub total_epochs: usize,
pub loss: f64,
pub best_loss: f64,
pub elapsed_secs: f64,
pub epoch_ms: f64,
}
pub fn train_sparse<B: AutodiffBackend, F: Float>(
name: &str,
mut model: UMAPModel<B>,
num_samples: usize,
num_features: usize,
mut data: Vec<F>,
config: &TrainingConfig,
device: Device<B>,
exit_rx: Receiver<()>,
labels: Option<Vec<String>>,
on_progress: Option<Box<dyn Fn(EpochProgress) + Send>>,
) -> (UMAPModel<B>, Vec<F>, F)
where
F: FromPrimitive + Send + Sync + burn::tensor::Element,
{
#[cfg(feature = "plotters")]
let figures_dir: PathBuf = config
.figures_dir
.clone()
.unwrap_or_else(std::env::temp_dir);
#[cfg(feature = "plotters")]
let can_plot = match std::fs::create_dir_all(&figures_dir) {
Ok(()) => true,
Err(e) => {
eprintln!(
"[fast-umap] Warning: could not create figures directory '{}': {}. \
Plot output will be disabled for this run. \
Set `figures_dir` in your config to a writable path.",
figures_dir.display(),
e
);
false
}
};
let k = config.k_neighbors;
let repulsion_strength = config.repulsion_strength;
let kernel_a = config.kernel_a;
let kernel_b = config.kernel_b;
let neg_rate = config.neg_sample_rate;
let verbose = config.verbose;
if num_samples <= k {
panic!("num_samples ({num_samples}) must be > k_neighbors ({k}).");
}
if verbose {
println!("[fast-umap] Configuration:");
println!("[fast-umap] samples={num_samples} features={num_features} k_neighbors={k}");
println!("[fast-umap] epochs={} lr={:.0e} repulsion_strength={repulsion_strength}",
config.epochs, config.learning_rate);
println!("[fast-umap] kernel: a={kernel_a:.4} b={kernel_b:.4} (q = 1 / (1 + a·d^(2b)))");
}
normalize_data(&mut data, num_samples, num_features);
if verbose {
println!("[fast-umap] Computing global k-NN graph (k={k}) …");
}
let knn_start = Instant::now();
let all_data_tensor: Tensor<B, 2> =
convert_vector_to_tensor(data.clone(), num_samples, num_features, &device);
let knn_indices: Vec<i32>;
{
let hd_pairwise = pairwise_distances(all_data_tensor.clone());
let flat: Vec<f32> = hd_pairwise.to_data().to_vec::<f32>().unwrap();
let (idx, _dist) = knn_from_pairwise_cpu(&flat, num_samples, k);
knn_indices = idx;
}
let n_all_edges = num_samples * k;
let mut all_pos_edges: Vec<(i64, i64)> = Vec::with_capacity(n_all_edges);
for i in 0..num_samples {
for j in 0..k {
all_pos_edges.push((i as i64, knn_indices[i * k + j] as i64));
}
}
let n_pos = n_all_edges.min(MAX_POS_EDGES_PER_EPOCH);
let subsampling = n_pos < n_all_edges;
let n_neg = (n_pos * neg_rate).max(num_samples);
let n_total = n_pos + n_neg;
let mut rng = rand::rng();
let fused_batches: Vec<(Tensor<B, 1, Int>, Tensor<B, 1, Int>, usize, usize)> =
(0..EDGE_BATCH_COUNT)
.map(|_| {
let pos_sample: Vec<&(i64, i64)> = if subsampling {
let mut indices: Vec<usize> = (0..n_all_edges).collect();
indices.shuffle(&mut rng);
indices.iter().take(n_pos).map(|&i| &all_pos_edges[i]).collect()
} else {
all_pos_edges.iter().collect()
};
let actual_n_pos = pos_sample.len();
let mut all_head: Vec<i64> = Vec::with_capacity(n_total);
let mut all_tail: Vec<i64> = Vec::with_capacity(n_total);
for &(h, t) in &pos_sample {
all_head.push(*h);
all_tail.push(*t);
}
let actual_n_neg = n_neg;
for _ in 0..actual_n_neg {
let i = rng.random_range(0..num_samples);
let mut j = rng.random_range(0..num_samples - 1);
if j >= i {
j += 1;
}
all_head.push(i as i64);
all_tail.push(j as i64);
}
let total = actual_n_pos + actual_n_neg;
let h = Tensor::from_data(
TensorData::new(all_head, [total]),
&device,
);
let t = Tensor::from_data(
TensorData::new(all_tail, [total]),
&device,
);
(h, t, actual_n_pos, total)
})
.collect();
let knn_elapsed = knn_start.elapsed();
if verbose {
println!(
"[fast-umap] k-NN done in {:.2}s — {n_all_edges} total edges{}",
knn_elapsed.as_secs_f64(),
if subsampling {
format!(
", subsampling {n_pos} positive + {n_neg} negative per epoch \
(neg_rate={neg_rate}, {EDGE_BATCH_COUNT} pre-batched shuffles)"
)
} else {
format!(
", {n_pos} positive + {n_neg} negative edges per epoch \
(neg_rate={neg_rate})"
)
}
);
println!("[fast-umap] Training started …");
}
let config_optimizer = AdamConfig::new()
.with_weight_decay(Some(WeightDecayConfig::new(config.penalty)))
.with_beta_1(config.beta1 as f32)
.with_beta_2(config.beta2 as f32);
let mut optim = config_optimizer.init();
let start_time = Instant::now();
let pb = if verbose {
let pb = ProgressBar::new(config.epochs as u64);
pb.set_style(
ProgressStyle::default_bar()
.template("{bar:40} | {msg}")
.unwrap()
.progress_chars("=>-"),
);
Some(pb)
} else {
None
};
let mut epoch = 0usize;
let mut losses: Vec<F> = vec![];
let mut best_loss = F::infinity();
let mut last_read_loss = F::infinity();
let mut epochs_without_improvement = 0i32;
let mut best_record = model.clone().into_record();
'main: loop {
if exit_rx.try_recv().is_ok() {
if verbose {
eprintln!("[fast-umap] Interrupted — restoring best model (epoch {epoch}, loss {best_loss:.6})");
}
break 'main;
}
let embeddings = model.forward(all_data_tensor.clone());
let (fused_h, fused_t, batch_n_pos, batch_n_total) =
&fused_batches[epoch % EDGE_BATCH_COUNT];
let batch_n_pos = *batch_n_pos;
let batch_n_total = *batch_n_total;
let all_head_emb = embeddings.clone().select(0, fused_h.clone());
let all_tail_emb = embeddings.select(0, fused_t.clone());
let diff = all_head_emb - all_tail_emb;
let dist_sq = (diff.clone() * diff).sum_dim(1);
let dist_sq_pos = dist_sq.clone().slice([0..batch_n_pos]);
let dist_sq_neg = dist_sq.slice([batch_n_pos..batch_n_total]);
let dist_pow_pos = dist_sq_pos.clamp_min(1e-8f32).powf_scalar(kernel_b);
let q_pos = (dist_pow_pos.clone() * kernel_a + 1.0f32).recip();
let attraction = q_pos.clamp_min(1e-6f32).log().neg().mean();
let dist_pow_neg = dist_sq_neg.clamp_min(1e-8f32).powf_scalar(kernel_b);
let a_dpow_neg = dist_pow_neg.clone() * kernel_a;
let one_minus_q_neg = a_dpow_neg.clone() / (a_dpow_neg + 1.0f32);
let repulsion = one_minus_q_neg.clamp_min(1e-6f32).log().neg().mean();
let loss = attraction + repulsion_strength * repulsion;
let should_read = epoch % LOSS_READBACK_INTERVAL == 0
|| epoch >= config.epochs
|| epoch == 0;
let current_loss = if should_read {
let v = F::from(loss.clone().into_scalar().to_f64()).unwrap();
last_read_loss = v;
v
} else {
last_read_loss
};
if should_read && (current_loss.is_nan() || current_loss.is_infinite()) {
eprintln!(
"[fast-umap] WARNING: loss became {current_loss:.4} at epoch {epoch} — stopping early."
);
break 'main;
}
let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, &model);
model = optim.step(config.learning_rate, model, grads);
losses.push(current_loss);
let elapsed = start_time.elapsed();
let epoch_ms = if epoch > 0 {
elapsed.as_secs_f64() * 1000.0 / epoch as f64
} else {
0.0
};
if let Some(pb) = &pb {
pb.inc(1);
if should_read {
pb.set_message(format!(
"Elapsed: {} | Epoch: {epoch}/{} | Loss: {current_loss:.6} | Best: {best_loss:.6}",
format_duration(elapsed),
config.epochs,
));
}
}
if should_read {
if let Some(ref cb) = on_progress {
cb(EpochProgress {
epoch,
total_epochs: config.epochs,
loss: ToPrimitive::to_f64(¤t_loss).unwrap_or(0.0),
best_loss: ToPrimitive::to_f64(&best_loss).unwrap_or(0.0),
elapsed_secs: elapsed.as_secs_f64(),
epoch_ms,
});
}
}
if let Some(timeout) = config.timeout {
if elapsed >= Duration::from_secs(timeout) {
if verbose {
println!(
"[fast-umap] Timeout ({timeout}s) reached at epoch {epoch} — stopping."
);
}
break;
}
}
if should_read && current_loss <= best_loss {
best_loss = current_loss;
epochs_without_improvement = 0;
best_record = model.clone().into_record();
} else if should_read {
epochs_without_improvement += LOSS_READBACK_INTERVAL as i32;
}
if let Some(patience) = config.patience {
if epochs_without_improvement >= patience {
if verbose {
println!(
"[fast-umap] Early stopping — no improvement for {patience} epochs (best loss: {best_loss:.6})."
);
}
break;
}
}
if let Some(min_desired_loss) = config.min_desired_loss {
if should_read && current_loss < F::from(min_desired_loss).unwrap() {
if verbose {
println!(
"[fast-umap] Desired loss {min_desired_loss:.6} reached at epoch {epoch} (loss: {current_loss:.6})."
);
}
break;
}
}
if epoch >= config.epochs {
break;
}
#[cfg(feature = "plotters")]
let loss_plot_path = figures_dir
.join(format!("losses_{name}.png"))
.to_string_lossy()
.into_owned();
#[cfg(all(feature = "plotters", feature = "verbose"))]
if can_plot {
const STEP: usize = 100;
if epoch > 0 && epoch % STEP == 0 {
let losses_snap = losses.clone();
let model_snap = model.valid();
let tensor_data = convert_vector_to_tensor(
data.clone(),
num_samples,
num_features,
&device,
);
let embeddings = model_snap.forward(tensor_data);
let lpath = loss_plot_path.clone();
let caption = format!("{name}_{epoch}");
let fig_path = figures_dir
.join(format!("{name}_{epoch}.png"))
.to_string_lossy()
.into_owned();
let snap_labels = labels.clone();
thread::spawn(move || {
let chart_config = ChartConfigBuilder::default()
.caption(&caption)
.path(&fig_path)
.build();
chart::chart_tensor(embeddings, snap_labels, Some(chart_config));
if losses_snap.len() > STEP {
plot_loss(losses_snap[STEP..].to_vec(), &lpath).unwrap();
}
});
}
}
#[cfg(feature = "plotters")]
if can_plot && verbose && epoch % PLOT_INTERVAL == 0 {
plot_loss(losses.clone(), &loss_plot_path).unwrap();
}
if config.cooldown_ms > 0 {
thread::sleep(Duration::from_millis(config.cooldown_ms));
}
epoch += 1;
}
#[cfg(all(feature = "plotters", feature = "verbose"))]
if can_plot {
let path = figures_dir
.join(format!("losses_{name}.png"))
.to_string_lossy()
.into_owned();
plot_loss(losses.clone(), &path).unwrap();
}
#[cfg(feature = "plotters")]
if can_plot && verbose {
let path = figures_dir
.join(format!("losses_{name}.png"))
.to_string_lossy()
.into_owned();
plot_loss(losses.clone(), &path).unwrap();
}
if let Some(pb) = pb {
pb.finish();
}
model = model.load_record(best_record);
let total_elapsed = start_time.elapsed();
if verbose {
println!(
"[fast-umap] Training complete — {epoch} epochs in {}, best loss: {best_loss:.6}",
format_duration(total_elapsed),
);
}
(model, losses, best_loss)
}