use faer::Mat;
use rand::prelude::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
use rand_distr::{Distribution, StandardNormal};
use rayon::prelude::*;
use thousands::*;
use crate::prelude::*;
use crate::utils::math::*;
pub const DEFAULT_MDS_LR: f64 = 0.001;
pub const DEFAULT_MDS_ITER: usize = 800;
#[derive(Default, Clone, Debug)]
pub enum MdsMethod {
#[default]
SgdDense,
ClassicMds,
}
pub fn parse_mds_method(s: &str) -> Option<MdsMethod> {
match s.to_lowercase().as_str() {
"sgd_dense" | "dense" => Some(MdsMethod::SgdDense),
"classic" => Some(MdsMethod::ClassicMds),
_ => None,
}
}
pub struct MdsOptimParams<T> {
pub randomised: bool,
pub n_iter: usize,
pub pairs_per_iter: usize,
pub lr: T,
}
impl<T> MdsOptimParams<T>
where
T: ManifoldsFloat,
{
pub fn new(n: usize, randomised: bool, n_iter: Option<usize>, lr: Option<T>) -> Self {
let lr = lr.unwrap_or(T::from_f64(DEFAULT_MDS_LR).unwrap());
let n_iter = n_iter.unwrap_or(DEFAULT_MDS_ITER);
let pairs_per_iter = (n as f64 * (n as f64).ln() * 2.0) as usize;
Self {
randomised,
n_iter,
pairs_per_iter,
lr,
}
}
}
fn compute_std<T>(embedding: &[T]) -> T
where
T: ManifoldsFloat,
{
if embedding.is_empty() {
return T::zero();
}
let sum: T = embedding.iter().map(|&v| v * v).sum();
(sum / T::from(embedding.len()).unwrap()).sqrt()
}
pub fn classic_mds<T>(
dist: &[Vec<T>],
n_components: usize,
randomised: bool,
seed: usize,
) -> Result<Vec<Vec<T>>, ManifoldsError>
where
T: ManifoldsFloat,
StandardNormal: Distribution<T>,
{
let n = dist.len();
let mut d_sq = Mat::zeros(n, n);
for i in 0..n {
for j in 0..n {
d_sq[(i, j)] = dist[i][j] * dist[i][j];
}
}
let mean_row: Vec<T> = (0..n)
.map(|j| {
let sum: T = (0..n).map(|i| d_sq[(i, j)]).sum();
sum / T::from(n).unwrap()
})
.collect();
let mean_col: Vec<T> = (0..n)
.map(|i| {
let sum: T = (0..n).map(|j| d_sq[(i, j)]).sum();
sum / T::from(n).unwrap()
})
.collect();
let mean_total: T = mean_row.iter().copied().sum::<T>() / T::from(n).unwrap();
for i in 0..n {
for j in 0..n {
let val = d_sq[(i, j)];
let centred = val - mean_row[j] - mean_col[i] + mean_total;
d_sq[(i, j)] = -centred / T::from(2.0).unwrap();
}
}
let mut embedding = vec![vec![T::zero(); n_components]; n];
if randomised {
let rsvd = randomised_svd(d_sq.as_ref(), n_components, seed, None, None)?;
for i in 0..n {
for k in 0..n_components {
let singular_val = rsvd.s[k];
if singular_val > T::zero() {
embedding[i][k] = rsvd.u[(i, k)] * singular_val.sqrt();
}
}
}
} else {
let svd = d_sq.svd().unwrap();
let s = svd.S();
let u = svd.U();
for i in 0..n {
for k in 0..n_components {
let singular_val = s[k];
if singular_val > T::zero() {
embedding[i][k] = u[(i, k)] * singular_val.sqrt();
}
}
}
}
Ok(embedding)
}
pub fn sgd_mds<T>(
dist: &[Vec<T>],
n_dim: usize,
params: &MdsOptimParams<T>,
init: Option<Vec<T>>,
seed: usize,
verbose: usize,
) -> Result<Vec<Vec<T>>, ManifoldsError>
where
T: ManifoldsFloat,
StandardNormal: Distribution<T>,
{
let n = dist.len();
if n == 0 {
return Err(ManifoldsError::NoData);
}
if dist[0].len() != n {
return Err(ManifoldsError::NotSquareMatrix);
}
let verbosity = parse_verbosity_level(verbose);
let mut rng = StdRng::seed_from_u64(seed as u64);
let d_max = dist
.iter()
.flat_map(|row| row.iter())
.copied()
.fold(T::zero(), |acc, x| if x > acc { x } else { acc });
let d_norm: Vec<T> = if d_max > T::zero() {
dist.iter()
.flat_map(|row| row.iter().map(|&d| d / d_max))
.collect()
} else {
dist.iter().flat_map(|row| row.iter().copied()).collect()
};
let mut y = if let Some(init_y) = init {
let y_std = compute_std(&init_y);
if y_std > T::zero() {
init_y.iter().map(|&v| v / y_std).collect()
} else {
init_y
}
} else {
let init_embedding = classic_mds(dist, n_dim, true, seed)?;
let flat: Vec<T> = init_embedding.into_iter().flatten().collect();
let y_std = compute_std(&flat);
if y_std > T::zero() {
flat.iter().map(|&v| v / y_std).collect()
} else {
flat
}
};
let n_iter = params.n_iter;
let pairs_per_iter = params.pairs_per_iter;
let total_pairs = T::from(n * (n - 1) / 2).unwrap();
let sampling_ratio = T::from(pairs_per_iter).unwrap() / total_pairs;
let batch_scale = (T::one() / sampling_ratio).sqrt();
let eta_max = params.lr * batch_scale;
let eta_min = params.lr * T::from(0.01).unwrap() * batch_scale;
let lambda = if n_iter > 1 {
(eta_max / eta_min).ln() / T::from(n_iter - 1).unwrap()
} else {
T::zero()
};
if verbosity.normal_verbosity() {
println!(
"SGD-MDS: n={}, pairs_per_iter={}, n_iter={}, eta_max={:.6}, eta_min={:.6}, batch_scale={:.2}",
n.separate_with_underscores(),
pairs_per_iter.separate_with_underscores(),
n_iter.separate_with_underscores(),
eta_max.to_f64().unwrap(),
eta_min.to_f64().unwrap(),
batch_scale.to_f64().unwrap(),
);
}
let mut prev_stress = None;
for iteration in 0..n_iter {
let lr_i = eta_max * (-lambda * T::from(iteration).unwrap()).exp();
let pairs: Vec<(usize, usize)> =
std::iter::from_fn(|| Some((rng.random_range(0..n), rng.random_range(0..n))))
.filter(|(i, j)| i != j)
.take(pairs_per_iter)
.collect();
let contribs: Vec<(usize, usize, Vec<T>, T)> = pairs
.par_iter()
.map(|&(i, j)| {
let target_dist = d_norm[i * n + j];
let mut dist_sq = T::zero();
for k in 0..n_dim {
let diff = y[i * n_dim + k] - y[j * n_dim + k];
dist_sq += diff * diff;
}
let current_dist = dist_sq.sqrt().max(T::from(1e-10).unwrap());
let error = target_dist - current_dist;
let weight = T::from(-2.0).unwrap() * error / current_dist;
let contrib: Vec<T> = (0..n_dim)
.map(|k| (y[i * n_dim + k] - y[j * n_dim + k]) * weight)
.collect();
(i, j, contrib, error * error)
})
.collect();
let mut gradients = vec![T::zero(); n * n_dim];
let mut total_err = T::zero();
for (i, j, contrib, sq_err) in &contribs {
for k in 0..n_dim {
gradients[i * n_dim + k] += contrib[k];
gradients[j * n_dim + k] -= contrib[k];
}
total_err += *sq_err;
}
for idx in 0..n * n_dim {
y[idx] -= lr_i * gradients[idx];
}
let stress = total_err / T::from(contribs.len()).unwrap();
if verbosity.normal_verbosity() && iteration % 100 == 0 {
println!(
" Iter {}: stress={:.6}, lr={:.6}",
iteration.separate_with_underscores(),
stress.to_f64().unwrap(),
lr_i.to_f64().unwrap(),
);
}
if let Some(prev) = prev_stress {
let rel_change = ((stress - prev) / (prev + T::from(1e-10).unwrap())).abs();
if rel_change < T::from(1e-6).unwrap() && iteration > 50 {
if verbosity.normal_verbosity() {
println!(
" Converged at iteration {} (rel_change={:.2e})",
iteration,
rel_change.to_f64().unwrap()
);
}
break;
}
}
prev_stress = Some(stress);
}
if d_max > T::zero() {
y.iter_mut().for_each(|v| *v *= d_max);
}
let mut embedding = vec![vec![T::zero(); n_dim]; n];
for i in 0..n {
for j in 0..n_dim {
embedding[i][j] = y[i * n_dim + j];
}
}
Ok(embedding)
}
#[cfg(test)]
mod test_mds {
use super::*;
use approx::assert_relative_eq;
use num_traits::Float;
#[test]
fn test_sgd_mds_identity_distances() {
let distances = vec![
vec![0.0, 1.0, 1.0],
vec![1.0, 0.0, 1.0],
vec![1.0, 1.0, 0.0],
];
let mds_params = MdsOptimParams::new(distances.len(), true, None, None);
let embedding = sgd_mds(&distances, 2, &mds_params, None, 42, 0).unwrap();
assert_eq!(embedding.len(), 3);
assert_eq!(embedding[0].len(), 2);
for i in 0..3 {
for j in 0..3 {
let mut dist_sq = 0.0;
for k in 0..2 {
let diff = embedding[i][k] - embedding[j][k];
dist_sq += diff * diff;
}
let dist = dist_sq.sqrt();
if i == j {
assert_relative_eq!(dist, 0.0, epsilon = 1e-2);
} else {
assert_relative_eq!(dist, 1.0, epsilon = 0.35);
}
}
}
}
#[test]
fn test_sgd_mds_converges() {
let distances = vec![
vec![0.0, 1.0, 1.414, 1.0],
vec![1.0, 0.0, 1.0, 1.414],
vec![1.414, 1.0, 0.0, 1.0],
vec![1.0, 1.414, 1.0, 0.0],
];
let mds_params = MdsOptimParams::new(distances.len(), true, None, None);
let embedding = sgd_mds(&distances, 2, &mds_params, None, 42, 0).unwrap();
for i in 0..4 {
for j in 0..4 {
let mut dist_sq = 0.0;
for k in 0..2 {
let diff = embedding[i][k] - embedding[j][k];
dist_sq += diff * diff;
}
let dist = dist_sq.sqrt();
assert_relative_eq!(dist, distances[i][j], epsilon = 0.5);
}
}
}
#[test]
fn test_classic_mds_identity() {
let distances = vec![
vec![0.0, 1.0, 1.0],
vec![1.0, 0.0, 1.0],
vec![1.0, 1.0, 0.0],
];
let embedding = classic_mds(&distances, 2, true, 42).unwrap();
assert_eq!(embedding.len(), 3);
assert_eq!(embedding[0].len(), 2);
for i in 0..3 {
for j in 0..3 {
let mut dist_sq = 0.0;
for k in 0..2 {
let diff = embedding[i][k] - embedding[j][k];
dist_sq += diff * diff;
}
let dist = dist_sq.sqrt();
if i == j {
assert_relative_eq!(dist, 0.0, epsilon = 1e-6);
} else {
assert_relative_eq!(dist, 1.0, epsilon = 0.3);
}
}
}
}
}