use crate::tda::alpha_complex::{sym_diff_sorted, Simplex};
use std::collections::HashMap;
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ZigzagDirection {
Forward,
Backward,
}
#[derive(Debug, Clone)]
pub struct ZigzagStep {
pub direction: ZigzagDirection,
pub simplex: Simplex,
}
#[derive(Debug, Clone)]
pub struct ZigzagPersistence {
filtration: Vec<Simplex>,
columns: Vec<Vec<usize>>,
pivot_col: HashMap<usize, usize>,
open_intervals: HashMap<Vec<usize>, (f64, usize)>,
completed: Vec<(f64, f64, usize)>,
time: usize,
}
impl ZigzagPersistence {
pub fn new() -> Self {
Self {
filtration: Vec::new(),
columns: Vec::new(),
pivot_col: HashMap::new(),
open_intervals: HashMap::new(),
completed: Vec::new(),
time: 0,
}
}
pub fn add_simplex(&mut self, s: Simplex) -> Vec<(f64, f64, usize)> {
self.time += 1;
let idx = self.filtration.len();
self.filtration.push(s.clone());
let simplex_index: HashMap<Vec<usize>, usize> = self
.filtration
.iter()
.enumerate()
.map(|(i, sx)| (sx.vertices.clone(), i))
.collect();
let mut col: Vec<usize> = s
.boundary_faces()
.iter()
.filter_map(|face| simplex_index.get(face).copied())
.collect();
col.sort_unstable();
while let Some(&pivot) = col.last() {
if let Some(&k) = self.pivot_col.get(&pivot) {
let col_k = self.columns[k].clone();
sym_diff_sorted(&mut col, &col_k);
} else {
break;
}
}
self.columns.push(col.clone());
let mut newly_closed = Vec::new();
if let Some(&pivot) = col.last() {
self.pivot_col.insert(pivot, idx);
let birth_key = self.filtration[pivot].vertices.clone();
if let Some((birth_val, dim)) = self.open_intervals.remove(&birth_key) {
let death_val = s.filtration_value;
newly_closed.push((birth_val, death_val, dim));
self.completed.push((birth_val, death_val, dim));
}
} else {
let dim = s.dimension();
self.open_intervals
.insert(s.vertices.clone(), (s.filtration_value, dim));
}
newly_closed
}
pub fn remove_simplex(&mut self, s: Simplex) -> Vec<(f64, f64, usize)> {
self.time += 1;
let mut newly_closed = Vec::new();
let pos = match self
.filtration
.iter()
.position(|x| x.vertices == s.vertices)
{
Some(p) => p,
None => return newly_closed, };
let last = self.filtration.len() - 1;
for i in pos..last {
self.transpose_adjacent(i);
}
let removed = self.filtration.pop();
self.columns.pop();
self.rebuild_pivot_map();
if let Some(rem) = removed {
let death_val = rem.filtration_value;
if let Some((birth_val, dim)) = self.open_intervals.remove(&rem.vertices) {
newly_closed.push((birth_val, death_val, dim));
self.completed.push((birth_val, death_val, dim));
}
}
newly_closed
}
pub fn pairs(&self) -> &[(f64, f64, usize)] {
&self.completed
}
fn transpose_adjacent(&mut self, i: usize) {
self.filtration.swap(i, i + 1);
let col_i1_has_i = self.columns[i + 1].binary_search(&i).is_ok();
let col_i_has_i1 = self.columns[i].binary_search(&(i + 1)).is_ok();
if col_i1_has_i {
let col_i = self.columns[i].clone();
sym_diff_sorted(&mut self.columns[i + 1], &col_i);
}
for col in self.columns.iter_mut() {
let had_i = col.binary_search(&i).is_ok();
let had_i1 = col.binary_search(&(i + 1)).is_ok();
if had_i && !had_i1 {
if let Ok(pos) = col.binary_search(&i) {
col[pos] = i + 1;
col.sort_unstable();
}
} else if had_i1 && !had_i {
if let Ok(pos) = col.binary_search(&(i + 1)) {
col[pos] = i;
col.sort_unstable();
}
} else if had_i && had_i1 {
}
}
self.columns.swap(i, i + 1);
if col_i_has_i1 {
let last_i = self.columns[i].last().copied();
if let Some(piv) = last_i {
if let Some(&k) = self.pivot_col.get(&piv) {
if k != i {
let col_k = self.columns[k].clone();
sym_diff_sorted(&mut self.columns[i], &col_k);
}
}
}
}
}
fn rebuild_pivot_map(&mut self) {
self.pivot_col.clear();
for (j, col) in self.columns.iter().enumerate() {
if let Some(&pivot) = col.last() {
self.pivot_col.insert(pivot, j);
}
}
}
}
impl Default for ZigzagPersistence {
fn default() -> Self {
Self::new()
}
}
pub fn compute_zigzag(steps: &[ZigzagStep]) -> Vec<(f64, f64, usize)> {
let mut zz = ZigzagPersistence::new();
for step in steps {
match step.direction {
ZigzagDirection::Forward => {
zz.add_simplex(step.simplex.clone());
}
ZigzagDirection::Backward => {
zz.remove_simplex(step.simplex.clone());
}
}
}
zz.completed
}
#[cfg(test)]
mod tests {
use super::*;
fn make_simplex(verts: Vec<usize>, fv: f64) -> Simplex {
Simplex {
vertices: verts,
filtration_value: fv,
}
}
#[test]
fn test_add_vertices_creates_intervals() {
let mut zz = ZigzagPersistence::new();
let v0 = make_simplex(vec![0], 0.0);
let v1 = make_simplex(vec![1], 1.0);
zz.add_simplex(v0);
zz.add_simplex(v1);
assert_eq!(zz.open_intervals.len(), 2);
assert!(zz.completed.is_empty());
}
#[test]
fn test_add_edge_closes_one_component() {
let mut zz = ZigzagPersistence::new();
let v0 = make_simplex(vec![0], 0.0);
let v1 = make_simplex(vec![1], 0.0);
let edge = make_simplex(vec![0, 1], 1.0);
zz.add_simplex(v0);
zz.add_simplex(v1);
let closed = zz.add_simplex(edge);
assert_eq!(closed.len(), 1, "Adding edge should close one interval");
let (birth, death, dim) = closed[0];
assert_eq!(dim, 0, "Should be H0 (connected component)");
assert!(birth < death, "birth < death expected");
}
#[test]
fn test_remove_simplex_closes_interval() {
let mut zz = ZigzagPersistence::new();
let v0 = make_simplex(vec![0], 0.0);
zz.add_simplex(v0.clone());
let closed = zz.remove_simplex(v0);
assert_eq!(zz.open_intervals.len(), 0);
assert!(!closed.is_empty() || !zz.completed.is_empty());
}
#[test]
fn test_zigzag_batch_add_then_remove() {
let v0 = make_simplex(vec![0], 0.0);
let v1 = make_simplex(vec![1], 0.5);
let edge = make_simplex(vec![0, 1], 1.0);
let steps = vec![
ZigzagStep {
direction: ZigzagDirection::Forward,
simplex: v0.clone(),
},
ZigzagStep {
direction: ZigzagDirection::Forward,
simplex: v1.clone(),
},
ZigzagStep {
direction: ZigzagDirection::Forward,
simplex: edge.clone(),
},
ZigzagStep {
direction: ZigzagDirection::Backward,
simplex: edge,
},
];
let pairs = compute_zigzag(&steps);
assert!(!pairs.is_empty(), "Expected at least one completed pair");
for (birth, death, _) in &pairs {
assert!(birth <= death, "birth={birth} > death={death}");
}
}
#[test]
fn test_directions_are_non_exhaustive() {
let d = ZigzagDirection::Forward;
assert_eq!(d, ZigzagDirection::Forward);
}
}