use num_traits::{Float, FromPrimitive};
use rayon::prelude::*;
use std::ops::{AddAssign, SubAssign};
use crate::prelude::*;
use crate::training::*;
const PHASE1_END: usize = 100;
const PHASE2_END: usize = 200;
const PHASE3_END: usize = 450;
const W_NB_PHASE1: f64 = 1.0;
const W_NB_PHASE2: f64 = 3.0;
const W_NB_PHASE3: f64 = 1.0;
const W_MN_PHASE1: f64 = 1000.0;
const W_FP: f64 = 1.0;
#[derive(Default)]
pub enum PacMapOptimiser {
#[default]
AdamParallel,
Adam,
}
pub fn parse_pacmap_optimiser(s: &str) -> Option<PacMapOptimiser> {
match s.to_lowercase().as_str() {
"adam" => Some(PacMapOptimiser::Adam),
"adam_parallel" => Some(PacMapOptimiser::AdamParallel),
_ => None,
}
}
#[derive(Clone, Debug)]
pub struct PacmapOptimParams<T> {
pub n_epochs: usize,
pub lr: T,
pub beta1: T,
pub beta2: T,
pub eps: T,
pub phase1_end: usize,
pub phase2_end: usize,
}
impl<T> PacmapOptimParams<T>
where
T: Float + FromPrimitive,
{
pub fn new(
n_epochs: Option<usize>,
lr: Option<T>,
beta1: Option<T>,
beta2: Option<T>,
eps: Option<T>,
phase1_end: Option<usize>,
phase2_end: Option<usize>,
) -> Self {
Self {
n_epochs: n_epochs.unwrap_or(PHASE3_END),
lr: lr.unwrap_or(T::from_f64(0.01).unwrap()),
beta1: beta1.unwrap_or(T::from_f64(BETA1).unwrap()),
beta2: beta2.unwrap_or(T::from_f64(BETA2).unwrap()),
eps: eps.unwrap_or(T::from_f64(EPS).unwrap()),
phase1_end: phase1_end.unwrap_or(PHASE1_END),
phase2_end: phase2_end.unwrap_or(PHASE2_END),
}
}
}
impl<T> Default for PacmapOptimParams<T>
where
T: Float + FromPrimitive,
{
fn default() -> Self {
Self::new(None, None, None, None, None, None, None)
}
}
#[inline]
fn phase_weights<T>(epoch: usize, phase1_end: usize, phase2_end: usize) -> (T, T, T)
where
T: Float + FromPrimitive,
{
if epoch < phase1_end {
(
T::from_f64(W_NB_PHASE1).unwrap(),
T::from_f64(W_MN_PHASE1).unwrap(),
T::from_f64(W_FP).unwrap(),
)
} else if epoch < phase2_end {
let progress = (epoch - phase1_end) as f64 / (phase2_end - phase1_end) as f64;
let w_mn = T::from_f64(W_MN_PHASE1 * (1.0 - progress)).unwrap();
(
T::from_f64(W_NB_PHASE2).unwrap(),
w_mn,
T::from_f64(W_FP).unwrap(),
)
} else {
(
T::from_f64(W_NB_PHASE3).unwrap(),
T::zero(),
T::from_f64(W_FP).unwrap(),
)
}
}
#[inline(always)]
fn attract_grad_coeff<T>(dist_sq: T, c: T) -> T
where
T: Float + FromPrimitive,
{
let two = T::from_f64(2.0).unwrap();
let d = (dist_sq + T::from_f64(1e-10).unwrap()).sqrt();
let denom = (d + c) * (d + c) * d;
c / denom * two
}
#[inline(always)]
fn repel_grad_coeff<T>(dist_sq: T) -> T
where
T: Float + FromPrimitive,
{
let two = T::from_f64(2.0).unwrap();
let d = (dist_sq + T::from_f64(1e-10).unwrap()).sqrt();
let denom = (T::one() + d) * (T::one() + d) * d;
T::one() / denom * two
}
pub fn optimise_pacmap<T>(
embd: &mut [Vec<T>],
pairs: &PacmapPairs,
params: &PacmapOptimParams<T>,
verbose: usize,
) -> Result<(), ManifoldsError>
where
T: ManifoldsFloat,
{
let n = embd.len();
if n == 0 {
return Err(ManifoldsError::NoData);
}
let verbosity = parse_verbosity_level(verbose);
let n_dim = embd[0].len();
let mut embd_flat: Vec<T> = embd.iter().flatten().copied().collect();
let mut m = vec![T::zero(); n * n_dim];
let mut v = vec![T::zero(); n * n_dim];
let one_minus_b1 = T::one() - params.beta1;
let one_minus_b2 = T::one() - params.beta2;
let mut beta1t = params.beta1;
let mut beta2t = params.beta2;
let c_near = T::from_f64(10.0).unwrap();
let c_mn = T::from_f64(10_000.0).unwrap();
for epoch in 0..params.n_epochs {
let (w_nb, w_mn, w_fp) = phase_weights::<T>(epoch, params.phase1_end, params.phase2_end);
let sqrt_b2t1 = (T::one() - beta2t).sqrt();
let ad_scale = params.lr * sqrt_b2t1 / (T::one() - beta1t);
let epsc = sqrt_b2t1 * params.eps;
let mut grads = vec![T::zero(); n * n_dim];
for &(i, j) in &pairs.near {
let base_i = i * n_dim;
let base_j = j * n_dim;
let mut dist_sq = T::zero();
for d in 0..n_dim {
let diff = embd_flat[base_i + d] - embd_flat[base_j + d];
dist_sq += diff * diff;
}
let coeff = w_nb * attract_grad_coeff(dist_sq, c_near);
for d in 0..n_dim {
let delta = embd_flat[base_i + d] - embd_flat[base_j + d];
grads[base_i + d] -= coeff * delta;
grads[base_j + d] += coeff * delta;
}
}
if w_mn > T::zero() {
for &(i, j) in &pairs.mid_near {
let base_i = i * n_dim;
let base_j = j * n_dim;
let mut dist_sq = T::zero();
for d in 0..n_dim {
let diff = embd_flat[base_i + d] - embd_flat[base_j + d];
dist_sq += diff * diff;
}
let coeff = w_mn * attract_grad_coeff(dist_sq, c_mn);
for d in 0..n_dim {
let delta = embd_flat[base_i + d] - embd_flat[base_j + d];
grads[base_i + d] -= coeff * delta;
grads[base_j + d] += coeff * delta;
}
}
}
for &(i, j) in &pairs.further {
let base_i = i * n_dim;
let base_j = j * n_dim;
let mut dist_sq = T::zero();
for d in 0..n_dim {
let diff = embd_flat[base_i + d] - embd_flat[base_j + d];
dist_sq += diff * diff;
}
let coeff = w_fp * repel_grad_coeff(dist_sq);
for d in 0..n_dim {
let delta = embd_flat[base_i + d] - embd_flat[base_j + d];
grads[base_i + d] += coeff * delta;
grads[base_j + d] -= coeff * delta;
}
}
for idx in 0..(n * n_dim) {
let g = grads[idx];
let m_old = m[idx];
let v_old = v[idx];
m[idx] = m_old + one_minus_b1 * (g - m_old);
v[idx] = v_old + one_minus_b2 * (g * g - v_old);
embd_flat[idx] += ad_scale * m[idx] / (v[idx].sqrt() + epsc);
}
beta1t *= params.beta1;
beta2t *= params.beta2;
if verbosity.normal_verbosity() && ((epoch + 1) % 50 == 0 || epoch + 1 == params.n_epochs) {
println!(" Completed epoch {}/{}", epoch + 1, params.n_epochs);
}
}
for (i, point) in embd.iter_mut().enumerate() {
let base = i * n_dim;
point.copy_from_slice(&embd_flat[base..base + n_dim]);
}
Ok(())
}
pub fn optimise_pacmap_parallel<T>(
embd: &mut [Vec<T>],
pairs: &PacmapPairs,
params: &PacmapOptimParams<T>,
verbose: usize,
) -> Result<(), ManifoldsError>
where
T: Float + FromPrimitive + AddAssign + Send + Sync + SubAssign,
{
let n = embd.len();
if n == 0 {
return Err(ManifoldsError::NoData);
}
let verbosity = parse_verbosity_level(verbose);
let n_dim = embd[0].len();
let mut embd_flat: Vec<T> = embd.iter().flatten().copied().collect();
let mut m = vec![T::zero(); n * n_dim];
let mut v = vec![T::zero(); n * n_dim];
let one_minus_b1 = T::one() - params.beta1;
let one_minus_b2 = T::one() - params.beta2;
let mut beta1t = params.beta1;
let mut beta2t = params.beta2;
let c_near = T::from_f64(10.0).unwrap();
let c_mn = T::from_f64(10_000.0).unwrap();
let mut node_near: Vec<Vec<usize>> = vec![vec![]; n];
let mut node_mn: Vec<Vec<usize>> = vec![vec![]; n];
let mut node_fp: Vec<Vec<usize>> = vec![vec![]; n];
for &(i, j) in &pairs.near {
node_near[i].push(j);
node_near[j].push(i);
}
for &(i, j) in &pairs.mid_near {
node_mn[i].push(j);
node_mn[j].push(i);
}
for &(i, j) in &pairs.further {
node_fp[i].push(j);
node_fp[j].push(i);
}
for epoch in 0..params.n_epochs {
let (w_nb, w_mn, w_fp) = phase_weights::<T>(epoch, params.phase1_end, params.phase2_end);
let sqrt_b2t1 = (T::one() - beta2t).sqrt();
let ad_scale = params.lr * sqrt_b2t1 / (T::one() - beta1t);
let epsc = sqrt_b2t1 * params.eps;
let grads: Vec<T> = (0..n)
.into_par_iter()
.flat_map(|i| {
let base_i = i * n_dim;
let mut node_grad = vec![T::zero(); n_dim];
for &j in &node_near[i] {
let base_j = j * n_dim;
let mut dist_sq = T::zero();
for d in 0..n_dim {
let diff = embd_flat[base_i + d] - embd_flat[base_j + d];
dist_sq += diff * diff;
}
let coeff = w_nb * attract_grad_coeff(dist_sq, c_near);
for d in 0..n_dim {
node_grad[d] -= coeff * (embd_flat[base_i + d] - embd_flat[base_j + d]);
}
}
if w_mn > T::zero() {
for &j in &node_mn[i] {
let base_j = j * n_dim;
let mut dist_sq = T::zero();
for d in 0..n_dim {
let diff = embd_flat[base_i + d] - embd_flat[base_j + d];
dist_sq += diff * diff;
}
let coeff = w_mn * attract_grad_coeff(dist_sq, c_mn);
for d in 0..n_dim {
node_grad[d] -= coeff * (embd_flat[base_i + d] - embd_flat[base_j + d]);
}
}
}
for &j in &node_fp[i] {
let base_j = j * n_dim;
let mut dist_sq = T::zero();
for d in 0..n_dim {
let diff = embd_flat[base_i + d] - embd_flat[base_j + d];
dist_sq += diff * diff;
}
let coeff = w_fp * repel_grad_coeff(dist_sq);
for d in 0..n_dim {
node_grad[d] += coeff * (embd_flat[base_i + d] - embd_flat[base_j + d]);
}
}
node_grad
})
.collect();
for idx in 0..(n * n_dim) {
let g = grads[idx];
let m_old = m[idx];
let v_old = v[idx];
m[idx] = m_old + one_minus_b1 * (g - m_old);
v[idx] = v_old + one_minus_b2 * (g * g - v_old);
embd_flat[idx] += ad_scale * m[idx] / (v[idx].sqrt() + epsc);
}
beta1t = beta1t * params.beta1;
beta2t = beta2t * params.beta2;
if verbosity.normal_verbosity() && ((epoch + 1) % 50 == 0 || epoch + 1 == params.n_epochs) {
println!(" Completed epoch {}/{}", epoch + 1, params.n_epochs);
}
}
for (i, point) in embd.iter_mut().enumerate() {
let base = i * n_dim;
point.copy_from_slice(&embd_flat[base..base + n_dim]);
}
Ok(())
}
#[cfg(test)]
mod test_pacmap_optimiser {
use super::*;
use crate::data::pacmap_pairs::PacmapPairs;
fn simple_pairs(n: usize) -> PacmapPairs {
let near = (0..n).map(|i| (i, (i + 1) % n)).collect();
let mid_near = (0..n).map(|i| (i, (i + 2) % n)).collect();
let further = (0..n).map(|i| (i, (i + n / 2) % n)).collect();
PacmapPairs {
near,
mid_near,
further,
}
}
fn simple_embd(n: usize) -> Vec<Vec<f64>> {
(0..n).map(|i| vec![i as f64 * 2.0, 0.0]).collect()
}
fn default_params() -> PacmapOptimParams<f64> {
PacmapOptimParams::new(Some(50), None, None, None, None, None, None)
}
fn total_movement(before: &[Vec<f64>], after: &[Vec<f64>]) -> f64 {
before
.iter()
.zip(after.iter())
.flat_map(|(b, a)| b.iter().zip(a.iter()).map(|(&x, &y)| (x - y).abs()))
.sum()
}
#[test]
fn test_sequential_moves_points() {
let pairs = simple_pairs(6);
let mut embd = simple_embd(6);
let initial = embd.clone();
let _ = optimise_pacmap(&mut embd, &pairs, &default_params(), 0);
assert!(total_movement(&initial, &embd) > 0.01);
}
#[test]
fn test_parallel_moves_points() {
let pairs = simple_pairs(6);
let mut embd = simple_embd(6);
let initial = embd.clone();
let _ = optimise_pacmap_parallel(&mut embd, &pairs, &default_params(), 0);
assert!(total_movement(&initial, &embd) > 0.01);
}
#[test]
fn test_sequential_all_finite() {
let pairs = simple_pairs(8);
let mut embd = simple_embd(8);
let _ = optimise_pacmap(&mut embd, &pairs, &default_params(), 0);
for point in &embd {
for &coord in point {
assert!(coord.is_finite(), "non-finite coordinate: {}", coord);
}
}
}
#[test]
fn test_parallel_all_finite() {
let pairs = simple_pairs(8);
let mut embd = simple_embd(8);
let _ = optimise_pacmap_parallel(&mut embd, &pairs, &default_params(), 0);
for point in &embd {
for &coord in point {
assert!(coord.is_finite(), "non-finite coordinate: {}", coord);
}
}
}
#[test]
fn test_sequential_reproducible() {
let pairs = simple_pairs(6);
let params = default_params();
let mut embd1 = simple_embd(6);
let mut embd2 = simple_embd(6);
let _ = optimise_pacmap(&mut embd1, &pairs, ¶ms, 0);
let _ = optimise_pacmap(&mut embd2, &pairs, ¶ms, 0);
assert_eq!(embd1, embd2);
}
#[test]
fn test_parallel_reproducible() {
let pairs = simple_pairs(6);
let params = default_params();
let mut embd1 = simple_embd(6);
let mut embd2 = simple_embd(6);
let _ = optimise_pacmap_parallel(&mut embd1, &pairs, ¶ms, 0);
let _ = optimise_pacmap_parallel(&mut embd2, &pairs, ¶ms, 0);
assert_eq!(embd1, embd2);
}
#[test]
fn test_empty_embedding_does_not_panic() {
let pairs = PacmapPairs {
near: vec![],
mid_near: vec![],
further: vec![],
};
let mut embd: Vec<Vec<f64>> = vec![];
let _ = optimise_pacmap(&mut embd, &pairs, &default_params(), 0);
let _ = optimise_pacmap_parallel(&mut embd, &pairs, &default_params(), 0);
}
#[test]
fn test_phase_weights_respected() {
let n = 10;
let pairs_full = simple_pairs(n);
let pairs_no_mn = PacmapPairs {
near: pairs_full.near.clone(),
mid_near: vec![],
further: pairs_full.further.clone(),
};
let params = PacmapOptimParams::new(Some(50), None, None, None, None, None, None);
let mut embd_full = simple_embd(n);
let mut embd_no_mn = simple_embd(n);
let initial = simple_embd(n);
let _ = optimise_pacmap(&mut embd_full, &pairs_full, ¶ms, 0);
let _ = optimise_pacmap(&mut embd_no_mn, &pairs_no_mn, ¶ms, 0);
let movement_full = total_movement(&initial, &embd_full);
let movement_no_mn = total_movement(&initial, &embd_no_mn);
assert!(
(movement_full - movement_no_mn).abs() > 1e-6,
"mid-near pairs had no effect: full={:.4}, no_mn={:.4}",
movement_full,
movement_no_mn
);
}
#[test]
fn test_sequential_and_parallel_broadly_agree() {
let pairs = simple_pairs(10);
let params = PacmapOptimParams::new(Some(200), None, None, None, None, None, None);
let mut embd_seq = simple_embd(10);
let mut embd_par = simple_embd(10);
let _ = optimise_pacmap(&mut embd_seq, &pairs, ¶ms, 0);
let _ = optimise_pacmap_parallel(&mut embd_par, &pairs, ¶ms, 0);
let diff = total_movement(&embd_seq, &embd_par);
let scale = total_movement(&simple_embd(10), &embd_seq);
assert!(
diff < scale * 0.1,
"sequential and parallel diverged too much: diff={:.4}, scale={:.4}",
diff,
scale
);
}
}