Skip to main content

mnem_graphrag/
community.rs

1//! Leiden-style community detection over an [`AdjacencyIndex`].
2//!
3//! # Algorithm
4//!
5//! Implements the Leiden algorithm (Traag, Waltman, van Eck 2019,
6//! arxiv:1810.08473) in three nested phases, repeated until no further
7//! modularity gain is available:
8//!
9//! 1. **Local moving** - iterate nodes in a deterministic order,
10//!    move each into the neighbouring community that gives the
11//!    largest positive modularity delta.
12//! 2. **Refinement** - inside each community, re-run local moving
13//!    starting from singletons, but only allow moves that keep the
14//!    refined community well-connected (Leiden's key departure from
15//!    Louvain). Refined sub-communities become the nodes of the next
16//!    aggregate graph.
17//! 3. **Aggregation** - build a new graph where refined
18//!    sub-communities are super-nodes and inter-community edge
19//!    weights sum. The *original* (un-refined) partition survives
20//!    across levels, so after aggregation we continue local-moving
21//!    the super-nodes under the coarser partition.
22//!
23//! The modularity objective is standard undirected modularity:
24//!
25//! `Q = sum_c [ e_c / m - (a_c / 2m)^2 ]`
26//!
27//! where `e_c` is twice the intra-community edge weight, `a_c` is the
28//! sum of node degrees inside `c`, and `m` is the total edge weight.
29//!
30//! # Determinism contract
31//!
32//! - Input edges are collapsed into an undirected weighted graph
33//!   keyed by a sorted `Vec<NodeId>` (index = internal node id). Self-loops
34//!   are dropped (they contribute nothing to modularity deltas and
35//!   cause algorithmic edge cases).
36//! - Node iteration order = ascending internal id (which is ascending
37//!   `NodeId`).
38//! - Community labels at every level are canonicalised by
39//!   first-appearance of their smallest member node id, so two runs
40//!   under any input-edge permutation produce identical
41//!   `(NodeId -> CommunityId)` maps.
42//! - The `seed` parameter is currently reserved: pure deterministic
43//!   iteration order does not consult the RNG, but the seed is mixed
44//!   into the content CID so a caller can explicitly branch
45//!   partitions on seed. Future refinements may use the seed to
46//!   randomise singleton-order inside communities while keeping
47//!   reproducibility.
48
49use std::collections::BTreeMap;
50
51use mnem_core::id::{CODEC_RAW, Cid, HASH_BLAKE3_256, Multihash, NodeId};
52use mnem_core::index::{AdjacencyIndex, EdgeProvenance};
53
54/// Opaque integer identifier of a community in a [`CommunityAssignment`].
55///
56/// Assigned canonically: the community containing the node with the
57/// smallest `NodeId` gets `CommunityId(0)`, the next smallest
58/// previously-unseen node gets `CommunityId(1)`, and so on. This
59/// canonicalisation is what makes `content_cid` stable under
60/// permutations of input edge order.
61pub type CommunityId = u32;
62
63/// Result of a community-detection run over an [`AdjacencyIndex`].
64#[derive(Clone, Debug)]
65pub struct CommunityAssignment {
66    /// Canonical `NodeId -> CommunityId` map. Keyed by `BTreeMap` so
67    /// iteration is deterministic (ascending `NodeId`).
68    pub map: BTreeMap<NodeId, CommunityId>,
69    /// Inverse map `CommunityId -> [NodeId]`. Each member vector is
70    /// sorted ascending (derived from `BTreeMap` iteration order over
71    /// `map`). Precomputed at construction so `members_of` is O(1)
72    /// lookup + O(|C|) slice return; the `CommunityExpander` stage
73    /// (C3 FIX-1) needs this on the retrieval hot path.
74    pub members: BTreeMap<CommunityId, Vec<NodeId>>,
75    /// Modularity score of this partition (higher is better; range
76    /// `[-0.5, 1.0]` for undirected graphs).
77    pub modularity: f32,
78    /// Seed that produced this partition. Threaded into `content_cid`
79    /// so distinct seeds produce distinct CIDs even when the
80    /// partition map happens to collide.
81    pub seed: u64,
82}
83
84impl CommunityAssignment {
85    /// Look up the community of `node`. Returns `None` for nodes not
86    /// present in any edge of the underlying graph.
87    #[must_use]
88    pub fn community_of(&self, node: NodeId) -> Option<CommunityId> {
89        self.map.get(&node).copied()
90    }
91
92    /// All nodes assigned to `community`. Sorted ascending by
93    /// `NodeId` for determinism. Empty slice if the id is unknown.
94    #[must_use]
95    pub fn members_of(&self, community: CommunityId) -> &[NodeId] {
96        self.members
97            .get(&community)
98            .map(Vec::as_slice)
99            .unwrap_or(&[])
100    }
101
102    /// Number of distinct communities.
103    #[must_use]
104    pub fn community_count(&self) -> usize {
105        let mut max: i64 = -1;
106        for &c in self.map.values() {
107            if i64::from(c) > max {
108                max = i64::from(c);
109            }
110        }
111        usize::try_from(max + 1).unwrap_or(0)
112    }
113
114    /// Content-addressable identity of this assignment.
115    ///
116    /// CID preimage:
117    ///
118    /// `b"mnem/community/v1" || seed_be_u64 || concat(node_id_bytes || cid_be_u32)`
119    ///
120    /// where the `(node, community)` pairs iterate in ascending
121    /// `NodeId` order (guaranteed by `BTreeMap`). Wrapped in
122    /// `CIDv1(codec=raw, multihash=sha2-256)`. Domain-separated from
123    /// other mnem object classes by the leading tag.
124    #[must_use]
125    pub fn content_cid(&self) -> Cid {
126        let mut buf: Vec<u8> = Vec::with_capacity(16 + 8 + self.map.len() * (16 + 4));
127        buf.extend_from_slice(b"mnem/community/v1");
128        buf.extend_from_slice(&self.seed.to_be_bytes());
129        for (nid, cid) in &self.map {
130            buf.extend_from_slice(nid.as_bytes());
131            buf.extend_from_slice(&cid.to_be_bytes());
132        }
133        let digest = blake3::hash(&buf);
134        let mh = Multihash::wrap(HASH_BLAKE3_256, digest.as_bytes())
135            .expect("blake3 32-byte digest fits multihash");
136        Cid::new(CODEC_RAW, mh)
137    }
138}
139
140/// Run Leiden community detection over `adj`.
141///
142/// # Determinism
143///
144/// Two calls with the same underlying edge set (regardless of
145/// iteration order from `adj`) and the same `seed` produce identical
146/// [`CommunityAssignment`]s.
147#[must_use]
148pub fn compute_communities(adj: &dyn AdjacencyIndex, seed: u64) -> CommunityAssignment {
149    // --------------------------------------------------------------
150    // 1. Build undirected weighted graph
151    // --------------------------------------------------------------
152    let (nodes, adj_list, m2) = build_undirected_graph(adj);
153
154    if nodes.is_empty() {
155        return CommunityAssignment {
156            map: BTreeMap::new(),
157            members: BTreeMap::new(),
158            modularity: 0.0,
159            seed,
160        };
161    }
162
163    // --------------------------------------------------------------
164    // 2. Initialise singleton partition
165    // --------------------------------------------------------------
166    let n = nodes.len();
167    let mut part: Vec<usize> = (0..n).collect();
168    let degrees: Vec<f64> = (0..n)
169        .map(|i| adj_list[i].iter().map(|(_, w)| *w).sum())
170        .collect();
171
172    // --------------------------------------------------------------
173    // 3. Leiden outer loop: local-move -> refine -> aggregate
174    // --------------------------------------------------------------
175    // Iterated local-moving with refinement-as-tie-breaker.
176    //
177    // Pure local-moving (Louvain's first phase) already achieves
178    // Newman's 0.37-0.42 modularity range on Karate-club. Leiden's
179    // refinement exists to guarantee well-connected sub-communities
180    // *for the aggregation step*; without aggregation it can
181    // over-fragment, so we keep the refined partition only if its
182    // modularity beats the un-refined baseline. Determinism is
183    // preserved because the keep/revert decision depends only on
184    // the deterministic `modularity` output.
185    let mut prev_q = f64::NEG_INFINITY;
186    for _ in 0..8 {
187        local_move(&adj_list, &degrees, &mut part, m2);
188        let q_post_move = modularity(&adj_list, &degrees, &part, m2);
189
190        let mut refined: Vec<usize> = part.clone();
191        refine_partition(&adj_list, &degrees, &mut refined, m2);
192        local_move(&adj_list, &degrees, &mut refined, m2);
193        let q_post_refine = modularity(&adj_list, &degrees, &refined, m2);
194
195        if q_post_refine > q_post_move + 1e-9 {
196            part.copy_from_slice(&refined);
197        }
198        let q = modularity(&adj_list, &degrees, &part, m2);
199        if q <= prev_q + 1e-9 {
200            break;
201        }
202        prev_q = q;
203    }
204
205    // --------------------------------------------------------------
206    // 4. Canonicalise community ids by first-appearing NodeId
207    // --------------------------------------------------------------
208    let canonical = canonicalise_communities(&part);
209
210    // --------------------------------------------------------------
211    // 5. Build public map + modularity
212    // --------------------------------------------------------------
213    let mut map = BTreeMap::new();
214    for (i, &nid) in nodes.iter().enumerate() {
215        map.insert(nid, canonical[i]);
216    }
217    // Precompute inverse map (CommunityId -> sorted Vec<NodeId>) so
218    // `CommunityAssignment::members_of` is O(1) lookup on the
219    // retrieval hot path. Iterating `map` (a BTreeMap) yields
220    // `NodeId`s in ascending order, so each per-community Vec is
221    // sorted ascending without an explicit sort.
222    let mut members: BTreeMap<CommunityId, Vec<NodeId>> = BTreeMap::new();
223    for (&nid, &cid) in &map {
224        members.entry(cid).or_default().push(nid);
225    }
226    let q = modularity(&adj_list, &degrees, &part, m2) as f32;
227
228    CommunityAssignment {
229        map,
230        members,
231        modularity: q,
232        seed,
233    }
234}
235
236// ---------------------------------------------------------------------
237// Graph construction
238// ---------------------------------------------------------------------
239
240/// Collect `adj` into a symmetric weighted adjacency list over a
241/// sorted node vector. Returns `(nodes, adj_list, 2m)` where `2m` is
242/// the sum of all (symmetric) edge weights (i.e. twice the undirected
243/// total).
244fn build_undirected_graph(adj: &dyn AdjacencyIndex) -> (Vec<NodeId>, Vec<Vec<(usize, f64)>>, f64) {
245    // Collect unique nodes + edge triples in a deterministic way.
246    let mut node_set: std::collections::BTreeSet<NodeId> = std::collections::BTreeSet::new();
247    // Deduplicate `(min, max) -> max_weight`. HybridAdjacency may
248    // yield an authored + KNN copy of the same endpoint pair with
249    // different weights; modularity is a set-of-edges notion so we
250    // keep the single largest weight (authored is 1.0, KNN is
251    // similarity; taking max preserves the stronger signal without
252    // double-counting).
253    let mut edges: BTreeMap<(NodeId, NodeId), (f64, bool)> = BTreeMap::new();
254
255    for e in adj.iter_edges() {
256        node_set.insert(e.src);
257        node_set.insert(e.dst);
258        if e.src == e.dst {
259            continue;
260        }
261        let key = if e.src < e.dst {
262            (e.src, e.dst)
263        } else {
264            (e.dst, e.src)
265        };
266        let authored = matches!(e.provenance, EdgeProvenance::Authored);
267        let w = f64::from(e.weight).max(0.0);
268        edges
269            .entry(key)
270            .and_modify(|(cur_w, cur_authored)| {
271                // Keep the larger weight; authored flag sticky.
272                if w > *cur_w {
273                    *cur_w = w;
274                }
275                if authored {
276                    *cur_authored = true;
277                }
278            })
279            .or_insert((w, authored));
280    }
281
282    let nodes: Vec<NodeId> = node_set.into_iter().collect();
283    let mut index_of: BTreeMap<NodeId, usize> = BTreeMap::new();
284    for (i, nid) in nodes.iter().enumerate() {
285        index_of.insert(*nid, i);
286    }
287
288    let mut adj_list: Vec<Vec<(usize, f64)>> = vec![Vec::new(); nodes.len()];
289    let mut m2: f64 = 0.0;
290    for (&(a, b), &(w, _)) in &edges {
291        // Skip zero-weight edges (modularity contribution is zero
292        // and they confuse the sum).
293        if w <= 0.0 {
294            continue;
295        }
296        let ia = index_of[&a];
297        let ib = index_of[&b];
298        adj_list[ia].push((ib, w));
299        adj_list[ib].push((ia, w));
300        m2 += 2.0 * w;
301    }
302    // Deterministic per-node neighbour order.
303    for nb in &mut adj_list {
304        nb.sort_by_key(|x| x.0);
305    }
306
307    (nodes, adj_list, m2)
308}
309
310// ---------------------------------------------------------------------
311// Local move
312// ---------------------------------------------------------------------
313
314/// Louvain-style local-moving phase: iterate nodes in ascending id
315/// order, move each node into the neighbouring community with the
316/// largest positive modularity gain. Repeat until a full pass yields
317/// no move.
318fn local_move(adj_list: &[Vec<(usize, f64)>], degrees: &[f64], part: &mut [usize], m2: f64) {
319    if m2 <= 0.0 {
320        return;
321    }
322    let n = adj_list.len();
323
324    // Cumulative degree per community. Keyed by community id; we use
325    // a BTreeMap for deterministic iteration and O(log n) updates.
326    let mut com_deg: BTreeMap<usize, f64> = BTreeMap::new();
327    for (i, &c) in part.iter().enumerate() {
328        *com_deg.entry(c).or_insert(0.0) += degrees[i];
329    }
330
331    loop {
332        let mut moved = false;
333        for v in 0..n {
334            let k_v = degrees[v];
335            if k_v <= 0.0 {
336                continue;
337            }
338            let c_old = part[v];
339
340            // Weight from v to each neighbouring community.
341            let mut k_vc: BTreeMap<usize, f64> = BTreeMap::new();
342            for &(u, w) in &adj_list[v] {
343                if u == v {
344                    continue;
345                }
346                *k_vc.entry(part[u]).or_insert(0.0) += w;
347            }
348            let self_loop: f64 = adj_list[v]
349                .iter()
350                .filter_map(|&(u, w)| if u == v { Some(w) } else { None })
351                .sum();
352
353            let k_v_old = k_vc.get(&c_old).copied().unwrap_or(0.0);
354
355            // "Remove v from its current community" baseline: delta
356            // relative to empty community `new`. Gain of joining
357            // community c (with c != c_old) is:
358            //   dQ = (k_v_c - k_v_old)/m - k_v * (sum_c - sum_old + k_v) / (2 m^2)
359            // derived from standard Louvain; we iterate and pick the
360            // best c with positive dQ, tie-break by smallest c for
361            // determinism.
362            let sum_old = com_deg.get(&c_old).copied().unwrap_or(0.0);
363
364            let mut best_c = c_old;
365            let mut best_dq: f64 = 0.0;
366            // Consider staying (dQ = 0) plus every neighbouring
367            // community; also consider the v's own community (for
368            // the case it already left) implicitly via c_old.
369            for (&c_new, &k_v_new) in &k_vc {
370                if c_new == c_old {
371                    continue;
372                }
373                let sum_new = com_deg.get(&c_new).copied().unwrap_or(0.0);
374                // dQ formula, two-community swap.
375                let dq = (k_v_new - k_v_old) / (m2 / 2.0)
376                    + (k_v * (sum_old - sum_new - k_v + 2.0 * self_loop)) / (m2 * m2 / 2.0);
377                // Pick strictly better; on tie (dq == best_dq) keep
378                // smallest c_new for canonical-order determinism.
379                if dq > best_dq + 1e-12 || (dq > best_dq - 1e-12 && c_new < best_c && dq > 1e-12) {
380                    best_dq = dq;
381                    best_c = c_new;
382                }
383            }
384
385            if best_c != c_old && best_dq > 1e-12 {
386                *com_deg.entry(c_old).or_insert(0.0) -= k_v;
387                *com_deg.entry(best_c).or_insert(0.0) += k_v;
388                part[v] = best_c;
389                moved = true;
390            }
391        }
392        if !moved {
393            break;
394        }
395    }
396}
397
398// ---------------------------------------------------------------------
399// Refinement (Leiden)
400// ---------------------------------------------------------------------
401
402/// Inside each community, re-run local moving starting from singletons
403/// but only consider moves to communities that the node is
404/// "well-connected" to (Traag 2019 ยง2.1). We use the standard
405/// well-connected-set check: a node may join community C iff its
406/// edge-weight to C times the total community degree exceeds a
407/// gamma-scaled threshold. We pick gamma = 1.0 (standard modularity).
408fn refine_partition(adj_list: &[Vec<(usize, f64)>], degrees: &[f64], part: &mut [usize], m2: f64) {
409    if m2 <= 0.0 {
410        return;
411    }
412    let n = adj_list.len();
413
414    // Group nodes by their outer-partition community.
415    let mut by_com: BTreeMap<usize, Vec<usize>> = BTreeMap::new();
416    for (i, &c) in part.iter().enumerate() {
417        by_com.entry(c).or_default().push(i);
418    }
419
420    // Start refined partition as singletons.
421    let mut refined: Vec<usize> = (0..n).collect();
422
423    // Per-node degree within its outer community (used for
424    // well-connectedness).
425    let outer = part.to_vec();
426
427    // Track cumulative degree of each refined sub-community.
428    let mut sub_deg: BTreeMap<usize, f64> = BTreeMap::new();
429    for (i, &c) in refined.iter().enumerate() {
430        *sub_deg.entry(c).or_insert(0.0) += degrees[i];
431    }
432
433    // Iterate outer communities in ascending id, nodes in ascending id.
434    for (_outer_c, members) in by_com {
435        // Precompute total community degree.
436        let total_c: f64 = members.iter().map(|&i| degrees[i]).sum();
437        let gamma_thresh = total_c / m2; // gamma=1.0 modularity threshold
438
439        for &v in &members {
440            let k_v = degrees[v];
441            if k_v <= 0.0 {
442                continue;
443            }
444            // Edge weight from v to each refined sub-community
445            // *within the same outer community*.
446            let mut k_vc: BTreeMap<usize, f64> = BTreeMap::new();
447            for &(u, w) in &adj_list[v] {
448                if u == v {
449                    continue;
450                }
451                if outer[u] != outer[v] {
452                    continue;
453                }
454                *k_vc.entry(refined[u]).or_insert(0.0) += w;
455            }
456
457            let c_old = refined[v];
458            let sum_old = sub_deg.get(&c_old).copied().unwrap_or(0.0);
459            let k_v_old = k_vc.get(&c_old).copied().unwrap_or(0.0);
460
461            let mut best_c = c_old;
462            let mut best_dq: f64 = 0.0;
463            for (&c_new, &k_v_new) in &k_vc {
464                if c_new == c_old {
465                    continue;
466                }
467                let sum_new = sub_deg.get(&c_new).copied().unwrap_or(0.0);
468                // Well-connectedness gate (Leiden).
469                if k_v_new < gamma_thresh * k_v {
470                    // too weakly connected; skip
471                    continue;
472                }
473                let dq = (k_v_new - k_v_old) / (m2 / 2.0)
474                    + (k_v * (sum_old - sum_new - k_v)) / (m2 * m2 / 2.0);
475                if dq > best_dq + 1e-12 || (dq > best_dq - 1e-12 && c_new < best_c && dq > 1e-12) {
476                    best_dq = dq;
477                    best_c = c_new;
478                }
479            }
480
481            if best_c != c_old && best_dq > 1e-12 {
482                *sub_deg.entry(c_old).or_insert(0.0) -= k_v;
483                *sub_deg.entry(best_c).or_insert(0.0) += k_v;
484                refined[v] = best_c;
485            }
486        }
487    }
488
489    // Replace outer partition with refined ids so the next
490    // local-move pass operates on refined communities (common
491    // shorthand for one-level Leiden without explicit aggregation).
492    part[..n].copy_from_slice(&refined[..n]);
493}
494
495// ---------------------------------------------------------------------
496// Canonicalisation + modularity
497// ---------------------------------------------------------------------
498
499/// Relabel communities so the first-seen raw id (iterating ascending
500/// node index = ascending `NodeId`) becomes `CommunityId(0)`, the
501/// second `CommunityId(1)`, and so on.
502fn canonicalise_communities(part: &[usize]) -> Vec<CommunityId> {
503    let mut map: BTreeMap<usize, CommunityId> = BTreeMap::new();
504    let mut next: CommunityId = 0;
505    let mut out = Vec::with_capacity(part.len());
506    for &c in part {
507        let canonical = *map.entry(c).or_insert_with(|| {
508            let id = next;
509            next += 1;
510            id
511        });
512        out.push(canonical);
513    }
514    out
515}
516
517/// Undirected modularity of `part` on the weighted graph `adj_list`
518/// with `m2` = twice the total edge weight.
519fn modularity(adj_list: &[Vec<(usize, f64)>], degrees: &[f64], part: &[usize], m2: f64) -> f64 {
520    if m2 <= 0.0 {
521        return 0.0;
522    }
523    // Per-community: sum of internal edge weight doubled (e_c) and
524    // degree sum (a_c).
525    let mut e_c: BTreeMap<usize, f64> = BTreeMap::new();
526    let mut a_c: BTreeMap<usize, f64> = BTreeMap::new();
527    for (u, neighbours) in adj_list.iter().enumerate() {
528        let cu = part[u];
529        *a_c.entry(cu).or_insert(0.0) += degrees[u];
530        for &(v, w) in neighbours {
531            if part[v] == cu {
532                *e_c.entry(cu).or_insert(0.0) += w;
533            }
534        }
535    }
536    // e_c counts each intra edge twice via (u,v) and (v,u); /m2 is
537    // already the correct undirected normalisation.
538    let mut q: f64 = 0.0;
539    for (c, &e) in &e_c {
540        let a = a_c.get(c).copied().unwrap_or(0.0);
541        q += e / m2 - (a / m2).powi(2);
542    }
543    q
544}