use burn::tensor::{backend::AutodiffBackend, Tensor};
pub fn pairwise_distance<B: AutodiffBackend>(x: Tensor<B, 2>) -> Tensor<B, 1> {
let n_samples = x.dims()[0]; let _n_features = x.dims()[1];
let x_expanded = x.clone().unsqueeze::<3>(); let x_transposed = x.clone().unsqueeze_dim(1);
let diff = x_expanded - x_transposed;
let squared_diff = diff.powi_scalar(2);
let pairwise_squared_distances = squared_diff.sum_dim(2);
let pairwise_distances = pairwise_squared_distances.triu(0);
let distances = pairwise_distances
.slice([0..n_samples, 0..1])
.reshape([n_samples]);
distances
}
pub fn umap_loss<B: AutodiffBackend>(
global_distances: Tensor<B, 1>, local: Tensor<B, 2>, ) -> Tensor<B, 1> {
let local_distances = pairwise_distance(local);
let max_distance = 1e6; let safe_global_distances = global_distances.clamp(0.0, max_distance);
let safe_local_distances = local_distances.clamp(0.0, max_distance);
let difference = (safe_global_distances - safe_local_distances)
.powi_scalar(2)
.sum();
difference
}