use crate::error::{Error, Result};
use faer::Mat;
#[derive(Debug, Clone)]
pub struct CellularSheaf {
num_nodes: usize,
stalk_dims: Vec<usize>,
edges: Vec<(usize, usize)>,
restriction_maps: Vec<(Vec<f64>, Vec<f64>)>,
edge_dims: Vec<usize>,
}
impl CellularSheaf {
pub fn new(
num_nodes: usize,
stalk_dims: Vec<usize>,
edges: Vec<(usize, usize)>,
edge_dims: Vec<usize>,
restriction_maps: Vec<(Vec<f64>, Vec<f64>)>,
) -> Result<Self> {
if stalk_dims.len() != num_nodes {
return Err(Error::DimensionMismatch {
expected: num_nodes,
found: stalk_dims.len(),
});
}
if edge_dims.len() != edges.len() {
return Err(Error::DimensionMismatch {
expected: edges.len(),
found: edge_dims.len(),
});
}
if restriction_maps.len() != edges.len() {
return Err(Error::DimensionMismatch {
expected: edges.len(),
found: restriction_maps.len(),
});
}
for (i, &(u, v)) in edges.iter().enumerate() {
if u >= num_nodes || v >= num_nodes {
return Err(Error::Other(format!(
"edge {i} references node {} but only {num_nodes} nodes exist",
u.max(v)
)));
}
let de = edge_dims[i];
let du = stalk_dims[u];
let dv = stalk_dims[v];
let (ref fu, ref fv) = restriction_maps[i];
if fu.len() != de * du {
return Err(Error::ShapeMismatch {
expected: format!(
"{}x{} = {} entries for source map of edge {i}",
de,
du,
de * du
),
actual: format!("{} entries", fu.len()),
});
}
if fv.len() != de * dv {
return Err(Error::ShapeMismatch {
expected: format!(
"{}x{} = {} entries for target map of edge {i}",
de,
dv,
de * dv
),
actual: format!("{} entries", fv.len()),
});
}
}
Ok(Self {
num_nodes,
stalk_dims,
edges,
restriction_maps,
edge_dims,
})
}
pub fn num_nodes(&self) -> usize {
self.num_nodes
}
pub fn stalk_dims(&self) -> &[usize] {
&self.stalk_dims
}
pub fn edges(&self) -> &[(usize, usize)] {
&self.edges
}
pub fn edge_dims(&self) -> &[usize] {
&self.edge_dims
}
pub fn total_dim(&self) -> usize {
self.stalk_dims.iter().sum()
}
pub fn laplacian(&self) -> Mat<f64> {
let n = self.total_dim();
let mut lap = Mat::<f64>::zeros(n, n);
let offsets: Vec<usize> = {
let mut o = vec![0usize; self.num_nodes + 1];
for i in 0..self.num_nodes {
o[i + 1] = o[i] + self.stalk_dims[i];
}
o
};
for (idx, &(u, v)) in self.edges.iter().enumerate() {
let de = self.edge_dims[idx];
let du = self.stalk_dims[u];
let dv = self.stalk_dims[v];
let (ref fu_flat, ref fv_flat) = self.restriction_maps[idx];
let fu = faer::mat::from_column_major_slice::<f64>(fu_flat, de, du);
let fv = faer::mat::from_column_major_slice::<f64>(fv_flat, de, dv);
let ou = offsets[u];
let ov = offsets[v];
for r in 0..du {
for c in 0..du {
let mut val = 0.0;
for k in 0..de {
val += fu[(k, r)] * fu[(k, c)];
}
lap[(ou + r, ou + c)] += val;
}
}
for r in 0..dv {
for c in 0..dv {
let mut val = 0.0;
for k in 0..de {
val += fv[(k, r)] * fv[(k, c)];
}
lap[(ov + r, ov + c)] += val;
}
}
for r in 0..du {
for c in 0..dv {
let mut val = 0.0;
for k in 0..de {
val += fu[(k, r)] * fv[(k, c)];
}
lap[(ou + r, ov + c)] -= val;
lap[(ov + c, ou + r)] -= val;
}
}
}
lap
}
pub fn h0_dimension(&self, tol: f64) -> usize {
let lap = self.laplacian();
let n = lap.nrows();
if n == 0 {
return 0;
}
eigenvalues_below_tol(&lap, tol)
}
#[allow(clippy::expect_used)]
pub fn trivial(num_nodes: usize, edges: &[(usize, usize)]) -> Self {
let stalk_dims = vec![1; num_nodes];
let edge_dims = vec![1; edges.len()];
let restriction_maps = vec![(vec![1.0], vec![1.0]); edges.len()];
Self::new(
num_nodes,
stalk_dims,
edges.to_vec(),
edge_dims,
restriction_maps,
)
.expect("trivial sheaf edges must reference valid nodes")
}
#[allow(clippy::expect_used)]
pub fn constant(num_nodes: usize, edges: &[(usize, usize)], d: usize) -> Self {
let stalk_dims = vec![d; num_nodes];
let edge_dims = vec![d; edges.len()];
let eye: Vec<f64> = {
let mut m = vec![0.0; d * d];
for i in 0..d {
m[i * d + i] = 1.0; }
m
};
let restriction_maps = vec![(eye.clone(), eye); edges.len()];
Self::new(
num_nodes,
stalk_dims,
edges.to_vec(),
edge_dims,
restriction_maps,
)
.expect("constant sheaf edges must reference valid nodes")
}
}
fn eigenvalues_below_tol(mat: &Mat<f64>, tol: f64) -> usize {
let n = mat.nrows();
let eigenvalues = mat
.as_ref()
.selfadjoint_eigendecomposition(faer::Side::Lower);
let s = eigenvalues.s();
let mut count = 0;
for i in 0..n {
if s.column_vector().read(i).abs() < tol {
count += 1;
}
}
count
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
fn triangle_edges() -> Vec<(usize, usize)> {
vec![(0, 1), (1, 2), (0, 2)]
}
fn path_edges() -> Vec<(usize, usize)> {
vec![(0, 1), (1, 2)]
}
#[test]
fn trivial_sheaf_equals_graph_laplacian_triangle() {
let sheaf = CellularSheaf::trivial(3, &triangle_edges());
let lap = sheaf.laplacian();
assert_eq!(lap.nrows(), 3);
assert_eq!(lap.ncols(), 3);
let expected = [[2.0, -1.0, -1.0], [-1.0, 2.0, -1.0], [-1.0, -1.0, 2.0]];
for r in 0..3 {
for c in 0..3 {
assert!(
(lap[(r, c)] - expected[r][c]).abs() < 1e-12,
"mismatch at ({r},{c}): got {}, expected {}",
lap[(r, c)],
expected[r][c]
);
}
}
}
#[test]
fn trivial_sheaf_equals_graph_laplacian_path() {
let sheaf = CellularSheaf::trivial(3, &path_edges());
let lap = sheaf.laplacian();
let expected = [[1.0, -1.0, 0.0], [-1.0, 2.0, -1.0], [0.0, -1.0, 1.0]];
for r in 0..3 {
for c in 0..3 {
assert!(
(lap[(r, c)] - expected[r][c]).abs() < 1e-12,
"mismatch at ({r},{c}): got {}, expected {}",
lap[(r, c)],
expected[r][c]
);
}
}
}
#[test]
fn laplacian_is_symmetric() {
let edges = triangle_edges();
let stalk_dims = vec![2, 2, 2];
let edge_dims = vec![2, 2, 2];
let maps = vec![
(vec![1.0, 0.0, 0.5, 1.0], vec![1.0, 0.3, 0.0, 1.0]),
(vec![0.8, 0.2, 0.1, 0.9], vec![1.0, 0.0, 0.0, 1.0]),
(vec![1.0, 0.0, 0.0, 1.0], vec![0.7, 0.4, 0.1, 0.6]),
];
let sheaf = CellularSheaf::new(3, stalk_dims, edges, edge_dims, maps).unwrap();
let lap = sheaf.laplacian();
let n = lap.nrows();
for r in 0..n {
for c in 0..n {
assert!(
(lap[(r, c)] - lap[(c, r)]).abs() < 1e-12,
"not symmetric at ({r},{c}): {} vs {}",
lap[(r, c)],
lap[(c, r)]
);
}
}
}
#[test]
fn laplacian_is_positive_semidefinite() {
let edges = triangle_edges();
let stalk_dims = vec![2, 2, 2];
let edge_dims = vec![2, 2, 2];
let maps = vec![
(vec![1.0, 0.0, 0.5, 1.0], vec![1.0, 0.3, 0.0, 1.0]),
(vec![0.8, 0.2, 0.1, 0.9], vec![1.0, 0.0, 0.0, 1.0]),
(vec![1.0, 0.0, 0.0, 1.0], vec![0.7, 0.4, 0.1, 0.6]),
];
let sheaf = CellularSheaf::new(3, stalk_dims, edges, edge_dims, maps).unwrap();
let lap = sheaf.laplacian();
let n = lap.nrows();
let eig = lap
.as_ref()
.selfadjoint_eigendecomposition(faer::Side::Lower);
let s = eig.s();
for i in 0..n {
assert!(
s.column_vector().read(i) >= -1e-10,
"negative eigenvalue at index {i}: {}",
s.column_vector().read(i)
);
}
}
#[test]
fn h0_connected_trivial_is_one() {
let sheaf = CellularSheaf::trivial(3, &triangle_edges());
assert_eq!(sheaf.h0_dimension(1e-8), 1);
}
#[test]
fn h0_disconnected_trivial_equals_components() {
let edges = vec![(0, 1), (2, 3)];
let sheaf = CellularSheaf::trivial(4, &edges);
assert_eq!(sheaf.h0_dimension(1e-8), 2);
}
#[test]
fn h0_isolated_nodes() {
let sheaf = CellularSheaf::trivial(3, &[]);
assert_eq!(sheaf.h0_dimension(1e-8), 3);
let lap = sheaf.laplacian();
for r in 0..3 {
for c in 0..3 {
assert!((lap[(r, c)]).abs() < 1e-12);
}
}
}
#[test]
fn h0_inconsistent_sheaf() {
let sheaf = CellularSheaf::new(
2,
vec![1, 1],
vec![(0, 1)],
vec![1],
vec![(vec![1.0], vec![2.0])],
)
.unwrap();
assert_eq!(sheaf.h0_dimension(1e-8), 1);
}
#[test]
fn h0_overconstrained_sheaf() {
let sheaf = CellularSheaf::new(
3,
vec![1, 1, 1],
vec![(0, 1), (1, 2), (0, 2)],
vec![1, 1, 1],
vec![
(vec![1.0], vec![1.0]),
(vec![1.0], vec![1.0]),
(vec![1.0], vec![-1.0]),
],
)
.unwrap();
assert_eq!(sheaf.h0_dimension(1e-8), 0);
}
#[test]
fn constant_sheaf_h0_equals_components_times_d() {
let sheaf = CellularSheaf::constant(3, &triangle_edges(), 3);
assert_eq!(sheaf.h0_dimension(1e-8), 3);
}
#[test]
fn constant_sheaf_disconnected() {
let edges = vec![(0, 1), (2, 3)];
let sheaf = CellularSheaf::constant(4, &edges, 2);
assert_eq!(sheaf.h0_dimension(1e-8), 4);
}
#[test]
fn validation_rejects_bad_dimensions() {
let result = CellularSheaf::new(
2,
vec![1], vec![(0, 1)],
vec![1],
vec![(vec![1.0], vec![1.0])],
);
assert!(result.is_err());
}
#[test]
fn validation_rejects_bad_map_size() {
let result = CellularSheaf::new(
2,
vec![2, 1],
vec![(0, 1)],
vec![1],
vec![(vec![1.0], vec![1.0])], );
assert!(result.is_err());
}
#[test]
fn validation_rejects_out_of_bounds_node() {
let result = CellularSheaf::new(
2,
vec![1, 1],
vec![(0, 5)], vec![1],
vec![(vec![1.0], vec![1.0])],
);
assert!(result.is_err());
}
}