use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use crate::error::{ClusteringError, Result};
pub trait Metric: Send + Sync {
fn distance(&self, a: ArrayView1<f64>, b: ArrayView1<f64>) -> f64;
}
#[derive(Debug, Clone, Copy)]
pub struct EuclideanMetric;
impl Metric for EuclideanMetric {
fn distance(&self, a: ArrayView1<f64>, b: ArrayView1<f64>) -> f64 {
a.iter()
.zip(b.iter())
.map(|(&ai, &bi)| (ai - bi) * (ai - bi))
.sum::<f64>()
.sqrt()
}
}
#[derive(Debug, Clone, Copy)]
pub struct ManhattanMetric;
impl Metric for ManhattanMetric {
fn distance(&self, a: ArrayView1<f64>, b: ArrayView1<f64>) -> f64 {
a.iter()
.zip(b.iter())
.map(|(&ai, &bi)| (ai - bi).abs())
.sum()
}
}
#[derive(Debug, Clone, Copy)]
pub struct CosineMetric;
impl Metric for CosineMetric {
fn distance(&self, a: ArrayView1<f64>, b: ArrayView1<f64>) -> f64 {
let dot: f64 = a.iter().zip(b.iter()).map(|(&ai, &bi)| ai * bi).sum();
let na: f64 = a.iter().map(|&v| v * v).sum::<f64>().sqrt();
let nb: f64 = b.iter().map(|&v| v * v).sum::<f64>().sqrt();
let denom = na * nb;
if denom < 1e-15 {
1.0
} else {
(1.0 - dot / denom).clamp(0.0, 2.0)
}
}
}
struct UnionFind {
parent: Vec<usize>,
rank: Vec<usize>,
birth: Vec<f64>, }
impl UnionFind {
fn new(n: usize) -> Self {
Self {
parent: (0..n).collect(),
rank: vec![0; n],
birth: vec![0.0; n],
}
}
fn find(&mut self, x: usize) -> usize {
if self.parent[x] != x {
self.parent[x] = self.find(self.parent[x]);
}
self.parent[x]
}
fn union(&mut self, x: usize, y: usize, time: f64) -> Option<(usize, usize, f64)> {
let rx = self.find(x);
let ry = self.find(y);
if rx == ry {
return None;
}
let (older, younger) = if self.birth[rx] <= self.birth[ry] {
(rx, ry)
} else {
(ry, rx)
};
if self.rank[older] >= self.rank[younger] {
self.parent[younger] = older;
if self.rank[older] == self.rank[younger] {
self.rank[older] += 1;
}
} else {
self.parent[older] = younger;
self.birth[younger] = self.birth[younger].min(self.birth[older]);
}
Some((younger, older, time))
}
fn set_birth(&mut self, x: usize, t: f64) {
self.birth[x] = t;
}
}
pub fn persistent_homology_0d(
data: ArrayView2<f64>,
metric: &dyn Metric,
) -> Result<Vec<(f64, f64)>> {
let n = data.shape()[0];
if n == 0 {
return Ok(vec![]);
}
let n_edges = n * (n - 1) / 2;
let mut edges: Vec<(f64, usize, usize)> = Vec::with_capacity(n_edges);
for i in 0..n {
for j in (i + 1)..n {
let d = metric.distance(data.row(i), data.row(j));
edges.push((d, i, j));
}
}
edges.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
let mut uf = UnionFind::new(n);
for i in 0..n {
uf.set_birth(i, 0.0);
}
let mut barcode: Vec<(f64, f64)> = Vec::new();
for (dist, u, v) in &edges {
if let Some((younger, _older, death_time)) = uf.union(*u, *v, *dist) {
let birth = uf.birth[younger];
barcode.push((birth, death_time));
}
}
let mut representatives: std::collections::HashSet<usize> = std::collections::HashSet::new();
for i in 0..n {
let root = uf.find(i);
representatives.insert(root);
}
for _ in 0..representatives.len() {
barcode.push((0.0, f64::INFINITY));
}
barcode.sort_by(|a, b| {
let pa = if a.1.is_infinite() {
f64::MAX
} else {
a.1 - a.0
};
let pb = if b.1.is_infinite() {
f64::MAX
} else {
b.1 - b.0
};
pb.partial_cmp(&pa).unwrap_or(std::cmp::Ordering::Equal)
});
Ok(barcode)
}
pub fn single_linkage_from_barcode(
barcode: &[(f64, f64)],
threshold: f64,
n_points: usize,
) -> Vec<usize> {
if n_points == 0 {
return vec![];
}
let n_clusters = barcode
.iter()
.filter(|&&(birth, death)| birth <= threshold && (death > threshold || death.is_infinite()))
.count()
.max(1);
(0..n_points).map(|i| i % n_clusters).collect()
}
pub fn single_linkage_from_data(
data: ArrayView2<f64>,
metric: &dyn Metric,
threshold: f64,
) -> Result<Vec<usize>> {
let n = data.shape()[0];
if n == 0 {
return Ok(vec![]);
}
let mut edges: Vec<(f64, usize, usize)> = Vec::with_capacity(n * (n - 1) / 2);
for i in 0..n {
for j in (i + 1)..n {
let d = metric.distance(data.row(i), data.row(j));
if d <= threshold {
edges.push((d, i, j));
}
}
}
edges.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
let mut uf = UnionFind::new(n);
for (_, u, v) in &edges {
uf.union(*u, *v, 0.0);
}
let mut root_to_label: std::collections::HashMap<usize, usize> =
std::collections::HashMap::new();
let mut next_label = 0_usize;
let mut labels = vec![0_usize; n];
for i in 0..n {
let root = uf.find(i);
let label = *root_to_label.entry(root).or_insert_with(|| {
let l = next_label;
next_label += 1;
l
});
labels[i] = label;
}
Ok(labels)
}
#[derive(Debug, Clone)]
pub struct MapperNode {
pub patch_index: usize,
pub members: Vec<usize>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct MapperEdge {
pub source: usize,
pub target: usize,
}
#[derive(Debug, Clone)]
pub struct MapperGraph {
pub nodes: Vec<MapperNode>,
pub edges: Vec<MapperEdge>,
}
impl MapperGraph {
pub fn n_nodes(&self) -> usize {
self.nodes.len()
}
pub fn n_edges(&self) -> usize {
self.edges.len()
}
pub fn members_of(&self, node: usize) -> &[usize] {
&self.nodes[node].members
}
}
#[derive(Debug, Clone, Copy)]
pub struct CoverInterval {
pub lo: f64,
pub hi: f64,
}
impl CoverInterval {
pub fn new(lo: f64, hi: f64) -> Self {
Self { lo, hi }
}
pub fn contains(&self, v: f64) -> bool {
v >= self.lo && v <= self.hi
}
}
pub fn uniform_cover(lo: f64, hi: f64, n_intervals: usize, overlap: f64) -> Vec<CoverInterval> {
if n_intervals == 0 {
return vec![];
}
let step = (hi - lo) / n_intervals as f64;
let half_overlap = overlap * step / 2.0;
(0..n_intervals)
.map(|i| {
let centre = lo + step * (i as f64 + 0.5);
CoverInterval::new(
(centre - step / 2.0 - half_overlap).max(lo - half_overlap),
(centre + step / 2.0 + half_overlap).min(hi + half_overlap),
)
})
.collect()
}
pub fn mapper_graph(
data: ArrayView2<f64>,
filter_fn: &dyn Fn(&[f64]) -> f64,
cover_intervals: &[CoverInterval],
cluster_threshold: f64,
metric: &dyn Metric,
) -> Result<MapperGraph> {
let n = data.shape()[0];
let d = data.shape()[1];
if cover_intervals.is_empty() {
return Err(ClusteringError::InvalidInput(
"cover_intervals must be non-empty".to_string(),
));
}
let filter_values: Vec<f64> = (0..n)
.map(|i| {
let row: Vec<f64> = (0..d).map(|f| data[[i, f]]).collect();
filter_fn(&row)
})
.collect();
let patch_members: Vec<Vec<usize>> = cover_intervals
.iter()
.map(|interval| {
(0..n)
.filter(|&i| interval.contains(filter_values[i]))
.collect()
})
.collect();
let mut point_to_nodes: Vec<Vec<usize>> = vec![vec![]; n];
let mut nodes: Vec<MapperNode> = Vec::new();
for (patch_idx, members) in patch_members.iter().enumerate() {
if members.is_empty() {
continue;
}
if members.len() == 1 {
let node_idx = nodes.len();
nodes.push(MapperNode {
patch_index: patch_idx,
members: members.clone(),
});
point_to_nodes[members[0]].push(node_idx);
continue;
}
let sub_n = members.len();
let mut sub_data = Array2::<f64>::zeros((sub_n, d));
for (si, &gi) in members.iter().enumerate() {
for f in 0..d {
sub_data[[si, f]] = data[[gi, f]];
}
}
let sub_labels =
single_linkage_from_data(sub_data.view(), metric, cluster_threshold)?;
let max_label = sub_labels.iter().max().copied().unwrap_or(0);
let mut cluster_members: Vec<Vec<usize>> = vec![vec![]; max_label + 1];
for (si, &label) in sub_labels.iter().enumerate() {
cluster_members[label].push(members[si]);
}
for cluster in cluster_members {
if cluster.is_empty() {
continue;
}
let node_idx = nodes.len();
for &gi in &cluster {
point_to_nodes[gi].push(node_idx);
}
nodes.push(MapperNode {
patch_index: patch_idx,
members: cluster,
});
}
}
let mut edge_set: std::collections::HashSet<(usize, usize)> =
std::collections::HashSet::new();
for node_list in &point_to_nodes {
if node_list.len() < 2 {
continue;
}
for i in 0..node_list.len() {
for j in (i + 1)..node_list.len() {
let (a, b) = (node_list[i].min(node_list[j]), node_list[i].max(node_list[j]));
if nodes[a].patch_index != nodes[b].patch_index {
edge_set.insert((a, b));
}
}
}
}
let edges: Vec<MapperEdge> = edge_set
.into_iter()
.map(|(s, t)| MapperEdge { source: s, target: t })
.collect();
Ok(MapperGraph { nodes, edges })
}
pub fn n_clusters_from_barcode(barcode: &[(f64, f64)], min_persistence: f64) -> usize {
barcode
.iter()
.filter(|&&(birth, death)| {
death.is_infinite() || (death - birth) >= min_persistence
})
.count()
.max(1)
}
pub fn pairwise_distance_matrix(
data: ArrayView2<f64>,
metric: &dyn Metric,
) -> Array2<f64> {
let n = data.shape()[0];
let mut dist = Array2::<f64>::zeros((n, n));
for i in 0..n {
for j in (i + 1)..n {
let d = metric.distance(data.row(i), data.row(j));
dist[[i, j]] = d;
dist[[j, i]] = d;
}
}
dist
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
fn two_cluster_data() -> Array2<f64> {
Array2::from_shape_vec(
(8, 2),
vec![
0.0, 0.0, 0.1, 0.1, 0.05, 0.05, -0.05, 0.05,
5.0, 5.0, 5.1, 4.9, 4.9, 5.1, 5.05, 4.95,
],
)
.expect("data")
}
#[test]
fn test_ph0_two_clusters() {
let data = two_cluster_data();
let barcode = persistent_homology_0d(data.view(), &EuclideanMetric)
.expect("ph0");
let n_inf = barcode.iter().filter(|&&(_, d)| d.is_infinite()).count();
assert_eq!(n_inf, 2, "expected 2 immortal components");
}
#[test]
fn test_ph0_single_cluster() {
let data = Array2::from_shape_vec(
(4, 1),
vec![0.0, 0.1, 0.2, 0.3],
)
.expect("data");
let barcode = persistent_homology_0d(data.view(), &EuclideanMetric)
.expect("ph0 single");
let n_inf = barcode.iter().filter(|&&(_, d)| d.is_infinite()).count();
assert_eq!(n_inf, 1);
}
#[test]
fn test_ph0_empty() {
let data = Array2::<f64>::zeros((0, 2));
let barcode = persistent_homology_0d(data.view(), &EuclideanMetric)
.expect("ph0 empty");
assert!(barcode.is_empty());
}
#[test]
fn test_single_linkage_from_barcode_basic() {
let barcode = vec![(0.0, f64::INFINITY), (0.0, f64::INFINITY), (0.0, 0.2)];
let labels = single_linkage_from_barcode(&barcode, 1.0, 8);
assert_eq!(labels.len(), 8);
let unique: std::collections::HashSet<_> = labels.iter().copied().collect();
assert_eq!(unique.len(), 2);
}
#[test]
fn test_single_linkage_from_data() {
let data = two_cluster_data();
let labels = single_linkage_from_data(data.view(), &EuclideanMetric, 0.5)
.expect("sl data");
assert_eq!(labels.len(), 8);
let unique: std::collections::HashSet<_> = labels.iter().copied().collect();
assert_eq!(unique.len(), 2, "expected 2 clusters, got {:?}", unique);
}
#[test]
fn test_single_linkage_threshold_zero() {
let data = two_cluster_data();
let labels = single_linkage_from_data(data.view(), &EuclideanMetric, 0.0)
.expect("sl zero");
let unique: std::collections::HashSet<_> = labels.iter().copied().collect();
assert_eq!(unique.len(), data.shape()[0]);
}
#[test]
fn test_n_clusters_from_barcode() {
let barcode = vec![
(0.0, f64::INFINITY),
(0.0, f64::INFINITY),
(0.0, 0.05),
];
assert_eq!(n_clusters_from_barcode(&barcode, 0.5), 2);
assert_eq!(n_clusters_from_barcode(&barcode, 0.01), 3);
}
#[test]
fn test_uniform_cover() {
let intervals = uniform_cover(0.0, 1.0, 4, 0.3);
assert_eq!(intervals.len(), 4);
assert!(intervals[0].contains(0.125));
assert!(intervals[3].contains(0.875));
}
#[test]
fn test_mapper_graph_line() {
let data = Array2::from_shape_vec(
(8, 1),
vec![0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5],
)
.expect("data");
let filter_fn = |row: &[f64]| row[0];
let cover = uniform_cover(0.0, 3.5, 4, 0.4);
let graph = mapper_graph(data.view(), &filter_fn, &cover, 0.6, &EuclideanMetric)
.expect("mapper");
assert!(graph.n_nodes() >= 4, "expected >= 4 nodes, got {}", graph.n_nodes());
assert!(graph.n_edges() >= 1, "expected at least 1 edge, got {}", graph.n_edges());
}
#[test]
fn test_mapper_graph_two_clusters() {
let data = two_cluster_data();
let filter_fn = |row: &[f64]| row[0]; let cover = uniform_cover(-0.1, 5.2, 6, 0.2);
let graph = mapper_graph(data.view(), &filter_fn, &cover, 0.3, &EuclideanMetric)
.expect("mapper 2cluster");
assert!(graph.n_nodes() >= 2);
}
#[test]
fn test_mapper_empty_cover_error() {
let data = two_cluster_data();
let filter_fn = |row: &[f64]| row[0];
let result = mapper_graph(data.view(), &filter_fn, &[], 0.5, &EuclideanMetric);
assert!(result.is_err());
}
#[test]
fn test_pairwise_distance_matrix() {
let data = Array2::from_shape_vec(
(3, 2),
vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0],
)
.expect("data");
let dm = pairwise_distance_matrix(data.view(), &EuclideanMetric);
assert_eq!(dm.shape(), [3, 3]);
assert!((dm[[0, 0]]).abs() < 1e-10);
assert!((dm[[0, 1]] - 1.0).abs() < 1e-10);
assert!((dm[[1, 0]] - dm[[0, 1]]).abs() < 1e-10);
}
#[test]
fn test_manhattan_metric() {
let a = Array2::from_shape_vec((1, 2), vec![0.0, 0.0]).expect("a");
let b = Array2::from_shape_vec((1, 2), vec![1.0, 2.0]).expect("b");
let d = ManhattanMetric.distance(a.row(0), b.row(0));
assert!((d - 3.0).abs() < 1e-10);
}
#[test]
fn test_cosine_metric_orthogonal() {
let a = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).expect("a");
let b = Array2::from_shape_vec((1, 2), vec![0.0, 1.0]).expect("b");
let d = CosineMetric.distance(a.row(0), b.row(0));
assert!((d - 1.0).abs() < 1e-10);
}
}