use crate::algo::GraphProjection;
use crate::algo::algorithms::Algorithm;
use rand::prelude::*;
use std::collections::HashMap;
use uni_common::core::id::Vid;
pub struct LabelPropagation;
#[derive(Debug, Clone)]
pub struct LabelPropagationConfig {
pub max_iterations: usize,
pub seed_property: Option<String>,
pub write: bool,
pub write_property: String,
}
impl Default for LabelPropagationConfig {
fn default() -> Self {
Self {
max_iterations: 10,
seed_property: None,
write: false,
write_property: "community".to_string(),
}
}
}
#[derive(Debug)]
pub struct LabelPropagationResult {
pub communities: Vec<(Vid, u64)>,
pub iterations: usize,
pub converged: bool,
}
impl Algorithm for LabelPropagation {
type Config = LabelPropagationConfig;
type Result = LabelPropagationResult;
fn name() -> &'static str {
"labelPropagation"
}
fn needs_reverse() -> bool {
true
}
fn run(graph: &GraphProjection, config: Self::Config) -> Self::Result {
let num_nodes = graph.vertex_count();
if num_nodes == 0 {
return LabelPropagationResult {
communities: Vec::new(),
iterations: 0,
converged: true,
};
}
let mut labels = vec![0u64; num_nodes];
for (i, label) in labels.iter_mut().enumerate().take(num_nodes) {
*label = graph.to_vid(i as u32).as_u64();
}
let mut converged = false;
let mut iterations = 0;
let mut node_indices: Vec<u32> = (0..num_nodes as u32).collect();
let mut rng = rand::thread_rng();
while iterations < config.max_iterations {
let mut changes = 0;
node_indices.shuffle(&mut rng);
for &node_idx in &node_indices {
let out_neighbors = graph.out_neighbors(node_idx);
let mut label_counts: HashMap<u64, usize> = HashMap::new();
for &neighbor_idx in out_neighbors {
let neighbor_label = labels[neighbor_idx as usize];
*label_counts.entry(neighbor_label).or_insert(0) += 1;
}
if graph.has_reverse() {
let in_neighbors = graph.in_neighbors(node_idx);
for &neighbor_idx in in_neighbors {
let neighbor_label = labels[neighbor_idx as usize];
*label_counts.entry(neighbor_label).or_insert(0) += 1;
}
}
if label_counts.is_empty() {
continue;
}
let mut max_count = 0;
for &count in label_counts.values() {
if count > max_count {
max_count = count;
}
}
let best_labels: Vec<u64> = label_counts
.iter()
.filter(|(_, count)| **count == max_count)
.map(|(label, _)| *label)
.collect();
let new_label = if best_labels.len() == 1 {
best_labels[0]
} else {
*best_labels.choose(&mut rng).unwrap()
};
if labels[node_idx as usize] != new_label {
labels[node_idx as usize] = new_label;
changes += 1;
}
}
iterations += 1;
if changes == 0 {
converged = true;
break;
}
}
let communities = labels
.into_iter()
.enumerate()
.map(|(slot, label)| (graph.to_vid(slot as u32), label))
.collect();
LabelPropagationResult {
communities,
iterations,
converged,
}
}
}