use std::collections::HashMap;
use scirs2_core::random::{Rng, SeedableRng, StdRng};
use crate::error::{GraphError, Result};
use super::louvain::compact_communities;
fn build_adj(edges: &[(usize, usize, f64)], n: usize) -> Vec<Vec<(usize, f64)>> {
let mut adj: Vec<Vec<(usize, f64)>> = vec![vec![]; n];
for &(u, v, w) in edges {
if u < n && v < n {
adj[u].push((v, w));
if u != v {
adj[v].push((u, w));
}
}
}
adj
}
fn plurality_label(
node: usize,
adj: &[Vec<(usize, f64)>],
labels: &[usize],
rng: &mut impl Rng,
) -> Option<usize> {
let mut label_weight: HashMap<usize, f64> = HashMap::new();
for &(nbr, w) in &adj[node] {
if nbr < labels.len() {
*label_weight.entry(labels[nbr]).or_insert(0.0) += w;
}
}
if label_weight.is_empty() {
return None;
}
let max_w = label_weight
.values()
.cloned()
.fold(f64::NEG_INFINITY, f64::max);
let best: Vec<usize> = label_weight
.iter()
.filter(|(_, &w)| (w - max_w).abs() < 1e-12)
.map(|(&l, _)| l)
.collect();
if best.len() == 1 {
Some(best[0])
} else {
Some(best[rng.random_range(0..best.len())])
}
}
pub fn label_propagation_edge_list(
adj: &[(usize, usize, f64)],
n_nodes: usize,
max_iter: usize,
) -> Result<Vec<usize>> {
if n_nodes == 0 {
return Err(GraphError::InvalidGraph("label_propagation: n_nodes must be > 0".into()));
}
let graph = build_adj(adj, n_nodes);
let mut labels: Vec<usize> = (0..n_nodes).collect();
let mut rng = StdRng::seed_from_u64(0x1234567890abcdef);
for _ in 0..max_iter {
let prev_labels = labels.clone();
let mut changed = false;
let mut order: Vec<usize> = (0..n_nodes).collect();
for i in (1..n_nodes).rev() {
let j = rng.random_range(0..=i);
order.swap(i, j);
}
for &node in &order {
if let Some(chosen) = plurality_label(node, &graph, &prev_labels, &mut rng) {
if chosen != labels[node] {
labels[node] = chosen;
changed = true;
}
}
}
if !changed {
break;
}
}
compact_communities(&mut labels);
Ok(labels)
}
pub fn async_label_propagation(
adj: &[(usize, usize, f64)],
n_nodes: usize,
max_iter: usize,
) -> Result<Vec<usize>> {
if n_nodes == 0 {
return Err(GraphError::InvalidGraph(
"async_label_propagation: n_nodes must be > 0".into(),
));
}
let graph = build_adj(adj, n_nodes);
let mut labels: Vec<usize> = (0..n_nodes).collect();
let mut rng = StdRng::seed_from_u64(0xfedcba9876543210);
for _ in 0..max_iter {
let mut changed = false;
let mut order: Vec<usize> = (0..n_nodes).collect();
for i in (1..n_nodes).rev() {
let j = rng.random_range(0..=i);
order.swap(i, j);
}
for &node in &order {
if let Some(chosen) = plurality_label(node, &graph, &labels, &mut rng) {
if chosen != labels[node] {
labels[node] = chosen;
changed = true;
}
}
}
if !changed {
break;
}
}
compact_communities(&mut labels);
Ok(labels)
}
#[cfg(test)]
mod tests {
use super::*;
fn two_clique_edges(k: usize) -> (Vec<(usize, usize, f64)>, usize) {
let n = 2 * k;
let mut edges = Vec::new();
for i in 0..k {
for j in (i + 1)..k {
edges.push((i, j, 1.0));
edges.push((k + i, k + j, 1.0));
}
}
edges.push((0, k, 0.05));
(edges, n)
}
#[test]
fn test_sync_lpa_two_cliques() {
let (edges, n) = two_clique_edges(4);
let labels = label_propagation_edge_list(&edges, n, 50).expect("sync lpa");
assert_eq!(labels.len(), 8);
let l0 = labels[0];
for i in 1..4 {
assert_eq!(labels[i], l0, "clique1 node {} has wrong label", i);
}
let l1 = labels[4];
for i in 5..8 {
assert_eq!(labels[i], l1, "clique2 node {} has wrong label", i);
}
assert_ne!(l0, l1, "the two cliques must have different labels");
}
#[test]
fn test_async_lpa_two_cliques() {
let (edges, n) = two_clique_edges(4);
let labels = async_label_propagation(&edges, n, 50).expect("async lpa");
assert_eq!(labels.len(), 8);
let l0 = labels[0];
for i in 1..4 {
assert_eq!(labels[i], l0, "clique1 node {} has wrong label", i);
}
let l1 = labels[4];
for i in 5..8 {
assert_eq!(labels[i], l1, "clique2 node {} has wrong label", i);
}
assert_ne!(l0, l1);
}
#[test]
fn test_lpa_single_node() {
let labels = label_propagation_edge_list(&[], 1, 10).expect("single node");
assert_eq!(labels, vec![0]);
}
#[test]
fn test_async_lpa_single_node() {
let labels = async_label_propagation(&[], 1, 10).expect("async single node");
assert_eq!(labels, vec![0]);
}
#[test]
fn test_lpa_empty_error() {
assert!(label_propagation_edge_list(&[], 0, 10).is_err());
assert!(async_label_propagation(&[], 0, 10).is_err());
}
#[test]
fn test_lpa_isolated_nodes() {
let labels = label_propagation_edge_list(&[], 5, 10).expect("isolated");
assert_eq!(labels.len(), 5);
let unique: std::collections::HashSet<usize> = labels.iter().cloned().collect();
assert_eq!(unique.len(), 5);
}
}