use std::collections::HashMap;
use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
use crate::error::{ClusteringError, Result};
struct UnionFind {
parent: Vec<usize>,
rank: Vec<usize>,
}
impl UnionFind {
fn new(n: usize) -> Self {
Self {
parent: (0..n).collect(),
rank: vec![0; n],
}
}
fn find(&mut self, x: usize) -> usize {
if self.parent[x] != x {
self.parent[x] = self.find(self.parent[x]);
}
self.parent[x]
}
fn union(&mut self, x: usize, y: usize) {
let rx = self.find(x);
let ry = self.find(y);
if rx == ry {
return;
}
if self.rank[rx] < self.rank[ry] {
self.parent[rx] = ry;
} else if self.rank[rx] > self.rank[ry] {
self.parent[ry] = rx;
} else {
self.parent[ry] = rx;
self.rank[rx] += 1;
}
}
}
#[derive(Debug, Clone)]
pub struct PersistenceBar {
pub birth: f64,
pub death: f64,
pub representative: usize,
}
impl PersistenceBar {
pub fn persistence(&self) -> f64 {
if self.death == f64::MAX {
f64::MAX
} else {
(self.death - self.birth).abs()
}
}
}
pub fn density_persistence(
data: ArrayView2<f64>,
k_neighbors: usize,
) -> Result<Vec<PersistenceBar>> {
let n = data.nrows();
let d = data.ncols();
if n == 0 || d == 0 {
return Err(ClusteringError::InvalidInput(
"persistence: data must be non-empty".into(),
));
}
let k = k_neighbors.min(n - 1).max(1);
let mut knn_dist = vec![0.0f64; n];
let mut knn_graph: Vec<Vec<usize>> = vec![Vec::new(); n];
for i in 0..n {
let mut dists: Vec<(usize, f64)> = (0..n)
.filter(|&j| j != i)
.map(|j| {
let d2: f64 = data
.row(i)
.iter()
.zip(data.row(j).iter())
.map(|(&a, &b)| (a - b) * (a - b))
.sum();
(j, d2.sqrt())
})
.collect();
dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
knn_dist[i] = dists[k - 1].1.max(1e-15);
for &(j, _) in dists.iter().take(k) {
knn_graph[i].push(j);
}
}
let density: Vec<f64> = knn_dist
.iter()
.map(|&r| 1.0 / r.powi(d as i32))
.collect();
let mut order: Vec<usize> = (0..n).collect();
order.sort_by(|&a, &b| {
density[b]
.partial_cmp(&density[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut uf = UnionFind::new(n);
let mut added = vec![false; n];
let mut birth_density = vec![0.0f64; n];
let mut bars: Vec<PersistenceBar> = Vec::new();
for &i in &order {
added[i] = true;
birth_density[i] = density[i];
let mut new_root = i;
for &j in &knn_graph[i] {
if !added[j] {
continue;
}
let rj = uf.find(j);
let ri = uf.find(new_root);
if ri == rj {
continue;
}
let birth_ri = birth_density[ri];
let birth_rj = birth_density[rj];
if birth_ri >= birth_rj {
bars.push(PersistenceBar {
birth: birth_rj,
death: density[i],
representative: rj,
});
uf.union(ri, rj);
let new_r = uf.find(ri);
birth_density[new_r] = birth_density[ri];
new_root = new_r;
} else {
bars.push(PersistenceBar {
birth: birth_ri,
death: density[i],
representative: ri,
});
uf.union(rj, ri);
let new_r = uf.find(rj);
birth_density[new_r] = birth_density[rj];
new_root = new_r;
}
}
}
let global_root = uf.find(order[0]);
bars.push(PersistenceBar {
birth: birth_density[global_root],
death: f64::MAX,
representative: global_root,
});
bars.sort_by(|a, b| {
b.persistence()
.partial_cmp(&a.persistence())
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(bars)
}
#[derive(Debug, Clone)]
pub struct TomaToConfig {
pub k_neighbors: usize,
pub persistence_threshold: f64,
pub auto_threshold: bool,
}
impl Default for TomaToConfig {
fn default() -> Self {
Self {
k_neighbors: 5,
persistence_threshold: 0.0,
auto_threshold: true,
}
}
}
impl TomaToConfig {
pub fn with_threshold(k_neighbors: usize, tau: f64) -> Self {
Self {
k_neighbors,
persistence_threshold: tau,
auto_threshold: false,
}
}
}
#[derive(Debug, Clone)]
pub struct TomaToResult {
pub labels: Vec<usize>,
pub n_clusters: usize,
pub threshold: f64,
pub persistence_diagram: Vec<PersistenceBar>,
}
pub fn tomato(data: ArrayView2<f64>, config: &TomaToConfig) -> Result<TomaToResult> {
let n = data.nrows();
let d = data.ncols();
if n == 0 || d == 0 {
return Err(ClusteringError::InvalidInput(
"tomato: data must be non-empty".into(),
));
}
let k = config.k_neighbors.min(n - 1).max(1);
let mut knn_dist = vec![0.0f64; n];
let mut knn_graph: Vec<Vec<usize>> = vec![Vec::new(); n];
for i in 0..n {
let mut dists: Vec<(usize, f64)> = (0..n)
.filter(|&j| j != i)
.map(|j| {
let d2: f64 = data
.row(i)
.iter()
.zip(data.row(j).iter())
.map(|(&a, &b)| (a - b) * (a - b))
.sum();
(j, d2.sqrt())
})
.collect();
dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
knn_dist[i] = dists[k - 1].1.max(1e-15);
for &(j, _) in dists.iter().take(k) {
knn_graph[i].push(j);
}
}
let density: Vec<f64> = knn_dist
.iter()
.map(|&r| 1.0 / r.powi(d as i32))
.collect();
let mut gradient_parent: Vec<usize> = (0..n).collect();
for i in 0..n {
let mut best = i;
let mut best_dens = density[i];
for &j in &knn_graph[i] {
if density[j] > best_dens {
best = j;
best_dens = density[j];
}
}
gradient_parent[i] = best;
}
let mut mode_of: Vec<usize> = (0..n).collect();
for i in 0..n {
mode_of[i] = follow_gradient(&gradient_parent, i);
}
let bars = density_persistence(data, k)?;
let tau = if config.auto_threshold || config.persistence_threshold <= 0.0 {
gap_threshold(&bars)
} else {
config.persistence_threshold
};
let mut mode_uf = UnionFind::new(n);
for bar in &bars {
if bar.death == f64::MAX {
continue;
}
if bar.persistence() < tau {
let rep = bar.representative;
let merge_to = knn_graph[rep]
.iter()
.copied()
.max_by(|&a, &b| {
density[a]
.partial_cmp(&density[b])
.unwrap_or(std::cmp::Ordering::Equal)
})
.unwrap_or(rep);
mode_uf.union(rep, merge_to);
}
}
let mut root_to_label: HashMap<usize, usize> = HashMap::new();
let mut next_label = 0usize;
let mut labels = vec![0usize; n];
for i in 0..n {
let mode = mode_of[i];
let root = mode_uf.find(mode);
let label = root_to_label.entry(root).or_insert_with(|| {
let l = next_label;
next_label += 1;
l
});
labels[i] = *label;
}
Ok(TomaToResult {
labels,
n_clusters: next_label,
threshold: tau,
persistence_diagram: bars,
})
}
fn follow_gradient(parent: &[usize], start: usize) -> usize {
let mut cur = start;
let n = parent.len();
for _ in 0..n {
let next = parent[cur];
if next == cur {
return cur;
}
cur = next;
}
cur
}
pub fn flat_clustering_from_persistence(
bars: &[PersistenceBar],
n: usize,
threshold: f64,
) -> Vec<usize> {
let surviving: Vec<usize> = bars
.iter()
.enumerate()
.filter(|(_, bar)| bar.persistence() >= threshold || bar.death == f64::MAX)
.map(|(i, _)| i)
.collect();
if surviving.is_empty() {
return vec![0; n];
}
let representatives: Vec<usize> = surviving.iter().map(|&i| bars[i].representative).collect();
let mut labels = vec![0usize; n];
for (label_idx, &rep) in representatives.iter().enumerate() {
if rep < n {
labels[rep] = label_idx;
}
}
labels
}
pub fn gap_threshold(bars: &[PersistenceBar]) -> f64 {
let mut finite_perss: Vec<f64> = bars
.iter()
.filter(|b| b.death < f64::MAX)
.map(|b| b.persistence())
.collect();
finite_perss.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
if finite_perss.is_empty() {
return 0.0;
}
if finite_perss.len() == 1 {
return finite_perss[0] * 0.5;
}
let mut best_gap = 0.0f64;
let mut best_tau = finite_perss[0] * 0.5;
for i in 1..finite_perss.len() {
let gap = finite_perss[i] - finite_perss[i - 1];
if gap > best_gap {
best_gap = gap;
best_tau = (finite_perss[i] + finite_perss[i - 1]) / 2.0;
}
}
best_tau
}
pub fn tomato_n_clusters(
data: ArrayView2<f64>,
n_clusters: usize,
k_neighbors: usize,
) -> Result<Vec<usize>> {
if n_clusters == 0 {
return Err(ClusteringError::InvalidInput(
"n_clusters must be > 0".into(),
));
}
let bars = density_persistence(data, k_neighbors)?;
let mut perss: Vec<f64> = bars
.iter()
.filter(|b| b.death < f64::MAX)
.map(|b| b.persistence())
.collect();
perss.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
let tau = if n_clusters <= 1 {
f64::MAX
} else if n_clusters - 1 <= perss.len() {
let idx = n_clusters - 1; if idx < perss.len() {
perss[idx - 1] * 0.5 + perss[idx.saturating_sub(1)] * 0.5
} else {
0.0
}
} else {
0.0
};
let config = TomaToConfig {
k_neighbors,
persistence_threshold: tau,
auto_threshold: false,
};
let result = tomato(data, &config)?;
Ok(result.labels)
}
#[derive(Debug, Clone)]
pub struct ClusterStats {
pub label: usize,
pub size: usize,
pub centroid: Array1<f64>,
pub avg_intra_dist: f64,
}
pub fn cluster_stats(
data: ArrayView2<f64>,
labels: &[usize],
) -> Result<Vec<ClusterStats>> {
let n = data.nrows();
let d = data.ncols();
if labels.len() != n {
return Err(ClusteringError::InvalidInput(
"cluster_stats: labels length must match data rows".into(),
));
}
let n_clusters = labels.iter().copied().max().unwrap_or(0) + 1;
let mut members: Vec<Vec<usize>> = vec![Vec::new(); n_clusters];
for (i, &l) in labels.iter().enumerate() {
members[l].push(i);
}
let mut stats = Vec::with_capacity(n_clusters);
for (label, pts) in members.iter().enumerate() {
if pts.is_empty() {
continue;
}
let m = pts.len();
let mut centroid = Array1::zeros(d);
for &i in pts {
for j in 0..d {
centroid[j] += data[[i, j]];
}
}
for j in 0..d {
centroid[j] /= m as f64;
}
let mut sum_dist = 0.0f64;
let mut count = 0usize;
for ii in 0..m {
for jj in (ii + 1)..m {
let d_val: f64 = data
.row(pts[ii])
.iter()
.zip(data.row(pts[jj]).iter())
.map(|(&a, &b)| (a - b) * (a - b))
.sum::<f64>()
.sqrt();
sum_dist += d_val;
count += 1;
}
}
let avg_intra_dist = if count > 0 { sum_dist / count as f64 } else { 0.0 };
stats.push(ClusterStats {
label,
size: m,
centroid,
avg_intra_dist,
});
}
Ok(stats)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
fn two_blobs() -> Array2<f64> {
Array2::from_shape_vec(
(8, 2),
vec![
0.0, 0.0, 0.2, 0.1, -0.1, 0.1, 0.05, -0.05,
6.0, 6.0, 6.2, 5.9, 5.9, 6.1, 6.1, 6.0,
],
)
.expect("shape ok")
}
#[test]
fn test_density_persistence_two_blobs() {
let data = two_blobs();
let bars = density_persistence(data.view(), 3).expect("ok");
assert!(!bars.is_empty(), "expected at least one persistence bar");
assert!(
bars.iter().any(|b| b.death == f64::MAX),
"no infinite bar found"
);
}
#[test]
fn test_tomato_two_blobs() {
let data = two_blobs();
let config = TomaToConfig {
k_neighbors: 3,
auto_threshold: true,
..Default::default()
};
let result = tomato(data.view(), &config).expect("tomato ok");
assert_eq!(result.labels.len(), 8);
assert!(
result.n_clusters >= 1,
"expected at least 1 cluster, got {}",
result.n_clusters
);
}
#[test]
fn test_tomato_n_clusters() {
let data = two_blobs();
let labels = tomato_n_clusters(data.view(), 2, 3).expect("ok");
assert_eq!(labels.len(), 8);
}
#[test]
fn test_gap_threshold_empty() {
let bars: Vec<PersistenceBar> = Vec::new();
let tau = gap_threshold(&bars);
assert_eq!(tau, 0.0);
}
#[test]
fn test_cluster_stats() {
let data = two_blobs();
let labels = vec![0, 0, 0, 0, 1, 1, 1, 1];
let stats = cluster_stats(data.view(), &labels).expect("ok");
assert_eq!(stats.len(), 2);
assert_eq!(stats[0].size, 4);
assert_eq!(stats[1].size, 4);
}
#[test]
fn test_flat_clustering_from_persistence() {
let bars = vec![
PersistenceBar { birth: 10.0, death: f64::MAX, representative: 0 },
PersistenceBar { birth: 8.0, death: 9.0, representative: 5 },
PersistenceBar { birth: 3.0, death: 3.5, representative: 2 },
];
let labels = flat_clustering_from_persistence(&bars, 8, 0.8);
assert_eq!(labels.len(), 8);
}
}