use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use super::{AdjacencyGraph, CommunityResult};
use crate::error::{ClusteringError, Result};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LabelPropagationConfig {
pub max_iterations: usize,
pub seed: u64,
#[serde(skip)]
pub seed_labels: Option<Vec<Option<usize>>>,
}
impl Default for LabelPropagationConfig {
fn default() -> Self {
Self {
max_iterations: 100,
seed: 42,
seed_labels: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LabelPropagationResult {
pub community: CommunityResult,
pub iterations_used: usize,
pub converged: bool,
}
struct Xorshift64(u64);
impl Xorshift64 {
fn new(seed: u64) -> Self {
Self(if seed == 0 { 1 } else { seed })
}
fn next_u64(&mut self) -> u64 {
let mut x = self.0;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
self.0 = x;
x
}
fn shuffle(&mut self, slice: &mut [usize]) {
let n = slice.len();
for i in (1..n).rev() {
let j = (self.next_u64() as usize) % (i + 1);
slice.swap(i, j);
}
}
}
pub fn label_propagation_community(
graph: &AdjacencyGraph,
config: &LabelPropagationConfig,
) -> Result<LabelPropagationResult> {
let n = graph.n_nodes;
if n == 0 {
return Err(ClusteringError::InvalidInput(
"Graph has zero nodes".to_string(),
));
}
if let Some(ref sl) = config.seed_labels {
if sl.len() != n {
return Err(ClusteringError::InvalidInput(format!(
"seed_labels length ({}) must equal number of nodes ({})",
sl.len(),
n
)));
}
}
let mut labels: Vec<usize> = Vec::with_capacity(n);
let mut pinned = vec![false; n];
let mut max_seed_label: usize = 0;
if let Some(ref sl) = config.seed_labels {
for opt in sl.iter() {
if let Some(l) = opt {
if *l > max_seed_label {
max_seed_label = *l;
}
}
}
}
let mut next_free_label = max_seed_label + 1;
for i in 0..n {
if let Some(ref sl) = config.seed_labels {
if let Some(l) = sl[i] {
labels.push(l);
pinned[i] = true;
continue;
}
}
labels.push(next_free_label);
next_free_label += 1;
}
let mut rng = Xorshift64::new(config.seed);
let mut converged = false;
let mut iterations_used = 0;
for _iter in 0..config.max_iterations {
iterations_used += 1;
let mut changed = false;
let mut order: Vec<usize> = (0..n).collect();
rng.shuffle(&mut order);
for &v in &order {
if pinned[v] {
continue;
}
let mut votes: HashMap<usize, f64> = HashMap::new();
for &(nb, w) in &graph.adjacency[v] {
*votes.entry(labels[nb]).or_insert(0.0) += w;
}
if votes.is_empty() {
continue;
}
let max_weight = votes.values().cloned().fold(f64::NEG_INFINITY, f64::max);
let mut best_labels: Vec<usize> = votes
.iter()
.filter(|(_, &v)| (v - max_weight).abs() < 1e-12)
.map(|(&l, _)| l)
.collect();
best_labels.sort_unstable();
let chosen_idx = (rng.next_u64() as usize) % best_labels.len();
let chosen = best_labels[chosen_idx];
if chosen != labels[v] {
labels[v] = chosen;
changed = true;
}
}
if !changed {
converged = true;
break;
}
}
let mut mapping: HashMap<usize, usize> = HashMap::new();
let mut next_id = 0usize;
for lbl in &labels {
if !mapping.contains_key(lbl) {
mapping.insert(*lbl, next_id);
next_id += 1;
}
}
let compacted: Vec<usize> = labels
.iter()
.map(|l| mapping.get(l).copied().unwrap_or(0))
.collect();
let num_communities = next_id;
let quality = graph.modularity(&compacted);
Ok(LabelPropagationResult {
community: CommunityResult {
labels: compacted,
num_communities,
quality_score: Some(quality),
},
iterations_used,
converged,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lp_two_cliques() {
let k = 5;
let n = 2 * k;
let mut g = AdjacencyGraph::new(n);
for i in 0..k {
for j in (i + 1)..k {
let _ = g.add_edge(i, j, 1.0);
}
}
for i in k..n {
for j in (i + 1)..n {
let _ = g.add_edge(i, j, 1.0);
}
}
let config = LabelPropagationConfig::default();
let result = label_propagation_community(&g, &config).expect("lp should succeed");
assert_eq!(result.community.num_communities, 2);
let c0 = result.community.labels[0];
for i in 1..k {
assert_eq!(result.community.labels[i], c0);
}
let c1 = result.community.labels[k];
for i in (k + 1)..n {
assert_eq!(result.community.labels[i], c1);
}
assert_ne!(c0, c1);
}
#[test]
fn test_lp_convergence() {
let mut g = AdjacencyGraph::new(4);
let _ = g.add_edge(0, 1, 1.0);
let _ = g.add_edge(2, 3, 1.0);
let config = LabelPropagationConfig::default();
let result = label_propagation_community(&g, &config).expect("lp should succeed");
assert!(result.converged);
assert!(result.iterations_used < config.max_iterations);
}
#[test]
fn test_lp_seed_labels() {
let mut g = AdjacencyGraph::new(4);
let _ = g.add_edge(0, 1, 1.0);
let _ = g.add_edge(1, 2, 1.0);
let _ = g.add_edge(2, 3, 1.0);
let seed_labels = vec![Some(0), None, None, Some(1)];
let config = LabelPropagationConfig {
seed_labels: Some(seed_labels),
..Default::default()
};
let result = label_propagation_community(&g, &config).expect("lp should succeed");
let l0 = result.community.labels[0];
let l3 = result.community.labels[3];
assert_ne!(l0, l3);
}
#[test]
fn test_lp_weighted_edges() {
let mut g = AdjacencyGraph::new(4);
let _ = g.add_edge(0, 1, 10.0);
let _ = g.add_edge(0, 2, 10.0);
let _ = g.add_edge(1, 2, 0.1);
let _ = g.add_edge(1, 3, 10.0);
let _ = g.add_edge(2, 3, 0.1);
let config = LabelPropagationConfig::default();
let result = label_propagation_community(&g, &config).expect("lp should succeed");
assert!(result.community.num_communities >= 1);
assert_eq!(result.community.labels.len(), 4);
}
#[test]
fn test_lp_empty_graph() {
let g = AdjacencyGraph::new(0);
let config = LabelPropagationConfig::default();
assert!(label_propagation_community(&g, &config).is_err());
}
#[test]
fn test_lp_single_node() {
let g = AdjacencyGraph::new(1);
let config = LabelPropagationConfig::default();
let result = label_propagation_community(&g, &config).expect("lp should succeed");
assert_eq!(result.community.num_communities, 1);
assert_eq!(result.community.labels, vec![0]);
}
#[test]
fn test_lp_seed_labels_length_mismatch() {
let g = AdjacencyGraph::new(3);
let config = LabelPropagationConfig {
seed_labels: Some(vec![Some(0), None]),
..Default::default()
};
assert!(label_propagation_community(&g, &config).is_err());
}
#[test]
fn test_lp_isolated_nodes() {
let g = AdjacencyGraph::new(5);
let config = LabelPropagationConfig::default();
let result = label_propagation_community(&g, &config).expect("lp should succeed");
assert_eq!(result.community.num_communities, 5);
}
}