use super::traits::CommunityDetection;
use crate::error::{Error, Result};
use petgraph::graph::UnGraph;
use petgraph::visit::EdgeRef;
use rand::prelude::*;
#[derive(Debug, Clone)]
pub struct LabelPropagation {
max_iter: usize,
seed: Option<u64>,
}
impl LabelPropagation {
pub fn new() -> Self {
Self {
max_iter: 100,
seed: None,
}
}
pub fn with_max_iter(mut self, max_iter: usize) -> Self {
self.max_iter = max_iter;
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
}
impl Default for LabelPropagation {
fn default() -> Self {
Self::new()
}
}
impl CommunityDetection for LabelPropagation {
fn detect<N, E>(&self, graph: &UnGraph<N, E>) -> Result<Vec<usize>> {
let n = graph.node_count();
if n == 0 {
return Err(Error::EmptyInput);
}
let mut labels: Vec<usize> = (0..n).collect();
let mut rng: Box<dyn RngCore> = match self.seed {
Some(s) => Box::new(StdRng::seed_from_u64(s)),
None => Box::new(rand::rng()),
};
for _iter in 0..self.max_iter {
let mut changed = false;
let mut order: Vec<usize> = (0..n).collect();
order.shuffle(&mut rng);
for &node in &order {
let node_idx = petgraph::graph::NodeIndex::new(node);
let mut label_counts = std::collections::HashMap::new();
for edge in graph.edges(node_idx) {
let neighbor = edge.target().index();
*label_counts.entry(labels[neighbor]).or_insert(0) += 1;
}
if label_counts.is_empty() {
continue;
}
let max_count = label_counts.values().max().copied().unwrap_or(0);
let candidates: Vec<usize> = label_counts
.iter()
.filter(|(_, &count)| count == max_count)
.map(|(&label, _)| label)
.collect();
let new_label = if candidates.len() == 1 {
candidates[0]
} else {
candidates[rng.random_range(0..candidates.len())]
};
if labels[node] != new_label {
labels[node] = new_label;
changed = true;
}
}
if !changed {
break;
}
}
let mut unique: Vec<usize> = labels.to_vec();
unique.sort_unstable();
unique.dedup();
Ok(labels
.iter()
.map(|&l| unique.iter().position(|&u| u == l).unwrap_or(0))
.collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_label_propagation_basic() {
let mut graph = UnGraph::<(), ()>::new_undirected();
let n0 = graph.add_node(());
let n1 = graph.add_node(());
let n2 = graph.add_node(());
let n3 = graph.add_node(());
let _ = graph.add_edge(n0, n1, ());
let _ = graph.add_edge(n2, n3, ());
let lp = LabelPropagation::new().with_seed(42);
let communities = lp.detect(&graph).unwrap();
assert_eq!(communities[0], communities[1]);
assert_eq!(communities[2], communities[3]);
assert_ne!(communities[0], communities[2]);
}
}