Skip to main content

oxiphysics_core/
persistent_homology.rs

1#![allow(clippy::needless_range_loop)]
2// Copyright 2026 COOLJAPAN OU (Team KitaSan)
3// SPDX-License-Identifier: Apache-2.0
4
5//! Persistent homology and topological data analysis.
6//!
7//! Provides simplicial complexes, filtered complexes, persistence diagrams,
8//! Vietoris-Rips filtration, persistence images, and barcode statistics.
9
10#![allow(dead_code)]
11
12use std::collections::HashMap;
13
14// ─────────────────────────────────────────────────────────────────────────────
15// Simplex
16// ─────────────────────────────────────────────────────────────────────────────
17
18/// A simplex defined by its vertices, dimension, and filtration value.
19///
20/// A `k`-simplex has `k+1` vertices (e.g., 0-simplex = vertex, 1-simplex = edge).
21#[derive(Debug, Clone, PartialEq)]
22pub struct Simplex {
23    /// Sorted vertex indices forming this simplex.
24    pub vertices: Vec<usize>,
25    /// Homological dimension (0 = vertex, 1 = edge, 2 = triangle, …).
26    pub dimension: usize,
27    /// Filtration value at which this simplex enters the complex.
28    pub filtration: f64,
29}
30
31impl Simplex {
32    /// Create a new simplex from a list of vertices and filtration value.
33    ///
34    /// The dimension is inferred as `vertices.len() - 1`.
35    /// Vertices are sorted in ascending order.
36    pub fn new(mut vertices: Vec<usize>, filtration: f64) -> Self {
37        vertices.sort_unstable();
38        let dimension = vertices.len().saturating_sub(1);
39        Self {
40            vertices,
41            dimension,
42            filtration,
43        }
44    }
45
46    /// Return the faces (boundary simplices) of this simplex.
47    ///
48    /// A `k`-simplex has `k+1` faces of dimension `k-1`, each formed by
49    /// removing one vertex.
50    pub fn faces(&self) -> Vec<Simplex> {
51        if self.dimension == 0 {
52            return vec![];
53        }
54        (0..self.vertices.len())
55            .map(|i| {
56                let verts: Vec<usize> = self
57                    .vertices
58                    .iter()
59                    .enumerate()
60                    .filter(|(j, _)| *j != i)
61                    .map(|(_, &v)| v)
62                    .collect();
63                Simplex::new(verts, self.filtration)
64            })
65            .collect()
66    }
67
68    /// Return `true` if this simplex contains all vertices of `other`.
69    pub fn contains(&self, other: &Simplex) -> bool {
70        other.vertices.iter().all(|v| self.vertices.contains(v))
71    }
72}
73
74// ─────────────────────────────────────────────────────────────────────────────
75// SimplicialComplex
76// ─────────────────────────────────────────────────────────────────────────────
77
78/// A finite simplicial complex with boundary operators and topological invariants.
79#[derive(Debug, Clone)]
80pub struct SimplicialComplex {
81    /// All simplices grouped by dimension.
82    pub simplices: Vec<Vec<Simplex>>,
83    /// Maximum dimension present in the complex.
84    pub max_dimension: usize,
85}
86
87impl SimplicialComplex {
88    /// Create an empty simplicial complex.
89    pub fn new() -> Self {
90        Self {
91            simplices: vec![],
92            max_dimension: 0,
93        }
94    }
95
96    /// Add a simplex (and all its faces recursively) to the complex.
97    pub fn add_simplex(&mut self, s: Simplex) {
98        // First ensure faces are present
99        let faces = s.faces();
100        for face in faces {
101            self.add_simplex(face);
102        }
103        let dim = s.dimension;
104        while self.simplices.len() <= dim {
105            self.simplices.push(vec![]);
106        }
107        if dim > self.max_dimension {
108            self.max_dimension = dim;
109        }
110        // Avoid duplicates
111        if !self.simplices[dim].iter().any(|x| x.vertices == s.vertices) {
112            self.simplices[dim].push(s);
113        }
114    }
115
116    /// Return the number of simplices of dimension `k`.
117    pub fn count(&self, k: usize) -> usize {
118        self.simplices.get(k).map_or(0, |v| v.len())
119    }
120
121    /// Compute the Euler characteristic χ = Σ (-1)^k * f_k.
122    ///
123    /// Where `f_k` is the number of `k`-simplices.
124    pub fn euler_characteristic(&self) -> i64 {
125        self.simplices
126            .iter()
127            .enumerate()
128            .map(|(k, s)| {
129                if k % 2 == 0 {
130                    s.len() as i64
131                } else {
132                    -(s.len() as i64)
133                }
134            })
135            .sum()
136    }
137
138    /// Compute the boundary matrix for dimension `k`.
139    ///
140    /// Entry `[i][j]` is 1 if the `j`-th `(k-1)`-simplex is a face of the
141    /// `i`-th `k`-simplex (with sign ignored for simplicity), else 0.
142    pub fn boundary_matrix(&self, k: usize) -> Vec<Vec<i32>> {
143        if k == 0 || k > self.max_dimension {
144            return vec![];
145        }
146        let k_simplices = &self.simplices[k];
147        let km1_simplices = &self.simplices[k - 1];
148        let rows = k_simplices.len();
149        let cols = km1_simplices.len();
150        let mut mat = vec![vec![0i32; cols]; rows];
151        for (i, sigma) in k_simplices.iter().enumerate() {
152            for (j, tau) in km1_simplices.iter().enumerate() {
153                if sigma.contains(tau) {
154                    mat[i][j] = 1;
155                }
156            }
157        }
158        mat
159    }
160}
161
162impl Default for SimplicialComplex {
163    fn default() -> Self {
164        Self::new()
165    }
166}
167
168// ─────────────────────────────────────────────────────────────────────────────
169// FilteredComplex
170// ─────────────────────────────────────────────────────────────────────────────
171
172/// A simplicial complex with a filtration ordering.
173///
174/// Simplices are ordered by their filtration values, and then by dimension.
175/// The boundary matrix algorithm operates on this ordering.
176#[derive(Debug, Clone)]
177pub struct FilteredComplex {
178    /// All simplices in filtration order (sorted by `filtration` then `dimension`).
179    pub ordered_simplices: Vec<Simplex>,
180}
181
182impl FilteredComplex {
183    /// Create a filtered complex from a list of simplices.
184    ///
185    /// Simplices are sorted by filtration value then dimension.
186    pub fn new(mut simplices: Vec<Simplex>) -> Self {
187        simplices.sort_by(|a, b| {
188            a.filtration
189                .partial_cmp(&b.filtration)
190                .unwrap_or(std::cmp::Ordering::Equal)
191                .then(a.dimension.cmp(&b.dimension))
192        });
193        Self {
194            ordered_simplices: simplices,
195        }
196    }
197
198    /// Return the number of simplices in this complex.
199    pub fn len(&self) -> usize {
200        self.ordered_simplices.len()
201    }
202
203    /// Return `true` if the complex is empty.
204    pub fn is_empty(&self) -> bool {
205        self.ordered_simplices.is_empty()
206    }
207
208    /// Compute the boundary matrix column for simplex at index `i`.
209    ///
210    /// Returns a sorted vector of indices `j < i` such that
211    /// `ordered_simplices[j]` is a codimension-1 face of `ordered_simplices[i]`.
212    pub fn boundary_column(&self, i: usize) -> Vec<usize> {
213        let sigma = &self.ordered_simplices[i];
214        if sigma.dimension == 0 {
215            return vec![];
216        }
217        let faces = sigma.faces();
218        let mut col = Vec::new();
219        for face in &faces {
220            if let Some(j) = self.ordered_simplices[..i]
221                .iter()
222                .rposition(|s| s.vertices == face.vertices)
223            {
224                col.push(j);
225            }
226        }
227        col.sort_unstable();
228        col
229    }
230
231    /// Run the standard persistence algorithm (column reduction).
232    ///
233    /// Returns the pivot pairs `(i, j)` where column `j` has pivot row `i`,
234    /// meaning simplex `j` dies when simplex `i` is born.
235    pub fn reduce(&self) -> Vec<(usize, usize)> {
236        let n = self.ordered_simplices.len();
237        // Reduced columns stored as sorted Vec<usize>
238        let mut columns: Vec<Vec<usize>> = (0..n).map(|i| self.boundary_column(i)).collect();
239        // low[j] = lowest set bit index in column j, or None
240        let low = |col: &Vec<usize>| col.last().copied();
241        let mut pivot_col: HashMap<usize, usize> = HashMap::new();
242        let mut pairs = Vec::new();
243
244        for j in 0..n {
245            loop {
246                if columns[j].is_empty() {
247                    break;
248                }
249                let l = low(&columns[j]).expect("column j is non-empty");
250                if let Some(&k) = pivot_col.get(&l) {
251                    // Add column k to column j (XOR-style over Z/2Z)
252                    let col_k = columns[k].clone();
253                    let col_j = std::mem::take(&mut columns[j]);
254                    columns[j] = xor_sorted_vecs(col_j, col_k);
255                } else {
256                    pivot_col.insert(l, j);
257                    pairs.push((l, j));
258                    break;
259                }
260            }
261        }
262        pairs
263    }
264}
265
266/// Symmetric difference of two sorted vectors (XOR over Z/2Z).
267fn xor_sorted_vecs(mut a: Vec<usize>, mut b: Vec<usize>) -> Vec<usize> {
268    a.sort_unstable();
269    b.sort_unstable();
270    let mut result = Vec::new();
271    let (mut i, mut j) = (0, 0);
272    while i < a.len() && j < b.len() {
273        match a[i].cmp(&b[j]) {
274            std::cmp::Ordering::Less => {
275                result.push(a[i]);
276                i += 1;
277            }
278            std::cmp::Ordering::Greater => {
279                result.push(b[j]);
280                j += 1;
281            }
282            std::cmp::Ordering::Equal => {
283                // Cancel (both drop out)
284                i += 1;
285                j += 1;
286            }
287        }
288    }
289    result.extend_from_slice(&a[i..]);
290    result.extend_from_slice(&b[j..]);
291    result
292}
293
294// ─────────────────────────────────────────────────────────────────────────────
295// PersistenceDiagram
296// ─────────────────────────────────────────────────────────────────────────────
297
298/// A persistence diagram: collection of birth-death pairs per dimension.
299#[derive(Debug, Clone)]
300pub struct PersistenceDiagram {
301    /// Birth-death pairs `(birth, death)` per homological dimension.
302    pub pairs: Vec<(f64, f64, usize)>,
303    /// Essential classes (never destroyed): `(birth, dimension)`.
304    pub essential: Vec<(f64, usize)>,
305}
306
307impl PersistenceDiagram {
308    /// Create a `PersistenceDiagram` from a `FilteredComplex` using standard reduction.
309    pub fn from_filtered_complex(fc: &FilteredComplex) -> Self {
310        let pairs_idx = fc.reduce();
311        let n = fc.ordered_simplices.len();
312        let mut paired = vec![false; n];
313        let mut pairs = Vec::new();
314        let mut essential = Vec::new();
315
316        for (birth_idx, death_idx) in &pairs_idx {
317            let b = fc.ordered_simplices[*birth_idx].filtration;
318            let d = fc.ordered_simplices[*death_idx].filtration;
319            let dim = fc.ordered_simplices[*birth_idx].dimension;
320            if (d - b).abs() > 1e-15 {
321                pairs.push((b, d, dim));
322            }
323            paired[*birth_idx] = true;
324            paired[*death_idx] = true;
325        }
326        // Unpaired simplices → essential classes
327        for (i, s) in fc.ordered_simplices.iter().enumerate() {
328            if !paired[i] {
329                essential.push((s.filtration, s.dimension));
330            }
331        }
332        Self { pairs, essential }
333    }
334
335    /// Return Betti numbers up to dimension `max_dim` at filtration value `t`.
336    ///
337    /// β_k(t) = number of classes born at or before `t` that are still alive at `t`.
338    pub fn betti_numbers(&self, t: f64, max_dim: usize) -> Vec<usize> {
339        let mut betti = vec![0usize; max_dim + 1];
340        for &(b, d, dim) in &self.pairs {
341            if dim <= max_dim && b <= t && d > t {
342                betti[dim] += 1;
343            }
344        }
345        for &(b, dim) in &self.essential {
346            if dim <= max_dim && b <= t {
347                betti[dim] += 1;
348            }
349        }
350        betti
351    }
352
353    /// Return total persistence (sum of lifetimes for all finite pairs).
354    pub fn total_persistence(&self) -> f64 {
355        self.pairs.iter().map(|(b, d, _)| d - b).sum()
356    }
357
358    /// Return all finite pairs for dimension `k`.
359    pub fn pairs_for_dim(&self, k: usize) -> Vec<(f64, f64)> {
360        self.pairs
361            .iter()
362            .filter(|&&(_, _, d)| d == k)
363            .map(|&(b, d, _)| (b, d))
364            .collect()
365    }
366}
367
368// ─────────────────────────────────────────────────────────────────────────────
369// RipsFiltration
370// ─────────────────────────────────────────────────────────────────────────────
371
372/// Vietoris-Rips filtration built from a point cloud or distance matrix.
373#[derive(Debug, Clone)]
374pub struct RipsFiltration {
375    /// Distance matrix (symmetric, zero diagonal).
376    pub distances: Vec<Vec<f64>>,
377    /// Number of points.
378    pub n_points: usize,
379    /// Maximum filtration radius.
380    pub max_radius: f64,
381    /// Maximum simplex dimension to build.
382    pub max_dim: usize,
383}
384
385impl RipsFiltration {
386    /// Build a Rips filtration from a 2D point cloud (Euclidean distances).
387    ///
388    /// `max_radius` controls how far to grow the filtration.
389    /// `max_dim` controls the maximum simplex dimension built.
390    pub fn from_point_cloud(points: &[[f64; 2]], max_radius: f64, max_dim: usize) -> Self {
391        let n = points.len();
392        let mut distances = vec![vec![0.0_f64; n]; n];
393        for i in 0..n {
394            for j in (i + 1)..n {
395                let dx = points[i][0] - points[j][0];
396                let dy = points[i][1] - points[j][1];
397                let d = (dx * dx + dy * dy).sqrt();
398                distances[i][j] = d;
399                distances[j][i] = d;
400            }
401        }
402        Self {
403            distances,
404            n_points: n,
405            max_radius,
406            max_dim,
407        }
408    }
409
410    /// Build a Rips filtration from a precomputed distance matrix.
411    pub fn from_distance_matrix(distances: Vec<Vec<f64>>, max_radius: f64, max_dim: usize) -> Self {
412        let n = distances.len();
413        Self {
414            distances,
415            n_points: n,
416            max_radius,
417            max_dim,
418        }
419    }
420
421    /// Build a `FilteredComplex` from this filtration.
422    pub fn build_complex(&self) -> FilteredComplex {
423        let mut simplices = Vec::new();
424        let n = self.n_points;
425
426        // 0-simplices (vertices) — born at filtration 0
427        for i in 0..n {
428            simplices.push(Simplex::new(vec![i], 0.0));
429        }
430
431        // 1-simplices (edges)
432        for i in 0..n {
433            for j in (i + 1)..n {
434                let d = self.distances[i][j];
435                if d <= self.max_radius {
436                    simplices.push(Simplex::new(vec![i, j], d));
437                }
438            }
439        }
440
441        // 2-simplices (triangles) if max_dim >= 2
442        if self.max_dim >= 2 {
443            for i in 0..n {
444                for j in (i + 1)..n {
445                    for k in (j + 1)..n {
446                        let d = self.distances[i][j]
447                            .max(self.distances[i][k])
448                            .max(self.distances[j][k]);
449                        if d <= self.max_radius {
450                            simplices.push(Simplex::new(vec![i, j, k], d));
451                        }
452                    }
453                }
454            }
455        }
456
457        // 3-simplices (tetrahedra) if max_dim >= 3
458        if self.max_dim >= 3 && n >= 4 {
459            for i in 0..n {
460                for j in (i + 1)..n {
461                    for k in (j + 1)..n {
462                        for l in (k + 1)..n {
463                            let d = self.distances[i][j]
464                                .max(self.distances[i][k])
465                                .max(self.distances[i][l])
466                                .max(self.distances[j][k])
467                                .max(self.distances[j][l])
468                                .max(self.distances[k][l]);
469                            if d <= self.max_radius {
470                                simplices.push(Simplex::new(vec![i, j, k, l], d));
471                            }
472                        }
473                    }
474                }
475            }
476        }
477
478        FilteredComplex::new(simplices)
479    }
480
481    /// Compute persistence diagram for this Rips filtration.
482    pub fn persistence_diagram(&self) -> PersistenceDiagram {
483        let fc = self.build_complex();
484        PersistenceDiagram::from_filtered_complex(&fc)
485    }
486}
487
488// ─────────────────────────────────────────────────────────────────────────────
489// PersistenceImage
490// ─────────────────────────────────────────────────────────────────────────────
491
492/// Persistence image: a pixelated representation of a persistence diagram.
493///
494/// Points are represented in birth-persistence coordinates, then convolved
495/// with a Gaussian kernel and discretized onto a pixel grid.
496#[derive(Debug, Clone)]
497pub struct PersistenceImage {
498    /// Pixel grid (row-major, rows = persistence axis, cols = birth axis).
499    pub pixels: Vec<Vec<f64>>,
500    /// Number of grid rows (persistence resolution).
501    pub n_rows: usize,
502    /// Number of grid cols (birth resolution).
503    pub n_cols: usize,
504    /// Minimum birth value.
505    pub birth_min: f64,
506    /// Maximum birth value.
507    pub birth_max: f64,
508    /// Maximum persistence value.
509    pub pers_max: f64,
510    /// Gaussian bandwidth (sigma).
511    pub sigma: f64,
512}
513
514impl PersistenceImage {
515    /// Build a persistence image from a persistence diagram.
516    ///
517    /// `n_rows` and `n_cols` control resolution.
518    /// `sigma` is the Gaussian kernel bandwidth.
519    pub fn from_diagram(
520        diagram: &PersistenceDiagram,
521        n_rows: usize,
522        n_cols: usize,
523        sigma: f64,
524    ) -> Self {
525        // Collect finite pairs in birth-persistence coords
526        let pts: Vec<(f64, f64)> = diagram
527            .pairs
528            .iter()
529            .filter(|&&(b, d, _)| d.is_finite() && d > b)
530            .map(|&(b, d, _)| (b, d - b))
531            .collect();
532
533        if pts.is_empty() {
534            return Self {
535                pixels: vec![vec![0.0; n_cols]; n_rows],
536                n_rows,
537                n_cols,
538                birth_min: 0.0,
539                birth_max: 1.0,
540                pers_max: 1.0,
541                sigma,
542            };
543        }
544
545        let birth_min = pts.iter().map(|(b, _)| *b).fold(f64::INFINITY, f64::min);
546        let birth_max = pts
547            .iter()
548            .map(|(b, _)| *b)
549            .fold(f64::NEG_INFINITY, f64::max);
550        let pers_max = pts
551            .iter()
552            .map(|(_, p)| *p)
553            .fold(f64::NEG_INFINITY, f64::max);
554
555        let birth_range = (birth_max - birth_min).max(1e-10);
556        let pers_range = pers_max.max(1e-10);
557
558        let mut pixels = vec![vec![0.0_f64; n_cols]; n_rows];
559        let two_sigma_sq = 2.0 * sigma * sigma;
560
561        for (bx, py) in &pts {
562            // Ramp weight: persistence / max_persistence
563            let weight = py / pers_range;
564            for row in 0..n_rows {
565                // row 0 = high persistence
566                let p_grid =
567                    pers_range * (n_rows - 1 - row) as f64 / (n_rows as f64 - 1.0).max(1.0);
568                for col in 0..n_cols {
569                    let b_grid =
570                        birth_min + birth_range * col as f64 / (n_cols as f64 - 1.0).max(1.0);
571                    let db = bx - b_grid;
572                    let dp = py - p_grid;
573                    let gauss = (-(db * db + dp * dp) / two_sigma_sq).exp();
574                    pixels[row][col] += weight * gauss;
575                }
576            }
577        }
578
579        Self {
580            pixels,
581            n_rows,
582            n_cols,
583            birth_min,
584            birth_max,
585            pers_max,
586            sigma,
587        }
588    }
589
590    /// Flatten the pixel grid to a 1D feature vector (row-major).
591    pub fn to_vector(&self) -> Vec<f64> {
592        self.pixels
593            .iter()
594            .flat_map(|row| row.iter().copied())
595            .collect()
596    }
597
598    /// Return the maximum pixel value.
599    pub fn max_value(&self) -> f64 {
600        self.pixels
601            .iter()
602            .flat_map(|row| row.iter())
603            .cloned()
604            .fold(f64::NEG_INFINITY, f64::max)
605    }
606}
607
608// ─────────────────────────────────────────────────────────────────────────────
609// BarcodeStatistics
610// ─────────────────────────────────────────────────────────────────────────────
611
612/// Summary statistics computed from a persistence barcode.
613#[derive(Debug, Clone)]
614pub struct BarcodeStatistics {
615    /// Sum of finite bar lengths.
616    pub total_persistence: f64,
617    /// Mean finite bar length.
618    pub mean_persistence: f64,
619    /// Median finite bar length.
620    pub median_persistence: f64,
621    /// Maximum finite bar length.
622    pub max_persistence: f64,
623    /// Number of finite bars.
624    pub n_finite: usize,
625    /// Number of essential (infinite) classes.
626    pub n_essential: usize,
627    /// Bottleneck distance approximation to the empty diagram.
628    ///
629    /// Equals the half-persistence of the most persistent finite bar.
630    pub bottleneck_to_empty: f64,
631}
632
633impl BarcodeStatistics {
634    /// Compute statistics from a `PersistenceDiagram`.
635    pub fn from_diagram(diagram: &PersistenceDiagram) -> Self {
636        let mut lifetimes: Vec<f64> = diagram
637            .pairs
638            .iter()
639            .filter(|&&(b, d, _)| d.is_finite() && d > b)
640            .map(|&(b, d, _)| d - b)
641            .collect();
642        lifetimes.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
643
644        let n_finite = lifetimes.len();
645        let total_persistence = lifetimes.iter().sum::<f64>();
646        let mean_persistence = if n_finite > 0 {
647            total_persistence / n_finite as f64
648        } else {
649            0.0
650        };
651        let median_persistence = if n_finite > 0 {
652            if n_finite % 2 == 1 {
653                lifetimes[n_finite / 2]
654            } else {
655                (lifetimes[n_finite / 2 - 1] + lifetimes[n_finite / 2]) / 2.0
656            }
657        } else {
658            0.0
659        };
660        let max_persistence = lifetimes.last().copied().unwrap_or(0.0);
661        let bottleneck_to_empty = max_persistence / 2.0;
662        let n_essential = diagram.essential.len();
663
664        Self {
665            total_persistence,
666            mean_persistence,
667            median_persistence,
668            max_persistence,
669            n_finite,
670            n_essential,
671            bottleneck_to_empty,
672        }
673    }
674
675    /// Approximate bottleneck distance between two diagrams.
676    ///
677    /// Uses greedy matching by L∞ distance in birth-persistence coordinates.
678    pub fn bottleneck_distance(a: &PersistenceDiagram, b: &PersistenceDiagram) -> f64 {
679        let pts_a: Vec<(f64, f64)> = a
680            .pairs
681            .iter()
682            .filter(|&&(b_, d, _)| d.is_finite() && d > b_)
683            .map(|&(b_, d, _)| (b_, d - b_))
684            .collect();
685        let pts_b: Vec<(f64, f64)> = b
686            .pairs
687            .iter()
688            .filter(|&&(b_, d, _)| d.is_finite() && d > b_)
689            .map(|&(b_, d, _)| (b_, d - b_))
690            .collect();
691
692        let linf = |pa: (f64, f64), pb: (f64, f64)| -> f64 {
693            (pa.0 - pb.0).abs().max((pa.1 - pb.1).abs())
694        };
695        let diag_dist = |p: (f64, f64)| -> f64 { p.1 / 2.0 };
696
697        let mut max_dist = 0.0_f64;
698        let mut used = vec![false; pts_b.len()];
699
700        for &pa in &pts_a {
701            let best = pts_b
702                .iter()
703                .enumerate()
704                .filter(|(idx, _)| !used[*idx])
705                .map(|(idx, &pb)| (idx, linf(pa, pb)))
706                .min_by(|x, y| x.1.partial_cmp(&y.1).unwrap_or(std::cmp::Ordering::Equal));
707            match best {
708                Some((idx, d)) if d < diag_dist(pa) => {
709                    used[idx] = true;
710                    max_dist = max_dist.max(d);
711                }
712                _ => {
713                    max_dist = max_dist.max(diag_dist(pa));
714                }
715            }
716        }
717        for (idx, &pb) in pts_b.iter().enumerate() {
718            if !used[idx] {
719                max_dist = max_dist.max(diag_dist(pb));
720            }
721        }
722        max_dist
723    }
724}
725
726// ─────────────────────────────────────────────────────────────────────────────
727// Tests
728// ─────────────────────────────────────────────────────────────────────────────
729
730#[cfg(test)]
731mod tests {
732    use super::*;
733
734    // ── Simplex ───────────────────────────────────────────────────────────────
735
736    #[test]
737    fn test_simplex_vertex() {
738        let s = Simplex::new(vec![3], 0.5);
739        assert_eq!(s.dimension, 0);
740        assert_eq!(s.vertices, vec![3]);
741        assert!((s.filtration - 0.5_f64).abs() < 1e-12);
742    }
743
744    #[test]
745    fn test_simplex_edge_dimension() {
746        let s = Simplex::new(vec![0, 1], 1.0);
747        assert_eq!(s.dimension, 1);
748    }
749
750    #[test]
751    fn test_simplex_triangle_dimension() {
752        let s = Simplex::new(vec![0, 1, 2], 2.0);
753        assert_eq!(s.dimension, 2);
754    }
755
756    #[test]
757    fn test_simplex_vertices_sorted() {
758        let s = Simplex::new(vec![3, 1, 0, 2], 0.0);
759        assert_eq!(s.vertices, vec![0, 1, 2, 3]);
760    }
761
762    #[test]
763    fn test_simplex_vertex_has_no_faces() {
764        let s = Simplex::new(vec![5], 0.0);
765        assert!(s.faces().is_empty());
766    }
767
768    #[test]
769    fn test_simplex_edge_has_two_faces() {
770        let s = Simplex::new(vec![0, 1], 1.0);
771        let faces = s.faces();
772        assert_eq!(faces.len(), 2);
773        assert!(faces.iter().all(|f| f.dimension == 0));
774    }
775
776    #[test]
777    fn test_simplex_triangle_has_three_faces() {
778        let s = Simplex::new(vec![0, 1, 2], 1.5);
779        let faces = s.faces();
780        assert_eq!(faces.len(), 3);
781        assert!(faces.iter().all(|f| f.dimension == 1));
782    }
783
784    #[test]
785    fn test_simplex_contains() {
786        let edge = Simplex::new(vec![0, 1], 1.0);
787        let v0 = Simplex::new(vec![0], 0.0);
788        let v2 = Simplex::new(vec![2], 0.0);
789        assert!(edge.contains(&v0));
790        assert!(!edge.contains(&v2));
791    }
792
793    // ── SimplicialComplex ─────────────────────────────────────────────────────
794
795    #[test]
796    fn test_simplicial_complex_empty() {
797        let sc = SimplicialComplex::new();
798        assert_eq!(sc.count(0), 0);
799        assert_eq!(sc.euler_characteristic(), 0);
800    }
801
802    #[test]
803    fn test_simplicial_complex_add_triangle() {
804        let mut sc = SimplicialComplex::new();
805        sc.add_simplex(Simplex::new(vec![0, 1, 2], 1.0));
806        assert_eq!(sc.count(0), 3); // 3 vertices
807        assert_eq!(sc.count(1), 3); // 3 edges
808        assert_eq!(sc.count(2), 1); // 1 triangle
809    }
810
811    #[test]
812    fn test_simplicial_complex_euler_triangle() {
813        let mut sc = SimplicialComplex::new();
814        sc.add_simplex(Simplex::new(vec![0, 1, 2], 1.0));
815        // χ = V - E + F = 3 - 3 + 1 = 1
816        assert_eq!(sc.euler_characteristic(), 1);
817    }
818
819    #[test]
820    fn test_simplicial_complex_no_duplicates() {
821        let mut sc = SimplicialComplex::new();
822        sc.add_simplex(Simplex::new(vec![0, 1], 1.0));
823        sc.add_simplex(Simplex::new(vec![0, 1], 2.0)); // duplicate
824        assert_eq!(sc.count(1), 1);
825    }
826
827    #[test]
828    fn test_simplicial_complex_boundary_matrix_edge() {
829        let mut sc = SimplicialComplex::new();
830        sc.add_simplex(Simplex::new(vec![0, 1], 1.0));
831        sc.add_simplex(Simplex::new(vec![1, 2], 1.0));
832        let mat = sc.boundary_matrix(1);
833        assert_eq!(mat.len(), 2); // 2 edges
834    }
835
836    #[test]
837    fn test_simplicial_complex_default() {
838        let sc = SimplicialComplex::default();
839        assert_eq!(sc.count(0), 0);
840    }
841
842    // ── FilteredComplex ───────────────────────────────────────────────────────
843
844    #[test]
845    fn test_filtered_complex_ordering() {
846        let simplices = vec![
847            Simplex::new(vec![0, 1], 2.0),
848            Simplex::new(vec![0], 0.0),
849            Simplex::new(vec![1], 0.0),
850        ];
851        let fc = FilteredComplex::new(simplices);
852        // Vertices first (dim 0, filt 0) then edge (dim 1, filt 2)
853        assert_eq!(fc.ordered_simplices[0].dimension, 0);
854        assert_eq!(fc.ordered_simplices[2].dimension, 1);
855    }
856
857    #[test]
858    fn test_filtered_complex_len() {
859        let simplices = vec![
860            Simplex::new(vec![0], 0.0),
861            Simplex::new(vec![1], 0.0),
862            Simplex::new(vec![0, 1], 1.0),
863        ];
864        let fc = FilteredComplex::new(simplices);
865        assert_eq!(fc.len(), 3);
866        assert!(!fc.is_empty());
867    }
868
869    #[test]
870    fn test_filtered_complex_empty() {
871        let fc = FilteredComplex::new(vec![]);
872        assert_eq!(fc.len(), 0);
873        assert!(fc.is_empty());
874    }
875
876    #[test]
877    fn test_filtered_complex_boundary_column_vertex() {
878        let simplices = vec![Simplex::new(vec![0], 0.0)];
879        let fc = FilteredComplex::new(simplices);
880        assert!(fc.boundary_column(0).is_empty());
881    }
882
883    #[test]
884    fn test_filtered_complex_reduce_two_points() {
885        let simplices = vec![
886            Simplex::new(vec![0], 0.0),
887            Simplex::new(vec![1], 0.0),
888            Simplex::new(vec![0, 1], 1.0),
889        ];
890        let fc = FilteredComplex::new(simplices);
891        let pairs = fc.reduce();
892        // The edge should pair with one of the vertices
893        assert!(!pairs.is_empty());
894    }
895
896    // ── PersistenceDiagram ────────────────────────────────────────────────────
897
898    #[test]
899    fn test_persistence_diagram_two_points() {
900        let simplices = vec![
901            Simplex::new(vec![0], 0.0),
902            Simplex::new(vec![1], 0.0),
903            Simplex::new(vec![0, 1], 1.5),
904        ];
905        let fc = FilteredComplex::new(simplices);
906        let diag = PersistenceDiagram::from_filtered_complex(&fc);
907        // H0: one component dies when edge is added
908        assert!(!diag.pairs.is_empty() || !diag.essential.is_empty());
909    }
910
911    #[test]
912    fn test_persistence_diagram_betti_numbers() {
913        let simplices = vec![
914            Simplex::new(vec![0], 0.0),
915            Simplex::new(vec![1], 0.0),
916            Simplex::new(vec![0, 1], 1.0),
917        ];
918        let fc = FilteredComplex::new(simplices);
919        let diag = PersistenceDiagram::from_filtered_complex(&fc);
920        // At t=0.5: 2 components, at t=2.0: 1 component
921        let b0 = diag.betti_numbers(2.0, 1);
922        assert_eq!(b0[0], 1);
923    }
924
925    #[test]
926    fn test_persistence_diagram_total_persistence() {
927        let simplices = vec![
928            Simplex::new(vec![0], 0.0),
929            Simplex::new(vec![1], 0.0),
930            Simplex::new(vec![0, 1], 1.0),
931        ];
932        let fc = FilteredComplex::new(simplices);
933        let diag = PersistenceDiagram::from_filtered_complex(&fc);
934        assert!(diag.total_persistence() >= 0.0);
935    }
936
937    // ── RipsFiltration ────────────────────────────────────────────────────────
938
939    #[test]
940    fn test_rips_from_point_cloud_distances() {
941        let pts = [[0.0_f64, 0.0_f64], [1.0, 0.0], [0.0, 1.0]];
942        let rips = RipsFiltration::from_point_cloud(&pts, 2.0, 2);
943        assert_eq!(rips.n_points, 3);
944        assert!((rips.distances[0][1] - 1.0_f64).abs() < 1e-12);
945    }
946
947    #[test]
948    fn test_rips_build_complex_vertices() {
949        let pts = [[0.0_f64, 0.0_f64], [1.0, 0.0], [2.0, 0.0]];
950        let rips = RipsFiltration::from_point_cloud(&pts, 5.0, 1);
951        let fc = rips.build_complex();
952        let n_vertices = fc
953            .ordered_simplices
954            .iter()
955            .filter(|s| s.dimension == 0)
956            .count();
957        assert_eq!(n_vertices, 3);
958    }
959
960    #[test]
961    fn test_rips_build_complex_edges_within_radius() {
962        let pts = [[0.0_f64, 0.0_f64], [1.0, 0.0], [10.0, 0.0]];
963        let rips = RipsFiltration::from_point_cloud(&pts, 1.5, 1);
964        let fc = rips.build_complex();
965        let n_edges = fc
966            .ordered_simplices
967            .iter()
968            .filter(|s| s.dimension == 1)
969            .count();
970        // Only edge [0,1] has dist 1.0 ≤ 1.5
971        assert_eq!(n_edges, 1);
972    }
973
974    #[test]
975    fn test_rips_from_distance_matrix() {
976        let d = vec![
977            vec![0.0_f64, 1.0, 2.0],
978            vec![1.0, 0.0, 1.5],
979            vec![2.0, 1.5, 0.0],
980        ];
981        let rips = RipsFiltration::from_distance_matrix(d, 2.0, 2);
982        assert_eq!(rips.n_points, 3);
983    }
984
985    #[test]
986    fn test_rips_persistence_diagram_single_point() {
987        let pts = [[0.0_f64, 0.0_f64]];
988        let rips = RipsFiltration::from_point_cloud(&pts, 1.0, 0);
989        let diag = rips.persistence_diagram();
990        // Single point: one essential H0 class
991        assert!(!diag.essential.is_empty());
992    }
993
994    #[test]
995    fn test_rips_triangle_has_triangles() {
996        let pts = [[0.0_f64, 0.0_f64], [1.0, 0.0], [0.5, 1.0]];
997        let rips = RipsFiltration::from_point_cloud(&pts, 3.0, 2);
998        let fc = rips.build_complex();
999        let n_tri = fc
1000            .ordered_simplices
1001            .iter()
1002            .filter(|s| s.dimension == 2)
1003            .count();
1004        assert_eq!(n_tri, 1);
1005    }
1006
1007    // ── PersistenceImage ──────────────────────────────────────────────────────
1008
1009    #[test]
1010    fn test_persistence_image_shape() {
1011        let diag = PersistenceDiagram {
1012            pairs: vec![(0.0, 1.0, 0), (0.5, 2.0, 0)],
1013            essential: vec![],
1014        };
1015        let img = PersistenceImage::from_diagram(&diag, 10, 10, 0.1);
1016        assert_eq!(img.pixels.len(), 10);
1017        assert_eq!(img.pixels[0].len(), 10);
1018    }
1019
1020    #[test]
1021    fn test_persistence_image_non_negative() {
1022        let diag = PersistenceDiagram {
1023            pairs: vec![(0.0, 1.0, 0), (0.5, 2.0, 1)],
1024            essential: vec![],
1025        };
1026        let img = PersistenceImage::from_diagram(&diag, 8, 8, 0.2);
1027        for row in &img.pixels {
1028            for &v in row {
1029                assert!(v >= 0.0);
1030            }
1031        }
1032    }
1033
1034    #[test]
1035    fn test_persistence_image_empty_diagram() {
1036        let diag = PersistenceDiagram {
1037            pairs: vec![],
1038            essential: vec![],
1039        };
1040        let img = PersistenceImage::from_diagram(&diag, 5, 5, 0.1);
1041        let total: f64 = img.to_vector().iter().sum();
1042        assert!(total.abs() < 1e-12);
1043    }
1044
1045    #[test]
1046    fn test_persistence_image_to_vector_length() {
1047        let diag = PersistenceDiagram {
1048            pairs: vec![(0.0, 1.0, 0)],
1049            essential: vec![],
1050        };
1051        let img = PersistenceImage::from_diagram(&diag, 6, 7, 0.1);
1052        assert_eq!(img.to_vector().len(), 42);
1053    }
1054
1055    #[test]
1056    fn test_persistence_image_max_value_positive() {
1057        let diag = PersistenceDiagram {
1058            pairs: vec![(0.0, 1.0, 0)],
1059            essential: vec![],
1060        };
1061        let img = PersistenceImage::from_diagram(&diag, 5, 5, 0.5);
1062        assert!(img.max_value() > 0.0);
1063    }
1064
1065    // ── BarcodeStatistics ─────────────────────────────────────────────────────
1066
1067    #[test]
1068    fn test_barcode_stats_empty() {
1069        let diag = PersistenceDiagram {
1070            pairs: vec![],
1071            essential: vec![],
1072        };
1073        let stats = BarcodeStatistics::from_diagram(&diag);
1074        assert_eq!(stats.n_finite, 0);
1075        assert!((stats.total_persistence).abs() < 1e-12);
1076    }
1077
1078    #[test]
1079    fn test_barcode_stats_single_bar() {
1080        let diag = PersistenceDiagram {
1081            pairs: vec![(0.0, 2.0, 0)],
1082            essential: vec![],
1083        };
1084        let stats = BarcodeStatistics::from_diagram(&diag);
1085        assert_eq!(stats.n_finite, 1);
1086        assert!((stats.total_persistence - 2.0_f64).abs() < 1e-12);
1087        assert!((stats.mean_persistence - 2.0_f64).abs() < 1e-12);
1088        assert!((stats.max_persistence - 2.0_f64).abs() < 1e-12);
1089    }
1090
1091    #[test]
1092    fn test_barcode_stats_median_odd() {
1093        let diag = PersistenceDiagram {
1094            pairs: vec![(0.0, 1.0, 0), (0.0, 2.0, 0), (0.0, 3.0, 0)],
1095            essential: vec![],
1096        };
1097        let stats = BarcodeStatistics::from_diagram(&diag);
1098        assert!((stats.median_persistence - 2.0_f64).abs() < 1e-12);
1099    }
1100
1101    #[test]
1102    fn test_barcode_stats_median_even() {
1103        let diag = PersistenceDiagram {
1104            pairs: vec![(0.0, 1.0, 0), (0.0, 3.0, 0)],
1105            essential: vec![],
1106        };
1107        let stats = BarcodeStatistics::from_diagram(&diag);
1108        assert!((stats.median_persistence - 2.0_f64).abs() < 1e-12);
1109    }
1110
1111    #[test]
1112    fn test_barcode_stats_essential_counted() {
1113        let diag = PersistenceDiagram {
1114            pairs: vec![],
1115            essential: vec![(0.0, 0), (0.0, 1)],
1116        };
1117        let stats = BarcodeStatistics::from_diagram(&diag);
1118        assert_eq!(stats.n_essential, 2);
1119    }
1120
1121    #[test]
1122    fn test_barcode_bottleneck_identical() {
1123        let diag = PersistenceDiagram {
1124            pairs: vec![(0.0, 1.0, 0), (1.0, 3.0, 1)],
1125            essential: vec![],
1126        };
1127        let d = BarcodeStatistics::bottleneck_distance(&diag, &diag);
1128        assert!(d < 1e-12, "identical diagrams have bottleneck distance 0");
1129    }
1130
1131    #[test]
1132    fn test_barcode_bottleneck_nonneg() {
1133        let a = PersistenceDiagram {
1134            pairs: vec![(0.0, 1.0, 0)],
1135            essential: vec![],
1136        };
1137        let b = PersistenceDiagram {
1138            pairs: vec![(0.1, 1.2, 0)],
1139            essential: vec![],
1140        };
1141        assert!(BarcodeStatistics::bottleneck_distance(&a, &b) >= 0.0);
1142    }
1143
1144    #[test]
1145    fn test_barcode_bottleneck_empty_vs_nonempty() {
1146        let empty = PersistenceDiagram {
1147            pairs: vec![],
1148            essential: vec![],
1149        };
1150        let nonempty = PersistenceDiagram {
1151            pairs: vec![(0.0, 2.0, 0)],
1152            essential: vec![],
1153        };
1154        let d = BarcodeStatistics::bottleneck_distance(&empty, &nonempty);
1155        // Should equal diag_dist of (0, 2) = 1.0
1156        assert!((d - 1.0_f64).abs() < 1e-12);
1157    }
1158
1159    #[test]
1160    fn test_xor_sorted_vecs_cancel() {
1161        let a = vec![0, 1, 2];
1162        let b = vec![1, 2, 3];
1163        let result = xor_sorted_vecs(a, b);
1164        assert_eq!(result, vec![0, 3]);
1165    }
1166
1167    #[test]
1168    fn test_xor_sorted_vecs_empty() {
1169        let result = xor_sorted_vecs(vec![], vec![1, 2]);
1170        assert_eq!(result, vec![1, 2]);
1171    }
1172}