use crate::error::SheafError;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CellularSheaf {
pub graph: Vec<Vec<usize>>,
pub stalk_dims: Vec<usize>,
pub restriction_maps: Vec<(usize, usize, Vec<Vec<f64>>)>,
}
impl CellularSheaf {
pub fn constant(n: usize, stalk_dim: usize) -> Result<Self, SheafError> {
if n == 0 {
return Err(SheafError::EmptySheaf);
}
let graph = vec![vec![]; n];
let stalk_dims = vec![stalk_dim; n];
Ok(Self {
graph,
stalk_dims,
restriction_maps: vec![],
})
}
pub fn path(n: usize, stalk_dim: usize) -> Result<Self, SheafError> {
if n == 0 {
return Err(SheafError::EmptySheaf);
}
let mut graph = vec![vec![]; n];
let mut restriction_maps = Vec::new();
let identity = identity_matrix(stalk_dim);
for i in 0..n.saturating_sub(1) {
graph[i].push(i + 1);
graph[i + 1].push(i);
restriction_maps.push((i, i + 1, identity.clone()));
}
Ok(Self {
graph,
stalk_dims: vec![stalk_dim; n],
restriction_maps,
})
}
pub fn cycle(n: usize, stalk_dim: usize) -> Result<Self, SheafError> {
if n < 3 {
return Err(SheafError::EmptySheaf);
}
let mut sheaf = Self::path(n, stalk_dim)?;
let identity = identity_matrix(stalk_dim);
sheaf.graph[n - 1].push(0);
sheaf.graph[0].push(n - 1);
sheaf.restriction_maps.push((n - 1, 0, identity));
Ok(sheaf)
}
pub fn complete(n: usize, stalk_dim: usize) -> Result<Self, SheafError> {
if n == 0 {
return Err(SheafError::EmptySheaf);
}
let mut graph = vec![vec![]; n];
let mut restriction_maps = Vec::new();
let identity = identity_matrix(stalk_dim);
for i in 0..n {
for j in (i + 1)..n {
graph[i].push(j);
graph[j].push(i);
restriction_maps.push((i, j, identity.clone()));
}
}
Ok(Self {
graph,
stalk_dims: vec![stalk_dim; n],
restriction_maps,
})
}
pub fn builder() -> SheafBuilder {
SheafBuilder::default()
}
pub fn node_count(&self) -> usize {
self.stalk_dims.len()
}
pub fn total_dim(&self) -> usize {
self.stalk_dims.iter().sum()
}
pub fn validate(&self) -> Result<(), SheafError> {
if self.stalk_dims.is_empty() {
return Err(SheafError::EmptySheaf);
}
for (i, j, mat) in &self.restriction_maps {
let max = self.stalk_dims.len();
if *i >= max || *j >= max {
return Err(SheafError::InvalidEdge(*i, *j));
}
let expected_rows = self.stalk_dims[*j];
let expected_cols = self.stalk_dims[*i];
let got_rows = mat.len();
let got_cols = mat.first().map_or(0, |r| r.len());
if got_rows != expected_rows || got_cols != expected_cols {
return Err(SheafError::DimensionMismatch {
edge: (*i, *j),
expected_rows,
expected_cols,
got_rows,
got_cols,
});
}
}
Ok(())
}
pub fn get_restriction_map(&self, i: usize, j: usize) -> Option<&Vec<Vec<f64>>> {
self.restriction_maps
.iter()
.find(|(src, tgt, _)| (*src, *tgt) == (i, j) || (*src, *tgt) == (j, i))
.map(|(_, _, mat)| mat)
}
}
fn identity_matrix(n: usize) -> Vec<Vec<f64>> {
let mut m = vec![vec![0.0; n]; n];
for (i, row) in m.iter_mut().enumerate() {
row[i] = 1.0;
}
m
}
#[derive(Debug, Default)]
pub struct SheafBuilder {
graph: Vec<Vec<usize>>,
stalk_dims: Vec<usize>,
restriction_maps: Vec<(usize, usize, Vec<Vec<f64>>)>,
}
impl SheafBuilder {
pub fn add_node(mut self, stalk_dim: usize) -> Self {
self.graph.push(vec![]);
self.stalk_dims.push(stalk_dim);
self
}
pub fn add_edge(mut self, i: usize, j: usize, map: Vec<Vec<f64>>) -> Self {
if i < self.graph.len() && j < self.graph.len() {
self.graph[i].push(j);
self.graph[j].push(i);
self.restriction_maps.push((i, j, map));
}
self
}
pub fn build(self) -> Result<CellularSheaf, SheafError> {
if self.stalk_dims.is_empty() {
return Err(SheafError::EmptySheaf);
}
let sheaf = CellularSheaf {
graph: self.graph,
stalk_dims: self.stalk_dims,
restriction_maps: self.restriction_maps,
};
sheaf.validate()?;
Ok(sheaf)
}
}