use crate::error::{Result, TransformError};
use scirs2_core::ndarray::Array2;
#[derive(Debug, Clone, PartialEq)]
pub struct PersistencePair {
pub birth: f64,
pub death: f64,
pub dimension: usize,
}
impl PersistencePair {
pub fn new(birth: f64, death: f64, dimension: usize) -> Self {
Self {
birth,
death,
dimension,
}
}
pub fn persistence(&self) -> f64 {
if self.death.is_infinite() {
f64::INFINITY
} else {
self.death - self.birth
}
}
pub fn is_essential(&self) -> bool {
self.death.is_infinite()
}
pub fn diagonal_projection(&self) -> Option<(f64, f64)> {
if self.is_essential() {
None
} else {
let proj = (self.birth + self.death) / 2.0;
Some((proj, proj))
}
}
}
struct UnionFind {
parent: Vec<usize>,
rank: Vec<usize>,
birth: Vec<f64>, }
impl UnionFind {
fn new(n: usize, birth_values: &[f64]) -> Self {
Self {
parent: (0..n).collect(),
rank: vec![0; n],
birth: birth_values.to_vec(),
}
}
fn find(&mut self, x: usize) -> usize {
if self.parent[x] != x {
self.parent[x] = self.find(self.parent[x]);
}
self.parent[x]
}
fn union(&mut self, x: usize, y: usize, edge_weight: f64) -> Option<(usize, f64)> {
let rx = self.find(x);
let ry = self.find(y);
if rx == ry {
return None;
}
let (survivor, killed) = if self.birth[rx] <= self.birth[ry] {
(rx, ry)
} else {
(ry, rx)
};
if self.rank[survivor] < self.rank[killed] {
self.parent[survivor] = killed;
let (s2, k2) = (killed, survivor);
self.parent[k2] = s2;
self.parent[s2] = s2;
self.birth[s2] = self.birth[s2].min(self.birth[k2]);
return Some((k2, edge_weight));
} else {
self.parent[killed] = survivor;
if self.rank[survivor] == self.rank[killed] {
self.rank[survivor] += 1;
}
}
Some((killed, edge_weight))
}
}
pub fn persistence_diagram_0d(distance_matrix: &Array2<f64>) -> Result<Vec<PersistencePair>> {
let n = distance_matrix.nrows();
if n == 0 {
return Ok(Vec::new());
}
if distance_matrix.ncols() != n {
return Err(TransformError::InvalidInput(
"distance_matrix must be square".to_string(),
));
}
let mut edges: Vec<(usize, usize, f64)> = Vec::with_capacity(n * (n - 1) / 2);
for i in 0..n {
for j in (i + 1)..n {
let w = distance_matrix[[i, j]];
edges.push((i, j, w));
}
}
edges.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal));
let birth_times = vec![0.0f64; n];
let mut uf = UnionFind::new(n, &birth_times);
let mut pairs: Vec<PersistencePair> = Vec::new();
for (i, j, w) in &edges {
if let Some((_killed, death)) = uf.union(*i, *j, *w) {
pairs.push(PersistencePair::new(0.0, death, 0));
}
}
pairs.push(PersistencePair::new(0.0, f64::INFINITY, 0));
Ok(pairs)
}
pub fn bottleneck_distance(diagram1: &[PersistencePair], diagram2: &[PersistencePair]) -> f64 {
let pts1: Vec<(f64, f64)> = diagram1
.iter()
.filter(|p| !p.is_essential())
.map(|p| (p.birth, p.death))
.collect();
let pts2: Vec<(f64, f64)> = diagram2
.iter()
.filter(|p| !p.is_essential())
.map(|p| (p.birth, p.death))
.collect();
bottleneck_finite(&pts1, &pts2)
}
pub fn bottleneck_distance_dim(
diagram1: &[PersistencePair],
diagram2: &[PersistencePair],
dim: usize,
) -> f64 {
let pts1: Vec<(f64, f64)> = diagram1
.iter()
.filter(|p| p.dimension == dim && !p.is_essential())
.map(|p| (p.birth, p.death))
.collect();
let pts2: Vec<(f64, f64)> = diagram2
.iter()
.filter(|p| p.dimension == dim && !p.is_essential())
.map(|p| (p.birth, p.death))
.collect();
bottleneck_finite(&pts1, &pts2)
}
fn bottleneck_finite(pts1: &[(f64, f64)], pts2: &[(f64, f64)]) -> f64 {
let diag_cost = |b: f64, d: f64| -> f64 { (d - b).abs() / 2.0 };
let linf = |a: (f64, f64), b: (f64, f64)| -> f64 {
(a.0 - b.0).abs().max((a.1 - b.1).abs())
};
let mut candidates = Vec::new();
for &p1 in pts1 {
for &p2 in pts2 {
candidates.push(linf(p1, p2));
}
candidates.push(diag_cost(p1.0, p1.1));
}
for &p2 in pts2 {
candidates.push(diag_cost(p2.0, p2.1));
}
if candidates.is_empty() {
return 0.0;
}
candidates.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
candidates.dedup_by(|a, b| (*a - *b).abs() < 1e-12);
let mut lo = 0usize;
let mut hi = candidates.len();
while lo < hi {
let mid = (lo + hi) / 2;
let delta = candidates[mid];
if admits_perfect_matching(pts1, pts2, delta, &linf, &diag_cost) {
hi = mid;
} else {
lo = mid + 1;
}
}
if lo < candidates.len() {
candidates[lo]
} else {
*candidates.last().unwrap_or(&0.0)
}
}
fn admits_perfect_matching(
pts1: &[(f64, f64)],
pts2: &[(f64, f64)],
delta: f64,
linf: &dyn Fn((f64, f64), (f64, f64)) -> f64,
diag_cost: &dyn Fn(f64, f64) -> f64,
) -> bool {
let n1 = pts1.len();
let n2 = pts2.len();
let left_size = n1 + n2;
let right_size = n2 + n1;
debug_assert_eq!(left_size, right_size);
let total = left_size;
let mut adj: Vec<Vec<usize>> = vec![Vec::new(); total];
for i in 0..n1 {
for j in 0..n2 {
if linf(pts1[i], pts2[j]) <= delta + 1e-12 {
adj[i].push(j);
}
}
if diag_cost(pts1[i].0, pts1[i].1) <= delta + 1e-12 {
adj[i].push(n2 + i);
}
}
for j in 0..n2 {
if diag_cost(pts2[j].0, pts2[j].1) <= delta + 1e-12 {
adj[n1 + j].push(j);
}
for i in 0..n1 {
adj[n1 + j].push(n2 + i);
}
}
let mut match_l = vec![usize::MAX; total]; let mut match_r = vec![usize::MAX; total];
let mut matched = 0usize;
for l in 0..total {
let mut visited = vec![false; total];
if augment(l, &adj, &mut match_l, &mut match_r, &mut visited) {
matched += 1;
}
}
matched == total
}
fn augment(
u: usize,
adj: &[Vec<usize>],
match_l: &mut Vec<usize>,
match_r: &mut Vec<usize>,
visited: &mut Vec<bool>,
) -> bool {
for &v in &adj[u] {
if visited[v] {
continue;
}
visited[v] = true;
let prev = match_r[v];
let can_augment = prev == usize::MAX
|| augment(prev, adj, match_l, match_r, visited);
if can_augment {
match_l[u] = v;
match_r[v] = u;
return true;
}
}
false
}
pub fn wasserstein_distance_pd(
diagram1: &[PersistencePair],
diagram2: &[PersistencePair],
p: usize,
) -> f64 {
if p == 0 {
return f64::NAN;
}
let pts1: Vec<(f64, f64)> = diagram1
.iter()
.filter(|pt| !pt.is_essential())
.map(|pt| (pt.birth, pt.death))
.collect();
let pts2: Vec<(f64, f64)> = diagram2
.iter()
.filter(|pt| !pt.is_essential())
.map(|pt| (pt.birth, pt.death))
.collect();
wasserstein_finite(&pts1, &pts2, p)
}
fn wasserstein_finite(pts1: &[(f64, f64)], pts2: &[(f64, f64)], p: usize) -> f64 {
let diag_cost = |b: f64, d: f64| -> f64 { (d - b).abs() / 2.0 };
let linf = |a: (f64, f64), b: (f64, f64)| -> f64 {
(a.0 - b.0).abs().max((a.1 - b.1).abs())
};
let n1 = pts1.len();
let n2 = pts2.len();
let n = n1 + n2;
if n == 0 {
return 0.0;
}
let cost_fn = |r: usize, c: usize| -> f64 {
match (r < n1, c < n2) {
(true, true) => {
linf(pts1[r], pts2[c])
}
(true, false) => {
let i = c - n2;
if i == r {
diag_cost(pts1[r].0, pts1[r].1)
} else {
f64::INFINITY
}
}
(false, true) => {
let j = r - n1;
if j == c {
diag_cost(pts2[c].0, pts2[c].1)
} else {
f64::INFINITY
}
}
(false, false) => {
0.0
}
}
};
let assignment = hungarian_assignment(n, &cost_fn);
let cost: f64 = (0..n)
.map(|i| {
let j = assignment[i];
let c = cost_fn(i, j);
if c.is_infinite() {
0.0 } else {
c.powi(p as i32)
}
})
.sum::<f64>();
cost.powf(1.0 / p as f64)
}
fn hungarian_assignment(n: usize, cost_fn: &dyn Fn(usize, usize) -> f64) -> Vec<usize> {
let sentinel = 1e18f64;
let mut cost: Vec<Vec<f64>> = (0..n)
.map(|i| {
(0..n)
.map(|j| {
let c = cost_fn(i, j);
if c.is_infinite() { sentinel } else { c }
})
.collect()
})
.collect();
for i in 0..n {
let min_r = cost[i].iter().cloned().fold(f64::INFINITY, f64::min);
for j in 0..n {
cost[i][j] -= min_r;
}
}
for j in 0..n {
let min_c = (0..n).map(|i| cost[i][j]).fold(f64::INFINITY, f64::min);
for i in 0..n {
cost[i][j] -= min_c;
}
}
let mut row_cover = vec![false; n];
let mut col_cover = vec![false; n];
let mut assignment = vec![usize::MAX; n];
let mut assigned_row = vec![usize::MAX; n];
for i in 0..n {
for j in 0..n {
if cost[i][j].abs() < 1e-10 && !col_cover[j] && assignment[i] == usize::MAX {
assignment[i] = j;
assigned_row[j] = i;
col_cover[j] = true;
}
}
}
for _ in 0..(n * n) {
let unassigned = (0..n).find(|&i| assignment[i] == usize::MAX);
let Some(start_row) = unassigned else { break };
let mut visited_cols = vec![false; n];
let mut path_col = vec![usize::MAX; n];
let mut path_row = vec![usize::MAX; n];
if !try_augment_hungarian(
start_row,
&cost,
&assignment,
&assigned_row,
&mut visited_cols,
&mut path_col,
&mut path_row,
n,
) {
let min_uncov = (0..n)
.filter(|&i| !row_cover[i])
.flat_map(|i| {
let row = &cost[i];
(0..n)
.filter(|&j| !col_cover[j])
.map(|j| row[j])
.collect::<Vec<_>>()
})
.fold(f64::INFINITY, f64::min);
if min_uncov.is_infinite() {
break;
}
for i in 0..n {
for j in 0..n {
if !row_cover[i] && !col_cover[j] {
cost[i][j] -= min_uncov;
} else if row_cover[i] && col_cover[j] {
cost[i][j] += min_uncov;
}
}
}
col_cover = vec![false; n];
row_cover = vec![false; n];
assignment = vec![usize::MAX; n];
assigned_row = vec![usize::MAX; n];
for i in 0..n {
for j in 0..n {
if cost[i][j].abs() < 1e-10
&& !col_cover[j]
&& assignment[i] == usize::MAX
{
assignment[i] = j;
assigned_row[j] = i;
col_cover[j] = true;
}
}
}
}
}
for i in 0..n {
if assignment[i] == usize::MAX {
for j in 0..n {
if assigned_row[j] == usize::MAX {
assignment[i] = j;
assigned_row[j] = i;
break;
}
}
}
}
assignment
}
fn try_augment_hungarian(
row: usize,
cost: &[Vec<f64>],
assignment: &[usize],
assigned_row: &[usize],
visited_cols: &mut Vec<bool>,
path_col: &mut Vec<usize>,
path_row: &mut Vec<usize>,
n: usize,
) -> bool {
for j in 0..n {
if !visited_cols[j] && cost[row][j].abs() < 1e-10 {
visited_cols[j] = true;
path_col[row] = j;
let prev_row = assigned_row[j];
if prev_row == usize::MAX
|| try_augment_hungarian(
prev_row, cost, assignment, assigned_row, visited_cols, path_col,
path_row, n,
)
{
path_row[j] = row;
return true;
}
}
}
false
}
pub fn persistence_entropy(diagram: &[PersistencePair]) -> f64 {
let persts: Vec<f64> = diagram
.iter()
.filter(|p| !p.is_essential() && p.persistence() > 0.0)
.map(|p| p.persistence())
.collect();
if persts.is_empty() {
return 0.0;
}
let total: f64 = persts.iter().sum();
if total <= 0.0 {
return 0.0;
}
persts
.iter()
.map(|&pi| {
let p = pi / total;
if p > 0.0 {
-p * p.log2()
} else {
0.0
}
})
.sum()
}
pub fn betti_numbers(diagram: &[PersistencePair], threshold: f64) -> Vec<usize> {
if diagram.is_empty() {
return Vec::new();
}
let max_dim = diagram.iter().map(|p| p.dimension).max().unwrap_or(0);
let mut betti = vec![0usize; max_dim + 1];
for pair in diagram {
let alive = pair.birth <= threshold
&& (pair.is_essential() || pair.death > threshold);
if alive && pair.dimension <= max_dim {
betti[pair.dimension] += 1;
}
}
betti
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
fn triangle_dist() -> Array2<f64> {
array![[0.0, 1.0, 2.0], [1.0, 0.0, 3.0], [2.0, 3.0, 0.0]]
}
fn four_points_dist() -> Array2<f64> {
array![
[0.0, 0.5, 5.0, 5.2],
[0.5, 0.0, 5.1, 5.0],
[5.0, 5.1, 0.0, 0.3],
[5.2, 5.0, 0.3, 0.0],
]
}
#[test]
fn test_persistence_diagram_0d_triangle() {
let dist = triangle_dist();
let pairs = persistence_diagram_0d(&dist).expect("should succeed");
assert_eq!(pairs.len(), 3);
let finite: Vec<_> = pairs.iter().filter(|p| !p.is_essential()).collect();
assert_eq!(finite.len(), 2);
let essential: Vec<_> = pairs.iter().filter(|p| p.is_essential()).collect();
assert_eq!(essential.len(), 1);
}
#[test]
fn test_persistence_diagram_0d_two_clusters() {
let dist = four_points_dist();
let pairs = persistence_diagram_0d(&dist).expect("should succeed");
let finite: Vec<_> = pairs.iter().filter(|p| !p.is_essential()).collect();
assert_eq!(finite.len(), 3);
let max_death = finite
.iter()
.map(|p| p.death)
.fold(f64::NEG_INFINITY, f64::max);
assert!(max_death > 4.0, "large inter-cluster death expected, got {max_death}");
}
#[test]
fn test_bottleneck_same_diagram() {
let pairs = persistence_diagram_0d(&triangle_dist()).expect("ok");
let d = bottleneck_distance(&pairs, &pairs);
assert!(d < 1e-10, "self-distance should be 0, got {d}");
}
#[test]
fn test_bottleneck_empty_diagrams() {
let d = bottleneck_distance(&[], &[]);
assert_eq!(d, 0.0);
}
#[test]
fn test_wasserstein_same_diagram() {
let pairs = persistence_diagram_0d(&triangle_dist()).expect("ok");
let d = wasserstein_distance_pd(&pairs, &pairs, 2);
assert!(d < 1e-6, "self-distance should be 0, got {d}");
}
#[test]
fn test_wasserstein_empty() {
let d = wasserstein_distance_pd(&[], &[], 2);
assert_eq!(d, 0.0);
}
#[test]
fn test_persistence_entropy_empty() {
assert_eq!(persistence_entropy(&[]), 0.0);
}
#[test]
fn test_persistence_entropy_positive() {
let pairs = persistence_diagram_0d(&four_points_dist()).expect("ok");
let h = persistence_entropy(&pairs);
assert!(h >= 0.0);
}
#[test]
fn test_betti_numbers_threshold() {
let pairs = persistence_diagram_0d(&four_points_dist()).expect("ok");
let b_start = betti_numbers(&pairs, 0.0);
assert!(!b_start.is_empty());
let b_end = betti_numbers(&pairs, 100.0);
assert_eq!(b_end[0], 1, "one essential component, got {:?}", b_end);
}
#[test]
fn test_betti_numbers_empty() {
let b = betti_numbers(&[], 1.0);
assert!(b.is_empty());
}
#[test]
fn test_persistence_pair_methods() {
let p = PersistencePair::new(1.0, 4.0, 0);
assert!((p.persistence() - 3.0).abs() < 1e-10);
assert!(!p.is_essential());
let proj = p.diagonal_projection().expect("has projection");
assert!((proj.0 - 2.5).abs() < 1e-10);
let ess = PersistencePair::new(0.0, f64::INFINITY, 0);
assert!(ess.is_essential());
assert_eq!(ess.persistence(), f64::INFINITY);
assert!(ess.diagonal_projection().is_none());
}
#[test]
fn test_non_square_distance_matrix() {
let dist = Array2::<f64>::zeros((3, 4));
assert!(persistence_diagram_0d(&dist).is_err());
}
}