use sheaf::community::CommunityDetection;
use sheaf::{knn_graph_with_config, KnnGraphConfig, Leiden, WeightFunction};
use std::collections::BTreeMap;
fn main() -> Result<(), Box<dyn std::error::Error>> {
let n_per_cluster = 60;
let n_clusters = 4;
let dim = 8;
let n_total = n_per_cluster * n_clusters;
let n_noise = 10;
let centers: Vec<Vec<f32>> = vec![
{
let mut v = vec![0.0f32; dim];
v[0] = 5.0;
v
},
{
let mut v = vec![0.0f32; dim];
v[2] = 5.0;
v
},
{
let mut v = vec![0.0f32; dim];
v[4] = 5.0;
v
},
{
let mut v = vec![0.0f32; dim];
v[6] = 5.0;
v
},
];
let mut embeddings: Vec<Vec<f32>> = Vec::with_capacity(n_total + n_noise);
let mut ground_truth: Vec<usize> = Vec::with_capacity(n_total + n_noise);
let mut lcg_state: u64 = 12345;
let next_uniform = |state: &mut u64| -> f64 {
*state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(*state >> 11) as f64 / (1u64 << 53) as f64
};
let next_normal = |state: &mut u64| -> f64 {
let u1 = next_uniform(state).max(1e-15);
let u2 = next_uniform(state);
(-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
};
let noise_scale = 0.4;
for (cluster_id, center) in centers.iter().enumerate() {
for _ in 0..n_per_cluster {
let point: Vec<f32> = center
.iter()
.map(|&c| c + noise_scale * next_normal(&mut lcg_state) as f32)
.collect();
embeddings.push(point);
ground_truth.push(cluster_id);
}
}
for _ in 0..n_noise {
let point: Vec<f32> = (0..dim)
.map(|_| (next_uniform(&mut lcg_state) * 20.0 - 10.0) as f32)
.collect();
embeddings.push(point);
ground_truth.push(n_clusters); }
let total = embeddings.len();
println!("Generated {total} points: {n_clusters} clusters x {n_per_cluster} + {n_noise} noise, dim={dim}");
let config = KnnGraphConfig {
k: 10,
symmetric: true,
weight_fn: WeightFunction::InverseDistance,
..Default::default()
};
let graph = knn_graph_with_config(&embeddings, &config)?;
println!(
"kNN graph: {} nodes, {} edges",
graph.node_count(),
graph.edge_count()
);
let leiden = Leiden::new().with_resolution(1.0);
let labels = leiden.detect(&graph)?;
let mut by_comm: BTreeMap<usize, Vec<usize>> = BTreeMap::new();
for (idx, &comm) in labels.iter().enumerate() {
by_comm.entry(comm).or_default().push(idx);
}
println!("\nDetected {} communities:", by_comm.len());
println!("{:>5} {:>8}", "comm", "size");
println!("{}", "-".repeat(15));
for (&cid, members) in &by_comm {
println!("{:>5} {:>8}", cid, members.len());
}
let mut correct = 0usize;
for members in by_comm.values() {
let mut label_counts: BTreeMap<usize, usize> = BTreeMap::new();
for &idx in members {
*label_counts.entry(ground_truth[idx]).or_default() += 1;
}
let max_count = label_counts.values().max().copied().unwrap_or(0);
correct += max_count;
}
let purity = correct as f64 / total as f64;
println!("\nPurity: {correct}/{total} = {purity:.4}");
println!("(Purity 1.0 means every detected community contains only one ground-truth cluster.)");
Ok(())
}