use crate::core::error::{GraphinaError, Result};
use crate::core::types::{BaseGraph, GraphConstructor, NodeId};
use rand::prelude::*;
use rand::{SeedableRng, rngs::StdRng};
use std::collections::HashMap;
fn create_rng(seed: Option<u64>) -> StdRng {
match seed {
Some(s) => StdRng::seed_from_u64(s),
None => StdRng::seed_from_u64(rand::random::<u64>()),
}
}
pub fn louvain<A, Ty>(graph: &BaseGraph<A, f64, Ty>, seed: Option<u64>) -> Result<Vec<Vec<NodeId>>>
where
Ty: GraphConstructor<A, f64>,
{
let n = graph.node_count();
if n == 0 {
return Err(GraphinaError::invalid_graph("Louvain: empty graph"));
}
if n == 1 {
let node = graph.nodes().next().map(|(nid, _)| nid).unwrap();
return Ok(vec![vec![node]]);
}
let m: f64 = graph.edges().map(|(_u, _v, &w)| w).sum();
if m == 0.0 {
return Ok(graph.nodes().map(|(nid, _)| vec![nid]).collect());
}
let node_list: Vec<NodeId> = graph.nodes().map(|(nid, _)| nid).collect();
let node_to_idx: HashMap<NodeId, usize> = node_list
.iter()
.enumerate()
.map(|(idx, &nid)| (nid, idx))
.collect();
let mut community: Vec<usize> = (0..n).collect();
let mut degrees = vec![0.0; n];
for (u, v, &w) in graph.edges() {
let ui = node_to_idx[&u];
let vi = node_to_idx[&v];
degrees[ui] += w;
degrees[vi] += w;
}
let mut neighbors: Vec<Vec<(usize, f64)>> = vec![Vec::new(); n];
for (u, v, &w) in graph.edges() {
let ui = node_to_idx[&u];
let vi = node_to_idx[&v];
neighbors[ui].push((vi, w));
neighbors[vi].push((ui, w));
}
let mut community_degree = vec![0.0f64; n];
for i in 0..n {
community_degree[i] = degrees[i];
}
let mut rng = create_rng(seed);
let mut improvement = true;
let max_iterations = 100;
let mut iteration_count = 0;
while improvement && iteration_count < max_iterations {
improvement = false;
iteration_count += 1;
let mut nodes: Vec<usize> = (0..n).collect();
nodes.shuffle(&mut rng);
for &i in &nodes {
let current_comm = community[i];
let k_i = degrees[i];
if k_i == 0.0 {
continue;
}
let mut comm_weights: HashMap<usize, f64> = HashMap::new();
for &(j, w) in &neighbors[i] {
let comm_j = community[j];
*comm_weights.entry(comm_j).or_insert(0.0) += w;
}
let total_current = community_degree[current_comm];
let k_i_in = comm_weights.get(¤t_comm).copied().unwrap_or(0.0);
let delta_remove = k_i_in - (total_current * k_i) / (2.0 * m);
let mut best_delta = 0.0;
let mut best_comm = current_comm;
for (&comm, &w_in) in &comm_weights {
if comm == current_comm {
continue;
}
let total_comm = community_degree[comm];
let delta = w_in - (total_comm * k_i) / (2.0 * m);
if delta > best_delta {
best_delta = delta;
best_comm = comm;
}
}
if best_delta > delta_remove + 1e-10 {
community_degree[current_comm] -= k_i;
community_degree[best_comm] += k_i;
community[i] = best_comm;
improvement = true;
}
}
}
let mut comm_map: HashMap<usize, usize> = HashMap::new();
for &c in &community {
if !comm_map.contains_key(&c) {
let new_index = comm_map.len();
comm_map.insert(c, new_index);
}
}
let mut new_comms: Vec<Vec<NodeId>> = vec![Vec::new(); comm_map.len()];
for (i, &comm) in community.iter().enumerate() {
let new_comm = comm_map[&comm];
let node = node_list[i];
new_comms[new_comm].push(node);
}
new_comms.retain(|comm| !comm.is_empty());
Ok(new_comms)
}
#[cfg(test)]
mod tests {
use super::louvain;
use crate::core::types::Graph;
#[test]
fn test_louvain_simple() {
let mut graph = Graph::new();
let n1 = graph.add_node(1);
let n2 = graph.add_node(2);
let n3 = graph.add_node(3);
let n4 = graph.add_node(4);
graph.add_edge(n1, n2, 1.0);
graph.add_edge(n3, n4, 1.0);
let communities = louvain(&graph, Some(42)).unwrap();
assert!(!communities.is_empty());
}
#[test]
fn test_louvain_empty_graph() {
let graph: Graph<i32, f64> = Graph::new();
let communities = louvain(&graph, Some(42)).unwrap_err();
assert!(matches!(
communities,
crate::core::error::GraphinaError::InvalidGraph { .. }
));
}
#[test]
fn test_louvain_single_node() {
let mut graph = Graph::new();
let n1 = graph.add_node(1);
let communities = louvain(&graph, Some(42)).unwrap();
assert_eq!(communities.len(), 1);
assert_eq!(communities[0].len(), 1);
assert_eq!(communities[0][0], n1);
}
#[test]
fn test_louvain_no_edges() {
let mut graph = Graph::new();
let _n1 = graph.add_node(1);
let _n2 = graph.add_node(2);
let _n3 = graph.add_node(3);
let communities = louvain(&graph, Some(42)).unwrap();
assert_eq!(communities.len(), 3);
}
#[test]
fn test_louvain_with_removed_nodes() {
let mut graph = Graph::new();
let n0 = graph.add_node(0);
let n1 = graph.add_node(1);
let n2 = graph.add_node(2);
let n3 = graph.add_node(3);
let n4 = graph.add_node(4);
graph.add_edge(n0, n1, 1.0);
graph.add_edge(n1, n2, 1.0);
graph.add_edge(n3, n4, 1.0);
graph.remove_node(n2);
let communities = louvain(&graph, Some(42)).unwrap();
assert!(!communities.is_empty());
let total_nodes: usize = communities.iter().map(|c| c.len()).sum();
assert_eq!(total_nodes, graph.node_count());
}
#[test]
fn test_louvain_performance_smoke() {
let mut g = Graph::<u32, f64>::new();
let n = 200;
let nodes: Vec<_> = (0..n).map(|i| g.add_node(i as u32)).collect();
for i in 0..n {
for j in (i + 1)..n {
if (j - i) <= 3 {
g.add_edge(nodes[i], nodes[j], 1.0);
}
}
}
let start = std::time::Instant::now();
let comms = louvain(&g, Some(123)).unwrap();
let dur = start.elapsed();
assert!(!comms.is_empty());
assert!(dur.as_secs_f32() < 1.5, "Louvain took too long: {:?}", dur);
}
}