use super::{BettiNumbers, Filtration, Simplex};
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, PartialEq)]
pub struct BirthDeathPair {
pub dimension: usize,
pub birth: f64,
pub death: Option<f64>,
}
impl BirthDeathPair {
pub fn finite(dimension: usize, birth: f64, death: f64) -> Self {
Self {
dimension,
birth,
death: Some(death),
}
}
pub fn essential(dimension: usize, birth: f64) -> Self {
Self {
dimension,
birth,
death: None,
}
}
pub fn persistence(&self) -> f64 {
match self.death {
Some(d) => d - self.birth,
None => f64::INFINITY,
}
}
pub fn is_essential(&self) -> bool {
self.death.is_none()
}
pub fn midpoint(&self) -> f64 {
match self.death {
Some(d) => (self.birth + d) / 2.0,
None => f64::INFINITY,
}
}
}
#[derive(Debug, Clone)]
pub struct PersistenceDiagram {
pub pairs: Vec<BirthDeathPair>,
pub max_dim: usize,
}
impl PersistenceDiagram {
pub fn new() -> Self {
Self {
pairs: Vec::new(),
max_dim: 0,
}
}
pub fn add(&mut self, pair: BirthDeathPair) {
self.max_dim = self.max_dim.max(pair.dimension);
self.pairs.push(pair);
}
pub fn pairs_of_dim(&self, d: usize) -> impl Iterator<Item = &BirthDeathPair> {
self.pairs.iter().filter(move |p| p.dimension == d)
}
pub fn betti_at(&self, t: f64) -> BettiNumbers {
let mut b0 = 0;
let mut b1 = 0;
let mut b2 = 0;
for pair in &self.pairs {
let alive = pair.birth <= t && pair.death.map(|d| d > t).unwrap_or(true);
if alive {
match pair.dimension {
0 => b0 += 1,
1 => b1 += 1,
2 => b2 += 1,
_ => {}
}
}
}
BettiNumbers::new(b0, b1, b2)
}
pub fn total_persistence(&self) -> f64 {
self.pairs
.iter()
.filter(|p| !p.is_essential())
.map(|p| p.persistence())
.sum()
}
pub fn average_persistence(&self) -> f64 {
let finite: Vec<f64> = self
.pairs
.iter()
.filter(|p| !p.is_essential())
.map(|p| p.persistence())
.collect();
if finite.is_empty() {
0.0
} else {
finite.iter().sum::<f64>() / finite.len() as f64
}
}
pub fn filter_by_persistence(&self, min_persistence: f64) -> Self {
Self {
pairs: self
.pairs
.iter()
.filter(|p| p.persistence() >= min_persistence)
.cloned()
.collect(),
max_dim: self.max_dim,
}
}
pub fn feature_counts(&self) -> Vec<usize> {
let mut counts = vec![0; self.max_dim + 1];
for pair in &self.pairs {
if pair.dimension <= self.max_dim {
counts[pair.dimension] += 1;
}
}
counts
}
}
impl Default for PersistenceDiagram {
fn default() -> Self {
Self::new()
}
}
pub struct PersistentHomology {
columns: Vec<Option<HashSet<usize>>>,
pivot_to_col: HashMap<usize, usize>,
birth_times: Vec<f64>,
dimensions: Vec<usize>,
}
impl PersistentHomology {
pub fn compute(filtration: &Filtration) -> PersistenceDiagram {
let mut ph = Self {
columns: Vec::new(),
pivot_to_col: HashMap::new(),
birth_times: Vec::new(),
dimensions: Vec::new(),
};
ph.run(filtration)
}
fn run(&mut self, filtration: &Filtration) -> PersistenceDiagram {
let n = filtration.simplices.len();
if n == 0 {
return PersistenceDiagram::new();
}
let simplex_to_idx: HashMap<&Simplex, usize> = filtration
.simplices
.iter()
.enumerate()
.map(|(i, fs)| (&fs.simplex, i))
.collect();
self.columns = Vec::with_capacity(n);
self.birth_times = filtration.simplices.iter().map(|fs| fs.birth).collect();
self.dimensions = filtration
.simplices
.iter()
.map(|fs| fs.simplex.dim())
.collect();
for fs in &filtration.simplices {
let boundary = self.boundary(&fs.simplex, &simplex_to_idx);
self.columns.push(if boundary.is_empty() {
None
} else {
Some(boundary)
});
}
self.reduce();
self.extract_pairs()
}
fn boundary(&self, simplex: &Simplex, idx_map: &HashMap<&Simplex, usize>) -> HashSet<usize> {
let mut boundary = HashSet::new();
for face in simplex.faces() {
if let Some(&idx) = idx_map.get(&face) {
boundary.insert(idx);
}
}
boundary
}
fn reduce(&mut self) {
let n = self.columns.len();
for j in 0..n {
while let Some(pivot) = self.get_pivot(j) {
if let Some(&other) = self.pivot_to_col.get(&pivot) {
self.add_columns(j, other);
} else {
self.pivot_to_col.insert(pivot, j);
break;
}
}
}
}
fn get_pivot(&self, col: usize) -> Option<usize> {
self.columns[col]
.as_ref()
.and_then(|c| c.iter().max().copied())
}
fn add_columns(&mut self, dst: usize, src: usize) {
let src_col = self.columns[src].clone();
if let (Some(ref mut dst_col), Some(ref src_col)) = (&mut self.columns[dst], &src_col) {
let mut new_col = HashSet::new();
for &idx in dst_col.iter() {
if !src_col.contains(&idx) {
new_col.insert(idx);
}
}
for &idx in src_col.iter() {
if !dst_col.contains(&idx) {
new_col.insert(idx);
}
}
if new_col.is_empty() {
self.columns[dst] = None;
} else {
*dst_col = new_col;
}
}
}
fn extract_pairs(&self) -> PersistenceDiagram {
let n = self.columns.len();
let mut diagram = PersistenceDiagram::new();
let mut paired = HashSet::new();
for (&pivot, &col) in &self.pivot_to_col {
let birth = self.birth_times[pivot];
let death = self.birth_times[col];
let dim = self.dimensions[pivot];
if death > birth {
diagram.add(BirthDeathPair::finite(dim, birth, death));
}
paired.insert(pivot);
paired.insert(col);
}
for j in 0..n {
if !paired.contains(&j) && self.columns[j].is_none() {
let dim = self.dimensions[j];
let birth = self.birth_times[j];
diagram.add(BirthDeathPair::essential(dim, birth));
}
}
diagram
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::homology::{PointCloud, VietorisRips};
#[test]
fn test_birth_death_pair() {
let finite = BirthDeathPair::finite(0, 0.0, 1.0);
assert_eq!(finite.persistence(), 1.0);
assert!(!finite.is_essential());
let essential = BirthDeathPair::essential(0, 0.0);
assert!(essential.is_essential());
assert_eq!(essential.persistence(), f64::INFINITY);
}
#[test]
fn test_persistence_diagram() {
let mut diagram = PersistenceDiagram::new();
diagram.add(BirthDeathPair::essential(0, 0.0));
diagram.add(BirthDeathPair::finite(0, 0.0, 1.0));
diagram.add(BirthDeathPair::finite(1, 0.5, 2.0));
assert_eq!(diagram.pairs.len(), 3);
let betti = diagram.betti_at(0.75);
assert_eq!(betti.b0, 2); assert_eq!(betti.b1, 1); }
#[test]
fn test_persistent_homology_simple() {
let cloud = PointCloud::from_flat(&[0.0, 0.0, 1.0, 0.0], 2);
let rips = VietorisRips::new(1, 2.0);
let filtration = rips.build(&cloud);
let diagram = PersistentHomology::compute(&filtration);
let h0_pairs: Vec<_> = diagram.pairs_of_dim(0).collect();
assert!(!h0_pairs.is_empty());
}
#[test]
fn test_persistent_homology_triangle() {
let cloud = PointCloud::from_flat(&[0.0, 0.0, 1.0, 0.0, 0.5, 0.866], 2);
let rips = VietorisRips::new(2, 2.0);
let filtration = rips.build(&cloud);
let diagram = PersistentHomology::compute(&filtration);
let h0_count = diagram.pairs_of_dim(0).count();
assert!(h0_count > 0);
}
#[test]
fn test_filter_by_persistence() {
let mut diagram = PersistenceDiagram::new();
diagram.add(BirthDeathPair::finite(0, 0.0, 0.1));
diagram.add(BirthDeathPair::finite(0, 0.0, 1.0));
diagram.add(BirthDeathPair::essential(0, 0.0));
let filtered = diagram.filter_by_persistence(0.5);
assert_eq!(filtered.pairs.len(), 2); }
#[test]
fn test_feature_counts() {
let mut diagram = PersistenceDiagram::new();
diagram.add(BirthDeathPair::finite(0, 0.0, 1.0));
diagram.add(BirthDeathPair::finite(0, 0.0, 1.0));
diagram.add(BirthDeathPair::finite(1, 0.0, 1.0));
let counts = diagram.feature_counts();
assert_eq!(counts[0], 2);
assert_eq!(counts[1], 1);
}
}