use crate::distance::euclidean;
use crate::error::SpatialResult;
use scirs2_core::ndarray::{Array2, ArrayView2};
use scirs2_core::numeric::Float;
use scirs2_core::random::{rngs::StdRng, SeedableRng};
use scirs2_core::SliceRandomExt;
#[allow(dead_code)]
pub fn directed_hausdorff<T: Float + Send + Sync>(
set1: &ArrayView2<T>,
set2: &ArrayView2<T>,
seed: Option<u64>,
) -> (T, usize, usize) {
let n1 = set1.shape()[0];
let n2 = set2.shape()[0];
let dims = set1.shape()[1];
if n1 == 0 || n2 == 0 {
return (T::infinity(), 0, 0);
}
if set2.shape()[1] != dims {
return (T::infinity(), 0, 0);
}
let mut rng = match seed {
Some(s) => StdRng::seed_from_u64(s),
None => StdRng::try_from_rng(&mut rand::rngs::SysRng)
.unwrap_or_else(|_| StdRng::seed_from_u64(42)),
};
let mut indices1: Vec<usize> = (0..n1).collect();
let mut indices2: Vec<usize> = (0..n2).collect();
indices1.shuffle(&mut rng);
indices2.shuffle(&mut rng);
let mut set1_shuffled = Array2::zeros((n1, dims));
let mut set2_shuffled = Array2::zeros((n2, dims));
for (i, &idx) in indices1.iter().enumerate() {
set1_shuffled.row_mut(i).assign(&set1.row(idx));
}
for (i, &idx) in indices2.iter().enumerate() {
set2_shuffled.row_mut(i).assign(&set2.row(idx));
}
let mut cmax = T::zero();
let mut i_ret = 0;
let mut j_ret = 0;
for i in 0..n1 {
let mut cmin = T::infinity();
let mut j_store = 0;
let mut d_early_break = T::infinity();
for j in 0..n2 {
let mut d = T::zero();
for k in 0..dims {
let diff = set1_shuffled[[i, k]] - set2_shuffled[[j, k]];
d = d + diff * diff;
}
if d < cmax {
d_early_break = d;
break;
}
if d < cmin {
cmin = d;
j_store = j;
}
}
if d_early_break < cmax {
continue;
}
if cmin >= cmax {
cmax = cmin;
i_ret = i;
j_ret = j_store;
}
}
let i_original = indices1[i_ret];
let j_original = indices2[j_ret];
(cmax.sqrt(), i_original, j_original)
}
#[allow(dead_code)]
pub fn hausdorff_distance<T: Float + Send + Sync>(
set1: &ArrayView2<T>,
set2: &ArrayView2<T>,
seed: Option<u64>,
) -> T {
let (dist_forward__, _, _) = directed_hausdorff(set1, set2, seed);
let (dist_backward__, _, _) = directed_hausdorff(set2, set1, seed);
if dist_forward__ > dist_backward__ {
dist_forward__
} else {
dist_backward__
}
}
#[allow(dead_code)]
pub fn wasserstein_distance<T: Float + Send + Sync>(
set1: &ArrayView2<T>,
set2: &ArrayView2<T>,
) -> SpatialResult<T> {
let n1 = set1.shape()[0];
let n2 = set2.shape()[0];
let dims = set1.shape()[1];
if n1 == 0 || n2 == 0 {
return Ok(T::infinity());
}
if set2.shape()[1] != dims {
return Err(crate::error::SpatialError::DimensionError(
"Dimension mismatch: sets must have the same number of dimensions".to_string(),
));
}
if n1 == n2 {
let mut total_distance = T::zero();
let mut used = vec![false; n2];
for i in 0..n1 {
let point1 = set1.row(i);
let mut min_dist = T::infinity();
let mut best_j = 0;
for (j, &is_used) in used.iter().enumerate().take(n2) {
if !is_used {
let point2 = set2.row(j);
let dist = euclidean(&point1.to_vec(), &point2.to_vec());
if dist < min_dist {
min_dist = dist;
best_j = j;
}
}
}
used[best_j] = true;
total_distance = total_distance + min_dist;
}
return Ok(total_distance / T::from(n1).expect("Operation failed"));
}
let mut total_distance = T::zero();
for i in 0..n1 {
let point1 = set1.row(i);
let mut min_dist = T::infinity();
for j in 0..n2 {
let point2 = set2.row(j);
let dist = euclidean(&point1.to_vec(), &point2.to_vec());
if dist < min_dist {
min_dist = dist;
}
}
total_distance = total_distance + min_dist;
}
for j in 0..n2 {
let point2 = set2.row(j);
let mut min_dist = T::infinity();
for i in 0..n1 {
let point1 = set1.row(i);
let dist = euclidean(&point2.to_vec(), &point1.to_vec());
if dist < min_dist {
min_dist = dist;
}
}
total_distance = total_distance + min_dist;
}
let avg_n = T::from(n1 + n2).expect("Operation failed");
Ok(total_distance / avg_n)
}
#[allow(dead_code)]
pub fn gromov_hausdorff_distance<T: Float + Send + Sync>(
set1: &ArrayView2<T>,
set2: &ArrayView2<T>,
) -> T {
let n1 = set1.shape()[0];
let n2 = set2.shape()[0];
if n1 == 0 || n2 == 0 {
return T::infinity();
}
let mut dist_matrix1 = Array2::zeros((n1, n1));
let mut dist_matrix2 = Array2::zeros((n2, n2));
for i in 0..n1 {
for j in 0..n1 {
let p1 = set1.row(i).to_vec();
let p2 = set1.row(j).to_vec();
dist_matrix1[[i, j]] = euclidean(&p1, &p2);
}
}
for i in 0..n2 {
for j in 0..n2 {
let p1 = set2.row(i).to_vec();
let p2 = set2.row(j).to_vec();
dist_matrix2[[i, j]] = euclidean(&p1, &p2);
}
}
let diam1 = dist_matrix1.fold(T::neg_infinity(), |max, &val| max.max(val));
let diam2 = dist_matrix2.fold(T::neg_infinity(), |max, &val| max.max(val));
(diam1 - diam2).abs() / T::from(2).expect("Operation failed")
}
#[cfg(test)]
mod tests {
use super::{
directed_hausdorff, gromov_hausdorff_distance, hausdorff_distance, wasserstein_distance,
};
use approx::assert_relative_eq;
use scirs2_core::ndarray::{array, Array2};
#[test]
fn test_directed_hausdorff() {
let set1 = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]];
let set2 = array![[0.0, 0.5], [1.0, 0.5], [0.5, 1.0]];
let (dist, _idx1, _idx2) = directed_hausdorff(&set1.view(), &set2.view(), Some(42));
assert_relative_eq!(dist, 0.5, epsilon = 1e-6);
}
#[test]
fn test_hausdorff_distance() {
let set1 = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]];
let set2 = array![[0.0, 0.5], [1.0, 0.5], [0.5, 1.0]];
let dist = hausdorff_distance(&set1.view(), &set2.view(), Some(42));
assert_relative_eq!(dist, 0.5, epsilon = 1e-6);
let set3 = array![[0.0, 0.0], [1.0, 0.0]];
let set4 = array![[0.0, 2.0], [1.0, 2.0]];
let dist = hausdorff_distance(&set3.view(), &set4.view(), Some(42));
assert_relative_eq!(dist, 2.0, epsilon = 1e-6);
}
#[test]
fn test_wasserstein_distance() {
let set1 = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]];
let set2 = array![[0.0, 0.5], [1.0, 0.5], [0.5, 1.0]];
let dist = wasserstein_distance(&set1.view(), &set2.view()).expect("Operation failed");
assert!(dist > 0.0);
assert!(dist < 1.0);
}
#[test]
fn test_gromov_hausdorff_distance() {
let set1 = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]];
let set2 = array![[0.0, 0.0], [2.0, 0.0], [0.0, 2.0]];
let dist = gromov_hausdorff_distance(&set1.view(), &set2.view());
assert!(dist > 0.0);
}
#[test]
fn test_empty_sets() {
let set1 = array![[0.0, 0.0], [1.0, 0.0]];
let empty: Array2<f64> = Array2::zeros((0, 2));
let dist = hausdorff_distance(&set1.view(), &empty.view(), None);
assert!(dist.is_infinite());
let dist = hausdorff_distance(&empty.view(), &set1.view(), None);
assert!(dist.is_infinite());
}
}