use std::collections::{HashMap, HashSet};
use super::super::graph_store::GraphStore;
pub struct LabelPropagation {
pub max_iterations: usize,
}
impl Default for LabelPropagation {
fn default() -> Self {
Self {
max_iterations: 100,
}
}
}
#[derive(Debug, Clone)]
pub struct Community {
pub label: String,
pub nodes: Vec<String>,
pub size: usize,
}
#[derive(Debug, Clone)]
pub struct CommunitiesResult {
pub communities: Vec<Community>,
pub iterations: usize,
pub converged: bool,
}
impl CommunitiesResult {
pub fn largest(&self) -> Option<&Community> {
self.communities.first()
}
pub fn community_of(&self, node_id: &str) -> Option<&Community> {
self.communities
.iter()
.find(|c| c.nodes.contains(&node_id.to_string()))
}
}
impl LabelPropagation {
pub fn new() -> Self {
Self::default()
}
pub fn max_iterations(mut self, max: usize) -> Self {
self.max_iterations = max;
self
}
pub fn run(&self, graph: &GraphStore) -> CommunitiesResult {
let nodes: Vec<String> = graph.iter_nodes().map(|n| n.id.clone()).collect();
if nodes.is_empty() {
return CommunitiesResult {
communities: Vec::new(),
iterations: 0,
converged: true,
};
}
let mut labels: HashMap<String, String> =
nodes.iter().map(|id| (id.clone(), id.clone())).collect();
let mut converged = false;
let mut iterations = 0;
for iter in 0..self.max_iterations {
iterations = iter + 1;
let mut changed = false;
for node_id in &nodes {
let mut label_counts: HashMap<String, usize> = HashMap::new();
for (_, neighbor, _) in graph.outgoing_edges(node_id) {
if let Some(label) = labels.get(&neighbor) {
*label_counts.entry(label.clone()).or_insert(0) += 1;
}
}
for (_, neighbor, _) in graph.incoming_edges(node_id) {
if let Some(label) = labels.get(&neighbor) {
*label_counts.entry(label.clone()).or_insert(0) += 1;
}
}
if let Some((best_label, _)) =
label_counts.into_iter().max_by_key(|(_, count)| *count)
{
let current = labels.get(node_id).cloned().unwrap_or_default();
if best_label != current {
labels.insert(node_id.clone(), best_label);
changed = true;
}
}
}
if !changed {
converged = true;
break;
}
}
let mut groups: HashMap<String, Vec<String>> = HashMap::new();
for (node_id, label) in &labels {
groups
.entry(label.clone())
.or_default()
.push(node_id.clone());
}
let mut communities: Vec<Community> = groups
.into_iter()
.map(|(label, nodes)| {
let size = nodes.len();
Community { label, nodes, size }
})
.collect();
communities.sort_by_key(|b| std::cmp::Reverse(b.size));
CommunitiesResult {
communities,
iterations,
converged,
}
}
}
pub struct Louvain {
pub resolution: f64,
pub max_iterations: usize,
pub min_improvement: f64,
}
impl Default for Louvain {
fn default() -> Self {
Self {
resolution: 1.0,
max_iterations: 10,
min_improvement: 1e-6,
}
}
}
#[derive(Debug, Clone)]
pub struct LouvainResult {
pub communities: HashMap<String, usize>,
pub count: usize,
pub modularity: f64,
pub passes: usize,
}
impl LouvainResult {
pub fn get_community(&self, community_id: usize) -> Vec<String> {
self.communities
.iter()
.filter(|(_, &c)| c == community_id)
.map(|(n, _)| n.clone())
.collect()
}
pub fn community_sizes(&self) -> HashMap<usize, usize> {
let mut sizes: HashMap<usize, usize> = HashMap::new();
for &c in self.communities.values() {
*sizes.entry(c).or_insert(0) += 1;
}
sizes
}
}
impl Louvain {
pub fn new() -> Self {
Self::default()
}
pub fn resolution(mut self, resolution: f64) -> Self {
self.resolution = resolution;
self
}
pub fn max_iterations(mut self, max: usize) -> Self {
self.max_iterations = max;
self
}
pub fn run(&self, graph: &GraphStore) -> LouvainResult {
let nodes: Vec<String> = graph.iter_nodes().map(|n| n.id.clone()).collect();
if nodes.is_empty() {
return LouvainResult {
communities: HashMap::new(),
count: 0,
modularity: 0.0,
passes: 0,
};
}
let mut weights: HashMap<(String, String), f64> = HashMap::new();
let mut node_strength: HashMap<String, f64> = HashMap::new();
let mut total_weight = 0.0;
for node in &nodes {
for (_, target, _) in graph.outgoing_edges(node) {
if node != &target {
let key = if node < &target {
(node.clone(), target.clone())
} else {
(target.clone(), node.clone())
};
let w = weights.entry(key).or_insert(0.0);
*w += 1.0; }
}
}
for ((a, b), w) in &weights {
*node_strength.entry(a.clone()).or_insert(0.0) += w;
*node_strength.entry(b.clone()).or_insert(0.0) += w;
total_weight += w;
}
if total_weight == 0.0 {
let communities: HashMap<String, usize> = nodes
.iter()
.enumerate()
.map(|(i, n)| (n.clone(), i))
.collect();
return LouvainResult {
count: nodes.len(),
communities,
modularity: 0.0,
passes: 0,
};
}
let mut communities: HashMap<String, usize> = nodes
.iter()
.enumerate()
.map(|(i, n)| (n.clone(), i))
.collect();
let mut comm_total: HashMap<usize, f64> = nodes
.iter()
.enumerate()
.map(|(i, n)| (i, *node_strength.get(n).unwrap_or(&0.0)))
.collect();
let mut comm_internal: HashMap<usize, f64> = HashMap::new();
let mut passes = 0;
let mut improved = true;
while improved && passes < self.max_iterations {
improved = false;
passes += 1;
for node in &nodes {
let current_comm = *communities.get(node).unwrap();
let node_w = *node_strength.get(node).unwrap_or(&0.0);
let mut neighbor_comm_weights: HashMap<usize, f64> = HashMap::new();
for ((a, b), w) in &weights {
if a == node {
let neighbor_comm = *communities.get(b).unwrap();
*neighbor_comm_weights.entry(neighbor_comm).or_insert(0.0) += w;
} else if b == node {
let neighbor_comm = *communities.get(a).unwrap();
*neighbor_comm_weights.entry(neighbor_comm).or_insert(0.0) += w;
}
}
let mut best_comm = current_comm;
let mut best_delta = 0.0;
let current_internal = neighbor_comm_weights
.get(¤t_comm)
.copied()
.unwrap_or(0.0);
let current_total = *comm_total.get(¤t_comm).unwrap_or(&0.0);
for (&target_comm, &weight_to_target) in &neighbor_comm_weights {
if target_comm == current_comm {
continue;
}
let target_total = *comm_total.get(&target_comm).unwrap_or(&0.0);
let delta = (weight_to_target - current_internal) / total_weight
- self.resolution * node_w * (target_total - current_total + node_w)
/ (2.0 * total_weight * total_weight);
if delta > best_delta + self.min_improvement {
best_delta = delta;
best_comm = target_comm;
}
}
if best_comm != current_comm {
improved = true;
*comm_total.entry(current_comm).or_insert(0.0) -= node_w;
*comm_total.entry(best_comm).or_insert(0.0) += node_w;
let current_internal = neighbor_comm_weights
.get(¤t_comm)
.copied()
.unwrap_or(0.0);
*comm_internal.entry(current_comm).or_insert(0.0) -= current_internal;
let new_internal = neighbor_comm_weights
.get(&best_comm)
.copied()
.unwrap_or(0.0);
*comm_internal.entry(best_comm).or_insert(0.0) += new_internal;
communities.insert(node.clone(), best_comm);
}
}
}
let unique_communities: Vec<usize> = {
let c: HashSet<usize> = communities.values().copied().collect();
let mut v: Vec<usize> = c.into_iter().collect();
v.sort();
v
};
let comm_map: HashMap<usize, usize> = unique_communities
.iter()
.enumerate()
.map(|(new, &old)| (old, new))
.collect();
let remapped: HashMap<String, usize> = communities
.into_iter()
.map(|(n, c)| (n, *comm_map.get(&c).unwrap_or(&0)))
.collect();
let modularity =
self.calculate_modularity(&remapped, &weights, &node_strength, total_weight);
LouvainResult {
count: unique_communities.len(),
communities: remapped,
modularity,
passes,
}
}
fn calculate_modularity(
&self,
communities: &HashMap<String, usize>,
weights: &HashMap<(String, String), f64>,
node_strength: &HashMap<String, f64>,
total_weight: f64,
) -> f64 {
if total_weight == 0.0 {
return 0.0;
}
let mut q = 0.0;
for ((a, b), w) in weights {
let ca = communities.get(a).unwrap();
let cb = communities.get(b).unwrap();
if ca == cb {
let ka = node_strength.get(a).unwrap_or(&0.0);
let kb = node_strength.get(b).unwrap_or(&0.0);
q += w - self.resolution * ka * kb / (2.0 * total_weight);
}
}
q / total_weight
}
}