use exo_core::EntityId;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct Simplex {
pub vertices: Vec<EntityId>,
}
impl Simplex {
pub fn new(mut vertices: Vec<EntityId>) -> Self {
vertices.sort_by_key(|v| v.0);
vertices.dedup();
Self { vertices }
}
pub fn dimension(&self) -> usize {
self.vertices.len().saturating_sub(1)
}
pub fn faces(&self) -> Vec<Simplex> {
if self.vertices.is_empty() {
return vec![];
}
let mut faces = Vec::new();
for i in 0..self.vertices.len() {
let mut face_vertices = self.vertices.clone();
face_vertices.remove(i);
if !face_vertices.is_empty() {
faces.push(Simplex::new(face_vertices));
}
}
faces
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SimplicialComplex {
simplices: HashMap<usize, HashSet<Simplex>>,
max_dimension: usize,
}
impl SimplicialComplex {
pub fn new() -> Self {
Self {
simplices: HashMap::new(),
max_dimension: 0,
}
}
pub fn add_simplex(&mut self, vertices: &[EntityId]) {
if vertices.is_empty() {
return;
}
let simplex = Simplex::new(vertices.to_vec());
let dim = simplex.dimension();
self.simplices
.entry(dim)
.or_insert_with(HashSet::new)
.insert(simplex.clone());
if dim > self.max_dimension {
self.max_dimension = dim;
}
for face in simplex.faces() {
self.add_simplex(&face.vertices);
}
}
pub fn get_simplices(&self, dimension: usize) -> Vec<Simplex> {
self.simplices
.get(&dimension)
.map(|set| set.iter().cloned().collect())
.unwrap_or_default()
}
pub fn count_simplices(&self, dimension: usize) -> usize {
self.simplices
.get(&dimension)
.map(|set| set.len())
.unwrap_or(0)
}
pub fn betti_number(&self, dimension: usize) -> usize {
if dimension == 0 {
self.count_connected_components()
} else {
0
}
}
fn count_connected_components(&self) -> usize {
let vertices = self.get_simplices(0);
if vertices.is_empty() {
return 0;
}
let mut parent: HashMap<EntityId, EntityId> = HashMap::new();
for simplex in &vertices {
if let Some(v) = simplex.vertices.first() {
parent.insert(*v, *v);
}
}
let edges = self.get_simplices(1);
for edge in edges {
if edge.vertices.len() == 2 {
let v1 = edge.vertices[0];
let v2 = edge.vertices[1];
self.union(&mut parent, v1, v2);
}
}
let mut roots = HashSet::new();
for v in parent.keys() {
roots.insert(self.find(&parent, *v));
}
roots.len()
}
fn find(&self, parent: &HashMap<EntityId, EntityId>, mut x: EntityId) -> EntityId {
while parent.get(&x) != Some(&x) {
if let Some(&p) = parent.get(&x) {
x = p;
} else {
break;
}
}
x
}
fn union(&self, parent: &mut HashMap<EntityId, EntityId>, x: EntityId, y: EntityId) {
let root_x = self.find(parent, x);
let root_y = self.find(parent, y);
if root_x != root_y {
parent.insert(root_x, root_y);
}
}
pub fn filtration(&self, _epsilon_range: (f32, f32)) -> Filtration {
Filtration {
complexes: vec![],
epsilon_values: vec![],
}
}
pub fn persistent_homology(
&self,
_dimension: usize,
_epsilon_range: (f32, f32),
) -> PersistenceDiagram {
PersistenceDiagram { pairs: vec![] }
}
}
impl Default for SimplicialComplex {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Filtration {
pub complexes: Vec<SimplicialComplex>,
pub epsilon_values: Vec<f32>,
}
impl Filtration {
pub fn birth_time(&self, _simplex_index: usize) -> f32 {
0.0
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PersistenceDiagram {
pub pairs: Vec<(f32, f32)>,
}
impl PersistenceDiagram {
pub fn significant_features(&self, min_persistence: f32) -> Vec<(f32, f32)> {
self.pairs
.iter()
.filter(|(birth, death)| {
if death.is_infinite() {
true
} else {
death - birth >= min_persistence
}
})
.copied()
.collect()
}
}
#[allow(dead_code)]
fn column_reduction(_matrix: &BoundaryMatrix) -> BoundaryMatrix {
BoundaryMatrix { columns: vec![] }
}
#[derive(Debug, Clone)]
struct BoundaryMatrix {
columns: Vec<Vec<usize>>,
}
impl BoundaryMatrix {
#[allow(dead_code)]
fn low(&self, _col: usize) -> Option<usize> {
None
}
#[allow(dead_code)]
fn column(&self, _index: usize) -> Vec<usize> {
vec![]
}
#[allow(dead_code)]
fn num_cols(&self) -> usize {
self.columns.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simplex_dimension() {
let e1 = EntityId::new();
let e2 = EntityId::new();
let e3 = EntityId::new();
let s0 = Simplex::new(vec![e1]);
assert_eq!(s0.dimension(), 0);
let s1 = Simplex::new(vec![e1, e2]);
assert_eq!(s1.dimension(), 1);
let s2 = Simplex::new(vec![e1, e2, e3]);
assert_eq!(s2.dimension(), 2);
}
#[test]
fn test_simplex_faces() {
let e1 = EntityId::new();
let e2 = EntityId::new();
let e3 = EntityId::new();
let triangle = Simplex::new(vec![e1, e2, e3]);
let faces = triangle.faces();
assert_eq!(faces.len(), 3);
assert!(faces.iter().all(|f| f.dimension() == 1));
}
#[test]
fn test_simplicial_complex() {
let mut complex = SimplicialComplex::new();
let e1 = EntityId::new();
let e2 = EntityId::new();
let e3 = EntityId::new();
complex.add_simplex(&[e1, e2, e3]);
assert_eq!(complex.count_simplices(0), 3);
assert_eq!(complex.count_simplices(1), 3);
assert_eq!(complex.count_simplices(2), 1);
assert_eq!(complex.betti_number(0), 1);
}
#[test]
fn test_betti_number_disconnected() {
let mut complex = SimplicialComplex::new();
let e1 = EntityId::new();
let e2 = EntityId::new();
let e3 = EntityId::new();
let e4 = EntityId::new();
complex.add_simplex(&[e1, e2]);
complex.add_simplex(&[e3, e4]);
assert_eq!(complex.betti_number(0), 2);
}
}