use rand::{
rngs::SmallRng,
{Rng, SeedableRng},
};
use rayon::prelude::*;
use crate::prelude::*;
use crate::training::*;
#[derive(Clone, Debug)]
pub struct UmapOptimParams<T> {
pub a: T,
pub b: T,
pub lr: T,
pub gamma: T,
pub n_epochs: usize,
pub neg_sample_rate: usize,
pub min_dist: T,
pub beta1: T,
pub beta2: T,
pub eps: T,
}
impl<T> UmapOptimParams<T>
where
T: ManifoldsFloat,
{
pub fn default_2d() -> Self {
Self {
a: T::from_f64(1.5).unwrap(),
b: T::from_f64(0.9).unwrap(),
lr: T::one(),
gamma: T::one(),
n_epochs: 500,
neg_sample_rate: 5,
min_dist: T::from_f64(0.1).unwrap(),
beta1: T::from(UMAP_BETA1).unwrap(),
beta2: T::from(UMAP_BETA2).unwrap(),
eps: T::from(EPS).unwrap(),
}
}
#[allow(clippy::too_many_arguments)]
pub fn from_min_dist_spread(
min_dist: T,
spread: T,
lr: Option<T>,
gamma: Option<T>,
n_epochs: Option<usize>,
neg_sample_rate: Option<usize>,
beta1: Option<T>,
beta2: Option<T>,
eps: Option<T>,
) -> Self {
let beta1 = beta1.unwrap_or(T::from(UMAP_BETA1).unwrap());
let beta2 = beta2.unwrap_or(T::from(UMAP_BETA2).unwrap());
let eps = eps.unwrap_or(T::from(EPS).unwrap());
let n_epochs = n_epochs.unwrap_or(500);
let neg_sample_rate = neg_sample_rate.unwrap_or(5);
let lr = lr.unwrap_or(T::one());
let gamma = gamma.unwrap_or(T::one());
let (a, b) = Self::fit_params(min_dist, spread, None);
Self {
a,
b,
lr,
gamma,
n_epochs,
neg_sample_rate,
min_dist,
beta1,
beta2,
eps,
}
}
fn fit_params(min_dist: T, spread: T, n_iter: Option<usize>) -> (T, T) {
let n_iter = n_iter.unwrap_or(300);
let n_points = 300;
let three = T::from_f64(3.0).unwrap();
let max_x = spread * three;
let step = max_x / T::from_usize(n_points - 1).unwrap();
let mut xv = Vec::with_capacity(n_points);
let mut yv = Vec::with_capacity(n_points);
for i in 0..n_points {
let x = step * T::from_usize(i).unwrap();
let y = if x < min_dist {
T::one()
} else {
(-(x - min_dist) / spread).exp()
};
xv.push(x);
yv.push(y);
}
let mut a = T::one();
let mut b = T::one();
let two = T::from_f64(2.0).unwrap();
for _ in 0..n_iter {
let mut grad_a = T::zero();
let mut grad_b = T::zero();
let n_points_t = T::from_usize(n_points).unwrap();
for i in 0..n_points {
let x = xv[i];
if x <= T::zero() {
continue;
}
let y_target = yv[i];
let x_2b = x.powf(two * b);
let denom = T::one() + a * x_2b;
let pred = T::one() / denom;
let err = pred - y_target;
grad_a += err * (-x_2b / (denom * denom));
let log_x = x.ln();
grad_b += err * (-two * a * x_2b * log_x / (denom * denom));
}
grad_a /= n_points_t;
grad_b /= n_points_t;
let lr_a = T::from_f64(1.0).unwrap();
let lr_b = T::from_f64(1.0).unwrap();
a -= lr_a * grad_a;
b -= lr_b * grad_b;
a = a
.max(T::from_f64(0.001).unwrap())
.min(T::from_f64(10.0).unwrap());
b = b
.max(T::from_f64(0.1).unwrap())
.min(T::from_f64(2.0).unwrap());
}
(a, b)
}
}
impl<T> Default for UmapOptimParams<T>
where
T: ManifoldsFloat,
{
fn default() -> Self {
UmapOptimParams::default_2d()
}
}
#[derive(Default)]
pub enum UmapOptimiser {
#[default]
AdamParallel,
Adam,
Sgd,
}
pub fn parse_umap_optimiser(s: &str) -> Option<UmapOptimiser> {
match s.to_lowercase().as_str() {
"adam" => Some(UmapOptimiser::Adam),
"sgd" => Some(UmapOptimiser::Sgd),
"adam_parallel" => Some(UmapOptimiser::AdamParallel),
_ => None,
}
}
struct OptimConstants<T> {
a: T,
b: T,
two_a_b: T,
two_gamma_b: T,
clip_val: T,
eps: T,
}
impl<T> OptimConstants<T>
where
T: ManifoldsFloat,
{
fn new(a: T, b: T, gamma: T) -> Self {
let two = T::from_f64(2.0).unwrap();
Self {
a,
b,
two_a_b: two * a * b,
two_gamma_b: two * gamma * b,
clip_val: T::from_f64(4.0).unwrap(),
eps: T::from_f64(0.001).unwrap(),
}
}
}
struct FastPowLut<T> {
b: T,
max_val: T,
inv_step: T,
table: Vec<T>,
}
impl<T> FastPowLut<T>
where
T: ManifoldsFloat,
{
fn new(b: T, max_val: f64, size: usize) -> Self {
let mut table = Vec::with_capacity(size);
let max_t = T::from(max_val).unwrap();
let step = max_t / T::from(size - 1).unwrap();
for i in 0..size {
let x = step * T::from(i).unwrap();
table.push(x.powf(b));
}
Self {
b,
max_val: max_t,
inv_step: T::one() / step,
table,
}
}
#[inline(always)]
fn get(&self, x: T) -> T {
if x >= self.max_val {
return x.powf(self.b);
}
let idx_f = x * self.inv_step;
let idx = idx_f.to_usize().unwrap_or(0);
if idx >= self.table.len() - 1 {
return self.table.last().copied().unwrap();
}
let rem = idx_f - T::from(idx).unwrap();
let y0 = self.table[idx];
let y1 = self.table[idx + 1];
y0 + rem * (y1 - y0)
}
}
#[inline(always)]
fn fast_pow<T: ManifoldsFloat>(x: T, b: T, b_is_one: bool, b_is_half: bool) -> T {
if b_is_one {
x
} else if b_is_half {
x.sqrt()
} else {
x.powf(b)
}
}
pub fn optimise_embedding_sgd<T>(
embd: &mut [Vec<T>],
graph: &[Vec<(usize, T)>],
params: &UmapOptimParams<T>,
seed: u64,
verbose: usize,
) -> Result<(), ManifoldsError>
where
T: ManifoldsFloat,
{
let n = embd.len();
if n == 0 {
return Err(ManifoldsError::NoData);
}
let n_dim = embd[0].len();
let verbosity = parse_verbosity_level(verbose);
let mut embd_flat: Vec<T> = Vec::with_capacity(n * n_dim);
for point in embd.iter() {
embd_flat.extend_from_slice(point);
}
let consts = OptimConstants::new(params.a, params.b, params.gamma);
let zero = T::zero();
let one = T::one();
let half = T::from(0.5).unwrap();
let dist_sq_threshold = T::from(1e-8).unwrap();
let large_epoch = T::from(1e8).unwrap();
let rep_eps = T::from(0.001).unwrap();
let b_is_one = (consts.b - one).abs() < T::from(1e-10).unwrap();
let b_is_half = (consts.b - half).abs() < T::from(1e-10).unwrap();
let mut edges: Vec<(usize, usize, T)> = Vec::new();
for (i, neighbours) in graph.iter().enumerate() {
for &(j, w) in neighbours {
edges.push((i, j, w));
}
}
if edges.is_empty() {
return Err(ManifoldsError::NoGraphEdges);
}
let max_weight = edges
.iter()
.map(|(_, _, w)| *w)
.fold(zero, |acc, w| if w > acc { w } else { acc });
let epochs_per_sample: Vec<T> = edges
.iter()
.map(|(_, _, w)| {
let norm = *w / max_weight;
if norm > zero {
one / norm
} else {
large_epoch
}
})
.collect();
let mut epoch_of_next_sample: Vec<T> = epochs_per_sample.clone();
let neg_sample_rate_t = T::from(params.neg_sample_rate).unwrap();
let epochs_per_neg_sample: Vec<T> = epochs_per_sample
.iter()
.map(|eps| *eps / neg_sample_rate_t)
.collect();
let mut epoch_of_next_neg_sample: Vec<T> = epochs_per_neg_sample.clone();
let n_epochs_f = T::from(params.n_epochs).unwrap();
let lr_schedule: Vec<T> = (0..params.n_epochs)
.map(|e| params.lr * (one - T::from(e).unwrap() / n_epochs_f))
.collect();
let mut rng_states: Vec<SmallRng> = (0..n)
.map(|i| SmallRng::seed_from_u64(seed + i as u64))
.collect();
for epoch in 0..params.n_epochs {
let lr = lr_schedule[epoch];
let epoch_t = T::from(epoch).unwrap();
for (edge_idx, &(i, j, _weight)) in edges.iter().enumerate() {
if epoch_of_next_sample[edge_idx] > epoch_t {
continue;
}
let base_i = i * n_dim;
let base_j = j * n_dim;
let mut dist_sq = zero;
for d in 0..n_dim {
let diff = embd_flat[base_i + d] - embd_flat[base_j + d];
dist_sq += diff * diff;
}
if dist_sq >= dist_sq_threshold {
let dist_sq_b = fast_pow(dist_sq, consts.b, b_is_one, b_is_half);
let denom = one + consts.a * dist_sq_b;
let grad_coeff = consts.two_a_b * dist_sq_b / (dist_sq * denom);
for d in 0..n_dim {
let delta = embd_flat[base_j + d] - embd_flat[base_i + d];
let grad_d = (grad_coeff * delta)
.max(-consts.clip_val)
.min(consts.clip_val);
embd_flat[base_i + d] += grad_d * lr;
embd_flat[base_j + d] -= grad_d * lr;
}
}
epoch_of_next_sample[edge_idx] += epochs_per_sample[edge_idx];
let n_neg_samples = ((epoch_t - epoch_of_next_neg_sample[edge_idx])
/ epochs_per_neg_sample[edge_idx])
.floor()
.to_usize()
.unwrap_or(0);
for _ in 0..n_neg_samples {
let k = rng_states[i].random_range(0..n);
if k == i {
continue;
}
let base_k = k * n_dim;
let mut dist_sq = zero;
for d in 0..n_dim {
let diff = embd_flat[base_i + d] - embd_flat[base_k + d];
dist_sq += diff * diff;
}
let dist_sq_safe = dist_sq + rep_eps;
let dist_sq_b = fast_pow(dist_sq_safe, consts.b, b_is_one, b_is_half);
let denom = dist_sq_safe * (one + consts.a * dist_sq_b);
let grad_coeff = (consts.two_gamma_b / denom)
.max(-consts.clip_val)
.min(consts.clip_val);
for d in 0..n_dim {
let delta = embd_flat[base_i + d] - embd_flat[base_k + d];
let grad_d = grad_coeff * delta;
embd_flat[base_i + d] += grad_d * lr;
}
}
epoch_of_next_neg_sample[edge_idx] +=
T::from(n_neg_samples).unwrap() * epochs_per_neg_sample[edge_idx];
}
if verbosity.normal_verbosity() && ((epoch + 1) % 50 == 0 || epoch + 1 == params.n_epochs) {
println!(" Completed epoch {}/{}", epoch + 1, params.n_epochs);
}
}
for (i, point) in embd.iter_mut().enumerate() {
let base = i * n_dim;
point.copy_from_slice(&embd_flat[base..base + n_dim]);
}
Ok(())
}
pub fn optimise_embedding_adam<T>(
embd: &mut [Vec<T>],
graph: &[Vec<(usize, T)>],
params: &UmapOptimParams<T>,
seed: u64,
verbose: usize,
) -> Result<(), ManifoldsError>
where
T: ManifoldsFloat,
{
let n = embd.len();
if n == 0 {
return Err(ManifoldsError::NoData);
}
let n_dim = embd[0].len();
let verbosity = parse_verbosity_level(verbose);
let mut embd_flat: Vec<T> = Vec::with_capacity(n * n_dim);
for point in embd.iter() {
embd_flat.extend_from_slice(point);
}
let consts = OptimConstants::new(params.a, params.b, params.gamma);
let zero = T::zero();
let one = T::one();
let half = T::from(0.5).unwrap();
let dist_sq_threshold = T::from(1e-8).unwrap();
let large_epoch = T::from(1e8).unwrap();
let b_is_one = (consts.b - one).abs() < T::from(1e-10).unwrap();
let b_is_half = (consts.b - half).abs() < T::from(1e-10).unwrap();
let mut edges: Vec<(usize, usize, T)> = Vec::new();
for (i, neighbours) in graph.iter().enumerate() {
for &(j, w) in neighbours {
edges.push((i, j, w));
}
}
if edges.is_empty() {
return Err(ManifoldsError::NoGraphEdges);
}
let max_weight = edges
.iter()
.map(|(_, _, w)| *w)
.fold(zero, |acc, w| if w > acc { w } else { acc });
let epochs_per_sample: Vec<T> = edges
.iter()
.map(|(_, _, w)| {
let norm = *w / max_weight;
if norm > zero {
one / norm
} else {
large_epoch
}
})
.collect();
let mut epoch_of_next_sample: Vec<T> = epochs_per_sample.clone();
let neg_sample_rate_t = T::from(params.neg_sample_rate).unwrap();
let epochs_per_neg_sample: Vec<T> = epochs_per_sample
.iter()
.map(|eps| *eps / neg_sample_rate_t)
.collect();
let mut epoch_of_next_neg_sample: Vec<T> = epochs_per_neg_sample.clone();
let n_epochs_f = T::from(params.n_epochs).unwrap();
let mut m: Vec<T> = vec![zero; n * n_dim];
let mut v: Vec<T> = vec![zero; n * n_dim];
let mut rng_states: Vec<SmallRng> = (0..n)
.map(|i| SmallRng::seed_from_u64(seed + i as u64))
.collect();
let beta11 = one - params.beta1; let beta21 = one - params.beta2; let mut beta1t = params.beta1;
let mut beta2t = params.beta2;
for epoch in 0..params.n_epochs {
let alpha = params.lr * (one - T::from(epoch).unwrap() / n_epochs_f);
let sqrt_b2t1 = (one - beta2t).sqrt();
let ad_scale = alpha * sqrt_b2t1 / (one - beta1t);
let epsc = sqrt_b2t1 * params.eps;
let epoch_t = T::from(epoch).unwrap();
for (edge_idx, &(i, j, _weight)) in edges.iter().enumerate() {
if epoch_of_next_sample[edge_idx] > epoch_t {
continue;
}
let base_i = i * n_dim;
let base_j = j * n_dim;
let mut dist_sq = zero;
for d in 0..n_dim {
let diff = embd_flat[base_i + d] - embd_flat[base_j + d];
dist_sq += diff * diff;
}
if dist_sq >= dist_sq_threshold {
let dist_sq_b = fast_pow(dist_sq, consts.b, b_is_one, b_is_half);
let denom = one + consts.a * dist_sq_b;
let grad_coeff = consts.two_a_b * dist_sq_b / (dist_sq * denom);
for d in 0..n_dim {
let delta = embd_flat[base_j + d] - embd_flat[base_i + d];
let grad = grad_coeff * delta;
let idx_i = base_i + d;
let v_old = v[idx_i];
let m_old = m[idx_i];
v[idx_i] = v_old + beta21 * (grad * grad - v_old);
m[idx_i] = m_old + beta11 * (grad - m_old);
embd_flat[idx_i] += ad_scale * m[idx_i] / (v[idx_i].sqrt() + epsc);
let idx_j = base_j + d;
let v_old = v[idx_j];
let m_old = m[idx_j];
v[idx_j] = v_old + beta21 * (grad * grad - v_old);
m[idx_j] = m_old + beta11 * (-grad - m_old);
embd_flat[idx_j] += ad_scale * m[idx_j] / (v[idx_j].sqrt() + epsc);
}
}
epoch_of_next_sample[edge_idx] += epochs_per_sample[edge_idx];
let n_neg_samples = ((epoch_t - epoch_of_next_neg_sample[edge_idx])
/ epochs_per_neg_sample[edge_idx])
.floor()
.to_usize()
.unwrap_or(0);
for _ in 0..n_neg_samples {
let k = rng_states[i].random_range(0..n);
if k == i {
continue;
}
let base_k = k * n_dim;
let mut dist_sq = zero;
for d in 0..n_dim {
let diff = embd_flat[base_i + d] - embd_flat[base_k + d];
dist_sq += diff * diff;
}
let dist_sq_safe = dist_sq + consts.eps;
let dist_sq_b = fast_pow(dist_sq_safe, consts.b, b_is_one, b_is_half);
let denom = dist_sq_safe * (one + consts.a * dist_sq_b);
let grad_coeff = (consts.two_gamma_b / denom)
.max(-consts.clip_val)
.min(consts.clip_val);
for d in 0..n_dim {
let delta = embd_flat[base_i + d] - embd_flat[base_k + d];
let grad = grad_coeff * delta;
let idx = base_i + d;
let v_old = v[idx];
let m_old = m[idx];
v[idx] = v_old + beta21 * (grad * grad - v_old);
m[idx] = m_old + beta11 * (grad - m_old);
embd_flat[idx] += ad_scale * m[idx] / (v[idx].sqrt() + epsc);
}
}
epoch_of_next_neg_sample[edge_idx] +=
T::from(n_neg_samples).unwrap() * epochs_per_neg_sample[edge_idx];
}
beta1t *= params.beta1;
beta2t *= params.beta2;
if verbosity.normal_verbosity() && ((epoch + 1) % 50 == 0 || epoch + 1 == params.n_epochs) {
println!(" Completed epoch {}/{}", epoch + 1, params.n_epochs);
}
}
for (i, point) in embd.iter_mut().enumerate() {
let base = i * n_dim;
point.copy_from_slice(&embd_flat[base..base + n_dim]);
}
Ok(())
}
pub fn optimise_embedding_adam_parallel<T>(
embd: &mut [Vec<T>],
graph: &[Vec<(usize, T)>],
params: &UmapOptimParams<T>,
seed: u64,
verbose: usize,
) -> Result<(), ManifoldsError>
where
T: ManifoldsFloat,
{
let n = embd.len();
if n == 0 {
return Err(ManifoldsError::NoData);
}
let n_dim = embd[0].len();
let verbosity = parse_verbosity_level(verbose);
let mut embd_flat: Vec<T> = Vec::with_capacity(n * n_dim);
for point in embd.iter() {
embd_flat.extend_from_slice(point);
}
let consts = OptimConstants::new(params.a, params.b, params.gamma);
let b_is_one = (consts.b - T::one()).abs() < T::from(1e-10).unwrap();
let lut = FastPowLut::new(consts.b, 25.0, 65_536);
let mut edges: Vec<(usize, usize, T)> = Vec::new();
let mut degree = vec![0; n];
for (i, neighbours) in graph.iter().enumerate() {
for &(j, w) in neighbours {
if i < j {
edges.push((i, j, w));
degree[i] += 1;
degree[j] += 1;
}
}
}
if edges.is_empty() {
return Err(ManifoldsError::NoGraphEdges);
}
let max_weight =
edges
.iter()
.map(|(_, _, w)| *w)
.fold(T::zero(), |acc, w| if w > acc { w } else { acc });
let epochs_per_sample: Vec<T> = edges
.iter()
.map(|(_, _, w)| {
let norm = *w / max_weight;
if norm > T::zero() {
T::one() / norm
} else {
T::from(1e8).unwrap()
}
})
.collect();
let mut epoch_of_next_sample: Vec<T> = epochs_per_sample.clone();
let epochs_per_neg_sample: Vec<T> = epochs_per_sample
.iter()
.map(|eps| *eps / T::from(params.neg_sample_rate).unwrap())
.collect();
let mut epoch_of_next_neg_sample: Vec<T> = epochs_per_neg_sample.clone();
let n_epochs_f = T::from(params.n_epochs).unwrap();
let lr_schedule: Vec<T> = (0..params.n_epochs)
.map(|e| params.lr * (T::one() - T::from(e).unwrap() / n_epochs_f))
.collect();
let mut m: Vec<T> = vec![T::zero(); n * n_dim];
let mut v: Vec<T> = vec![T::zero(); n * n_dim];
let mut node_edge_offsets = vec![0; n + 1];
for i in 0..n {
node_edge_offsets[i + 1] = node_edge_offsets[i] + degree[i];
}
let mut csr_edges = vec![(0usize, false, 0usize); edges.len() * 2];
let mut current_offset = node_edge_offsets.clone();
for (edge_idx, &(i, j, _)) in edges.iter().enumerate() {
csr_edges[current_offset[i]] = (edge_idx, true, j);
current_offset[i] += 1;
csr_edges[current_offset[j]] = (edge_idx, false, i);
current_offset[j] += 1;
}
let bias_corrections: Vec<(T, T)> = (0..params.n_epochs)
.map(|epoch| {
let t = T::from(epoch + 1).unwrap();
let beta1t = params.beta1.powf(t);
let beta2t = params.beta2.powf(t);
let sqrt_b2t1 = (T::one() - beta2t).sqrt();
let ad_scale = sqrt_b2t1 / (T::one() - beta1t);
let epsc = sqrt_b2t1 * params.eps;
(ad_scale, epsc)
})
.collect();
let one_minus_beta1 = T::one() - params.beta1;
let one_minus_beta2 = T::one() - params.beta2;
let mut node_gradients_all: Vec<T> = vec![T::zero(); n * n_dim];
let mut node_has_update: Vec<bool> = vec![false; n];
let mut node_rngs: Vec<SmallRng> = (0..n)
.map(|i| SmallRng::seed_from_u64(seed + i as u64))
.collect();
for epoch in 0..params.n_epochs {
let lr = lr_schedule[epoch];
let epoch_t = T::from(epoch).unwrap();
let (ad_scale, epsc) = bias_corrections[epoch];
node_has_update.fill(false);
node_gradients_all
.par_chunks_exact_mut(n_dim)
.zip(node_has_update.par_iter_mut())
.zip(node_rngs.par_iter_mut())
.enumerate()
.for_each(|(node_i, ((node_grad, has_update), rng))| {
for g in node_grad.iter_mut() {
*g = T::zero();
}
let mut local_has_updates = false;
let base_i = node_i * n_dim;
let start_idx = node_edge_offsets[node_i];
let end_idx = node_edge_offsets[node_i + 1];
let node_edges = &csr_edges[start_idx..end_idx];
for &(edge_idx, is_smaller, other_node) in node_edges {
if epoch_of_next_sample[edge_idx] > epoch_t {
continue;
}
local_has_updates = true;
let base_other = other_node * n_dim;
let mut dist_sq = T::zero();
for d in 0..n_dim {
let diff = embd_flat[base_i + d] - embd_flat[base_other + d];
dist_sq += diff * diff;
}
if dist_sq >= T::from(1e-8).unwrap() {
let dist_sq_b = if b_is_one { dist_sq } else { lut.get(dist_sq) };
let denom = T::one() + consts.a * dist_sq_b;
let grad_coeff = consts.two_a_b * dist_sq_b / (dist_sq * denom);
for d in 0..n_dim {
let delta = embd_flat[base_other + d] - embd_flat[base_i + d];
node_grad[d] += T::from(2.0).unwrap() * grad_coeff * delta;
}
}
if is_smaller {
let n_neg_samples = ((epoch_t - epoch_of_next_neg_sample[edge_idx])
/ epochs_per_neg_sample[edge_idx])
.floor()
.to_usize()
.unwrap_or(0);
for _ in 0..n_neg_samples {
let k = rng.random_range(0..n);
if k == node_i {
continue;
}
let base_k = k * n_dim;
let mut dist_sq = T::zero();
for d in 0..n_dim {
let diff = embd_flat[base_i + d] - embd_flat[base_k + d];
dist_sq += diff * diff;
}
let dist_sq_safe = dist_sq + consts.eps;
let dist_sq_b = if b_is_one {
dist_sq_safe
} else {
lut.get(dist_sq_safe)
};
let denom = dist_sq_safe * (T::one() + consts.a * dist_sq_b);
let grad_coeff = (consts.two_gamma_b / denom)
.max(-consts.clip_val)
.min(consts.clip_val);
for d in 0..n_dim {
let delta = embd_flat[base_i + d] - embd_flat[base_k + d];
node_grad[d] += grad_coeff * delta;
}
}
}
}
if local_has_updates {
*has_update = true;
}
});
node_gradients_all
.par_chunks_exact_mut(n_dim)
.zip(m.par_chunks_exact_mut(n_dim))
.zip(v.par_chunks_exact_mut(n_dim))
.zip(embd_flat.par_chunks_exact_mut(n_dim))
.zip(node_has_update.par_iter())
.for_each(|((((grad, m_node), v_node), embd_node), &has_update)| {
if !has_update {
return;
}
for d in 0..n_dim {
let g = grad[d];
let m_old = m_node[d];
m_node[d] += one_minus_beta1 * (g - m_old);
let v_old = v_node[d];
v_node[d] += one_minus_beta2 * (g * g - v_old);
embd_node[d] += lr * ad_scale * m_node[d] / (v_node[d].sqrt() + epsc);
}
});
epoch_of_next_sample
.par_iter_mut()
.zip(epoch_of_next_neg_sample.par_iter_mut())
.zip(epochs_per_sample.par_iter())
.zip(epochs_per_neg_sample.par_iter())
.for_each(|(((next_sample, next_neg), &per_sample), &per_neg)| {
if *next_sample <= epoch_t {
*next_sample += per_sample;
let n_neg_samples = ((epoch_t - *next_neg) / per_neg)
.floor()
.to_usize()
.unwrap_or(0);
*next_neg += T::from(n_neg_samples).unwrap() * per_neg;
}
});
if verbosity.normal_verbosity() && ((epoch + 1) % 50 == 0 || epoch + 1 == params.n_epochs) {
println!(" Completed epoch {}/{}", epoch + 1, params.n_epochs);
}
}
embd.par_iter_mut().enumerate().for_each(|(i, point)| {
let base = i * n_dim;
point.copy_from_slice(&embd_flat[base..base + n_dim]);
});
Ok(())
}
#[cfg(test)]
mod test_umap_optimiser {
use super::*;
use approx::assert_relative_eq;
use num_traits::Float;
#[inline(always)]
fn squared_dist_flat<T>(embd: &[T], i: usize, j: usize, n_dim: usize) -> T
where
T: Float,
{
let mut sum = T::zero();
let base_i = i * n_dim;
let base_j = j * n_dim;
for d in 0..n_dim {
let diff = embd[base_i + d] - embd[base_j + d];
sum = sum + diff * diff;
}
sum
}
#[test]
fn test_optim_params_default_2d() {
let params = UmapOptimParams::<f64>::default_2d();
assert_relative_eq!(params.a, 1.5, epsilon = 1e-6);
assert_relative_eq!(params.b, 0.9, epsilon = 1e-6);
assert_eq!(params.lr, 1.0);
assert_eq!(params.gamma, 1.0);
assert_eq!(params.n_epochs, 500);
assert_eq!(params.neg_sample_rate, 5);
assert_relative_eq!(params.min_dist, 0.1, epsilon = 1e-6);
}
#[test]
fn test_optim_params_from_min_dist_spread() {
let params = UmapOptimParams::<f64>::from_min_dist_spread(
0.1,
1.0,
Some(1.0),
Some(1.0),
Some(500),
Some(5),
None,
None,
None,
);
assert!(params.a > 0.0);
assert!(params.b > 0.0);
assert_eq!(params.lr, 1.0);
assert_eq!(params.gamma, 1.0);
assert_eq!(params.n_epochs, 500);
assert_eq!(params.neg_sample_rate, 5);
assert_relative_eq!(params.min_dist, 0.1, epsilon = 1e-6);
}
#[test]
fn test_fit_params_constraints() {
let (a, b) = UmapOptimParams::<f64>::fit_params(0.1, 1.0, None);
assert!((0.001..=10.0).contains(&a));
assert!((0.1..=2.0).contains(&b));
}
#[test]
fn test_fit_params_curve_properties() {
let min_dist = 0.1;
let spread = 1.0;
let (a, b) = UmapOptimParams::<f64>::fit_params(min_dist, spread, None);
let pred_min = 1.0 / (1.0 + a * min_dist.powf(2.0 * b));
assert!(
pred_min > 0.9,
"f(min_dist) = {:.3} should be > 0.9",
pred_min
);
let pred_spread = 1.0 / (1.0 + a * (3.0 * spread).powf(2.0 * b));
assert!(
pred_spread < 0.1,
"f(3*spread) = {:.3} should be < 0.1",
pred_spread
);
let mid_point = 1.5 * spread;
let pred_mid = 1.0 / (1.0 + a * mid_point.powf(2.0 * b));
assert!(pred_min > pred_mid && pred_mid > pred_spread);
}
#[test]
fn test_squared_dist_basic() {
let embd = vec![0.0, 0.0, 3.0, 4.0];
let dist = squared_dist_flat(&embd, 0, 1, 2);
assert_relative_eq!(dist, 25.0, epsilon = 1e-6);
}
#[test]
fn test_squared_dist_identical_points() {
let embd = vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0];
let dist = squared_dist_flat(&embd, 0, 1, 3);
assert_relative_eq!(dist, 0.0, epsilon = 1e-6);
}
#[test]
fn test_optimise_embedding_adam_basic() {
let graph = vec![
vec![(1, 1.0), (2, 0.5)],
vec![(0, 1.0), (2, 1.0)],
vec![(0, 0.5), (1, 1.0)],
];
let mut embd = vec![vec![0.0, 0.0], vec![5.0, 0.0], vec![0.0, 5.0]];
let initial_embd = embd.clone();
let params = UmapOptimParams::default_2d();
let _ = optimise_embedding_adam(&mut embd, &graph, ¶ms, 42, 0);
let total_movement: f64 = embd
.iter()
.zip(initial_embd.iter())
.map(|(new, old)| {
new.iter()
.zip(old.iter())
.map(|(&n, &o)| (n - o).abs())
.sum::<f64>()
})
.sum();
assert!(total_movement > 0.01);
for point in &embd {
for &coord in point {
assert!(coord.is_finite());
}
}
}
#[test]
fn test_optimise_embedding_adam_parallel_basic() {
let graph = vec![
vec![(1, 1.0), (2, 0.5)],
vec![(0, 1.0), (2, 1.0)],
vec![(0, 0.5), (1, 1.0)],
];
let mut embd = vec![vec![0.0, 0.0], vec![5.0, 0.0], vec![0.0, 5.0]];
let initial_embd = embd.clone();
let params = UmapOptimParams::default_2d();
let _ = optimise_embedding_adam_parallel(&mut embd, &graph, ¶ms, 42, 0);
let total_movement: f64 = embd
.iter()
.zip(initial_embd.iter())
.map(|(new, old)| {
new.iter()
.zip(old.iter())
.map(|(&n, &o)| (n - o).abs())
.sum::<f64>()
})
.sum();
assert!(total_movement > 0.01);
for point in &embd {
for &coord in point {
assert!(coord.is_finite());
}
}
}
#[test]
fn test_optimise_embedding_empty_graph() {
let graph: Vec<Vec<(usize, f64)>> = vec![vec![], vec![], vec![]];
let mut embd = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.0, 1.0]];
let params = UmapOptimParams::default_2d();
let _ = optimise_embedding_adam(&mut embd, &graph, ¶ms, 42, 0);
for point in &embd {
for &coord in point {
assert!(coord.is_finite());
}
}
}
#[test]
fn test_optimise_embedding_adam_reproducibility() {
let graph = vec![vec![(1, 1.0)], vec![(0, 1.0)]];
let mut embd1 = vec![vec![0.0, 0.0], vec![1.0, 0.0]];
let mut embd2 = vec![vec![0.0, 0.0], vec![1.0, 0.0]];
let params = UmapOptimParams {
a: 1.0,
b: 1.0,
lr: 0.5,
gamma: 1.0,
n_epochs: 10,
neg_sample_rate: 2,
min_dist: 0.1,
beta1: 0.5,
beta2: 0.9,
eps: 1e-7,
};
let _ = optimise_embedding_adam(&mut embd1, &graph, ¶ms, 42, 0);
let _ = optimise_embedding_adam(&mut embd2, &graph, ¶ms, 42, 0);
assert_eq!(embd1, embd2);
}
#[test]
fn test_optimise_embedding_adam_parallel_reproducibility() {
let graph = vec![vec![(1, 1.0)], vec![(0, 1.0)]];
let mut embd1 = vec![vec![0.0, 0.0], vec![1.0, 0.0]];
let mut embd2 = vec![vec![0.0, 0.0], vec![1.0, 0.0]];
let params = UmapOptimParams {
a: 1.0,
b: 1.0,
lr: 0.5,
gamma: 1.0,
n_epochs: 10,
neg_sample_rate: 2,
min_dist: 0.1,
beta1: 0.5,
beta2: 0.9,
eps: 1e-7,
};
let _ = optimise_embedding_adam_parallel(&mut embd1, &graph, ¶ms, 42, 0);
let _ = optimise_embedding_adam_parallel(&mut embd2, &graph, ¶ms, 42, 0);
assert_eq!(embd1, embd2);
}
#[test]
fn test_optimise_embedding_convergence() {
let graph = vec![vec![(1, 1.0)], vec![(0, 1.0)]];
let mut embd = vec![vec![0.0, 0.0], vec![10.0, 0.0]];
let embd_flat: Vec<f64> = embd.iter().flatten().copied().collect();
let initial_dist = squared_dist_flat(&embd_flat, 0, 1, 2).sqrt();
let params = UmapOptimParams {
a: 1.0,
b: 1.0,
lr: 1.0,
gamma: 1.0,
n_epochs: 100,
neg_sample_rate: 2,
min_dist: 0.1,
beta1: 0.5,
beta2: 0.9,
eps: 1e-7,
};
let _ = optimise_embedding_adam(&mut embd, &graph, ¶ms, 42, 0);
let embd_flat: Vec<f64> = embd.iter().flatten().copied().collect();
let final_dist = squared_dist_flat(&embd_flat, 0, 1, 2).sqrt();
assert!(final_dist < initial_dist);
}
#[test]
fn test_sgd_vs_adam_both_converge() {
let graph = vec![
vec![(1, 1.0), (2, 0.5)],
vec![(0, 1.0), (2, 1.0)],
vec![(0, 0.5), (1, 1.0)],
];
let initial_embd = vec![vec![0.0, 0.0], vec![10.0, 0.0], vec![0.0, 10.0]];
let params = UmapOptimParams {
a: 1.0,
b: 1.0,
lr: 1.0,
gamma: 1.0,
n_epochs: 50,
neg_sample_rate: 2,
min_dist: 0.1,
beta1: 0.5,
beta2: 0.9,
eps: 1e-7,
};
let mut embd_sgd = initial_embd.clone();
let _ = optimise_embedding_sgd(&mut embd_sgd, &graph, ¶ms, 42, 0);
let mut embd_adam = initial_embd.clone();
let _ = optimise_embedding_adam(&mut embd_adam, &graph, ¶ms, 42, 0);
let movement_sgd: f64 = embd_sgd
.iter()
.zip(initial_embd.iter())
.map(|(new, old)| {
new.iter()
.zip(old.iter())
.map(|(&n, &o)| (n - o).abs())
.sum::<f64>()
})
.sum();
let movement_adam: f64 = embd_adam
.iter()
.zip(initial_embd.iter())
.map(|(new, old)| {
new.iter()
.zip(old.iter())
.map(|(&n, &o)| (n - o).abs())
.sum::<f64>()
})
.sum();
assert!(movement_sgd > 1.0);
assert!(movement_adam > 1.0);
for point in embd_sgd.iter().chain(embd_adam.iter()) {
for &coord in point {
assert!(coord.is_finite());
}
}
}
#[test]
fn test_sgd_adam_adam_parallel_all_converge() {
let graph = vec![
vec![(1, 1.0), (2, 0.5)],
vec![(0, 1.0), (2, 1.0)],
vec![(0, 0.5), (1, 1.0)],
];
let initial_embd = vec![vec![0.0, 0.0], vec![10.0, 0.0], vec![0.0, 10.0]];
let params = UmapOptimParams {
a: 1.0,
b: 1.0,
lr: 1.0,
gamma: 1.0,
n_epochs: 50,
neg_sample_rate: 2,
min_dist: 0.1,
beta1: 0.5,
beta2: 0.9,
eps: 1e-7,
};
let mut embd_sgd = initial_embd.clone();
let _ = optimise_embedding_sgd(&mut embd_sgd, &graph, ¶ms, 42, 0);
let mut embd_adam = initial_embd.clone();
let _ = optimise_embedding_adam(&mut embd_adam, &graph, ¶ms, 42, 0);
let mut embd_adam_par = initial_embd.clone();
let _ = optimise_embedding_adam_parallel(&mut embd_adam_par, &graph, ¶ms, 42, 0);
let movement_sgd: f64 = embd_sgd
.iter()
.zip(initial_embd.iter())
.flat_map(|(new, old)| new.iter().zip(old.iter()).map(|(&n, &o)| (n - o).abs()))
.sum();
let movement_adam: f64 = embd_adam
.iter()
.zip(initial_embd.iter())
.flat_map(|(new, old)| new.iter().zip(old.iter()).map(|(&n, &o)| (n - o).abs()))
.sum();
let movement_adam_par: f64 = embd_adam_par
.iter()
.zip(initial_embd.iter())
.flat_map(|(new, old)| new.iter().zip(old.iter()).map(|(&n, &o)| (n - o).abs()))
.sum();
assert!(movement_sgd > 1.0);
assert!(movement_adam > 1.0);
assert!(movement_adam_par > 1.0);
for point in embd_sgd
.iter()
.chain(embd_adam.iter())
.chain(embd_adam_par.iter())
{
for &coord in point {
assert!(coord.is_finite());
}
}
}
#[test]
fn test_sgd_reproducibility() {
let graph = vec![vec![(1, 1.0)], vec![(0, 1.0)]];
let mut embd1 = vec![vec![0.0, 0.0], vec![1.0, 0.0]];
let mut embd2 = vec![vec![0.0, 0.0], vec![1.0, 0.0]];
let params = UmapOptimParams {
a: 1.0,
b: 1.0,
lr: 0.5,
gamma: 1.0,
n_epochs: 10,
neg_sample_rate: 2,
min_dist: 0.1,
beta1: 0.5,
beta2: 0.9,
eps: 1e-7,
};
let _ = optimise_embedding_sgd(&mut embd1, &graph, ¶ms, 42, 0);
let _ = optimise_embedding_sgd(&mut embd2, &graph, ¶ms, 42, 0);
assert_eq!(embd1, embd2);
}
#[test]
fn test_optimisation_preserves_graph_structure_adam() {
let graph = vec![
vec![(1, 1.0), (2, 1.0)],
vec![(0, 1.0), (2, 1.0)],
vec![(0, 1.0), (1, 1.0), (3, 0.1)],
vec![(2, 0.1), (4, 1.0), (5, 1.0)],
vec![(3, 1.0), (5, 1.0)],
vec![(3, 1.0), (4, 1.0)],
];
let mut embd = vec![
vec![0.0, 0.0],
vec![10.0, 0.0],
vec![0.0, 10.0],
vec![10.0, 10.0],
vec![-5.0, -5.0],
vec![15.0, 15.0],
];
let params = UmapOptimParams {
n_epochs: 200,
..UmapOptimParams::default_2d()
};
let _ = optimise_embedding_adam(&mut embd, &graph, ¶ms, 42, 0);
let dist = |a: &[f64], b: &[f64]| -> f64 {
((a[0] - b[0]).powi(2) + (a[1] - b[1]).powi(2)).sqrt()
};
let intra_clique1 =
(dist(&embd[0], &embd[1]) + dist(&embd[0], &embd[2]) + dist(&embd[1], &embd[2])) / 3.0;
let intra_clique2 =
(dist(&embd[3], &embd[4]) + dist(&embd[3], &embd[5]) + dist(&embd[4], &embd[5])) / 3.0;
let avg_intra = (intra_clique1 + intra_clique2) / 2.0;
let inter_distances = [
dist(&embd[0], &embd[3]),
dist(&embd[0], &embd[4]),
dist(&embd[0], &embd[5]),
dist(&embd[1], &embd[3]),
dist(&embd[1], &embd[4]),
dist(&embd[1], &embd[5]),
];
let avg_inter: f64 = inter_distances.iter().sum::<f64>() / inter_distances.len() as f64;
assert!(
avg_inter > avg_intra * 1.5,
"Inter-clique dist ({:.2}) should be > 1.5x intra-clique dist ({:.2})",
avg_inter,
avg_intra
);
}
#[test]
fn test_optimisation_preserves_graph_structure_adam_parallel() {
let graph = vec![
vec![(1, 1.0), (2, 1.0)],
vec![(0, 1.0), (2, 1.0)],
vec![(0, 1.0), (1, 1.0), (3, 0.1)],
vec![(2, 0.1), (4, 1.0), (5, 1.0)],
vec![(3, 1.0), (5, 1.0)],
vec![(3, 1.0), (4, 1.0)],
];
let mut embd = vec![
vec![0.0, 0.0],
vec![10.0, 0.0],
vec![0.0, 10.0],
vec![10.0, 10.0],
vec![-5.0, -5.0],
vec![15.0, 15.0],
];
let params = UmapOptimParams {
n_epochs: 200,
..UmapOptimParams::default_2d()
};
let _ = optimise_embedding_adam_parallel(&mut embd, &graph, ¶ms, 42, 0);
let dist = |a: &[f64], b: &[f64]| -> f64 {
((a[0] - b[0]).powi(2) + (a[1] - b[1]).powi(2)).sqrt()
};
let intra_clique1 =
(dist(&embd[0], &embd[1]) + dist(&embd[0], &embd[2]) + dist(&embd[1], &embd[2])) / 3.0;
let intra_clique2 =
(dist(&embd[3], &embd[4]) + dist(&embd[3], &embd[5]) + dist(&embd[4], &embd[5])) / 3.0;
let avg_intra = (intra_clique1 + intra_clique2) / 2.0;
let inter_distances = [
dist(&embd[0], &embd[3]),
dist(&embd[0], &embd[4]),
dist(&embd[0], &embd[5]),
dist(&embd[1], &embd[3]),
dist(&embd[1], &embd[4]),
dist(&embd[1], &embd[5]),
];
let avg_inter: f64 = inter_distances.iter().sum::<f64>() / inter_distances.len() as f64;
assert!(
avg_inter > avg_intra * 1.5,
"Inter-clique dist ({:.2}) should be > 1.5x intra-clique dist ({:.2})",
avg_inter,
avg_intra
);
}
#[test]
fn test_optimisation_preserves_graph_structure_sgd() {
let graph = vec![
vec![(1, 1.0), (2, 1.0)],
vec![(0, 1.0), (2, 1.0)],
vec![(0, 1.0), (1, 1.0), (3, 0.1)],
vec![(2, 0.1), (4, 1.0), (5, 1.0)],
vec![(3, 1.0), (5, 1.0)],
vec![(3, 1.0), (4, 1.0)],
];
let mut embd = vec![
vec![0.0, 0.0],
vec![10.0, 0.0],
vec![0.0, 10.0],
vec![10.0, 10.0],
vec![-5.0, -5.0],
vec![15.0, 15.0],
];
let params = UmapOptimParams {
n_epochs: 200,
..UmapOptimParams::default_2d()
};
let _ = optimise_embedding_sgd(&mut embd, &graph, ¶ms, 42, 0);
let dist = |a: &[f64], b: &[f64]| -> f64 {
((a[0] - b[0]).powi(2) + (a[1] - b[1]).powi(2)).sqrt()
};
let intra_clique1 =
(dist(&embd[0], &embd[1]) + dist(&embd[0], &embd[2]) + dist(&embd[1], &embd[2])) / 3.0;
let intra_clique2 =
(dist(&embd[3], &embd[4]) + dist(&embd[3], &embd[5]) + dist(&embd[4], &embd[5])) / 3.0;
let avg_intra = (intra_clique1 + intra_clique2) / 2.0;
let inter_distances = [
dist(&embd[0], &embd[3]),
dist(&embd[0], &embd[4]),
dist(&embd[0], &embd[5]),
dist(&embd[1], &embd[3]),
dist(&embd[1], &embd[4]),
dist(&embd[1], &embd[5]),
];
let avg_inter: f64 = inter_distances.iter().sum::<f64>() / inter_distances.len() as f64;
assert!(
avg_inter > avg_intra * 1.5,
"Inter-clique dist ({:.2}) should be > 1.5x intra-clique dist ({:.2})",
avg_inter,
avg_intra
);
}
}