use crate::core::error::{GraphinaError, Result};
use crate::core::types::{BaseGraph, GraphConstructor, NodeId};
use rand::prelude::*;
use rand::{SeedableRng, rngs::StdRng};
use std::collections::HashMap as StdHashMap;
fn create_rng(seed: Option<u64>) -> StdRng {
match seed {
Some(s) => StdRng::seed_from_u64(s),
None => StdRng::seed_from_u64(rand::random::<u64>()),
}
}
pub fn label_propagation<A, W, Ty>(
graph: &BaseGraph<A, W, Ty>,
max_iter: usize,
seed: Option<u64>,
) -> Result<Vec<usize>>
where
W: Copy + PartialOrd + Into<f64>,
Ty: GraphConstructor<A, W>,
{
let n = graph.node_count();
if n == 0 {
return Err(GraphinaError::invalid_graph(
"LabelPropagation: empty graph",
));
}
if max_iter == 0 {
return Err(GraphinaError::invalid_graph("LabelPropagation: max_iter=0"));
}
let node_list: Vec<NodeId> = graph.nodes().map(|(nid, _)| nid).collect();
let mut node_to_idx: StdHashMap<NodeId, usize> = StdHashMap::new();
for (i, &nid) in node_list.iter().enumerate() {
node_to_idx.insert(nid, i);
}
let mut adjacency: Vec<Vec<usize>> = vec![Vec::new(); n];
for (src, tgt, _w) in graph.edges() {
let si = node_to_idx[&src];
let ti = node_to_idx[&tgt];
adjacency[si].push(ti);
adjacency[ti].push(si);
}
let mut labels: Vec<usize> = (0..n).collect();
let mut rng = create_rng(seed);
let mut iter = 0;
loop {
let mut changed = false;
let mut nodes: Vec<usize> = (0..n).collect();
nodes.shuffle(&mut rng);
for &i in &nodes {
let mut freq: StdHashMap<usize, usize> = StdHashMap::new();
for &nbr in &adjacency[i] {
*freq.entry(labels[nbr]).or_insert(0) += 1;
}
if let Some((&best_label, _)) = freq.iter().max_by_key(|&(_, count)| count) {
if best_label != labels[i] {
labels[i] = best_label;
changed = true;
}
}
}
iter += 1;
if !changed || iter >= max_iter {
break;
}
}
Ok(labels)
}
#[cfg(test)]
mod tests {
#[test]
fn test_label_propagation_stability() {
use crate::community::label_propagation::label_propagation;
use crate::core::types::Graph;
let mut g: Graph<i32, f64> = Graph::new();
let nodes: Vec<_> = (0..10).map(|i| g.add_node(i)).collect();
for i in 0..4 {
for j in (i + 1)..5 {
g.add_edge(nodes[i], nodes[j], 1.0);
}
}
for i in 5..9 {
for j in (i + 1)..10 {
g.add_edge(nodes[i], nodes[j], 1.0);
}
}
g.add_edge(nodes[2], nodes[7], 0.1);
let communities = label_propagation(&g, 100, Some(42)).unwrap();
assert!(!communities.is_empty());
assert!(communities.len() <= 10);
}
}