use argmin::core::{CostFunction, Error, Executor, Gradient, State};
use argmin::solver::linesearch::MoreThuenteLineSearch;
use argmin::solver::quasinewton::LBFGS;
use nalgebra::DVector;
use rand::Rng;
use crate::spec::PairwiseRelations;
#[derive(Debug, Clone)]
pub(crate) struct InitialLayoutConfig {
pub max_attempts: usize,
pub patience: usize,
pub improvement_threshold: f64,
pub perfect_fit_threshold: f64,
}
impl Default for InitialLayoutConfig {
fn default() -> Self {
Self {
max_attempts: 100,
patience: 10,
improvement_threshold: 0.001, perfect_fit_threshold: 1e-8, }
}
}
pub(crate) fn compute_initial_layout(
distances: &Vec<Vec<f64>>,
relationships: &PairwiseRelations,
rng: &mut dyn rand::RngCore,
) -> Result<Vec<f64>, Error> {
compute_initial_layout_with_config(
distances,
relationships,
InitialLayoutConfig::default(),
rng,
)
}
pub(crate) fn compute_initial_layout_with_config(
distances: &Vec<Vec<f64>>,
relationships: &PairwiseRelations,
config: InitialLayoutConfig,
rng: &mut dyn rand::RngCore,
) -> Result<Vec<f64>, Error> {
let n_sets = distances.len();
let mut best_params = Vec::new();
let mut best_loss = f64::INFINITY;
let mut attempts_without_improvement = 0;
let max_distance = distances
.iter()
.flat_map(|row| row.iter())
.copied()
.fold(0.0_f64, f64::max);
let scale = if max_distance > 0.0 {
max_distance
} else {
10.0
};
for _attempt in 0..config.max_attempts {
let mut initial_values = vec![0.0; n_sets * 2];
for value in &mut initial_values {
*value = rng.random_range(-scale..scale);
}
let initial_param = DVector::from_vec(initial_values);
let cost_function = MdsCost {
distances,
relationships,
};
let line_search = MoreThuenteLineSearch::new();
let solver = LBFGS::new(line_search, 10);
let result = Executor::new(cost_function, solver)
.configure(|state| state.param(initial_param).max_iters(200))
.run()?;
let loss = result.state().get_cost();
println!("Attempt loss: {}", loss);
if loss < config.perfect_fit_threshold {
return Ok(result.state().get_best_param().unwrap().as_slice().to_vec());
}
if loss < best_loss {
let relative_improvement = if best_loss.is_finite() && best_loss > 0.0 {
(best_loss - loss) / best_loss
} else {
f64::INFINITY
};
if !best_loss.is_finite() || relative_improvement > config.improvement_threshold {
best_loss = loss;
best_params = result.state().get_best_param().unwrap().as_slice().to_vec();
attempts_without_improvement = 0;
} else {
attempts_without_improvement += 1;
}
} else {
attempts_without_improvement += 1;
}
if attempts_without_improvement >= config.patience {
break;
}
}
Ok(best_params)
}
struct MdsCost<'a> {
distances: &'a Vec<Vec<f64>>,
relationships: &'a PairwiseRelations,
}
impl<'a> CostFunction for MdsCost<'a> {
type Param = DVector<f64>;
type Output = f64;
fn cost(&self, param: &Self::Param) -> Result<Self::Output, Error> {
let n_sets = param.len() / 2;
let x = param.rows(0, n_sets);
let y = param.rows(n_sets, n_sets);
let mut loss = 0.0;
for i in 0..n_sets {
for j in 0..n_sets {
if i == j {
continue;
}
let xd = x[i] - x[j];
let yd = y[i] - y[j];
let d = xd.powi(2) + yd.powi(2) - self.distances[i][j].powi(2);
if self.relationships.is_disjoint(i, j) && d >= 0.0 {
continue;
}
if (self.relationships.is_subset(i, j) || self.relationships.is_subset(j, i))
&& d <= 0.0
{
continue;
}
loss += d.powi(2);
}
}
Ok(loss)
}
}
impl<'a> Gradient for MdsCost<'a> {
type Param = DVector<f64>;
type Gradient = DVector<f64>;
fn gradient(&self, param: &Self::Param) -> Result<Self::Gradient, Error> {
let n_sets = param.len() / 2;
let x = param.rows(0, n_sets);
let y = param.rows(n_sets, n_sets);
let mut grad = DVector::from_element(param.len(), 0.0);
for i in 0..n_sets {
for j in 0..n_sets {
if i == j {
continue;
}
let xd = x[i] - x[j];
let yd = y[i] - y[j];
let d = xd.powi(2) + yd.powi(2) - self.distances[i][j].powi(2);
if self.relationships.is_disjoint(i, j) && d >= 0.0 {
continue;
}
if (self.relationships.is_subset(i, j) || self.relationships.is_subset(j, i))
&& d <= 0.0
{
continue;
}
grad[i] += 4.0 * d * xd;
grad[i + n_sets] += 4.0 * d * yd;
}
}
Ok(grad)
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
(a - b).abs() < tol
}
fn create_test_relationships(n_sets: usize) -> PairwiseRelations {
PairwiseRelations {
n_sets,
subset: vec![vec![false; n_sets]; n_sets],
disjoint: vec![vec![false; n_sets]; n_sets],
overlap_areas: vec![vec![0.0; n_sets]; n_sets],
}
}
#[test]
fn test_compute_initial_layout_two_sets_touching() {
let mut rng = StdRng::seed_from_u64(42);
let distances = vec![vec![0.0, 2.0], vec![2.0, 0.0]];
let relationships = create_test_relationships(2);
let result = compute_initial_layout(&distances, &relationships, &mut rng).unwrap();
assert_eq!(result.len(), 4);
let (x, y) = result.split_at(2);
let dx = x[0] - x[1];
let dy = y[0] - y[1];
let actual_distance = (dx * dx + dy * dy).sqrt();
assert!(
approx_eq(actual_distance, 2.0, 0.1),
"Distance {} should be close to 2.0",
actual_distance
);
}
#[test]
fn test_compute_initial_layout_two_sets_overlapping() {
let mut rng = StdRng::seed_from_u64(42);
let distances = vec![vec![0.0, 1.0], vec![1.0, 0.0]];
let relationships = create_test_relationships(2);
let result = compute_initial_layout(&distances, &relationships, &mut rng).unwrap();
assert_eq!(result.len(), 4);
let (x, y) = result.split_at(2);
let dx = x[0] - x[1];
let dy = y[0] - y[1];
let actual_distance = (dx * dx + dy * dy).sqrt();
assert!(
approx_eq(actual_distance, 1.0, 0.1),
"Distance {} should be close to 1.0",
actual_distance
);
}
#[test]
fn test_compute_initial_layout_two_sets_separated() {
let mut rng = StdRng::seed_from_u64(42);
let distances = vec![vec![0.0, 5.0], vec![5.0, 0.0]];
let relationships = create_test_relationships(2);
let result = compute_initial_layout(&distances, &relationships, &mut rng).unwrap();
assert_eq!(result.len(), 4);
let (x, y) = result.split_at(2);
let dx = x[0] - x[1];
let dy = y[0] - y[1];
let actual_distance = (dx * dx + dy * dy).sqrt();
assert!(
approx_eq(actual_distance, 5.0, 0.1),
"Distance {} should be close to 5.0",
actual_distance
);
}
#[test]
fn test_compute_initial_layout_two_sets_disjoint() {
let mut rng = StdRng::seed_from_u64(42);
let distances = vec![vec![0.0, 3.0], vec![3.0, 0.0]];
let mut relationships = create_test_relationships(2);
relationships.disjoint[0][1] = true;
relationships.disjoint[1][0] = true;
let result = compute_initial_layout(&distances, &relationships, &mut rng).unwrap();
assert_eq!(result.len(), 4);
let (x, y) = result.split_at(2);
let dx = x[0] - x[1];
let dy = y[0] - y[1];
let actual_distance = (dx * dx + dy * dy).sqrt();
assert!(
actual_distance >= 3.0 - 0.1,
"Distance {} should be at least 3.0",
actual_distance
);
}
#[test]
fn test_compute_initial_layout_three_sets_triangle() {
let mut rng = StdRng::seed_from_u64(42);
let d = 2.0; let distances = vec![vec![0.0, d, d], vec![d, 0.0, d], vec![d, d, 0.0]];
let relationships = create_test_relationships(3);
let result = compute_initial_layout(&distances, &relationships, &mut rng).unwrap();
assert_eq!(result.len(), 6);
let (x, y) = result.split_at(3);
for i in 0..3 {
for j in (i + 1)..3 {
let dx = x[i] - x[j];
let dy = y[i] - y[j];
let actual_distance = (dx * dx + dy * dy).sqrt();
assert!(
approx_eq(actual_distance, d, 0.2),
"Distance between {} and {} is {}, should be close to {}",
i,
j,
actual_distance,
d
);
}
}
}
#[test]
fn test_compute_initial_layout_three_sets_collinear() {
let mut rng = StdRng::seed_from_u64(42);
let distances = vec![
vec![0.0, 1.0, 2.0],
vec![1.0, 0.0, 1.0],
vec![2.0, 1.0, 0.0],
];
let relationships = create_test_relationships(3);
let result = compute_initial_layout(&distances, &relationships, &mut rng).unwrap();
assert_eq!(result.len(), 6);
let (x, y) = result.split_at(3);
let d01 = ((x[0] - x[1]).powi(2) + (y[0] - y[1]).powi(2)).sqrt();
let d12 = ((x[1] - x[2]).powi(2) + (y[1] - y[2]).powi(2)).sqrt();
let d02 = ((x[0] - x[2]).powi(2) + (y[0] - y[2]).powi(2)).sqrt();
assert!(approx_eq(d01, 1.0, 0.2), "Distance 0-1: {}", d01);
assert!(approx_eq(d12, 1.0, 0.2), "Distance 1-2: {}", d12);
assert!(approx_eq(d02, 2.0, 0.2), "Distance 0-2: {}", d02);
}
#[test]
fn test_compute_initial_layout_four_sets_square() {
let mut rng = StdRng::seed_from_u64(42);
let side = 1.0;
let diag = side * 2.0_f64.sqrt();
let distances = vec![
vec![0.0, side, diag, side],
vec![side, 0.0, side, diag],
vec![diag, side, 0.0, side],
vec![side, diag, side, 0.0],
];
let relationships = create_test_relationships(4);
let result = compute_initial_layout(&distances, &relationships, &mut rng).unwrap();
assert_eq!(result.len(), 8);
let (x, y) = result.split_at(4);
let d01 = ((x[0] - x[1]).powi(2) + (y[0] - y[1]).powi(2)).sqrt();
let d02 = ((x[0] - x[2]).powi(2) + (y[0] - y[2]).powi(2)).sqrt();
assert!(d01 < d02);
}
#[test]
fn test_compute_initial_layout_with_restarts() {
let mut rng = StdRng::seed_from_u64(42);
let distances = vec![vec![0.0, 1.5], vec![1.5, 0.0]];
let relationships = create_test_relationships(2);
let result1 = compute_initial_layout(&distances, &relationships, &mut rng).unwrap();
let result2 = compute_initial_layout(&distances, &relationships, &mut rng).unwrap();
assert_eq!(result1.len(), 4);
assert_eq!(result2.len(), 4);
let (x1, y1) = result1.split_at(2);
let d1 = ((x1[0] - x1[1]).powi(2) + (y1[0] - y1[1]).powi(2)).sqrt();
let (x2, y2) = result2.split_at(2);
let d2 = ((x2[0] - x2[1]).powi(2) + (y2[0] - y2[1]).powi(2)).sqrt();
assert!(approx_eq(d1, 1.5, 0.2), "Distance with first run: {}", d1);
assert!(approx_eq(d2, 1.5, 0.2), "Distance with second run: {}", d2);
}
#[test]
fn test_compute_initial_layout_zero_distance() {
let mut rng = StdRng::seed_from_u64(42);
let distances = vec![vec![0.0, 0.0], vec![0.0, 0.0]];
let relationships = create_test_relationships(2);
let result = compute_initial_layout(&distances, &relationships, &mut rng).unwrap();
assert_eq!(result.len(), 4);
let (x, y) = result.split_at(2);
let dx = x[0] - x[1];
let dy = y[0] - y[1];
let actual_distance = (dx * dx + dy * dy).sqrt();
assert!(
actual_distance < 0.1,
"Distance {} should be close to 0.0",
actual_distance
);
}
#[test]
fn test_compute_initial_layout_asymmetric_distances() {
let mut rng = StdRng::seed_from_u64(42);
let distances = vec![
vec![0.0, 1.0, 2.0],
vec![1.0, 0.0, 1.5],
vec![2.0, 1.5, 0.0],
];
let relationships = create_test_relationships(3);
let result = compute_initial_layout(&distances, &relationships, &mut rng);
assert!(result.is_ok());
assert_eq!(result.unwrap().len(), 6);
}
#[test]
fn test_patience_based_optimization() {
let mut rng = StdRng::seed_from_u64(42);
let distances = vec![vec![0.0, 1.5], vec![1.5, 0.0]];
let relationships = create_test_relationships(2);
let config = InitialLayoutConfig {
max_attempts: 100,
patience: 10,
improvement_threshold: 0.01,
perfect_fit_threshold: 1e-8,
};
let result =
compute_initial_layout_with_config(&distances, &relationships, config, &mut rng);
assert!(result.is_ok());
let params = result.unwrap();
assert_eq!(params.len(), 4);
let (x, y) = params.split_at(2);
let d = ((x[0] - x[1]).powi(2) + (y[0] - y[1]).powi(2)).sqrt();
assert!(approx_eq(d, 1.5, 0.2), "Distance: {}", d);
}
#[test]
fn test_patience_with_zero_threshold() {
let mut rng = StdRng::seed_from_u64(42);
let distances = vec![vec![0.0, 2.0], vec![2.0, 0.0]];
let relationships = create_test_relationships(2);
let config = InitialLayoutConfig {
max_attempts: 10,
patience: 2,
improvement_threshold: 0.0,
perfect_fit_threshold: 1e-8,
};
let result =
compute_initial_layout_with_config(&distances, &relationships, config, &mut rng);
assert!(result.is_ok());
}
#[test]
fn test_early_stopping_on_perfect_fit() {
let mut rng = StdRng::seed_from_u64(42);
let distances = vec![vec![0.0, 1.0], vec![1.0, 0.0]];
let relationships = create_test_relationships(2);
let config = InitialLayoutConfig {
max_attempts: 100,
patience: 5,
improvement_threshold: 0.01,
perfect_fit_threshold: 1e-8,
};
let result =
compute_initial_layout_with_config(&distances, &relationships, config, &mut rng);
assert!(result.is_ok());
let params = result.unwrap();
assert_eq!(params.len(), 4);
let (x, y) = params.split_at(2);
let d = ((x[0] - x[1]).powi(2) + (y[0] - y[1]).powi(2)).sqrt();
assert!(approx_eq(d, 1.0, 0.1), "Distance: {}", d);
}
}