use crate::error::{Result, TransformError};
use crate::tda::{PersistenceDiagram, VietorisRips};
use scirs2_core::ndarray::Array2;
#[derive(Debug, Clone)]
pub struct VietorisRipsComplex {
pub points: Vec<Vec<f64>>,
pub epsilon: f64,
pub simplices: Vec<Vec<usize>>,
}
impl VietorisRipsComplex {
pub fn new(points: &[Vec<f64>], epsilon: f64) -> Result<Self> {
if points.is_empty() {
return Ok(Self {
points: Vec::new(),
epsilon,
simplices: Vec::new(),
});
}
if epsilon < 0.0 {
return Err(TransformError::InvalidInput(
"epsilon must be non-negative".to_string(),
));
}
let n = points.len();
let dim = points[0].len();
let dist = pairwise_distances(points, dim);
let mut simplices: Vec<Vec<usize>> = Vec::new();
for i in 0..n {
simplices.push(vec![i]);
}
for i in 0..n {
for j in (i + 1)..n {
if dist[i][j] <= epsilon {
simplices.push(vec![i, j]);
}
}
}
for i in 0..n {
for j in (i + 1)..n {
if dist[i][j] > epsilon {
continue;
}
for k in (j + 1)..n {
if dist[i][k] <= epsilon && dist[j][k] <= epsilon {
simplices.push(vec![i, j, k]);
}
}
}
}
simplices.sort_by(|a, b| a.len().cmp(&b.len()).then_with(|| a.cmp(b)));
Ok(Self {
points: points.to_vec(),
epsilon,
simplices,
})
}
pub fn n_simplices(&self, dim: usize) -> usize {
self.simplices.iter().filter(|s| s.len() == dim + 1).count()
}
pub fn euler_characteristic(&self) -> i64 {
let mut chi = 0_i64;
for simplex in &self.simplices {
let k = simplex.len() as i64 - 1;
if k % 2 == 0 {
chi += 1;
} else {
chi -= 1;
}
}
chi
}
pub fn are_connected(&self, u: usize, v: usize) -> bool {
let edge = if u < v { vec![u, v] } else { vec![v, u] };
self.simplices.contains(&edge)
}
pub fn edges(&self) -> Vec<(usize, usize)> {
self.simplices
.iter()
.filter(|s| s.len() == 2)
.map(|s| (s[0], s[1]))
.collect()
}
}
pub fn compute_persistence(
distance_matrix: &[Vec<f64>],
max_dim: usize,
max_epsilon: f64,
) -> Result<Vec<PersistenceDiagram>> {
let n = distance_matrix.len();
if n == 0 {
return Ok((0..=max_dim).map(|d| PersistenceDiagram::new(d)).collect());
}
for row in distance_matrix {
if row.len() != n {
return Err(TransformError::InvalidInput(
"distance_matrix must be square".to_string(),
));
}
}
if max_epsilon < 0.0 {
return Err(TransformError::InvalidInput(
"max_epsilon must be non-negative".to_string(),
));
}
let mut filt_values: Vec<f64> = Vec::new();
for i in 0..n {
for j in (i + 1)..n {
let d = distance_matrix[i][j];
if d <= max_epsilon && d >= 0.0 {
filt_values.push(d);
}
}
}
filt_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
filt_values.dedup_by(|a, b| (*a - *b).abs() < 1e-15);
#[derive(Clone)]
struct FiltSimplex {
vertices: Vec<usize>,
filtration: f64,
}
let mut simplices: Vec<FiltSimplex> = Vec::new();
for i in 0..n {
simplices.push(FiltSimplex {
vertices: vec![i],
filtration: 0.0,
});
}
for i in 0..n {
for j in (i + 1)..n {
let d = distance_matrix[i][j];
if d <= max_epsilon {
simplices.push(FiltSimplex {
vertices: vec![i, j],
filtration: d,
});
}
}
}
if max_dim >= 1 {
for i in 0..n {
for j in (i + 1)..n {
let d_ij = distance_matrix[i][j];
if d_ij > max_epsilon {
continue;
}
for k in (j + 1)..n {
let d_ik = distance_matrix[i][k];
let d_jk = distance_matrix[j][k];
if d_ik > max_epsilon || d_jk > max_epsilon {
continue;
}
let max_d = d_ij.max(d_ik).max(d_jk);
simplices.push(FiltSimplex {
vertices: vec![i, j, k],
filtration: max_d,
});
}
}
}
}
if max_dim >= 2 {
for i in 0..n {
for j in (i + 1)..n {
let d_ij = distance_matrix[i][j];
if d_ij > max_epsilon {
continue;
}
for k in (j + 1)..n {
let d_ik = distance_matrix[i][k];
let d_jk = distance_matrix[j][k];
if d_ik > max_epsilon || d_jk > max_epsilon {
continue;
}
for l in (k + 1)..n {
let d_il = distance_matrix[i][l];
let d_jl = distance_matrix[j][l];
let d_kl = distance_matrix[k][l];
if d_il > max_epsilon || d_jl > max_epsilon || d_kl > max_epsilon {
continue;
}
let max_d = d_ij.max(d_ik).max(d_jk).max(d_il).max(d_jl).max(d_kl);
simplices.push(FiltSimplex {
vertices: vec![i, j, k, l],
filtration: max_d,
});
}
}
}
}
}
simplices.sort_by(|a, b| {
a.filtration
.partial_cmp(&b.filtration)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.vertices.len().cmp(&b.vertices.len()))
});
let total = simplices.len();
let simplex_idx: std::collections::HashMap<Vec<usize>, usize> = simplices
.iter()
.enumerate()
.map(|(i, s)| (s.vertices.clone(), i))
.collect();
let mut boundary: Vec<Vec<usize>> = vec![Vec::new(); total];
for (j, simp) in simplices.iter().enumerate() {
let d = simp.vertices.len();
if d <= 1 {
continue; }
for omit in 0..d {
let face: Vec<usize> = simp
.vertices
.iter()
.enumerate()
.filter(|&(i, _)| i != omit)
.map(|(_, &v)| v)
.collect();
if let Some(&row_idx) = simplex_idx.get(&face) {
boundary[j].push(row_idx);
}
}
boundary[j].sort_unstable();
}
let mut low: Vec<Option<usize>> = vec![None; total];
let mut pivot_col: Vec<Option<usize>> = vec![None; total];
for j in 0..total {
loop {
let lo = boundary[j].last().copied();
match lo {
None => break,
Some(r) => {
if let Some(k) = pivot_col[r] {
let bk = boundary[k].clone();
sym_diff_inplace(&mut boundary[j], &bk);
} else {
low[j] = Some(r);
pivot_col[r] = Some(j);
break;
}
}
}
}
}
let mut diagrams: Vec<PersistenceDiagram> =
(0..=max_dim).map(|d| PersistenceDiagram::new(d)).collect();
let mut paired: Vec<bool> = vec![false; total];
for j in 0..total {
if let Some(r) = low[j] {
let birth = simplices[r].filtration;
let death = simplices[j].filtration;
let feature_dim = simplices[r].vertices.len() - 1;
if feature_dim <= max_dim && (death - birth).abs() > 1e-15 {
diagrams[feature_dim].add_point(birth, death, feature_dim);
}
paired[r] = true;
paired[j] = true;
}
}
for i in 0..total {
if !paired[i] {
let dim = simplices[i].vertices.len() - 1;
if dim <= max_dim {
diagrams[dim].add_point(simplices[i].filtration, f64::INFINITY, dim);
}
}
}
Ok(diagrams)
}
pub fn persistence_landscape_fn(
dgm: &PersistenceDiagram,
n_layers: usize,
x: &[f64],
) -> Vec<Vec<f64>> {
if n_layers == 0 || x.is_empty() {
return vec![vec![0.0; x.len()]; n_layers];
}
let finite_pts: Vec<(f64, f64)> = dgm
.points
.iter()
.filter(|p| p.death.is_finite())
.map(|p| (p.birth, p.death))
.collect();
let nx = x.len();
let mut landscape = vec![vec![0.0_f64; nx]; n_layers];
for (ix, &t) in x.iter().enumerate() {
let mut tents: Vec<f64> = finite_pts
.iter()
.map(|&(b, d)| {
let v = (t - b).min(d - t);
if v < 0.0 {
0.0
} else {
v
}
})
.collect();
tents.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
for k in 0..n_layers {
landscape[k][ix] = tents.get(k).copied().unwrap_or(0.0);
}
}
landscape
}
pub fn persistence_image_fn(
dgm: &PersistenceDiagram,
bandwidth: f64,
grid: (usize, usize),
max_birth: f64,
max_persistence: f64,
) -> Vec<Vec<f64>> {
let (n_rows, n_cols) = grid;
if n_rows == 0 || n_cols == 0 {
return vec![];
}
let bw = bandwidth.max(1e-10);
let two_bw_sq = 2.0 * bw * bw;
let norm_factor = 1.0 / (std::f64::consts::TAU * bw * bw);
let row_centers: Vec<f64> = if n_rows == 1 {
vec![max_persistence * 0.5]
} else {
(0..n_rows)
.map(|i| max_persistence * i as f64 / (n_rows - 1) as f64)
.collect()
};
let col_centers: Vec<f64> = if n_cols == 1 {
vec![max_birth * 0.5]
} else {
(0..n_cols)
.map(|j| max_birth * j as f64 / (n_cols - 1) as f64)
.collect()
};
let pts: Vec<(f64, f64, f64)> = dgm
.points
.iter()
.filter(|p| p.death.is_finite() && p.death > p.birth)
.map(|p| (p.birth, p.death - p.birth, p.death - p.birth)) .collect();
let mut image = vec![vec![0.0_f64; n_cols]; n_rows];
for (r, &p_center) in row_centers.iter().enumerate() {
for (c, &b_center) in col_centers.iter().enumerate() {
let mut val = 0.0_f64;
for &(b, pers, weight) in &pts {
let db = b_center - b;
let dp = p_center - pers;
let exponent = -(db * db + dp * dp) / two_bw_sq;
val += weight * norm_factor * exponent.exp();
}
image[r][c] = val;
}
}
image
}
fn pairwise_distances(points: &[Vec<f64>], dim: usize) -> Vec<Vec<f64>> {
let n = points.len();
let mut dist = vec![vec![0.0_f64; n]; n];
for i in 0..n {
for j in (i + 1)..n {
let mut sq = 0.0_f64;
for d in 0..dim.min(points[i].len()).min(points[j].len()) {
let diff = points[i][d] - points[j][d];
sq += diff * diff;
}
let d = sq.sqrt();
dist[i][j] = d;
dist[j][i] = d;
}
}
dist
}
fn sym_diff_inplace(a: &mut Vec<usize>, b: &[usize]) {
let mut result = Vec::with_capacity(a.len() + b.len());
let mut ai = 0_usize;
let mut bi = 0_usize;
while ai < a.len() && bi < b.len() {
match a[ai].cmp(&b[bi]) {
std::cmp::Ordering::Less => {
result.push(a[ai]);
ai += 1;
}
std::cmp::Ordering::Greater => {
result.push(b[bi]);
bi += 1;
}
std::cmp::Ordering::Equal => {
ai += 1;
bi += 1;
}
}
}
while ai < a.len() {
result.push(a[ai]);
ai += 1;
}
while bi < b.len() {
result.push(b[bi]);
bi += 1;
}
*a = result;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tda::PersistenceDiagram;
fn square_dist() -> Vec<Vec<f64>> {
vec![
vec![0.0, 1.0, 1.414, 1.0],
vec![1.0, 0.0, 1.0, 1.414],
vec![1.414, 1.0, 0.0, 1.0],
vec![1.0, 1.414, 1.0, 0.0],
]
}
fn square_points() -> Vec<Vec<f64>> {
vec![
vec![0.0, 0.0],
vec![1.0, 0.0],
vec![1.0, 1.0],
vec![0.0, 1.0],
]
}
#[test]
fn test_vrc_vertices() {
let pts = square_points();
let vrc = VietorisRipsComplex::new(&pts, 1.5).expect("new");
assert_eq!(vrc.n_simplices(0), 4, "Should have 4 vertices");
}
#[test]
fn test_vrc_edges_unit_square() {
let pts = square_points();
let vrc = VietorisRipsComplex::new(&pts, 1.0).expect("new");
assert_eq!(vrc.n_simplices(1), 4, "Unit square at eps=1 has 4 edges");
assert_eq!(vrc.n_simplices(2), 0, "No triangles at eps=1");
}
#[test]
fn test_vrc_complete_graph() {
let pts = square_points();
let vrc = VietorisRipsComplex::new(&pts, 2.0).expect("new");
assert_eq!(
vrc.n_simplices(1),
6,
"Complete graph on 4 vertices has 6 edges"
);
assert_eq!(vrc.n_simplices(2), 4, "4 triangles in K4");
}
#[test]
fn test_vrc_euler_characteristic() {
let pts = square_points();
let vrc = VietorisRipsComplex::new(&pts, 1.0).expect("new");
assert_eq!(vrc.euler_characteristic(), 0);
}
#[test]
fn test_vrc_empty_input() {
let vrc = VietorisRipsComplex::new(&[], 1.0).expect("empty ok");
assert_eq!(vrc.n_simplices(0), 0);
assert_eq!(vrc.euler_characteristic(), 0);
}
#[test]
fn test_vrc_negative_epsilon_error() {
let pts = square_points();
assert!(VietorisRipsComplex::new(&pts, -0.1).is_err());
}
#[test]
fn test_vrc_are_connected() {
let pts = square_points();
let vrc = VietorisRipsComplex::new(&pts, 1.0).expect("new");
assert!(vrc.are_connected(0, 1));
assert!(vrc.are_connected(1, 2));
assert!(!vrc.are_connected(0, 2));
}
#[test]
fn test_compute_persistence_h0_square() {
let dist = square_dist();
let diagrams = compute_persistence(&dist, 1, 2.0).expect("persistence");
assert_eq!(diagrams.len(), 2);
let h0 = &diagrams[0];
assert!(!h0.is_empty(), "H0 should not be empty");
}
#[test]
fn test_compute_persistence_empty() {
let diagrams = compute_persistence(&[], 1, 1.0).expect("empty");
assert_eq!(diagrams.len(), 2); assert!(diagrams[0].is_empty());
assert!(diagrams[1].is_empty());
}
#[test]
fn test_compute_persistence_non_square_error() {
let dist = vec![vec![0.0, 1.0], vec![1.0, 0.0, 2.0]];
assert!(compute_persistence(&dist, 1, 2.0).is_err());
}
#[test]
fn test_compute_persistence_returns_finite_pairs() {
let dist = square_dist();
let diagrams = compute_persistence(&dist, 1, 2.0).expect("persistence");
for dgm in &diagrams {
for pt in &dgm.points {
assert!(pt.birth.is_finite());
assert!(pt.birth >= 0.0);
if pt.death.is_finite() {
assert!(pt.death >= pt.birth);
}
}
}
}
#[test]
fn test_landscape_fn_shape() {
let mut dgm = PersistenceDiagram::new(0);
dgm.add_point(0.0, 2.0, 0);
dgm.add_point(0.5, 1.5, 0);
let x: Vec<f64> = (0..20).map(|i| i as f64 * 0.1).collect();
let l = persistence_landscape_fn(&dgm, 3, &x);
assert_eq!(l.len(), 3);
assert_eq!(l[0].len(), 20);
}
#[test]
fn test_landscape_fn_non_negative() {
let mut dgm = PersistenceDiagram::new(0);
dgm.add_point(0.0, 1.0, 0);
let x: Vec<f64> = (0..10).map(|i| i as f64 * 0.15).collect();
let l = persistence_landscape_fn(&dgm, 2, &x);
for row in &l {
for &v in row {
assert!(v >= 0.0, "landscape must be non-negative, got {v}");
}
}
}
#[test]
fn test_landscape_fn_tent_shape() {
let mut dgm = PersistenceDiagram::new(0);
dgm.add_point(0.0, 1.0, 0);
let x = vec![0.0, 0.25, 0.5, 0.75, 1.0];
let l = persistence_landscape_fn(&dgm, 1, &x);
assert!((l[0][2] - 0.5).abs() < 1e-10, "peak should be 0.5");
assert!(l[0][0] < 1e-10);
assert!(l[0][4] < 1e-10);
}
#[test]
fn test_landscape_fn_empty_diagram() {
let dgm = PersistenceDiagram::new(0);
let x = vec![0.0, 1.0, 2.0];
let l = persistence_landscape_fn(&dgm, 2, &x);
assert_eq!(l.len(), 2);
for row in &l {
assert!(row.iter().all(|&v| v == 0.0));
}
}
#[test]
fn test_persistence_image_fn_shape() {
let mut dgm = PersistenceDiagram::new(0);
dgm.add_point(0.0, 1.0, 0);
dgm.add_point(0.2, 0.8, 0);
let img = persistence_image_fn(&dgm, 0.1, (5, 5), 1.0, 1.0);
assert_eq!(img.len(), 5);
assert_eq!(img[0].len(), 5);
}
#[test]
fn test_persistence_image_fn_non_negative() {
let mut dgm = PersistenceDiagram::new(0);
dgm.add_point(0.0, 1.0, 0);
let img = persistence_image_fn(&dgm, 0.1, (4, 4), 1.0, 1.0);
for row in &img {
for &v in row {
assert!(v >= 0.0, "image pixel must be non-negative, got {v}");
}
}
}
#[test]
fn test_persistence_image_fn_has_signal() {
let mut dgm = PersistenceDiagram::new(0);
dgm.add_point(0.0, 1.0, 0);
let img = persistence_image_fn(&dgm, 0.15, (6, 6), 1.5, 1.5);
let has_positive = img.iter().flat_map(|row| row.iter()).any(|&v| v > 0.0);
assert!(
has_positive,
"image should have nonzero pixels for a nonempty diagram"
);
}
#[test]
fn test_persistence_image_fn_empty_diagram() {
let dgm = PersistenceDiagram::new(0);
let img = persistence_image_fn(&dgm, 0.1, (4, 4), 1.0, 1.0);
assert_eq!(img.len(), 4);
for row in &img {
assert!(row.iter().all(|&v| v == 0.0));
}
}
#[test]
fn test_sym_diff_inplace() {
let mut a = vec![1_usize, 3, 5];
let b = vec![2, 3, 4];
sym_diff_inplace(&mut a, &b);
assert_eq!(a, vec![1, 2, 4, 5]);
}
}