use std::collections::HashMap;
use super::params::AlgoParams;
use super::progress::ProgressReporter;
use super::result::AlgoResultBatch;
use crate::engine::graph::algo::GraphAlgorithm;
use crate::engine::graph::csr::CsrIndex;
pub fn run(csr: &CsrIndex, params: &AlgoParams) -> AlgoResultBatch {
let n = csr.node_count();
if n == 0 {
return AlgoResultBatch::new(GraphAlgorithm::Louvain);
}
let max_iter = params.iterations(10);
let resolution = params.louvain_resolution();
let mut reporter = ProgressReporter::new(GraphAlgorithm::Louvain, max_iter, None, n);
let (adj, node_degree, total_weight) = build_undirected_adjacency(csr, n);
let mut community: Vec<usize> = (0..n).collect();
let mut comm_total_degree: Vec<f64> = node_degree.clone();
for iter in 1..=max_iter {
let mut moved = 0usize;
for node in 0..n {
let current_comm = community[node];
let node_deg = node_degree[node];
let mut comm_weights: HashMap<usize, f64> = HashMap::new();
for &(neighbor, weight) in &adj[node] {
let nc = community[neighbor];
*comm_weights.entry(nc).or_insert(0.0) += weight;
}
let w_current = comm_weights.get(¤t_comm).copied().unwrap_or(0.0);
let mut best_comm = current_comm;
let mut best_gain = 0.0f64;
for (&candidate_comm, &w_candidate) in &comm_weights {
if candidate_comm == current_comm {
continue;
}
let net_gain = (w_candidate - w_current)
- resolution
* node_deg
* (comm_total_degree[candidate_comm] - comm_total_degree[current_comm]
+ node_deg)
/ (2.0 * total_weight);
if net_gain > best_gain {
best_gain = net_gain;
best_comm = candidate_comm;
}
}
if best_comm != current_comm {
comm_total_degree[current_comm] -= node_deg;
comm_total_degree[best_comm] += node_deg;
community[node] = best_comm;
moved += 1;
}
}
reporter.report_iteration(iter, Some(moved as f64));
if moved == 0 {
break;
}
}
reporter.finish();
let modularity = compute_modularity(&adj, &community, &node_degree, total_weight, resolution);
let mut comm_map: HashMap<usize, i64> = HashMap::new();
let mut next_id = 0i64;
for &c in &community {
comm_map.entry(c).or_insert_with(|| {
let id = next_id;
next_id += 1;
id
});
}
let mut batch = AlgoResultBatch::new(GraphAlgorithm::Louvain);
for node in 0..n {
let comm_id = comm_map[&community[node]];
batch.push_louvain(csr.node_name(node as u32).to_string(), comm_id, modularity);
}
batch
}
type UndirectedAdjacency = (Vec<Vec<(usize, f64)>>, Vec<f64>, f64);
fn build_undirected_adjacency(csr: &CsrIndex, n: usize) -> UndirectedAdjacency {
let mut adj: Vec<Vec<(usize, f64)>> = vec![Vec::new(); n];
let mut node_degree = vec![0.0f64; n];
let mut total_weight = 0.0f64;
let mut seen_edges: std::collections::HashSet<(usize, usize)> =
std::collections::HashSet::new();
for u in 0..n {
for (_lid, v, w) in csr.iter_out_edges_weighted(u as u32) {
let vi = v as usize;
if vi == u {
continue; }
adj[u].push((vi, w));
node_degree[u] += w;
let edge_key = if u < vi { (u, vi) } else { (vi, u) };
if seen_edges.insert(edge_key) {
total_weight += w;
}
}
for (_lid, v) in csr.iter_in_edges(u as u32) {
let vi = v as usize;
if vi == u {
continue;
}
if !adj[u].iter().any(|&(n, _)| n == vi) {
let w = 1.0; adj[u].push((vi, w));
node_degree[u] += w;
let edge_key = if u < vi { (u, vi) } else { (vi, u) };
if seen_edges.insert(edge_key) {
total_weight += w;
}
}
}
}
if total_weight == 0.0 {
total_weight = 1.0; }
(adj, node_degree, total_weight)
}
fn compute_modularity(
adj: &[Vec<(usize, f64)>],
community: &[usize],
node_degree: &[f64],
total_weight: f64,
resolution: f64,
) -> f64 {
let m2 = 2.0 * total_weight;
let mut q = 0.0f64;
for (u, neighbors) in adj.iter().enumerate() {
for &(v, w) in neighbors {
if community[u] == community[v] {
q += w - resolution * node_degree[u] * node_degree[v] / m2;
}
}
}
q / m2
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn louvain_two_cliques() {
let mut csr = CsrIndex::new();
for (s, d) in &[
("a", "b"),
("b", "a"),
("a", "c"),
("c", "a"),
("b", "c"),
("c", "b"),
("d", "e"),
("e", "d"),
("d", "f"),
("f", "d"),
("e", "f"),
("f", "e"),
("c", "d"),
("d", "c"),
] {
csr.add_edge(s, "L", d);
}
csr.compact();
let batch = run(&csr, &AlgoParams::default());
assert_eq!(batch.len(), 6);
let json = batch.to_json().unwrap();
let rows: Vec<serde_json::Value> = serde_json::from_slice(&json).unwrap();
let map: HashMap<&str, i64> = rows
.iter()
.map(|r| {
(
r["node_id"].as_str().unwrap(),
r["community_id"].as_i64().unwrap(),
)
})
.collect();
assert_eq!(map["a"], map["b"]);
assert_eq!(map["a"], map["c"]);
assert_eq!(map["d"], map["e"]);
assert_eq!(map["d"], map["f"]);
}
#[test]
fn louvain_positive_modularity() {
let mut csr = CsrIndex::new();
for (s, d) in &[
("a", "b"),
("b", "a"),
("a", "c"),
("c", "a"),
("b", "c"),
("c", "b"),
] {
csr.add_edge(s, "L", d);
}
csr.compact();
let batch = run(&csr, &AlgoParams::default());
let json = batch.to_json().unwrap();
let rows: Vec<serde_json::Value> = serde_json::from_slice(&json).unwrap();
let modularity = rows[0]["modularity"].as_f64().unwrap();
assert!(modularity >= -0.5, "modularity {modularity} too low");
}
#[test]
fn louvain_empty() {
let csr = CsrIndex::new();
assert!(run(&csr, &AlgoParams::default()).is_empty());
}
#[test]
fn louvain_single_node() {
let mut csr = CsrIndex::new();
csr.add_node("solo");
csr.compact();
let batch = run(&csr, &AlgoParams::default());
assert_eq!(batch.len(), 1);
}
}