use super::SinkhornSolver;
use crate::error::{MathError, Result};
use crate::utils::EPS;
#[derive(Debug, Clone)]
pub struct GromovWasserstein {
regularization: f64,
max_iterations: usize,
threshold: f64,
inner_iterations: usize,
}
impl GromovWasserstein {
pub fn new(regularization: f64) -> Self {
Self {
regularization: regularization.max(1e-6),
max_iterations: 100,
threshold: 1e-5,
inner_iterations: 50,
}
}
pub fn with_max_iterations(mut self, max_iter: usize) -> Self {
self.max_iterations = max_iter.max(1);
self
}
pub fn with_threshold(mut self, threshold: f64) -> Self {
self.threshold = threshold.max(1e-12);
self
}
fn distance_matrix(points: &[Vec<f64>]) -> Vec<Vec<f64>> {
let n = points.len();
let mut dist = vec![vec![0.0; n]; n];
for i in 0..n {
for j in (i + 1)..n {
let d: f64 = points[i]
.iter()
.zip(points[j].iter())
.map(|(&a, &b)| (a - b).powi(2))
.sum::<f64>()
.sqrt();
dist[i][j] = d;
dist[j][i] = d;
}
}
dist
}
fn compute_gw_loss(dist_x: &[Vec<f64>], dist_y: &[Vec<f64>], gamma: &[Vec<f64>]) -> f64 {
let n = dist_x.len();
let m = dist_y.len();
let p: Vec<f64> = gamma.iter().map(|row| row.iter().sum()).collect();
let term1: f64 = (0..n)
.map(|i| {
(0..n)
.map(|k| dist_x[i][k].powi(2) * p[i] * p[k])
.sum::<f64>()
})
.sum();
let q: Vec<f64> = (0..m)
.map(|j| gamma.iter().map(|row| row[j]).sum())
.collect();
let term2: f64 = (0..m)
.map(|j| {
(0..m)
.map(|l| dist_y[j][l].powi(2) * q[j] * q[l])
.sum::<f64>()
})
.sum();
let dx_gamma: Vec<Vec<f64>> = (0..n)
.map(|i| {
(0..m)
.map(|j| (0..n).map(|k| dist_x[i][k] * gamma[k][j]).sum())
.collect()
})
.collect();
let gamma_dy: Vec<Vec<f64>> = (0..n)
.map(|i| {
(0..m)
.map(|j| (0..m).map(|l| gamma[i][l] * dist_y[l][j]).sum())
.collect()
})
.collect();
let term3: f64 = 2.0
* (0..n)
.map(|i| (0..m).map(|j| dx_gamma[i][j] * gamma_dy[i][j]).sum::<f64>())
.sum::<f64>();
term1 + term2 - term3
}
fn compute_gradient(
dist_x: &[Vec<f64>],
dist_y: &[Vec<f64>],
gamma: &[Vec<f64>],
) -> Vec<Vec<f64>> {
let n = dist_x.len();
let m = dist_y.len();
let p: Vec<f64> = gamma.iter().map(|row| row.iter().sum()).collect();
let q: Vec<f64> = (0..m)
.map(|j| gamma.iter().map(|row| row[j]).sum())
.collect();
let dx2_p: Vec<f64> = (0..n)
.map(|i| (0..n).map(|k| dist_x[i][k].powi(2) * p[k]).sum())
.collect();
let dy2_q: Vec<f64> = (0..m)
.map(|j| (0..m).map(|l| dist_y[j][l].powi(2) * q[l]).sum())
.collect();
let dx_gamma_dy: Vec<Vec<f64>> = (0..n)
.map(|i| {
(0..m)
.map(|j| {
(0..n)
.map(|k| {
(0..m)
.map(|l| dist_x[i][k] * gamma[k][l] * dist_y[l][j])
.sum::<f64>()
})
.sum()
})
.collect()
})
.collect();
(0..n)
.map(|i| {
(0..m)
.map(|j| 2.0 * (dx2_p[i] + dy2_q[j] - 2.0 * dx_gamma_dy[i][j]))
.collect()
})
.collect()
}
pub fn solve(
&self,
source: &[Vec<f64>],
target: &[Vec<f64>],
) -> Result<GromovWassersteinResult> {
if source.is_empty() || target.is_empty() {
return Err(MathError::empty_input("points"));
}
let n = source.len();
let m = target.len();
let dist_x = Self::distance_matrix(source);
let dist_y = Self::distance_matrix(target);
let mut gamma: Vec<Vec<f64>> = (0..n).map(|_| vec![1.0 / (n * m) as f64; m]).collect();
let sinkhorn = SinkhornSolver::new(self.regularization, self.inner_iterations);
let source_weights = vec![1.0 / n as f64; n];
let target_weights = vec![1.0 / m as f64; m];
let mut loss = Self::compute_gw_loss(&dist_x, &dist_y, &gamma);
let mut converged = false;
for _iter in 0..self.max_iterations {
let gradient = Self::compute_gradient(&dist_x, &dist_y, &gamma);
let linear_result = sinkhorn.solve(&gradient, &source_weights, &target_weights)?;
let direction = linear_result.plan;
let mut best_alpha = 0.0;
let mut best_loss = loss;
for k in 1..=10 {
let alpha = k as f64 / 10.0;
let gamma_new: Vec<Vec<f64>> = (0..n)
.map(|i| {
(0..m)
.map(|j| (1.0 - alpha) * gamma[i][j] + alpha * direction[i][j])
.collect()
})
.collect();
let new_loss = Self::compute_gw_loss(&dist_x, &dist_y, &gamma_new);
if new_loss < best_loss {
best_alpha = alpha;
best_loss = new_loss;
}
}
if best_alpha > 0.0 {
for i in 0..n {
for j in 0..m {
gamma[i][j] =
(1.0 - best_alpha) * gamma[i][j] + best_alpha * direction[i][j];
}
}
}
let loss_change = (loss - best_loss).abs() / (loss.abs() + EPS);
loss = best_loss;
if loss_change < self.threshold {
converged = true;
break;
}
}
Ok(GromovWassersteinResult {
transport_plan: gamma,
loss,
converged,
})
}
pub fn distance(&self, source: &[Vec<f64>], target: &[Vec<f64>]) -> Result<f64> {
let result = self.solve(source, target)?;
Ok(result.loss.sqrt())
}
}
#[derive(Debug, Clone)]
pub struct GromovWassersteinResult {
pub transport_plan: Vec<Vec<f64>>,
pub loss: f64,
pub converged: bool,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gw_identical() {
let gw = GromovWasserstein::new(0.1);
let points = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.0, 1.0]];
let dist = gw.distance(&points, &points).unwrap();
assert!(
dist < 1.0,
"Identical structures should have low GW: {}",
dist
);
}
#[test]
fn test_gw_scaled() {
let gw = GromovWasserstein::new(0.1);
let source = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.0, 1.0]];
let target: Vec<Vec<f64>> = source
.iter()
.map(|p| vec![p[0] * 2.0, p[1] * 2.0])
.collect();
let dist = gw.distance(&source, &target).unwrap();
assert!(dist > 0.0, "Scaled structure should have some GW distance");
}
#[test]
fn test_gw_different_structures() {
let gw = GromovWasserstein::new(0.1);
let triangle = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.5, 0.866]];
let line = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![2.0, 0.0]];
let dist = gw.distance(&triangle, &line).unwrap();
assert!(
dist > 0.1,
"Different structures should have high GW: {}",
dist
);
}
#[test]
fn test_distance_matrix() {
let points = vec![vec![0.0, 0.0], vec![3.0, 4.0]];
let dist = GromovWasserstein::distance_matrix(&points);
assert!((dist[0][1] - 5.0).abs() < 1e-10);
assert!((dist[1][0] - 5.0).abs() < 1e-10);
assert!(dist[0][0].abs() < 1e-10);
}
}