use crate::node::NodeId;
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
pub struct LeidenConfig {
pub min_community_size: usize,
pub max_iterations: usize,
pub compute_centroids: bool,
}
impl Default for LeidenConfig {
fn default() -> Self {
Self {
min_community_size: 3,
max_iterations: 15,
compute_centroids: true,
}
}
}
#[derive(Debug, Clone)]
pub struct LeidenResult {
pub node_to_cluster: HashMap<NodeId, u32>,
pub cluster_sizes: HashMap<u32, usize>,
pub centroids: HashMap<u32, Vec<f32>>,
pub num_clusters: u32,
}
pub struct AdjacencySnapshot {
pub edges: HashMap<NodeId, Vec<(NodeId, f32)>>,
pub node_ids: Vec<NodeId>,
}
pub fn run_leiden(
adj: &AdjacencySnapshot,
config: &LeidenConfig,
) -> LeidenResult {
let nodes = &adj.node_ids;
if nodes.is_empty() {
return LeidenResult {
node_to_cluster: HashMap::new(),
cluster_sizes: HashMap::new(),
centroids: HashMap::new(),
num_clusters: 0,
};
}
let mut node_to_cluster: HashMap<NodeId, u32> = HashMap::with_capacity(nodes.len());
for (i, &n) in nodes.iter().enumerate() {
node_to_cluster.insert(n, i as u32);
}
let mut changed = true;
let mut iters_left = config.max_iterations;
while changed && iters_left > 0 {
changed = false;
iters_left -= 1;
for &n in nodes {
let current_c = match node_to_cluster.get(&n) {
Some(&c) => c,
None => continue,
};
if let Some(neighbors) = adj.edges.get(&n) {
let mut votes: HashMap<u32, f32> = HashMap::new();
for &(target, weight) in neighbors {
if let Some(&nc) = node_to_cluster.get(&target) {
*votes.entry(nc).or_insert(0.0) += weight;
}
}
if let Some((&best_c, _)) = votes
.iter()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
{
if best_c != current_c {
node_to_cluster.insert(n, best_c);
changed = true;
}
}
}
}
}
let mut cluster_counts: HashMap<u32, usize> = HashMap::new();
for &c in node_to_cluster.values() {
*cluster_counts.entry(c).or_insert(0) += 1;
}
let valid_clusters: HashSet<u32> = cluster_counts
.iter()
.filter(|(_, count)| **count >= config.min_community_size)
.map(|(&c, _)| c)
.collect();
let mut remap: HashMap<u32, u32> = HashMap::new();
let mut new_id = 1u32;
let mut sorted_valid: Vec<u32> = valid_clusters.into_iter().collect();
sorted_valid.sort_unstable();
for c in sorted_valid {
remap.insert(c, new_id);
new_id += 1;
}
let mut final_map: HashMap<NodeId, u32> = HashMap::new();
for (&n, &c) in &node_to_cluster {
if let Some(&nc) = remap.get(&c) {
final_map.insert(n, nc);
}
}
let mut cluster_sizes: HashMap<u32, usize> = HashMap::new();
for &c in final_map.values() {
*cluster_sizes.entry(c).or_insert(0) += 1;
}
let num_clusters = cluster_sizes.len() as u32;
LeidenResult {
node_to_cluster: final_map,
cluster_sizes,
centroids: HashMap::new(), num_clusters,
}
}
pub fn compute_centroids(
result: &mut LeidenResult,
vectors: &HashMap<NodeId, Vec<f32>>,
dim: usize,
) {
let mut cluster_sums: HashMap<u32, Vec<f64>> = HashMap::new();
let mut cluster_counts: HashMap<u32, usize> = HashMap::new();
for (&node_id, &cluster_id) in &result.node_to_cluster {
if let Some(vec) = vectors.get(&node_id) {
let sum = cluster_sums
.entry(cluster_id)
.or_insert_with(|| vec![0.0f64; dim]);
for i in 0..dim.min(vec.len()) {
sum[i] += vec[i] as f64;
}
*cluster_counts.entry(cluster_id).or_insert(0) += 1;
}
}
for (&c, sum) in &cluster_sums {
let count = cluster_counts.get(&c).copied().unwrap_or(1) as f64;
let centroid: Vec<f32> = sum.iter().map(|&s| (s / count) as f32).collect();
result.centroids.insert(c, centroid);
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_snapshot(edges: Vec<(NodeId, NodeId, f32)>) -> AdjacencySnapshot {
let mut adj: HashMap<NodeId, Vec<(NodeId, f32)>> = HashMap::new();
let mut all_ids: HashSet<NodeId> = HashSet::new();
for (src, dst, w) in edges {
adj.entry(src).or_default().push((dst, w));
adj.entry(dst).or_default().push((src, w)); all_ids.insert(src);
all_ids.insert(dst);
}
AdjacencySnapshot {
edges: adj,
node_ids: all_ids.into_iter().collect(),
}
}
#[test]
fn test_empty_graph() {
let snap = AdjacencySnapshot {
edges: HashMap::new(),
node_ids: vec![],
};
let result = run_leiden(&snap, &LeidenConfig::default());
assert_eq!(result.num_clusters, 0);
}
#[test]
fn test_two_cliques() {
let snap = make_snapshot(vec![
(1, 2, 1.0), (1, 3, 1.0), (2, 3, 1.0), (4, 5, 1.0), (4, 6, 1.0), (5, 6, 1.0), ]);
let result = run_leiden(&snap, &LeidenConfig { min_community_size: 3, ..Default::default() });
assert_eq!(result.num_clusters, 2, "应发现 2 个社区");
assert_eq!(result.node_to_cluster[&1], result.node_to_cluster[&2]);
assert_eq!(result.node_to_cluster[&4], result.node_to_cluster[&5]);
assert_ne!(result.node_to_cluster[&1], result.node_to_cluster[&4]);
}
#[test]
fn test_fragment_filtering() {
let snap = make_snapshot(vec![
(1, 2, 1.0), (1, 3, 1.0), (2, 3, 1.0),
(4, 5, 1.0),
]);
let result = run_leiden(&snap, &LeidenConfig { min_community_size: 3, ..Default::default() });
assert_eq!(result.num_clusters, 1, "碎片簇应被过滤");
assert!(result.node_to_cluster.contains_key(&1));
assert!(!result.node_to_cluster.contains_key(&4), "碎片节点不应出现");
}
#[test]
fn test_centroid_computation() {
let snap = make_snapshot(vec![
(1, 2, 1.0), (1, 3, 1.0), (2, 3, 1.0),
]);
let mut result = run_leiden(&snap, &LeidenConfig { min_community_size: 3, ..Default::default() });
let mut vectors = HashMap::new();
vectors.insert(1u64, vec![1.0f32, 0.0, 0.0]);
vectors.insert(2, vec![0.0, 1.0, 0.0]);
vectors.insert(3, vec![0.0, 0.0, 1.0]);
compute_centroids(&mut result, &vectors, 3);
assert_eq!(result.centroids.len(), 1);
let c = result.centroids.values().next().unwrap();
assert!((c[0] - 1.0/3.0).abs() < 0.01);
}
}