use num_traits::{Float, FromPrimitive};
use rayon::prelude::*;
use thousands::*;
use crate::data::structures::*;
use crate::prelude::*;
use crate::utils::bh_tree::*;
#[cfg(feature = "fft_tsne")]
use crate::utils::fft::*;
const TSNE_MOMENTUM_SWITCH_ITER: usize = 250;
const TSNE_INITIAL_MOMENTUM: f64 = 0.5;
const TSNE_FINAL_MOMENTUM: f64 = 0.8;
const TSNE_MIN_GAIN: f64 = 0.01;
const TSNE_EPS: f64 = 1e-12;
const TSNE_MAX_STEP_FRACTION: f64 = 0.025;
const TSNE_MAX_STEP_FLOOR: f64 = 5.0;
const TSNE_LR_DIVISOR: f64 = 12.0;
const TSNE_LR_FLOOR: f64 = 200.0;
#[cfg(feature = "fft_tsne")]
const TSNE_FFT_MAX_BOXES: usize = 140;
#[cfg(feature = "fft_tsne")]
const TSNE_FFT_MIN_BOX_WIDTH: f64 = 1.0;
#[cfg(feature = "fft_tsne")]
const TSNE_FFT_GRID_MARGIN: f64 = 0.3;
#[derive(Clone, Debug)]
pub struct TsneOptimParams<T> {
pub n_epochs: usize,
pub lr: Option<T>,
pub early_exag_iter: usize,
pub early_exag_factor: T,
pub late_exag_factor: Option<T>,
pub theta: T,
pub n_interp_points: usize,
}
impl<T> TsneOptimParams<T>
where
T: Float + FromPrimitive,
{
pub fn new(
n_epochs: usize,
lr: Option<T>,
early_exag_iter: usize,
early_exag_factor: T,
late_exag_factor: Option<T>,
theta: T,
n_interp_points: Option<usize>,
) -> Self {
let n_interp_points = n_interp_points.unwrap_or(3);
Self {
n_epochs,
lr,
early_exag_iter,
early_exag_factor,
late_exag_factor,
theta,
n_interp_points,
}
}
pub fn get_lr(&self, n_samples: usize) -> T {
self.lr.unwrap_or_else(|| {
T::from_f64((n_samples as f64 / TSNE_LR_DIVISOR).max(TSNE_LR_FLOOR)).unwrap()
})
}
pub fn get_late_exag_factor(&self) -> T {
self.late_exag_factor.unwrap_or(T::one())
}
}
impl<T> Default for TsneOptimParams<T>
where
T: Float + FromPrimitive,
{
fn default() -> Self {
Self {
n_epochs: 1000,
lr: None,
early_exag_iter: 250,
early_exag_factor: T::from_f64(12.0).unwrap(),
late_exag_factor: None,
theta: T::from_f64(0.5).unwrap(),
n_interp_points: 3,
}
}
}
#[derive(Default)]
pub enum TsneOpt {
#[default]
Fft,
BarnesHut,
}
pub fn parse_tsne_optimiser(s: &str) -> Option<TsneOpt> {
match s.to_lowercase().as_str() {
"barnes hut" | "bh" => Some(TsneOpt::BarnesHut),
"fft" => Some(TsneOpt::Fft),
_ => None,
}
}
#[inline(always)]
fn update_parameter<T>(
val: &mut T,
update: &mut T,
gain: &mut T,
grad: T,
lr: T,
momentum: T,
min_gain: T,
) where
T: ManifoldsFloat,
{
if (grad > T::zero()) != (*update > T::zero()) {
*gain += T::from_f64(0.2).unwrap();
} else {
*gain *= T::from_f64(0.8).unwrap();
}
*gain = (*gain).max(min_gain);
*update = momentum * *update - lr * *gain * grad;
*val += *update;
}
#[inline(always)]
fn clip_step<T>(point: &mut [T], u0: &mut T, u1: &mut T, prev_x: T, prev_y: T, max_step_norm: T)
where
T: ManifoldsFloat,
{
let step_sq = *u0 * *u0 + *u1 * *u1;
let max_sq = max_step_norm * max_step_norm;
if step_sq > max_sq {
let scale = max_step_norm / step_sq.sqrt();
*u0 *= scale;
*u1 *= scale;
point[0] = prev_x + *u0;
point[1] = prev_y + *u1;
}
}
#[inline]
fn step_cap_from_lr<T: ManifoldsFloat>(lr: T) -> T {
let lr_f64 = lr.to_f64().unwrap();
T::from_f64((lr_f64 * TSNE_MAX_STEP_FRACTION).max(TSNE_MAX_STEP_FLOOR)).unwrap()
}
fn recentre_embedding<T: ManifoldsFloat>(embd: &mut [Vec<T>]) {
let n = embd.len();
if n == 0 {
return;
}
let mut sum_x = 0.0_f64;
let mut sum_y = 0.0_f64;
for p in embd.iter() {
sum_x += p[0].to_f64().unwrap();
sum_y += p[1].to_f64().unwrap();
}
let n_f64 = n as f64;
let mean_x = T::from_f64(sum_x / n_f64).unwrap();
let mean_y = T::from_f64(sum_y / n_f64).unwrap();
embd.par_iter_mut().for_each(|p| {
p[0] -= mean_x;
p[1] -= mean_y;
});
}
pub fn optimise_bh_tsne<T>(
embd: &mut [Vec<T>],
params: &TsneOptimParams<T>,
graph: &CoordinateList<T>,
verbose: usize,
) where
T: ManifoldsFloat,
{
let verbosity = parse_verbosity_level(verbose);
let n = embd.len();
let n_dim = embd[0].len();
let lr = params.get_lr(n);
let initial_momentum = T::from_f64(TSNE_INITIAL_MOMENTUM).unwrap();
let final_momentum = T::from_f64(TSNE_FINAL_MOMENTUM).unwrap();
let min_gain = T::from_f64(TSNE_MIN_GAIN).unwrap();
let max_step_norm = step_cap_from_lr(lr);
let mut update_flat = vec![T::zero(); n * n_dim];
let mut gains_flat = vec![T::one(); n * n_dim];
let mut xs = vec![T::zero(); n];
let mut ys = vec![T::zero(); n];
let mut rep_forces: Vec<(T, T, T)> = vec![(T::zero(), T::zero(), T::zero()); n];
let mut adj: Vec<Vec<(usize, T)>> = vec![Vec::new(); n];
for ((&i, &j), &w) in graph
.row_indices
.iter()
.zip(&graph.col_indices)
.zip(&graph.values)
{
adj[i].push((j, w));
}
for epoch in 0..params.n_epochs {
let bh_tree = BarnesHutTree::new(embd);
let momentum = if epoch < TSNE_MOMENTUM_SWITCH_ITER {
initial_momentum
} else {
final_momentum
};
let exag_factor = if epoch < params.early_exag_iter {
params.early_exag_factor
} else {
params.get_late_exag_factor()
};
embd.par_iter()
.zip(xs.par_iter_mut())
.zip(ys.par_iter_mut())
.for_each(|((p, x), y)| {
*x = p[0];
*y = p[1];
});
rep_forces.par_iter_mut().enumerate().for_each_init(
|| Vec::with_capacity(256),
|stack, (i, slot)| {
*slot = bh_tree.compute_repulsive_force(i, xs[i], ys[i], params.theta, stack);
},
);
let z_total: f64 = rep_forces
.iter()
.map(|r| r.2.to_f64().unwrap())
.sum::<f64>();
let z_inv = if z_total > TSNE_EPS {
T::from_f64(1.0 / z_total).unwrap()
} else {
T::zero()
};
embd.par_iter_mut()
.zip(update_flat.par_chunks_mut(n_dim))
.zip(gains_flat.par_chunks_mut(n_dim))
.enumerate()
.for_each(|(i, ((point, u_i), g_i))| {
let px = xs[i];
let py = ys[i];
let (rep_x, rep_y, _) = rep_forces[i];
let mut attr_x = T::zero();
let mut attr_y = T::zero();
for &(j, p_val) in &adj[i] {
let dx = px - xs[j];
let dy = py - ys[j];
let dist_sq = dx * dx + dy * dy;
let q = T::one() / (T::one() + dist_sq);
let force = p_val * exag_factor * q;
attr_x += force * dx;
attr_y += force * dy;
}
let grad_x = attr_x - rep_x * z_inv;
let grad_y = attr_y - rep_y * z_inv;
update_parameter(
&mut point[0],
&mut u_i[0],
&mut g_i[0],
grad_x,
lr,
momentum,
min_gain,
);
update_parameter(
&mut point[1],
&mut u_i[1],
&mut g_i[1],
grad_y,
lr,
momentum,
min_gain,
);
let (u0, u1) = u_i.split_at_mut(1);
clip_step(point, &mut u0[0], &mut u1[0], px, py, max_step_norm);
});
recentre_embedding(embd);
if verbosity.normal_verbosity() && (epoch % 50 == 0 || epoch == params.n_epochs - 1) {
println!(
" Epoch {}/{} | Z = {}",
epoch,
params.n_epochs,
(z_total.round() as i64).separate_with_underscores()
);
}
}
}
#[cfg(feature = "fft_tsne")]
fn fft_grid_geometry(half_span: f64, min_intervals: usize) -> (usize, f64, f64) {
let span = 2.0 * half_span * 1.05;
let n_boxes_unconstrained = choose_grid_size(0.0, span, TSNE_FFT_MIN_BOX_WIDTH, min_intervals);
if n_boxes_unconstrained <= TSNE_FFT_MAX_BOXES {
let half = n_boxes_unconstrained as f64 * TSNE_FFT_MIN_BOX_WIDTH / 2.0;
(n_boxes_unconstrained, TSNE_FFT_MIN_BOX_WIDTH, half)
} else {
let grown_half = half_span * (1.05 + TSNE_FFT_GRID_MARGIN);
let bw = (grown_half * 2.0 / TSNE_FFT_MAX_BOXES as f64).max(TSNE_FFT_MIN_BOX_WIDTH);
let half = TSNE_FFT_MAX_BOXES as f64 * bw / 2.0;
(TSNE_FFT_MAX_BOXES, bw, half)
}
}
#[cfg(feature = "fft_tsne")]
pub fn optimise_fft_tsne<T>(
embd: &mut [Vec<T>],
params: &TsneOptimParams<T>,
graph: &CoordinateList<T>,
verbose: usize,
) -> Result<(), ManifoldsError>
where
T: FftwFloat + ManifoldsFloat,
{
let verbosity = parse_verbosity_level(verbose);
let n = embd.len();
let n_dim = embd[0].len();
let lr = params.get_lr(n);
if n_dim != 2 {
return Err(ManifoldsError::IncorrectDim { n_dim });
}
let n_terms = 4;
let initial_momentum = T::from_f64(TSNE_INITIAL_MOMENTUM).unwrap();
let final_momentum = T::from_f64(TSNE_FINAL_MOMENTUM).unwrap();
let min_gain = T::from_f64(TSNE_MIN_GAIN).unwrap();
let max_step_norm = step_cap_from_lr(lr);
let mut uy = vec![vec![T::zero(); n_dim]; n];
let mut gains = vec![vec![T::one(); n_dim]; n];
let mut adj: Vec<Vec<(usize, T)>> = vec![Vec::new(); n];
for ((&i, &j), &w) in graph
.row_indices
.iter()
.zip(&graph.col_indices)
.zip(&graph.values)
{
if i < n {
adj[i].push((j, w));
}
}
let mut charges = vec![T::zero(); n * n_terms];
let mut potentials = vec![T::zero(); n * n_terms];
let mut xs = vec![T::zero(); n];
let mut ys = vec![T::zero(); n];
let min_intervals = 50;
let mut cached_n_boxes: usize = 0;
let mut grid: Option<FftGrid<T>> = None;
let mut workspace: Option<FftWorkspace<T>> = None;
for epoch in 0..params.n_epochs {
embd.par_iter()
.zip(xs.par_iter_mut())
.zip(ys.par_iter_mut())
.for_each(|((p, x), y)| {
*x = p[0];
*y = p[1];
});
let mut min_val = xs[0];
let mut max_val = xs[0];
for v in xs.iter().chain(ys.iter()) {
if *v < min_val {
min_val = *v;
}
if *v > max_val {
max_val = *v;
}
}
let half_span = min_val
.to_f64()
.unwrap()
.abs()
.max(max_val.to_f64().unwrap().abs());
let (n_boxes, _box_width, grid_half) = fft_grid_geometry(half_span, min_intervals);
let needs_rebuild = match grid.as_ref() {
None => true,
Some(_) if cached_n_boxes != n_boxes => true,
Some(g) => {
let coord_max =
g.coord_min + g.box_width * T::from_usize(g.n_boxes_per_dim).unwrap();
let safe_max = coord_max - g.box_width;
let safe_min = g.coord_min + g.box_width;
let max_abs = T::from_f64(half_span).unwrap();
max_abs >= safe_max || -max_abs <= safe_min
}
};
if needs_rebuild {
let half = T::from_f64(grid_half).unwrap();
let new_grid = FftGrid::new(-half, half, n_boxes, params.n_interp_points);
if cached_n_boxes != n_boxes {
workspace = Some(FftWorkspace::new(new_grid.n_fft));
}
grid = Some(new_grid);
cached_n_boxes = n_boxes;
}
let grid_ref = grid.as_ref().unwrap();
let ws = workspace.as_mut().unwrap();
let momentum = if epoch < TSNE_MOMENTUM_SWITCH_ITER {
initial_momentum
} else {
final_momentum
};
let exag_factor = if epoch < params.early_exag_iter {
params.early_exag_factor
} else {
params.get_late_exag_factor()
};
charges
.par_chunks_mut(n_terms)
.enumerate()
.for_each(|(i, chunk)| {
let x = xs[i];
let y = ys[i];
chunk[0] = T::one();
chunk[1] = x;
chunk[2] = y;
chunk[3] = x * x + y * y;
});
for v in potentials.iter_mut() {
*v = T::zero();
}
n_body_fft_2d(&xs, &ys, &charges, n_terms, grid_ref, ws, &mut potentials);
let sum_q: f64 = (0..n)
.map(|i| {
let idx = i * n_terms;
let phi1 = potentials[idx].to_f64().unwrap();
let phi2 = potentials[idx + 1].to_f64().unwrap();
let phi3 = potentials[idx + 2].to_f64().unwrap();
let phi4 = potentials[idx + 3].to_f64().unwrap();
let x = xs[i].to_f64().unwrap();
let y = ys[i].to_f64().unwrap();
(1.0 + x * x + y * y) * phi1 - 2.0 * (x * phi2 + y * phi3) + phi4
})
.sum::<f64>()
- n as f64;
let sum_q_safe = if sum_q > TSNE_EPS { sum_q } else { 1.0 };
embd.par_iter_mut()
.zip(uy.par_iter_mut())
.zip(gains.par_iter_mut())
.enumerate()
.for_each(|(i, ((point, u_i), gains_i))| {
let x = xs[i];
let y = ys[i];
let mut attr_x = T::zero();
let mut attr_y = T::zero();
for &(j, p_val) in &adj[i] {
let other_x = xs[j];
let other_y = ys[j];
let dx = x - other_x;
let dy = y - other_y;
let dist_sq = dx * dx + dy * dy;
let q_ij = T::one() / (T::one() + dist_sq);
let force = p_val * exag_factor * q_ij;
attr_x += force * dx;
attr_y += force * dy;
}
let pot_idx = i * n_terms;
let phi1 = potentials[pot_idx].to_f64().unwrap();
let phi2 = potentials[pot_idx + 1].to_f64().unwrap();
let phi3 = potentials[pot_idx + 2].to_f64().unwrap();
let xf = x.to_f64().unwrap();
let yf = y.to_f64().unwrap();
let rep_x = T::from_f64((xf * phi1 - phi2) / sum_q_safe).unwrap();
let rep_y = T::from_f64((yf * phi1 - phi3) / sum_q_safe).unwrap();
let grad_x = attr_x - rep_x;
let grad_y = attr_y - rep_y;
update_parameter(
&mut point[0],
&mut u_i[0],
&mut gains_i[0],
grad_x,
lr,
momentum,
min_gain,
);
update_parameter(
&mut point[1],
&mut u_i[1],
&mut gains_i[1],
grad_y,
lr,
momentum,
min_gain,
);
let (u0, u1) = u_i.split_at_mut(1);
clip_step(point, &mut u0[0], &mut u1[0], x, y, max_step_norm);
});
recentre_embedding(embd);
if verbosity.normal_verbosity() && (epoch % 50 == 0 || epoch == params.n_epochs - 1) {
println!(
" Epoch {}/{} | Z = {} | n_boxes = {}",
epoch,
params.n_epochs,
(sum_q.round() as i64).separate_with_underscores(),
n_boxes,
);
}
}
Ok(())
}
#[cfg(test)]
mod test_tsne_optimiser {
use super::*;
use approx::assert_relative_eq;
fn create_coo_graph(n: usize, edges: &[(usize, usize, f64)]) -> CoordinateList<f64> {
let mut row_indices = Vec::new();
let mut col_indices = Vec::new();
let mut values = Vec::new();
for &(u, v, w) in edges {
row_indices.push(u);
col_indices.push(v);
values.push(w);
if u != v {
row_indices.push(v);
col_indices.push(u);
values.push(w);
}
}
CoordinateList {
row_indices,
col_indices,
values,
n_samples: n,
}
}
#[test]
fn test_tsne_params_defaults() {
let params = TsneOptimParams::<f64>::default();
assert_eq!(params.n_epochs, 1000);
assert_eq!(params.early_exag_iter, 250);
assert_relative_eq!(params.early_exag_factor, 12.0);
assert_relative_eq!(params.theta, 0.5);
}
#[test]
fn test_get_lr_floor_and_scaling() {
let params = TsneOptimParams::<f64>::default();
assert_relative_eq!(params.get_lr(100), 200.0);
assert_relative_eq!(params.get_lr(120_000), 10_000.0);
let fixed = TsneOptimParams {
lr: Some(50.0),
..TsneOptimParams::default()
};
assert_relative_eq!(fixed.get_lr(1_000_000), 50.0);
}
#[test]
fn test_step_cap_scales_with_lr() {
let cap_small: f64 = step_cap_from_lr(200.0);
assert_relative_eq!(cap_small, 5.0);
let cap_large: f64 = step_cap_from_lr(40_000.0);
assert_relative_eq!(cap_large, 40_000.0 * TSNE_MAX_STEP_FRACTION);
}
#[test]
fn test_bh_tsne_basic_convergence() {
let edges = vec![(0, 1, 1.0), (1, 2, 1.0), (2, 0, 1.0)];
let graph = create_coo_graph(3, &edges);
let mut embd = vec![vec![0.0, 0.0], vec![1.0, 1.0], vec![2.0, 2.0]];
let initial_embd = embd.clone();
let params = TsneOptimParams {
n_epochs: 50,
lr: Some(50.0),
..TsneOptimParams::default()
};
optimise_bh_tsne(&mut embd, ¶ms, &graph, 0);
for point in &embd {
for val in point {
assert!(val.is_finite(), "Embedding contains non-finite values");
}
}
let total_movement: f64 = embd
.iter()
.zip(initial_embd.iter())
.map(|(n, o)| (n[0] - o[0]).powi(2) + (n[1] - o[1]).powi(2))
.sum();
assert!(
total_movement > 0.01,
"Barnes-Hut t-SNE failed to move points significantly"
);
}
#[test]
#[cfg(feature = "fft_tsne")]
fn test_fft_tsne_basic_convergence() {
let edges = vec![(0, 1, 1.0), (1, 2, 1.0), (2, 0, 1.0)];
let graph = create_coo_graph(3, &edges);
let mut embd = vec![vec![0.0, 0.0], vec![1.0, 1.0], vec![2.0, 2.0]];
let initial_embd = embd.clone();
let params = TsneOptimParams {
n_epochs: 50,
lr: Some(50.0),
n_interp_points: 3,
..TsneOptimParams::default()
};
let _ = optimise_fft_tsne(&mut embd, ¶ms, &graph, 0);
for point in &embd {
for val in point {
assert!(val.is_finite(), "Embedding contains non-finite values");
}
}
let total_movement: f64 = embd
.iter()
.zip(initial_embd.iter())
.map(|(n, o)| (n[0] - o[0]).powi(2) + (n[1] - o[1]).powi(2))
.sum();
assert!(
total_movement > 0.01,
"FFT t-SNE failed to move points significantly"
);
}
#[test]
fn test_bh_tsne_determinism() {
let edges = vec![(0, 1, 1.0), (1, 2, 1.0)];
let graph = create_coo_graph(3, &edges);
let mut embd1 = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.0, 1.0]];
let mut embd2 = embd1.clone();
let params = TsneOptimParams {
n_epochs: 50,
..TsneOptimParams::default()
};
optimise_bh_tsne(&mut embd1, ¶ms, &graph, 0);
optimise_bh_tsne(&mut embd2, ¶ms, &graph, 0);
for (p1, p2) in embd1.iter().zip(embd2.iter()) {
assert_relative_eq!(p1[0], p2[0]);
assert_relative_eq!(p1[1], p2[1]);
}
}
}