use std::collections::HashMap;
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, thiserror::Error)]
pub enum SheafError {
#[error("Node {0} not found")]
NodeNotFound(usize),
#[error("Edge ({0}, {1}) not found")]
EdgeNotFound(usize, usize),
#[error("Dimension mismatch: expected {expected}, got {actual}")]
DimensionMismatch {
expected: usize,
actual: usize,
},
#[error("Invalid restriction: {0}")]
InvalidRestriction(String),
}
#[derive(Debug, Clone)]
pub struct SheafEdge {
pub source: usize,
pub target: usize,
pub restriction_source: DenseRestriction,
pub restriction_target: DenseRestriction,
pub weight: f32,
}
#[derive(Debug, Clone)]
pub struct DiffusionConfig {
pub num_steps: usize,
pub step_size: f32,
}
impl Default for DiffusionConfig {
fn default() -> Self {
Self {
num_steps: 5,
step_size: 0.1,
}
}
}
#[derive(Debug, Clone)]
pub struct DenseRestriction {
pub data: Vec<f32>,
pub rows: usize,
pub cols: usize,
}
impl DenseRestriction {
pub fn new(data: Vec<f32>, rows: usize, cols: usize) -> Result<Self, SheafError> {
if data.len() != rows * cols {
return Err(SheafError::DimensionMismatch {
expected: rows * cols,
actual: data.len(),
});
}
Ok(Self { data, rows, cols })
}
pub fn identity(dim: usize) -> Self {
let mut data = vec![0.0; dim * dim];
for i in 0..dim {
data[i * dim + i] = 1.0;
}
Self {
data,
rows: dim,
cols: dim,
}
}
pub fn in_dim(&self) -> usize {
self.cols
}
pub fn out_dim(&self) -> usize {
self.rows
}
pub fn apply(&self, x: &[f32]) -> Result<Vec<f32>, SheafError> {
if x.len() != self.cols {
return Err(SheafError::DimensionMismatch {
expected: self.cols,
actual: x.len(),
});
}
let mut result = vec![0.0; self.rows];
#[allow(clippy::needless_range_loop)]
for i in 0..self.rows {
for j in 0..self.cols {
result[i] += self.data[i * self.cols + j] * x[j];
}
}
Ok(result)
}
pub fn apply_transpose(&self, x: &[f32]) -> Result<Vec<f32>, SheafError> {
if x.len() != self.rows {
return Err(SheafError::DimensionMismatch {
expected: self.rows,
actual: x.len(),
});
}
let mut result = vec![0.0; self.cols];
#[allow(clippy::needless_range_loop)]
for j in 0..self.cols {
for i in 0..self.rows {
result[j] += self.data[i * self.cols + j] * x[i];
}
}
Ok(result)
}
pub fn as_matrix(&self) -> Vec<Vec<f32>> {
let mut matrix = vec![vec![0.0; self.cols]; self.rows];
#[allow(clippy::needless_range_loop)]
for i in 0..self.rows {
for j in 0..self.cols {
matrix[i][j] = self.data[i * self.cols + j];
}
}
matrix
}
pub fn frobenius_norm(&self) -> f32 {
self.data.iter().map(|x| x * x).sum::<f32>().sqrt()
}
}
#[derive(Debug, Clone)]
pub struct VecStalk {
value: Vec<f32>,
}
impl VecStalk {
pub fn new(value: Vec<f32>) -> Self {
Self { value }
}
pub fn dim(&self) -> usize {
self.value.len()
}
pub fn value(&self) -> &Vec<f32> {
&self.value
}
pub fn set_value(&mut self, v: Vec<f32>) -> Result<(), SheafError> {
if v.len() != self.value.len() {
return Err(SheafError::DimensionMismatch {
expected: self.value.len(),
actual: v.len(),
});
}
self.value = v;
Ok(())
}
pub fn zero(&self) -> Vec<f32> {
vec![0.0; self.value.len()]
}
}
#[derive(Debug, Clone)]
pub struct SimpleSheafGraph {
stalks: Vec<VecStalk>,
edges: Vec<SheafEdge>,
adjacency: HashMap<usize, Vec<usize>>,
}
impl SimpleSheafGraph {
pub fn new() -> Self {
Self {
stalks: Vec::new(),
edges: Vec::new(),
adjacency: HashMap::new(),
}
}
pub fn add_node(&mut self, value: Vec<f32>) -> usize {
let id = self.stalks.len();
self.stalks.push(VecStalk::new(value));
self.adjacency.insert(id, Vec::new());
id
}
pub fn add_edge(
&mut self,
source: usize,
target: usize,
restriction_source: DenseRestriction,
restriction_target: DenseRestriction,
weight: f32,
) -> Result<(), SheafError> {
if source >= self.stalks.len() {
return Err(SheafError::NodeNotFound(source));
}
if target >= self.stalks.len() {
return Err(SheafError::NodeNotFound(target));
}
if restriction_source.in_dim() != self.stalks[source].dim() {
return Err(SheafError::DimensionMismatch {
expected: self.stalks[source].dim(),
actual: restriction_source.in_dim(),
});
}
if restriction_target.in_dim() != self.stalks[target].dim() {
return Err(SheafError::DimensionMismatch {
expected: self.stalks[target].dim(),
actual: restriction_target.in_dim(),
});
}
if restriction_source.out_dim() != restriction_target.out_dim() {
return Err(SheafError::InvalidRestriction(
"Source and target restrictions must have same output dimension".into(),
));
}
self.edges.push(SheafEdge {
source,
target,
restriction_source,
restriction_target,
weight,
});
self.adjacency.entry(source).or_default().push(target);
self.adjacency.entry(target).or_default().push(source);
Ok(())
}
pub fn num_nodes(&self) -> usize {
self.stalks.len()
}
pub fn num_edges(&self) -> usize {
self.edges.len()
}
pub fn stalk(&self, node: usize) -> Result<&VecStalk, SheafError> {
self.stalks.get(node).ok_or(SheafError::NodeNotFound(node))
}
pub fn stalk_mut(&mut self, node: usize) -> Result<&mut VecStalk, SheafError> {
self.stalks
.get_mut(node)
.ok_or(SheafError::NodeNotFound(node))
}
pub fn edge(&self, source: usize, target: usize) -> Result<&SheafEdge, SheafError> {
self.edges
.iter()
.find(|e| {
(e.source == source && e.target == target)
|| (e.source == target && e.target == source)
})
.ok_or(SheafError::EdgeNotFound(source, target))
}
pub fn edges(&self) -> impl Iterator<Item = &SheafEdge> {
self.edges.iter()
}
pub fn neighbors(&self, node: usize) -> Result<Vec<usize>, SheafError> {
self.adjacency
.get(&node)
.cloned()
.ok_or(SheafError::NodeNotFound(node))
}
pub fn dirichlet_energy(&self) -> Result<f32, SheafError> {
let mut energy = 0.0;
for edge in &self.edges {
let x_u = self.stalks[edge.source].value();
let x_v = self.stalks[edge.target].value();
let r_u = edge.restriction_source.apply(x_u)?;
let r_v = edge.restriction_target.apply(x_v)?;
let diff_sq: f32 = r_u
.iter()
.zip(r_v.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum();
energy += edge.weight * diff_sq;
}
Ok(energy)
}
pub fn laplacian_at(&self, node: usize) -> Result<Vec<f32>, SheafError> {
let stalk = self.stalk(node)?;
let mut result = stalk.zero();
for edge in &self.edges {
let (is_source, other) = if edge.source == node {
(true, edge.target)
} else if edge.target == node {
(false, edge.source)
} else {
continue;
};
let x_node = self.stalks[node].value();
let x_other = self.stalks[other].value();
let (r_node, r_other) = if is_source {
(&edge.restriction_source, &edge.restriction_target)
} else {
(&edge.restriction_target, &edge.restriction_source)
};
let r_x_node = r_node.apply(x_node)?;
let r_x_other = r_other.apply(x_other)?;
let diff: Vec<f32> = r_x_node
.iter()
.zip(r_x_other.iter())
.map(|(a, b)| a - b)
.collect();
let contrib = r_node.apply_transpose(&diff)?;
for (i, c) in contrib.iter().enumerate() {
result[i] += edge.weight * c;
}
}
Ok(result)
}
pub fn diffusion_step(&mut self, step_size: f32) -> Result<(), SheafError> {
let laplacians: Vec<Vec<f32>> = (0..self.num_nodes())
.map(|i| self.laplacian_at(i))
.collect::<Result<_, _>>()?;
for (i, lap) in laplacians.into_iter().enumerate() {
let stalk = &mut self.stalks[i];
let new_value: Vec<f32> = stalk
.value()
.iter()
.zip(lap.iter())
.map(|(x, l)| x - step_size * l)
.collect();
stalk.set_value(new_value)?;
}
Ok(())
}
}
impl Default for SimpleSheafGraph {
fn default() -> Self {
Self::new()
}
}
pub fn consistency_score(graph: &SimpleSheafGraph) -> Result<f32, SheafError> {
let energy = graph.dirichlet_energy()?;
Ok((-energy).exp())
}
pub fn diffuse_until_convergence(
graph: &mut SimpleSheafGraph,
config: &DiffusionConfig,
tolerance: f32,
) -> Result<usize, SheafError> {
let mut prev_energy = graph.dirichlet_energy()?;
for step in 0..config.num_steps {
graph.diffusion_step(config.step_size)?;
let energy = graph.dirichlet_energy()?;
if (prev_energy - energy).abs() < tolerance {
return Ok(step + 1);
}
prev_energy = energy;
}
Ok(config.num_steps)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_identity_restriction() {
let r = DenseRestriction::identity(3);
let x = vec![1.0, 2.0, 3.0];
let y = r.apply(&x).unwrap();
assert_eq!(y, x);
}
#[test]
fn test_restriction_transpose() {
let r = DenseRestriction::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3).unwrap();
let x = vec![1.0, 2.0, 3.0];
let y = r.apply(&x).unwrap();
assert_eq!(y.len(), 2);
let z = vec![1.0, 1.0];
let w = r.apply_transpose(&z).unwrap();
assert_eq!(w.len(), 3);
assert_eq!(w, vec![5.0, 7.0, 9.0]);
}
#[test]
fn test_simple_sheaf_graph() {
let mut graph = SimpleSheafGraph::new();
let n0 = graph.add_node(vec![1.0, 0.0]);
let n1 = graph.add_node(vec![0.0, 1.0]);
let r = DenseRestriction::identity(2);
graph.add_edge(n0, n1, r.clone(), r.clone(), 1.0).unwrap();
assert_eq!(graph.num_nodes(), 2);
assert_eq!(graph.num_edges(), 1);
let energy = graph.dirichlet_energy().unwrap();
assert!((energy - 2.0).abs() < 1e-6);
}
#[test]
fn test_diffusion_reduces_energy() {
let mut graph = SimpleSheafGraph::new();
let n0 = graph.add_node(vec![1.0, 0.0]);
let n1 = graph.add_node(vec![0.5, 0.5]);
let n2 = graph.add_node(vec![0.0, 1.0]);
let r = DenseRestriction::identity(2);
graph.add_edge(n0, n1, r.clone(), r.clone(), 1.0).unwrap();
graph.add_edge(n1, n2, r.clone(), r.clone(), 1.0).unwrap();
let initial_energy = graph.dirichlet_energy().unwrap();
for _ in 0..10 {
graph.diffusion_step(0.1).unwrap();
}
let final_energy = graph.dirichlet_energy().unwrap();
assert!(
final_energy < initial_energy,
"Diffusion should reduce energy"
);
}
#[test]
fn test_consistency_score() {
let mut graph = SimpleSheafGraph::new();
graph.add_node(vec![1.0, 2.0]);
graph.add_node(vec![1.0, 2.0]);
let r = DenseRestriction::identity(2);
graph.add_edge(0, 1, r.clone(), r.clone(), 1.0).unwrap();
let score = consistency_score(&graph).unwrap();
assert!((score - 1.0).abs() < 1e-6);
}
#[test]
fn test_dense_restriction_new_dimension_mismatch() {
let result = DenseRestriction::new(vec![1.0, 2.0, 3.0], 2, 2);
assert!(matches!(
result,
Err(SheafError::DimensionMismatch {
expected: 4,
actual: 3
})
));
}
#[test]
fn test_dense_restriction_1x1() {
let r = DenseRestriction::new(vec![3.0], 1, 1).unwrap();
let x = vec![2.0];
let y = r.apply(&x).unwrap();
assert_eq!(y, vec![6.0]);
let yt = r.apply_transpose(&[2.0]).unwrap();
assert_eq!(yt, vec![6.0]); }
#[test]
fn test_dense_restriction_apply_wrong_dim() {
let r = DenseRestriction::identity(3);
let x = vec![1.0, 2.0]; let result = r.apply(&x);
assert!(matches!(
result,
Err(SheafError::DimensionMismatch {
expected: 3,
actual: 2
})
));
}
#[test]
fn test_dense_restriction_apply_transpose_wrong_dim() {
let r = DenseRestriction::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3).unwrap();
let x = vec![1.0, 2.0, 3.0]; let result = r.apply_transpose(&x);
assert!(matches!(
result,
Err(SheafError::DimensionMismatch {
expected: 2,
actual: 3
})
));
}
#[test]
fn test_dense_restriction_as_matrix() {
let r = DenseRestriction::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3).unwrap();
let m = r.as_matrix();
assert_eq!(m.len(), 2);
assert_eq!(m[0], vec![1.0, 2.0, 3.0]);
assert_eq!(m[1], vec![4.0, 5.0, 6.0]);
}
#[test]
fn test_dense_restriction_frobenius_norm() {
let r = DenseRestriction::new(vec![3.0, 4.0], 1, 2).unwrap();
let norm = r.frobenius_norm();
assert!((norm - 5.0).abs() < 1e-6); }
#[test]
fn test_identity_restriction_is_identity() {
let r = DenseRestriction::identity(4);
assert_eq!(r.in_dim(), 4);
assert_eq!(r.out_dim(), 4);
let x = vec![1.0, 2.0, 3.0, 4.0];
assert_eq!(r.apply(&x).unwrap(), x);
assert_eq!(r.apply_transpose(&x).unwrap(), x); }
#[test]
fn test_vec_stalk_set_value_dimension_mismatch() {
let mut s = VecStalk::new(vec![1.0, 2.0]);
let result = s.set_value(vec![1.0]);
assert!(matches!(
result,
Err(SheafError::DimensionMismatch {
expected: 2,
actual: 1
})
));
}
#[test]
fn test_vec_stalk_zero() {
let s = VecStalk::new(vec![5.0, 6.0, 7.0]);
assert_eq!(s.zero(), vec![0.0, 0.0, 0.0]);
}
#[test]
fn test_vec_stalk_roundtrip() {
let mut s = VecStalk::new(vec![1.0, 2.0]);
s.set_value(vec![3.0, 4.0]).unwrap();
assert_eq!(s.value(), &vec![3.0, 4.0]);
assert_eq!(s.dim(), 2);
}
#[test]
fn test_add_edge_source_not_found() {
let mut graph = SimpleSheafGraph::new();
graph.add_node(vec![1.0]);
let r = DenseRestriction::identity(1);
let result = graph.add_edge(5, 0, r.clone(), r.clone(), 1.0);
assert!(matches!(result, Err(SheafError::NodeNotFound(5))));
}
#[test]
fn test_add_edge_target_not_found() {
let mut graph = SimpleSheafGraph::new();
graph.add_node(vec![1.0]);
let r = DenseRestriction::identity(1);
let result = graph.add_edge(0, 99, r.clone(), r.clone(), 1.0);
assert!(matches!(result, Err(SheafError::NodeNotFound(99))));
}
#[test]
fn test_add_edge_restriction_dim_mismatch_source() {
let mut graph = SimpleSheafGraph::new();
graph.add_node(vec![1.0, 2.0]); graph.add_node(vec![1.0, 2.0]); let r_wrong = DenseRestriction::identity(3); let r_ok = DenseRestriction::identity(2);
let result = graph.add_edge(0, 1, r_wrong, r_ok, 1.0);
assert!(matches!(result, Err(SheafError::DimensionMismatch { .. })));
}
#[test]
fn test_add_edge_restriction_output_dim_mismatch() {
let mut graph = SimpleSheafGraph::new();
graph.add_node(vec![1.0, 2.0]);
graph.add_node(vec![1.0, 2.0]);
let r_src = DenseRestriction::new(vec![1.0; 6], 3, 2).unwrap();
let r_tgt = DenseRestriction::identity(2);
let result = graph.add_edge(0, 1, r_src, r_tgt, 1.0);
assert!(matches!(result, Err(SheafError::InvalidRestriction(_))));
}
#[test]
fn test_stalk_not_found() {
let graph = SimpleSheafGraph::new();
assert!(matches!(graph.stalk(0), Err(SheafError::NodeNotFound(0))));
}
#[test]
fn test_edge_not_found() {
let mut graph = SimpleSheafGraph::new();
graph.add_node(vec![1.0]);
graph.add_node(vec![1.0]);
assert!(matches!(
graph.edge(0, 1),
Err(SheafError::EdgeNotFound(0, 1))
));
}
#[test]
fn test_neighbors_not_found() {
let graph = SimpleSheafGraph::new();
assert!(matches!(
graph.neighbors(0),
Err(SheafError::NodeNotFound(0))
));
}
#[test]
fn test_edge_lookup_bidirectional() {
let mut graph = SimpleSheafGraph::new();
graph.add_node(vec![1.0]);
graph.add_node(vec![2.0]);
let r = DenseRestriction::identity(1);
graph.add_edge(0, 1, r.clone(), r.clone(), 1.0).unwrap();
assert!(graph.edge(0, 1).is_ok());
assert!(graph.edge(1, 0).is_ok());
}
#[test]
fn test_neighbors_bidirectional() {
let mut graph = SimpleSheafGraph::new();
graph.add_node(vec![1.0]);
graph.add_node(vec![2.0]);
graph.add_node(vec![3.0]);
let r = DenseRestriction::identity(1);
graph.add_edge(0, 1, r.clone(), r.clone(), 1.0).unwrap();
graph.add_edge(1, 2, r.clone(), r.clone(), 1.0).unwrap();
let n1 = graph.neighbors(1).unwrap();
assert_eq!(n1.len(), 2); }
#[test]
fn test_dirichlet_energy_zero_for_identical_stalks() {
let mut graph = SimpleSheafGraph::new();
graph.add_node(vec![1.0, 2.0, 3.0]);
graph.add_node(vec![1.0, 2.0, 3.0]);
graph.add_node(vec![1.0, 2.0, 3.0]);
let r = DenseRestriction::identity(3);
graph.add_edge(0, 1, r.clone(), r.clone(), 1.0).unwrap();
graph.add_edge(1, 2, r.clone(), r.clone(), 1.0).unwrap();
let energy = graph.dirichlet_energy().unwrap();
assert!(
(energy - 0.0).abs() < 1e-6,
"identical stalks should have zero energy"
);
}
#[test]
fn test_dirichlet_energy_weighted() {
let mut graph = SimpleSheafGraph::new();
graph.add_node(vec![1.0, 0.0]);
graph.add_node(vec![0.0, 1.0]);
let r = DenseRestriction::identity(2);
graph.add_edge(0, 1, r.clone(), r.clone(), 2.0).unwrap();
let energy = graph.dirichlet_energy().unwrap();
assert!((energy - 4.0).abs() < 1e-6);
}
#[test]
fn test_laplacian_at_zero_for_consistent_signal() {
let mut graph = SimpleSheafGraph::new();
graph.add_node(vec![1.0, 2.0]);
graph.add_node(vec![1.0, 2.0]);
let r = DenseRestriction::identity(2);
graph.add_edge(0, 1, r.clone(), r.clone(), 1.0).unwrap();
let lap = graph.laplacian_at(0).unwrap();
assert!(
lap.iter().all(|&x| x.abs() < 1e-6),
"Laplacian should be zero for consistent signal"
);
}
#[test]
fn test_laplacian_symmetry() {
let mut graph = SimpleSheafGraph::new();
graph.add_node(vec![1.0, 0.0]);
graph.add_node(vec![0.0, 1.0]);
let r = DenseRestriction::identity(2);
graph.add_edge(0, 1, r.clone(), r.clone(), 1.0).unwrap();
let lap0 = graph.laplacian_at(0).unwrap();
let lap1 = graph.laplacian_at(1).unwrap();
for i in 0..2 {
assert!(
(lap0[i] + lap1[i]).abs() < 1e-6,
"Laplacian should sum to zero"
);
}
}
#[test]
fn test_diffuse_until_convergence_identical_stalks() {
let mut graph = SimpleSheafGraph::new();
graph.add_node(vec![1.0, 1.0]);
graph.add_node(vec![1.0, 1.0]);
let r = DenseRestriction::identity(2);
graph.add_edge(0, 1, r.clone(), r.clone(), 1.0).unwrap();
let config = DiffusionConfig {
num_steps: 100,
step_size: 0.1,
};
let steps = diffuse_until_convergence(&mut graph, &config, 1e-8).unwrap();
assert!(
steps <= 2,
"already-converged graph should converge immediately, took {steps}"
);
}
#[test]
fn test_diffuse_until_convergence_reaches_max_steps() {
let mut graph = SimpleSheafGraph::new();
graph.add_node(vec![100.0, 0.0]);
graph.add_node(vec![0.0, 100.0]);
let r = DenseRestriction::identity(2);
graph.add_edge(0, 1, r.clone(), r.clone(), 1.0).unwrap();
let config = DiffusionConfig {
num_steps: 3,
step_size: 0.01, };
let steps = diffuse_until_convergence(&mut graph, &config, 1e-12).unwrap();
assert_eq!(steps, 3, "should reach max steps");
}
#[test]
fn test_consistency_score_decreases_with_distance() {
let mut g1 = SimpleSheafGraph::new();
g1.add_node(vec![1.0, 0.0]);
g1.add_node(vec![0.9, 0.1]);
let r = DenseRestriction::identity(2);
g1.add_edge(0, 1, r.clone(), r.clone(), 1.0).unwrap();
let score1 = consistency_score(&g1).unwrap();
let mut g2 = SimpleSheafGraph::new();
g2.add_node(vec![1.0, 0.0]);
g2.add_node(vec![0.0, 1.0]);
g2.add_edge(0, 1, r.clone(), r.clone(), 1.0).unwrap();
let score2 = consistency_score(&g2).unwrap();
assert!(
score1 > score2,
"closer stalks should have higher consistency"
);
}
#[test]
fn test_diffusion_config_default() {
let config = DiffusionConfig::default();
assert_eq!(config.num_steps, 5);
assert!((config.step_size - 0.1).abs() < 1e-6);
}
#[test]
fn test_sheaf_error_display() {
assert_eq!(
format!("{}", SheafError::NodeNotFound(5)),
"Node 5 not found"
);
assert_eq!(
format!("{}", SheafError::EdgeNotFound(1, 2)),
"Edge (1, 2) not found"
);
assert_eq!(
format!(
"{}",
SheafError::DimensionMismatch {
expected: 3,
actual: 2
}
),
"Dimension mismatch: expected 3, got 2"
);
assert!(format!("{}", SheafError::InvalidRestriction("bad".into())).contains("bad"));
}
#[test]
fn test_simple_sheaf_graph_default() {
let graph = SimpleSheafGraph::default();
assert_eq!(graph.num_nodes(), 0);
assert_eq!(graph.num_edges(), 0);
}
#[test]
fn test_non_square_restriction_maps() {
let mut graph = SimpleSheafGraph::new();
graph.add_node(vec![1.0, 0.0, 0.0]);
graph.add_node(vec![0.0, 1.0, 0.0]);
let proj = DenseRestriction::new(vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0], 2, 3).unwrap();
graph
.add_edge(0, 1, proj.clone(), proj.clone(), 1.0)
.unwrap();
let energy = graph.dirichlet_energy().unwrap();
assert!((energy - 2.0).abs() < 1e-6);
}
#[test]
fn test_diffusion_with_non_square_restrictions() {
let mut graph = SimpleSheafGraph::new();
graph.add_node(vec![1.0, 0.0, 0.0]);
graph.add_node(vec![0.0, 1.0, 0.0]);
let proj = DenseRestriction::new(vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0], 2, 3).unwrap();
graph
.add_edge(0, 1, proj.clone(), proj.clone(), 1.0)
.unwrap();
let initial_energy = graph.dirichlet_energy().unwrap();
graph.diffusion_step(0.1).unwrap();
let final_energy = graph.dirichlet_energy().unwrap();
assert!(
final_energy < initial_energy,
"diffusion should reduce energy with non-square maps"
);
}
#[test]
fn test_empty_graph_energy() {
let graph = SimpleSheafGraph::new();
let energy = graph.dirichlet_energy().unwrap();
assert_eq!(energy, 0.0);
}
#[test]
fn test_single_node_graph() {
let mut graph = SimpleSheafGraph::new();
graph.add_node(vec![1.0, 2.0]);
assert_eq!(graph.num_nodes(), 1);
assert_eq!(graph.num_edges(), 0);
let energy = graph.dirichlet_energy().unwrap();
assert_eq!(energy, 0.0);
}
}