#![allow(clippy::expect_used)]
use clump::Kmeans;
use sheaf::CellularSheaf;
const K: usize = 3;
const NODES_PER_COMMUNITY: usize = 4;
const N: usize = K * NODES_PER_COMMUNITY;
#[allow(clippy::expect_used)]
fn main() {
println!("=== Sheaf Spectral Clustering ===\n");
let (edges, intra_mask) = build_graph();
println!("Graph: {N} nodes, {} edges", edges.len());
println!(
" intra-community: {} inter-community: {}\n",
intra_mask.iter().filter(|&&b| b).count(),
intra_mask.iter().filter(|&&b| !b).count(),
);
let sheaf = build_sheaf(&edges, &intra_mask);
let sheaf_labels = spectral_cluster(&sheaf, K);
println!("Sheaf-based clustering (rotation maps on bridges):");
print_labels(&sheaf_labels);
let sheaf_acc = cluster_accuracy(&sheaf_labels);
println!(" accuracy: {:.0}%\n", sheaf_acc * 100.0);
let trivial = CellularSheaf::trivial(N, &edges);
let trivial_labels = spectral_cluster(&trivial, K);
println!("Standard spectral clustering (trivial sheaf):");
print_labels(&trivial_labels);
let trivial_acc = cluster_accuracy(&trivial_labels);
println!(" accuracy: {:.0}%\n", trivial_acc * 100.0);
println!("--- Summary ---\n");
println!(
"Sheaf clustering accuracy: {:.0}% (rotation maps encode community frames)",
sheaf_acc * 100.0
);
println!(
"Standard clustering accuracy: {:.0}% (topology only, no edge structure)",
trivial_acc * 100.0
);
if sheaf_acc > trivial_acc {
println!("\nSheaf Laplacian eigenvectors provide richer features for k-means.");
}
assert!(
sheaf_acc >= 1.0,
"Sheaf clustering should perfectly separate communities"
);
}
fn build_graph() -> (Vec<(usize, usize)>, Vec<bool>) {
let mut edges = Vec::new();
let mut intra = Vec::new();
for c in 0..K {
let base = c * NODES_PER_COMMUNITY;
for i in 0..NODES_PER_COMMUNITY {
for j in (i + 1)..NODES_PER_COMMUNITY {
edges.push((base + i, base + j));
intra.push(true);
}
}
}
let bridges = [
(2, 4),
(2, 5),
(3, 4),
(3, 5),
(6, 8),
(6, 9),
(7, 8),
(7, 9),
(0, 10),
(0, 11),
(1, 10),
(1, 11),
];
for &(u, v) in &bridges {
edges.push((u, v));
intra.push(false);
}
(edges, intra)
}
fn build_sheaf(edges: &[(usize, usize)], intra_mask: &[bool]) -> CellularSheaf {
let stalk_dim = 2; let stalk_dims = vec![stalk_dim; N];
let edge_dims = vec![stalk_dim; edges.len()];
let eye = vec![1.0, 0.0, 0.0, 1.0];
let rot90 = vec![0.0, 1.0, -1.0, 0.0];
let restriction_maps: Vec<(Vec<f64>, Vec<f64>)> = intra_mask
.iter()
.map(|&is_intra| {
if is_intra {
(eye.clone(), eye.clone())
} else {
(eye.clone(), rot90.clone())
}
})
.collect();
CellularSheaf::new(N, stalk_dims, edges.to_vec(), edge_dims, restriction_maps)
.expect("valid sheaf")
}
fn spectral_cluster(sheaf: &CellularSheaf, k: usize) -> Vec<usize> {
let lap = sheaf.laplacian();
let n = lap.nrows();
let eig = lap
.as_ref()
.selfadjoint_eigendecomposition(faer::Side::Lower);
let s = eig.s();
let u = eig.u();
let mut order: Vec<usize> = (0..n).collect();
order.sort_by(|&a, &b| {
s.column_vector()
.read(a)
.total_cmp(&s.column_vector().read(b))
});
let nonzero_start = order
.iter()
.position(|&i| s.column_vector().read(i) > 1e-8)
.unwrap_or(n);
let selected: Vec<usize> = order[nonzero_start..].iter().copied().take(k).collect();
let stalk_dims = sheaf.stalk_dims();
let mut offset = 0;
let mut node_features: Vec<Vec<f32>> = Vec::with_capacity(N);
for &sd in stalk_dims {
let mut feat = vec![0.0f32; selected.len()];
for row in offset..(offset + sd) {
for (j, &col) in selected.iter().enumerate() {
feat[j] += u.read(row, col) as f32;
}
}
for v in &mut feat {
*v /= sd as f32;
}
node_features.push(feat);
offset += sd;
}
Kmeans::new(k)
.with_seed(42)
.fit_predict(&node_features)
.expect("k-means succeeds")
}
fn cluster_accuracy(labels: &[usize]) -> f64 {
let perms = [
[0, 1, 2],
[0, 2, 1],
[1, 0, 2],
[1, 2, 0],
[2, 0, 1],
[2, 1, 0],
];
let ground_truth: Vec<usize> = (0..N).map(|i| i / NODES_PER_COMMUNITY).collect();
let mut best = 0usize;
for perm in &perms {
let correct = labels
.iter()
.zip(&ground_truth)
.filter(|(&predicted, >)| perm[predicted] == gt)
.count();
best = best.max(correct);
}
best as f64 / N as f64
}
fn print_labels(labels: &[usize]) {
for c in 0..K {
let base = c * NODES_PER_COMMUNITY;
let node_labels: Vec<usize> = (base..base + NODES_PER_COMMUNITY)
.map(|i| labels[i])
.collect();
println!(
" community {c} (nodes {base}-{}): labels {:?}",
base + NODES_PER_COMMUNITY - 1,
node_labels
);
}
}