use argmin::core::{CostFunction, Error, Executor, Gradient, Hessian, State};
use argmin::solver::conjugategradient::beta::PolakRibiere;
use argmin::solver::conjugategradient::NonlinearConjugateGradient;
use argmin::solver::linesearch::condition::ArmijoCondition;
use argmin::solver::linesearch::{BacktrackingLineSearch, MoreThuenteLineSearch};
use argmin::solver::newton::NewtonCG;
use argmin::solver::quasinewton::LBFGS;
use argmin::solver::trustregion::{Steihaug, TrustRegion};
use nalgebra::DVector;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use crate::spec::PairwiseRelations;
#[cfg(test)]
pub(crate) fn compute_initial_layout(
distances: &Vec<Vec<f64>>,
relationships: &PairwiseRelations,
set_areas: &[f64],
rng: &mut dyn rand::RngCore,
) -> Result<Vec<f64>, Error> {
compute_initial_layout_with_solver(
distances,
relationships,
set_areas,
rng,
MdsSolver::default(),
)
}
pub(crate) fn compute_initial_layout_with_solver(
distances: &Vec<Vec<f64>>,
relationships: &PairwiseRelations,
set_areas: &[f64],
rng: &mut dyn rand::RngCore,
solver: MdsSolver,
) -> Result<Vec<f64>, Error> {
let n_sets = distances.len();
let total_area: f64 = set_areas.iter().sum();
let scale = if total_area > 0.0 {
total_area.sqrt()
} else {
10.0
};
let seed: u64 = rng.random();
let (_loss, params) = run_attempt(distances, relationships, n_sets, scale, seed, solver)?;
Ok(params)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum MdsSolver {
#[default]
Lbfgs,
ConjugateGradient,
TrustRegion,
NewtonCg,
}
fn run_attempt(
distances: &Vec<Vec<f64>>,
relationships: &PairwiseRelations,
n_sets: usize,
scale: f64,
seed: u64,
solver: MdsSolver,
) -> Result<(f64, Vec<f64>), Error> {
let mut local_rng = StdRng::seed_from_u64(seed);
let mut initial_values = vec![0.0; n_sets * 2];
for value in &mut initial_values {
*value = local_rng.random_range(0.0..scale);
}
let initial_param = DVector::from_vec(initial_values);
let cost_function = MdsCost {
distances,
relationships,
};
match solver {
MdsSolver::Lbfgs => {
let line_search = MoreThuenteLineSearch::new();
let solver = LBFGS::new(line_search, 10);
let result = Executor::new(cost_function, solver)
.configure(|state| state.param(initial_param).max_iters(200))
.run()?;
Ok((
result.state().get_cost(),
result.state().get_best_param().unwrap().as_slice().to_vec(),
))
}
MdsSolver::ConjugateGradient => {
let line_search = BacktrackingLineSearch::new(ArmijoCondition::new(0.01)?);
let solver = NonlinearConjugateGradient::new(line_search, PolakRibiere::new());
let result = Executor::new(cost_function, solver)
.configure(|state| state.param(initial_param).max_iters(200))
.run()?;
Ok((
result.state().get_cost(),
result.state().get_best_param().unwrap().as_slice().to_vec(),
))
}
MdsSolver::TrustRegion => {
let vec_cost = VecMdsCost {
inner: cost_function,
};
let subproblem: Steihaug<Vec<f64>, f64> = Steihaug::new().with_max_iters(200);
let solver = TrustRegion::new(subproblem);
let initial_param_vec = initial_param.as_slice().to_vec();
let result = Executor::new(vec_cost, solver)
.configure(|state| state.param(initial_param_vec).max_iters(200))
.run()?;
Ok((
result.state().get_cost(),
result.state().get_best_param().unwrap().clone(),
))
}
MdsSolver::NewtonCg => {
let vec_cost = VecMdsCost {
inner: cost_function,
};
let line_search = MoreThuenteLineSearch::new();
let solver: NewtonCG<_, f64> = NewtonCG::new(line_search);
let initial_param_vec = initial_param.as_slice().to_vec();
let result = Executor::new(vec_cost, solver)
.configure(|state| state.param(initial_param_vec).max_iters(200))
.run()?;
Ok((
result.state().get_cost(),
result.state().get_best_param().unwrap().clone(),
))
}
}
}
struct VecMdsCost<'a> {
inner: MdsCost<'a>,
}
impl<'a> CostFunction for VecMdsCost<'a> {
type Param = Vec<f64>;
type Output = f64;
fn cost(&self, p: &Vec<f64>) -> Result<f64, Error> {
self.inner.cost(&DVector::from_vec(p.clone()))
}
}
impl<'a> Gradient for VecMdsCost<'a> {
type Param = Vec<f64>;
type Gradient = Vec<f64>;
fn gradient(&self, p: &Vec<f64>) -> Result<Vec<f64>, Error> {
let g = self.inner.gradient(&DVector::from_vec(p.clone()))?;
Ok(g.as_slice().to_vec())
}
}
impl<'a> Hessian for VecMdsCost<'a> {
type Param = Vec<f64>;
type Hessian = Vec<Vec<f64>>;
fn hessian(&self, p: &Vec<f64>) -> Result<Vec<Vec<f64>>, Error> {
self.inner.hessian(&DVector::from_vec(p.clone()))
}
}
struct MdsCost<'a> {
distances: &'a Vec<Vec<f64>>,
relationships: &'a PairwiseRelations,
}
impl<'a> CostFunction for MdsCost<'a> {
type Param = DVector<f64>;
type Output = f64;
fn cost(&self, param: &Self::Param) -> Result<Self::Output, Error> {
let n_sets = param.len() / 2;
let x = param.rows(0, n_sets);
let y = param.rows(n_sets, n_sets);
let mut loss = 0.0;
for i in 0..n_sets {
for j in 0..n_sets {
if i == j {
continue;
}
let xd = x[i] - x[j];
let yd = y[i] - y[j];
let d = xd.powi(2) + yd.powi(2) - self.distances[i][j].powi(2);
if self.relationships.is_disjoint(i, j) && d >= 0.0 {
continue;
}
if (self.relationships.is_subset(i, j) || self.relationships.is_subset(j, i))
&& d <= 0.0
{
continue;
}
loss += d.powi(2);
}
}
Ok(loss)
}
}
impl<'a> Gradient for MdsCost<'a> {
type Param = DVector<f64>;
type Gradient = DVector<f64>;
fn gradient(&self, param: &Self::Param) -> Result<Self::Gradient, Error> {
let n_sets = param.len() / 2;
let x = param.rows(0, n_sets);
let y = param.rows(n_sets, n_sets);
let mut grad = DVector::from_element(param.len(), 0.0);
for i in 0..n_sets {
for j in 0..n_sets {
if i == j {
continue;
}
let xd = x[i] - x[j];
let yd = y[i] - y[j];
let d = xd.powi(2) + yd.powi(2) - self.distances[i][j].powi(2);
if self.relationships.is_disjoint(i, j) && d >= 0.0 {
continue;
}
if (self.relationships.is_subset(i, j) || self.relationships.is_subset(j, i))
&& d <= 0.0
{
continue;
}
grad[i] += 8.0 * d * xd;
grad[i + n_sets] += 8.0 * d * yd;
}
}
Ok(grad)
}
}
impl<'a> Hessian for MdsCost<'a> {
type Param = DVector<f64>;
type Hessian = Vec<Vec<f64>>;
fn hessian(&self, param: &Self::Param) -> Result<Self::Hessian, Error> {
let n_sets = param.len() / 2;
let n = 2 * n_sets;
let x = param.rows(0, n_sets);
let y = param.rows(n_sets, n_sets);
let mut h = vec![vec![0.0; n]; n];
for i in 0..n_sets {
for j in 0..n_sets {
if i == j {
continue;
}
let xd = x[i] - x[j];
let yd = y[i] - y[j];
let d = xd.powi(2) + yd.powi(2) - self.distances[i][j].powi(2);
if self.relationships.is_disjoint(i, j) && d >= 0.0 {
continue;
}
if (self.relationships.is_subset(i, j) || self.relationships.is_subset(j, i))
&& d <= 0.0
{
continue;
}
let xx = 8.0 * xd * xd + 4.0 * d;
let yy = 8.0 * yd * yd + 4.0 * d;
let xy = 8.0 * xd * yd;
h[i][i] += xx;
h[i][j] -= xx;
h[j][i] -= xx;
h[j][j] += xx;
let ii = i + n_sets;
let jj = j + n_sets;
h[ii][ii] += yy;
h[ii][jj] -= yy;
h[jj][ii] -= yy;
h[jj][jj] += yy;
h[i][ii] += xy;
h[i][jj] -= xy;
h[j][ii] -= xy;
h[j][jj] += xy;
h[ii][i] += xy;
h[jj][i] -= xy;
h[ii][j] -= xy;
h[jj][j] += xy;
}
}
Ok(h)
}
}
#[cfg(test)]
mod gradient_check {
use super::*;
use nalgebra::DVector;
fn fd_gradient(cost: &MdsCost, p: &DVector<f64>) -> DVector<f64> {
let h = 1e-6;
let mut g = DVector::from_element(p.len(), 0.0);
for i in 0..p.len() {
let mut pp = p.clone();
let mut pm = p.clone();
pp[i] += h;
pm[i] -= h;
g[i] = (cost.cost(&pp).unwrap() - cost.cost(&pm).unwrap()) / (2.0 * h);
}
g
}
fn fd_hessian(cost: &MdsCost, p: &DVector<f64>) -> Vec<Vec<f64>> {
let h = 1e-5;
let n = p.len();
let mut hessian = vec![vec![0.0; n]; n];
for j in 0..n {
let mut pp = p.clone();
let mut pm = p.clone();
pp[j] += h;
pm[j] -= h;
let gp = cost.gradient(&pp).unwrap();
let gm = cost.gradient(&pm).unwrap();
for i in 0..n {
hessian[i][j] = (gp[i] - gm[i]) / (2.0 * h);
}
}
hessian
}
#[test]
fn analytic_gradient_matches_finite_difference() {
let distances = vec![
vec![0.0, 2.232, 2.232, 2.232],
vec![2.232, 0.0, 1.642, 1.642],
vec![2.232, 1.642, 0.0, 1.642],
vec![2.232, 1.642, 1.642, 0.0],
];
let mut relations = PairwiseRelations {
n_sets: 4,
subset: vec![vec![false; 4]; 4],
disjoint: vec![vec![false; 4]; 4],
overlap_areas: vec![vec![0.0; 4]; 4],
};
for i in 1..4 {
relations.subset[0][i] = true;
relations.subset[i][0] = true;
}
let cost = MdsCost {
distances: &distances,
relationships: &relations,
};
let p = DVector::from_vec(vec![
0.5, 3.0, 3.5, 4.5, 0.3, 1.0, 2.5, 0.8, ]);
let analytic = cost.gradient(&p).unwrap();
let numeric = fd_gradient(&cost, &p);
let max_abs_diff = analytic
.iter()
.zip(numeric.iter())
.map(|(a, n)| (a - n).abs())
.fold(0.0_f64, f64::max);
let max_abs = numeric.iter().map(|x| x.abs()).fold(0.0_f64, f64::max);
println!("analytic = {:?}", analytic.as_slice());
println!("numeric = {:?}", numeric.as_slice());
println!("max abs diff: {:.4e}", max_abs_diff);
println!("max abs num : {:.4e}", max_abs);
println!("ratio analytic/numeric per element:");
for i in 0..p.len() {
if numeric[i].abs() > 1e-9 {
println!(
" [{}] analytic={:>12.4} numeric={:>12.4} ratio={:.4}",
i,
analytic[i],
numeric[i],
analytic[i] / numeric[i]
);
}
}
}
#[test]
fn analytic_hessian_matches_finite_difference() {
let distances = vec![
vec![0.0, 2.232, 2.232, 2.232],
vec![2.232, 0.0, 1.642, 1.642],
vec![2.232, 1.642, 0.0, 1.642],
vec![2.232, 1.642, 1.642, 0.0],
];
let mut relations = PairwiseRelations {
n_sets: 4,
subset: vec![vec![false; 4]; 4],
disjoint: vec![vec![false; 4]; 4],
overlap_areas: vec![vec![0.0; 4]; 4],
};
for i in 1..4 {
relations.subset[0][i] = true;
relations.subset[i][0] = true;
}
let cost = MdsCost {
distances: &distances,
relationships: &relations,
};
let p = DVector::from_vec(vec![0.5, 3.0, 3.5, 4.5, 0.3, 1.0, 2.5, 0.8]);
let analytic = cost.hessian(&p).unwrap();
let numeric = fd_hessian(&cost, &p);
let n = p.len();
let mut max_abs_diff = 0.0_f64;
let mut max_abs = 0.0_f64;
for i in 0..n {
for j in 0..n {
let d = (analytic[i][j] - numeric[i][j]).abs();
max_abs_diff = max_abs_diff.max(d);
max_abs = max_abs.max(numeric[i][j].abs());
}
}
assert!(
max_abs_diff < 1e-3 * max_abs.max(1.0),
"analytic Hessian disagrees with FD reference: max abs diff {:.4e}, max abs {:.4e}",
max_abs_diff,
max_abs,
);
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
(a - b).abs() < tol
}
fn create_test_relationships(n_sets: usize) -> PairwiseRelations {
PairwiseRelations {
n_sets,
subset: vec![vec![false; n_sets]; n_sets],
disjoint: vec![vec![false; n_sets]; n_sets],
overlap_areas: vec![vec![0.0; n_sets]; n_sets],
}
}
#[test]
fn test_compute_initial_layout_two_sets_touching() {
let mut rng = StdRng::seed_from_u64(42);
let distances = vec![vec![0.0, 2.0], vec![2.0, 0.0]];
let relationships = create_test_relationships(2);
let result = compute_initial_layout(&distances, &relationships, &[], &mut rng).unwrap();
assert_eq!(result.len(), 4);
let (x, y) = result.split_at(2);
let dx = x[0] - x[1];
let dy = y[0] - y[1];
let actual_distance = (dx * dx + dy * dy).sqrt();
assert!(
approx_eq(actual_distance, 2.0, 0.1),
"Distance {} should be close to 2.0",
actual_distance
);
}
#[test]
fn test_compute_initial_layout_two_sets_overlapping() {
let mut rng = StdRng::seed_from_u64(42);
let distances = vec![vec![0.0, 1.0], vec![1.0, 0.0]];
let relationships = create_test_relationships(2);
let result = compute_initial_layout(&distances, &relationships, &[], &mut rng).unwrap();
assert_eq!(result.len(), 4);
let (x, y) = result.split_at(2);
let dx = x[0] - x[1];
let dy = y[0] - y[1];
let actual_distance = (dx * dx + dy * dy).sqrt();
assert!(
approx_eq(actual_distance, 1.0, 0.1),
"Distance {} should be close to 1.0",
actual_distance
);
}
#[test]
fn test_compute_initial_layout_two_sets_separated() {
let mut rng = StdRng::seed_from_u64(42);
let distances = vec![vec![0.0, 5.0], vec![5.0, 0.0]];
let relationships = create_test_relationships(2);
let result = compute_initial_layout(&distances, &relationships, &[], &mut rng).unwrap();
assert_eq!(result.len(), 4);
let (x, y) = result.split_at(2);
let dx = x[0] - x[1];
let dy = y[0] - y[1];
let actual_distance = (dx * dx + dy * dy).sqrt();
assert!(
approx_eq(actual_distance, 5.0, 0.1),
"Distance {} should be close to 5.0",
actual_distance
);
}
#[test]
fn test_compute_initial_layout_two_sets_disjoint() {
let mut rng = StdRng::seed_from_u64(42);
let distances = vec![vec![0.0, 3.0], vec![3.0, 0.0]];
let mut relationships = create_test_relationships(2);
relationships.disjoint[0][1] = true;
relationships.disjoint[1][0] = true;
let result = compute_initial_layout(&distances, &relationships, &[], &mut rng).unwrap();
assert_eq!(result.len(), 4);
let (x, y) = result.split_at(2);
let dx = x[0] - x[1];
let dy = y[0] - y[1];
let actual_distance = (dx * dx + dy * dy).sqrt();
assert!(
actual_distance >= 3.0 - 0.1,
"Distance {} should be at least 3.0",
actual_distance
);
}
#[test]
fn test_compute_initial_layout_three_sets_triangle() {
let mut rng = StdRng::seed_from_u64(42);
let d = 2.0; let distances = vec![vec![0.0, d, d], vec![d, 0.0, d], vec![d, d, 0.0]];
let relationships = create_test_relationships(3);
let result = compute_initial_layout(&distances, &relationships, &[], &mut rng).unwrap();
assert_eq!(result.len(), 6);
let (x, y) = result.split_at(3);
for i in 0..3 {
for j in (i + 1)..3 {
let dx = x[i] - x[j];
let dy = y[i] - y[j];
let actual_distance = (dx * dx + dy * dy).sqrt();
assert!(
approx_eq(actual_distance, d, 0.2),
"Distance between {} and {} is {}, should be close to {}",
i,
j,
actual_distance,
d
);
}
}
}
#[test]
fn test_compute_initial_layout_three_sets_collinear() {
let mut rng = StdRng::seed_from_u64(42);
let distances = vec![
vec![0.0, 1.0, 2.0],
vec![1.0, 0.0, 1.0],
vec![2.0, 1.0, 0.0],
];
let relationships = create_test_relationships(3);
let result = compute_initial_layout(&distances, &relationships, &[], &mut rng).unwrap();
assert_eq!(result.len(), 6);
let (x, y) = result.split_at(3);
let d01 = ((x[0] - x[1]).powi(2) + (y[0] - y[1]).powi(2)).sqrt();
let d12 = ((x[1] - x[2]).powi(2) + (y[1] - y[2]).powi(2)).sqrt();
let d02 = ((x[0] - x[2]).powi(2) + (y[0] - y[2]).powi(2)).sqrt();
assert!(approx_eq(d01, 1.0, 0.2), "Distance 0-1: {}", d01);
assert!(approx_eq(d12, 1.0, 0.2), "Distance 1-2: {}", d12);
assert!(approx_eq(d02, 2.0, 0.2), "Distance 0-2: {}", d02);
}
#[test]
fn test_compute_initial_layout_four_sets_square() {
let mut rng = StdRng::seed_from_u64(42);
let side = 1.0;
let diag = side * 2.0_f64.sqrt();
let distances = vec![
vec![0.0, side, diag, side],
vec![side, 0.0, side, diag],
vec![diag, side, 0.0, side],
vec![side, diag, side, 0.0],
];
let relationships = create_test_relationships(4);
let result = compute_initial_layout(&distances, &relationships, &[], &mut rng).unwrap();
assert_eq!(result.len(), 8);
let (x, y) = result.split_at(4);
let d01 = ((x[0] - x[1]).powi(2) + (y[0] - y[1]).powi(2)).sqrt();
let d02 = ((x[0] - x[2]).powi(2) + (y[0] - y[2]).powi(2)).sqrt();
assert!(d01 < d02);
}
#[test]
fn test_compute_initial_layout_with_restarts() {
let mut rng = StdRng::seed_from_u64(42);
let distances = vec![vec![0.0, 1.5], vec![1.5, 0.0]];
let relationships = create_test_relationships(2);
let result1 = compute_initial_layout(&distances, &relationships, &[], &mut rng).unwrap();
let result2 = compute_initial_layout(&distances, &relationships, &[], &mut rng).unwrap();
assert_eq!(result1.len(), 4);
assert_eq!(result2.len(), 4);
let (x1, y1) = result1.split_at(2);
let d1 = ((x1[0] - x1[1]).powi(2) + (y1[0] - y1[1]).powi(2)).sqrt();
let (x2, y2) = result2.split_at(2);
let d2 = ((x2[0] - x2[1]).powi(2) + (y2[0] - y2[1]).powi(2)).sqrt();
assert!(approx_eq(d1, 1.5, 0.2), "Distance with first run: {}", d1);
assert!(approx_eq(d2, 1.5, 0.2), "Distance with second run: {}", d2);
}
#[test]
fn test_compute_initial_layout_zero_distance() {
let mut rng = StdRng::seed_from_u64(42);
let distances = vec![vec![0.0, 0.0], vec![0.0, 0.0]];
let relationships = create_test_relationships(2);
let result = compute_initial_layout(&distances, &relationships, &[], &mut rng).unwrap();
assert_eq!(result.len(), 4);
let (x, y) = result.split_at(2);
let dx = x[0] - x[1];
let dy = y[0] - y[1];
let actual_distance = (dx * dx + dy * dy).sqrt();
assert!(
actual_distance < 0.1,
"Distance {} should be close to 0.0",
actual_distance
);
}
#[test]
fn test_compute_initial_layout_asymmetric_distances() {
let mut rng = StdRng::seed_from_u64(42);
let distances = vec![
vec![0.0, 1.0, 2.0],
vec![1.0, 0.0, 1.5],
vec![2.0, 1.5, 0.0],
];
let relationships = create_test_relationships(3);
let result = compute_initial_layout(&distances, &relationships, &[], &mut rng);
assert!(result.is_ok());
assert_eq!(result.unwrap().len(), 6);
}
}