use std::collections::HashMap;
use crate::graph::Graph;
use crate::types::{ulid_encode, DbError, Value};
use super::{opt_f64, opt_str, opt_usize, GraphSnapshot, Row};
type Adj = Vec<Vec<(usize, f64)>>;
fn build_undirected_adj(graph: &Graph, snap: &GraphSnapshot, weight_prop: &str) -> (Adj, Vec<f64>, f64) {
let n = snap.n;
let mut adj_map: Vec<HashMap<usize, f64>> = vec![HashMap::new(); n];
for edge in graph.all_edges() {
let (Some(&fi), Some(&ti)) = (
snap.id_to_idx.get(&edge.from_node),
snap.id_to_idx.get(&edge.to_node),
) else {
continue;
};
if fi == ti {
continue; }
let w: f64 = if weight_prop.is_empty() {
1.0
} else {
edge.properties
.get(weight_prop)
.and_then(|v| match v {
Value::Float(f) => Some(*f),
Value::Int(i) => Some(*i as f64),
_ => None,
})
.unwrap_or(1.0)
};
if w <= 0.0 {
continue; }
*adj_map[fi].entry(ti).or_insert(0.0) += w;
*adj_map[ti].entry(fi).or_insert(0.0) += w;
}
let adj: Adj = adj_map.into_iter().map(|m| m.into_iter().collect()).collect();
let degree: Vec<f64> = adj.iter().map(|nbrs| nbrs.iter().map(|&(_, w)| w).sum()).collect();
let m: f64 = degree.iter().sum::<f64>() / 2.0;
(adj, degree, m)
}
fn renumber(community: &[usize]) -> (Vec<usize>, usize) {
let mut map: HashMap<usize, usize> = HashMap::new();
let mut k = 0usize;
let out: Vec<usize> = community
.iter()
.map(|&c| {
*map.entry(c).or_insert_with(|| {
let id = k;
k += 1;
id
})
})
.collect();
(out, k)
}
fn aggregate(adj: &Adj, degree: &[f64], community: &[usize], k: usize) -> (Adj, Vec<f64>, f64) {
let n = adj.len();
let mut new_adj_map: Vec<HashMap<usize, f64>> = vec![HashMap::new(); k];
let mut new_degree = vec![0.0f64; k];
for i in 0..n {
let ci = community[i];
new_degree[ci] += degree[i]; for &(j, w) in &adj[i] {
let cj = community[j];
if ci != cj {
*new_adj_map[ci].entry(cj).or_insert(0.0) += w;
}
}
}
let new_adj: Adj = new_adj_map.into_iter().map(|m| m.into_iter().collect()).collect();
let new_m = new_degree.iter().sum::<f64>() / 2.0;
(new_adj, new_degree, new_m)
}
fn local_move(
adj: &Adj,
community: &mut Vec<usize>,
degree: &[f64],
m: f64,
resolution: f64,
max_passes: usize,
) -> bool {
if m == 0.0 {
return false; }
let n = adj.len();
let mut comm_total: Vec<f64> = degree.to_vec();
let two_m = 2.0 * m;
let mut any_improved = false;
for _ in 0..max_passes {
let mut pass_improved = false;
for i in 0..n {
let c_old = community[i];
let ki = degree[i];
comm_total[c_old] -= ki;
let mut k_to: HashMap<usize, f64> = HashMap::new();
for &(j, w) in &adj[i] {
*k_to.entry(community[j]).or_insert(0.0) += w;
}
let gain_old = k_to.get(&c_old).copied().unwrap_or(0.0)
- resolution * ki * comm_total[c_old] / two_m;
let mut best_c = c_old;
let mut best_gain = gain_old;
for (&c, &k_ic) in &k_to {
if c == c_old {
continue;
}
let gain = k_ic - resolution * ki * comm_total[c] / two_m;
if gain > best_gain {
best_gain = gain;
best_c = c;
}
}
community[i] = best_c;
comm_total[best_c] += ki;
if best_c != c_old {
pass_improved = true;
any_improved = true;
}
}
if !pass_improved {
break;
}
}
any_improved
}
fn leiden_refine(
adj: &Adj,
parent: &[usize],
degree: &[f64],
m: f64,
resolution: f64,
theta: f64,
) -> Vec<usize> {
let n = adj.len();
let mut refined: Vec<usize> = (0..n).collect();
let mut comm_total: Vec<f64> = degree.to_vec();
let two_m = 2.0 * m;
for i in 0..n {
let pi = parent[i];
let ki = degree[i];
let ci = refined[i];
comm_total[ci] -= ki;
let mut k_to: HashMap<usize, f64> = HashMap::new();
for &(j, w) in &adj[i] {
if parent[j] == pi {
*k_to.entry(refined[j]).or_insert(0.0) += w;
}
}
let mut best_c = ci; let mut best_gain = 0.0f64;
for (&c, &k_ic) in &k_to {
if c == ci {
continue;
}
if k_ic < theta * ki {
continue;
}
let gain = k_ic - resolution * ki * comm_total[c] / two_m;
if gain > best_gain {
best_gain = gain;
best_c = c;
}
}
refined[i] = best_c;
comm_total[best_c] += ki;
}
refined
}
fn build_rows(snap: &GraphSnapshot, membership: &[usize]) -> Result<Vec<Row>, DbError> {
let n = snap.n;
let mut rep: HashMap<usize, usize> = HashMap::new();
for (i, &c) in membership.iter().enumerate() {
rep.entry(c).or_insert(i);
}
Ok((0..n)
.map(|i| {
let rep_idx = rep[&membership[i]];
let mut row = HashMap::new();
row.insert("node".to_string(), Value::String(ulid_encode(snap.node_ids[i].0)));
row.insert(
"community".to_string(),
Value::String(ulid_encode(snap.node_ids[rep_idx].0)),
);
row
})
.collect())
}
pub fn run_louvain(graph: &Graph, params: &HashMap<String, Value>) -> Result<Vec<Row>, DbError> {
let resolution = opt_f64(params, "resolution", 1.0)?;
if resolution <= 0.0 {
return Err(DbError::Query("'resolution' must be positive".into()));
}
let max_passes = opt_usize(params, "maxPasses", 10)?;
if max_passes == 0 {
return Err(DbError::Query("'maxPasses' must be at least 1".into()));
}
let max_levels = opt_usize(params, "maxLevels", 10)?;
if max_levels == 0 {
return Err(DbError::Query("'maxLevels' must be at least 1".into()));
}
let weight_prop = opt_str(params, "weight", "")?.to_string();
let snap = GraphSnapshot::build(graph, None);
let n = snap.n;
if n == 0 {
return Ok(vec![]);
}
let (mut adj, mut degree, mut m) = build_undirected_adj(graph, &snap, &weight_prop);
let mut membership: Vec<usize> = (0..n).collect();
let mut cur_n = n;
for _ in 0..max_levels {
let mut community: Vec<usize> = (0..cur_n).collect();
let improved = local_move(&adj, &mut community, °ree, m, resolution, max_passes);
if !improved {
break;
}
let (renumbered, k) = renumber(&community);
for idx in membership.iter_mut() {
*idx = renumbered[*idx];
}
if k >= cur_n {
break; }
let (new_adj, new_degree, new_m) = aggregate(&adj, °ree, &renumbered, k);
adj = new_adj;
degree = new_degree;
m = new_m;
cur_n = k;
}
build_rows(&snap, &membership)
}
pub fn run_leiden(graph: &Graph, params: &HashMap<String, Value>) -> Result<Vec<Row>, DbError> {
let resolution = opt_f64(params, "resolution", 1.0)?;
if resolution <= 0.0 {
return Err(DbError::Query("'resolution' must be positive".into()));
}
let max_passes = opt_usize(params, "maxPasses", 10)?;
if max_passes == 0 {
return Err(DbError::Query("'maxPasses' must be at least 1".into()));
}
let max_levels = opt_usize(params, "maxLevels", 10)?;
if max_levels == 0 {
return Err(DbError::Query("'maxLevels' must be at least 1".into()));
}
let theta = opt_f64(params, "theta", 0.01)?;
if !(0.0..=1.0).contains(&theta) {
return Err(DbError::Query("'theta' must be in [0.0, 1.0]".into()));
}
let weight_prop = opt_str(params, "weight", "")?.to_string();
let snap = GraphSnapshot::build(graph, None);
let n = snap.n;
if n == 0 {
return Ok(vec![]);
}
let (mut adj, mut degree, mut m) = build_undirected_adj(graph, &snap, &weight_prop);
let mut membership: Vec<usize> = (0..n).collect();
let mut cur_n = n;
for _ in 0..max_levels {
let mut community: Vec<usize> = (0..cur_n).collect();
local_move(&adj, &mut community, °ree, m, resolution, max_passes);
let refined = leiden_refine(&adj, &community, °ree, m, resolution, theta);
let (renumbered, k) = renumber(&refined);
for idx in membership.iter_mut() {
*idx = renumbered[*idx];
}
if k >= cur_n {
break; }
let (new_adj, new_degree, new_m) = aggregate(&adj, °ree, &renumbered, k);
adj = new_adj;
degree = new_degree;
m = new_m;
cur_n = k;
}
build_rows(&snap, &membership)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::Graph;
use crate::types::{ulid_encode, Edge, Node, NodeId, Value};
fn make_node(g: &mut Graph) -> NodeId {
let id = g.alloc_node_id();
g.apply_insert_node(Node {
id,
labels: vec!["N".into()],
properties: Default::default(),
});
id
}
fn undirected_edge(g: &mut Graph, from: NodeId, to: NodeId) {
let id = g.alloc_edge_id();
g.apply_insert_edge(Edge {
id,
from_node: from,
to_node: to,
label: "E".into(),
properties: Default::default(),
directed: false,
});
}
fn weighted_edge(g: &mut Graph, from: NodeId, to: NodeId, w: f64) {
let id = g.alloc_edge_id();
g.apply_insert_edge(Edge {
id,
from_node: from,
to_node: to,
label: "E".into(),
properties: [("weight".to_string(), Value::Float(w))].into_iter().collect(),
directed: false,
});
}
fn two_cliques() -> (Graph, [NodeId; 6]) {
let mut g = Graph::new();
let a = make_node(&mut g);
let b = make_node(&mut g);
let c = make_node(&mut g);
let d = make_node(&mut g);
let e = make_node(&mut g);
let f = make_node(&mut g);
undirected_edge(&mut g, a, b);
undirected_edge(&mut g, b, c);
undirected_edge(&mut g, a, c);
undirected_edge(&mut g, d, e);
undirected_edge(&mut g, e, f);
undirected_edge(&mut g, d, f);
(g, [a, b, c, d, e, f])
}
fn community_of(rows: &[Row], id: NodeId) -> String {
let ulid = ulid_encode(id.0);
rows.iter()
.find(|r| r["node"] == Value::String(ulid.clone()))
.map(|r| match &r["community"] {
Value::String(s) => s.clone(),
other => panic!("expected String, got {other:?}"),
})
.expect("node not found in rows")
}
#[test]
fn louvain_two_disconnected_cliques() {
let (g, ids) = two_cliques();
let rows = run_louvain(&g, &HashMap::new()).unwrap();
assert_eq!(rows.len(), 6);
let ca = community_of(&rows, ids[0]);
let cb = community_of(&rows, ids[1]);
let cc = community_of(&rows, ids[2]);
let cd = community_of(&rows, ids[3]);
let ce = community_of(&rows, ids[4]);
let cf = community_of(&rows, ids[5]);
assert_eq!(ca, cb);
assert_eq!(cb, cc);
assert_eq!(cd, ce);
assert_eq!(ce, cf);
assert_ne!(ca, cd);
}
#[test]
fn louvain_fully_connected_triangle_one_community() {
let mut g = Graph::new();
let a = make_node(&mut g);
let b = make_node(&mut g);
let c = make_node(&mut g);
undirected_edge(&mut g, a, b);
undirected_edge(&mut g, b, c);
undirected_edge(&mut g, a, c);
let rows = run_louvain(&g, &HashMap::new()).unwrap();
assert_eq!(rows.len(), 3);
let comms: Vec<_> = rows.iter().map(|r| r["community"].clone()).collect();
assert!(comms.iter().all(|c| *c == comms[0]));
}
#[test]
fn louvain_isolated_nodes_each_own_community() {
let mut g = Graph::new();
let a = make_node(&mut g);
let b = make_node(&mut g);
let c = make_node(&mut g);
let rows = run_louvain(&g, &HashMap::new()).unwrap();
assert_eq!(rows.len(), 3);
let ca = community_of(&rows, a);
let cb = community_of(&rows, b);
let cc = community_of(&rows, c);
assert_ne!(ca, cb);
assert_ne!(cb, cc);
assert_ne!(ca, cc);
}
#[test]
fn louvain_empty_graph() {
let g = Graph::new();
assert!(run_louvain(&g, &HashMap::new()).unwrap().is_empty());
}
#[test]
fn louvain_weighted_two_cliques_with_bridge() {
let mut g = Graph::new();
let a = make_node(&mut g);
let b = make_node(&mut g);
let c = make_node(&mut g);
let d = make_node(&mut g);
let e = make_node(&mut g);
let f = make_node(&mut g);
weighted_edge(&mut g, a, b, 10.0);
weighted_edge(&mut g, b, c, 10.0);
weighted_edge(&mut g, a, c, 10.0);
weighted_edge(&mut g, d, e, 10.0);
weighted_edge(&mut g, e, f, 10.0);
weighted_edge(&mut g, d, f, 10.0);
weighted_edge(&mut g, c, d, 0.1);
let params: HashMap<String, Value> =
[("weight".to_string(), Value::String("weight".into()))].into_iter().collect();
let rows = run_louvain(&g, ¶ms).unwrap();
assert_eq!(rows.len(), 6);
let comm_set: std::collections::HashSet<String> = rows
.iter()
.map(|r| match &r["community"] {
Value::String(s) => s.clone(),
other => panic!("{other:?}"),
})
.collect();
assert!(comm_set.len() >= 2);
}
#[test]
fn louvain_high_resolution_more_communities() {
let (g, _) = two_cliques();
let lo: HashMap<String, Value> =
[("resolution".to_string(), Value::Float(0.1))].into_iter().collect();
let hi: HashMap<String, Value> =
[("resolution".to_string(), Value::Float(5.0))].into_iter().collect();
let count = |params| -> usize {
let rows: Vec<Row> = run_louvain(&g, params).unwrap();
rows.iter()
.map(|r| match &r["community"] {
Value::String(s) => s.clone(),
_ => panic!(),
})
.collect::<std::collections::HashSet<String>>()
.len()
};
assert!(count(&hi) >= count(&lo));
}
#[test]
fn louvain_invalid_resolution_errors() {
let g = Graph::new();
let params: HashMap<String, Value> =
[("resolution".to_string(), Value::Float(-1.0))].into_iter().collect();
assert!(run_louvain(&g, ¶ms).is_err());
}
#[test]
fn louvain_zero_max_passes_errors() {
let g = Graph::new();
let params: HashMap<String, Value> =
[("maxPasses".to_string(), Value::Int(0))].into_iter().collect();
assert!(run_louvain(&g, ¶ms).is_err());
}
#[test]
fn louvain_zero_max_levels_errors() {
let g = Graph::new();
let params: HashMap<String, Value> =
[("maxLevels".to_string(), Value::Int(0))].into_iter().collect();
assert!(run_louvain(&g, ¶ms).is_err());
}
#[test]
fn leiden_two_disconnected_cliques() {
let (g, ids) = two_cliques();
let rows = run_leiden(&g, &HashMap::new()).unwrap();
assert_eq!(rows.len(), 6);
let ca = community_of(&rows, ids[0]);
let cb = community_of(&rows, ids[1]);
let cc = community_of(&rows, ids[2]);
let cd = community_of(&rows, ids[3]);
let ce = community_of(&rows, ids[4]);
let cf = community_of(&rows, ids[5]);
assert_eq!(ca, cb);
assert_eq!(cb, cc);
assert_eq!(cd, ce);
assert_eq!(ce, cf);
assert_ne!(ca, cd);
}
#[test]
fn leiden_fully_connected_triangle_one_community() {
let mut g = Graph::new();
let a = make_node(&mut g);
let b = make_node(&mut g);
let c = make_node(&mut g);
undirected_edge(&mut g, a, b);
undirected_edge(&mut g, b, c);
undirected_edge(&mut g, a, c);
let rows = run_leiden(&g, &HashMap::new()).unwrap();
let comms: Vec<_> = rows.iter().map(|r| r["community"].clone()).collect();
assert!(comms.iter().all(|c| *c == comms[0]));
}
#[test]
fn leiden_empty_graph() {
let g = Graph::new();
assert!(run_leiden(&g, &HashMap::new()).unwrap().is_empty());
}
#[test]
fn leiden_theta_out_of_range_errors() {
let g = Graph::new();
let params: HashMap<String, Value> =
[("theta".to_string(), Value::Float(1.5))].into_iter().collect();
assert!(run_leiden(&g, ¶ms).is_err());
}
#[test]
fn leiden_theta_zero_matches_louvain_behaviour() {
let (g, ids) = two_cliques();
let params: HashMap<String, Value> =
[("theta".to_string(), Value::Float(0.0))].into_iter().collect();
let rows = run_leiden(&g, ¶ms).unwrap();
let ca = community_of(&rows, ids[0]);
let cd = community_of(&rows, ids[3]);
assert_ne!(ca, cd); }
#[test]
fn leiden_high_resolution_more_communities() {
let (g, _) = two_cliques();
let lo: HashMap<String, Value> =
[("resolution".to_string(), Value::Float(0.1))].into_iter().collect();
let hi: HashMap<String, Value> =
[("resolution".to_string(), Value::Float(5.0))].into_iter().collect();
let count = |params| -> usize {
let rows: Vec<Row> = run_leiden(&g, params).unwrap();
rows.iter()
.map(|r| match &r["community"] {
Value::String(s) => s.clone(),
_ => panic!(),
})
.collect::<std::collections::HashSet<String>>()
.len()
};
assert!(count(&hi) >= count(&lo));
}
#[test]
fn leiden_weighted_two_cliques_with_bridge() {
let mut g = Graph::new();
let a = make_node(&mut g);
let b = make_node(&mut g);
let c = make_node(&mut g);
let d = make_node(&mut g);
let e = make_node(&mut g);
let f = make_node(&mut g);
weighted_edge(&mut g, a, b, 10.0);
weighted_edge(&mut g, b, c, 10.0);
weighted_edge(&mut g, a, c, 10.0);
weighted_edge(&mut g, d, e, 10.0);
weighted_edge(&mut g, e, f, 10.0);
weighted_edge(&mut g, d, f, 10.0);
weighted_edge(&mut g, c, d, 0.1);
let params: HashMap<String, Value> =
[("weight".to_string(), Value::String("weight".into()))].into_iter().collect();
let rows = run_leiden(&g, ¶ms).unwrap();
let comm_set: std::collections::HashSet<String> = rows
.iter()
.map(|r| match &r["community"] {
Value::String(s) => s.clone(),
other => panic!("{other:?}"),
})
.collect();
assert!(comm_set.len() >= 2);
}
#[test]
fn leiden_isolated_nodes_each_own_community() {
let mut g = Graph::new();
let a = make_node(&mut g);
let b = make_node(&mut g);
let c = make_node(&mut g);
let rows = run_leiden(&g, &HashMap::new()).unwrap();
let ca = community_of(&rows, a);
let cb = community_of(&rows, b);
let cc = community_of(&rows, c);
assert_ne!(ca, cb);
assert_ne!(cb, cc);
assert_ne!(ca, cc);
}
#[test]
fn single_node_graph() {
let mut g = Graph::new();
let a = make_node(&mut g);
for runner in [run_louvain, run_leiden] {
let rows = runner(&g, &HashMap::new()).unwrap();
assert_eq!(rows.len(), 1);
let ca = community_of(&rows, a);
assert_eq!(ca, ulid_encode(a.0));
}
}
#[test]
fn two_node_edge() {
let mut g = Graph::new();
let a = make_node(&mut g);
let b = make_node(&mut g);
undirected_edge(&mut g, a, b);
for runner in [run_louvain, run_leiden] {
let rows = runner(&g, &HashMap::new()).unwrap();
assert_eq!(rows.len(), 2);
let ca = community_of(&rows, a);
let cb = community_of(&rows, b);
assert_eq!(ca, cb); }
}
}