Skip to main content

lean_ctx/core/
community.rs

1//! Leiden community detection on the Property Graph.
2//!
3//! Implements the Leiden algorithm (Traag, Waltman, van Eck 2019) for
4//! modularity-based graph clustering with guaranteed connected communities:
5//!   1. **Local moving:** greedily move nodes to the community that yields
6//!      the highest modularity gain.
7//!   2. **Refinement:** within each community, find well-connected
8//!      sub-communities to ensure connectivity.
9//!   3. **Aggregation:** collapse sub-communities into super-nodes and repeat.
10
11use std::collections::HashMap;
12
13use rusqlite::Connection;
14use serde::Serialize;
15
16#[derive(Debug, Clone, Serialize)]
17pub struct Community {
18    pub id: usize,
19    pub files: Vec<String>,
20    pub internal_edges: usize,
21    pub external_edges: usize,
22    pub cohesion: f64,
23}
24
25#[derive(Debug, Clone, Serialize)]
26pub struct CommunityResult {
27    pub communities: Vec<Community>,
28    pub modularity: f64,
29    pub node_count: usize,
30    pub edge_count: usize,
31}
32
33struct AdjGraph {
34    node_ids: Vec<String>,
35    #[allow(dead_code)]
36    node_to_idx: HashMap<String, usize>,
37    adj: Vec<Vec<(usize, f64)>>,
38    total_weight: f64,
39    degree: Vec<f64>,
40}
41
42impl AdjGraph {
43    fn from_property_graph(conn: &Connection) -> Self {
44        let mut node_ids: Vec<String> = Vec::new();
45        let mut node_to_idx: HashMap<String, usize> = HashMap::new();
46
47        let Ok(mut file_stmt) =
48            conn.prepare("SELECT DISTINCT file_path FROM nodes WHERE kind = 'file'")
49        else {
50            tracing::warn!("community: failed to prepare file query");
51            return Self {
52                node_ids: Vec::new(),
53                node_to_idx: HashMap::new(),
54                adj: Vec::new(),
55                degree: Vec::new(),
56                total_weight: 0.0,
57            };
58        };
59        let files = match file_stmt.query_map([], |row| row.get::<_, String>(0)) {
60            Ok(rows) => rows.filter_map(std::result::Result::ok).collect::<Vec<_>>(),
61            Err(e) => {
62                tracing::warn!("community: file query failed: {e}");
63                Vec::new()
64            }
65        };
66
67        for f in &files {
68            let idx = node_ids.len();
69            node_ids.push(f.clone());
70            node_to_idx.insert(f.clone(), idx);
71        }
72
73        let n = node_ids.len();
74        let mut adj: Vec<Vec<(usize, f64)>> = vec![Vec::new(); n];
75        let mut total_weight = 0.0;
76        let mut degree = vec![0.0; n];
77
78        let edge_sql = "
79            SELECT DISTINCT n1.file_path, n2.file_path, e.kind
80            FROM edges e
81            JOIN nodes n1 ON e.source_id = n1.id
82            JOIN nodes n2 ON e.target_id = n2.id
83            WHERE n1.kind = 'file' AND n2.kind = 'file'
84              AND n1.file_path != n2.file_path
85        ";
86        let Ok(mut edge_stmt) = conn.prepare(edge_sql) else {
87            tracing::warn!("community: failed to prepare edge query");
88            return Self {
89                node_ids,
90                node_to_idx,
91                adj,
92                total_weight,
93                degree,
94            };
95        };
96        let edges = match edge_stmt.query_map([], |row| {
97            Ok((
98                row.get::<_, String>(0)?,
99                row.get::<_, String>(1)?,
100                row.get::<_, String>(2)?,
101            ))
102        }) {
103            Ok(rows) => rows.filter_map(std::result::Result::ok).collect::<Vec<_>>(),
104            Err(e) => {
105                tracing::warn!("community: edge query failed: {e}");
106                Vec::new()
107            }
108        };
109
110        for (from, to, kind) in &edges {
111            let Some(&i) = node_to_idx.get(from) else {
112                continue;
113            };
114            let Some(&j) = node_to_idx.get(to) else {
115                continue;
116            };
117            let w = edge_weight(kind);
118            adj[i].push((j, w));
119            degree[i] += w;
120            degree[j] += w;
121            total_weight += w;
122        }
123
124        Self {
125            node_ids,
126            node_to_idx,
127            adj,
128            total_weight,
129            degree,
130        }
131    }
132}
133
134fn edge_weight(kind: &str) -> f64 {
135    match kind {
136        "imports" => 1.0,
137        "calls" => 1.5,
138        "type_ref" => 0.8,
139        "defines" | "exports" => 0.3,
140        _ => 0.5,
141    }
142}
143
144pub fn detect_communities(conn: &Connection) -> CommunityResult {
145    let graph = AdjGraph::from_property_graph(conn);
146    let n = graph.node_ids.len();
147
148    if n == 0 {
149        return CommunityResult {
150            communities: Vec::new(),
151            modularity: 0.0,
152            node_count: 0,
153            edge_count: 0,
154        };
155    }
156
157    let assignment = leiden(&graph);
158
159    let mut comm_map: HashMap<usize, Vec<usize>> = HashMap::new();
160    for (i, &c) in assignment.iter().enumerate() {
161        comm_map.entry(c).or_default().push(i);
162    }
163
164    let mut communities: Vec<Community> = Vec::new();
165    for members in comm_map.values() {
166        let files: Vec<String> = members.iter().map(|&i| graph.node_ids[i].clone()).collect();
167        let member_set: std::collections::HashSet<usize> = members.iter().copied().collect();
168
169        let mut internal = 0usize;
170        let mut external = 0usize;
171        for &i in members {
172            for &(j, _) in &graph.adj[i] {
173                if member_set.contains(&j) {
174                    internal += 1;
175                } else {
176                    external += 1;
177                }
178            }
179        }
180
181        let total = (internal + external).max(1) as f64;
182        let cohesion = internal as f64 / total;
183
184        communities.push(Community {
185            id: 0,
186            files,
187            internal_edges: internal,
188            external_edges: external,
189            cohesion,
190        });
191    }
192
193    communities.sort_by(|a, b| {
194        b.files.len().cmp(&a.files.len()).then_with(|| {
195            b.cohesion
196                .partial_cmp(&a.cohesion)
197                .unwrap_or(std::cmp::Ordering::Equal)
198        })
199    });
200
201    for (new_id, c) in communities.iter_mut().enumerate() {
202        c.id = new_id;
203    }
204
205    let modularity = compute_modularity(&graph, &assignment);
206    let edge_count = graph.adj.iter().map(Vec::len).sum::<usize>();
207
208    CommunityResult {
209        communities,
210        modularity,
211        node_count: n,
212        edge_count,
213    }
214}
215
216// ── Leiden Algorithm ────────────────────────────────────────
217
218const MAX_ITERATIONS: usize = 20;
219const GAMMA: f64 = 1.0;
220
221fn leiden(graph: &AdjGraph) -> Vec<usize> {
222    let n = graph.node_ids.len();
223    let mut assignment: Vec<usize> = (0..n).collect();
224    let m2 = graph.total_weight.max(1.0) * 2.0;
225
226    for _ in 0..MAX_ITERATIONS {
227        let moved = local_moving(graph, &mut assignment, m2);
228        if !moved {
229            break;
230        }
231        refine_communities(graph, &mut assignment, m2);
232    }
233
234    assignment
235}
236
237/// Phase 1: Local Moving — greedily move nodes to their best neighbor community.
238fn local_moving(graph: &AdjGraph, assignment: &mut [usize], m2: f64) -> bool {
239    let n = assignment.len();
240    let mut comm_total: Vec<f64> = vec![0.0; n];
241    for (i, &c) in assignment.iter().enumerate() {
242        comm_total[c] += graph.degree[i];
243    }
244
245    let mut changed = false;
246    let mut improved = true;
247
248    while improved {
249        improved = false;
250        for i in 0..n {
251            let current = assignment[i];
252            let ki = graph.degree[i];
253
254            let mut neighbor_comm_weight: HashMap<usize, f64> = HashMap::new();
255            for &(j, w) in &graph.adj[i] {
256                *neighbor_comm_weight.entry(assignment[j]).or_default() += w;
257            }
258
259            let sigma_current = comm_total[current];
260            let ki_in_current = neighbor_comm_weight.get(&current).copied().unwrap_or(0.0);
261
262            let mut best_delta = 0.0f64;
263            let mut best_comm = current;
264
265            for (&c, &ki_in) in &neighbor_comm_weight {
266                if c == current {
267                    continue;
268                }
269                let sigma_c = comm_total[c];
270                let delta_remove = -2.0 * (ki_in_current - ki * (sigma_current - ki) / m2) / m2;
271                let delta_add = 2.0 * (ki_in - ki * sigma_c / m2) / m2;
272                let delta = delta_add + delta_remove;
273
274                if delta > best_delta {
275                    best_delta = delta;
276                    best_comm = c;
277                }
278            }
279
280            if best_comm != current {
281                comm_total[current] -= ki;
282                comm_total[best_comm] += ki;
283                assignment[i] = best_comm;
284                improved = true;
285                changed = true;
286            }
287        }
288    }
289
290    changed
291}
292
293/// Phase 2: Refinement — ensure each community is well-connected by splitting
294/// disconnected components within communities.
295fn refine_communities(graph: &AdjGraph, assignment: &mut [usize], m2: f64) {
296    let mut comm_members: HashMap<usize, Vec<usize>> = HashMap::new();
297    for (i, &c) in assignment.iter().enumerate() {
298        comm_members.entry(c).or_default().push(i);
299    }
300
301    let mut next_id = *assignment.iter().max().unwrap_or(&0) + 1;
302
303    for members in comm_members.values() {
304        if members.len() <= 1 {
305            continue;
306        }
307
308        let components = find_connected_components(graph, members);
309        if components.len() <= 1 {
310            continue;
311        }
312
313        for component in components.iter().skip(1) {
314            let new_comm = next_id;
315            next_id += 1;
316            for &node in component {
317                assignment[node] = new_comm;
318            }
319        }
320    }
321
322    merge_singleton_communities(graph, assignment, m2);
323}
324
325/// Find connected components within a subset of nodes.
326fn find_connected_components(graph: &AdjGraph, members: &[usize]) -> Vec<Vec<usize>> {
327    let member_set: std::collections::HashSet<usize> = members.iter().copied().collect();
328    let mut visited = std::collections::HashSet::new();
329    let mut components = Vec::new();
330
331    for &start in members {
332        if visited.contains(&start) {
333            continue;
334        }
335
336        let mut component = Vec::new();
337        let mut stack = vec![start];
338
339        while let Some(node) = stack.pop() {
340            if !visited.insert(node) {
341                continue;
342            }
343            component.push(node);
344            for &(neighbor, _) in &graph.adj[node] {
345                if member_set.contains(&neighbor) && !visited.contains(&neighbor) {
346                    stack.push(neighbor);
347                }
348            }
349        }
350
351        components.push(component);
352    }
353
354    components
355}
356
357/// Try to merge singleton communities into their best neighbor community.
358fn merge_singleton_communities(graph: &AdjGraph, assignment: &mut [usize], m2: f64) {
359    let n = assignment.len();
360    let mut comm_total: Vec<f64> =
361        vec![0.0; n.max(assignment.iter().copied().max().unwrap_or(0) + 1)];
362    for (i, &c) in assignment.iter().enumerate() {
363        if c < comm_total.len() {
364            comm_total[c] += graph.degree[i];
365        }
366    }
367
368    let mut comm_sizes: HashMap<usize, usize> = HashMap::new();
369    for &c in assignment.iter() {
370        *comm_sizes.entry(c).or_default() += 1;
371    }
372
373    for i in 0..n {
374        let current = assignment[i];
375        if *comm_sizes.get(&current).unwrap_or(&0) > 1 {
376            continue;
377        }
378
379        let ki = graph.degree[i];
380        let mut neighbor_comm_weight: HashMap<usize, f64> = HashMap::new();
381        for &(j, w) in &graph.adj[i] {
382            *neighbor_comm_weight.entry(assignment[j]).or_default() += w;
383        }
384
385        let mut best_delta = 0.0f64;
386        let mut best_comm = current;
387
388        for (&c, &ki_in) in &neighbor_comm_weight {
389            if c == current {
390                continue;
391            }
392            let sigma_c = if c < comm_total.len() {
393                comm_total[c]
394            } else {
395                0.0
396            };
397            let delta = 2.0 * (ki_in - GAMMA * ki * sigma_c / m2) / m2;
398            if delta > best_delta {
399                best_delta = delta;
400                best_comm = c;
401            }
402        }
403
404        if best_comm != current {
405            if current < comm_total.len() {
406                comm_total[current] -= ki;
407            }
408            if best_comm < comm_total.len() {
409                comm_total[best_comm] += ki;
410            }
411            *comm_sizes.entry(current).or_default() -= 1;
412            *comm_sizes.entry(best_comm).or_default() += 1;
413            assignment[i] = best_comm;
414        }
415    }
416}
417
418fn compute_modularity(graph: &AdjGraph, community: &[usize]) -> f64 {
419    let m2 = graph.total_weight.max(1.0) * 2.0;
420    let mut q = 0.0;
421
422    for (i, neighbors) in graph.adj.iter().enumerate() {
423        for &(j, w) in neighbors {
424            if community[i] == community[j] {
425                let ki = graph.degree[i];
426                let kj = graph.degree[j];
427                q += w - (ki * kj) / m2;
428            }
429        }
430    }
431
432    q / m2
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438    use crate::core::property_graph::{CodeGraph, Edge, EdgeKind, Node};
439
440    fn build_test_graph() -> CodeGraph {
441        let graph = CodeGraph::open_in_memory().unwrap();
442
443        let node_a = graph.upsert_node(&Node::file("src/core/a.rs")).unwrap();
444        let node_b = graph.upsert_node(&Node::file("src/core/b.rs")).unwrap();
445        let node_c = graph.upsert_node(&Node::file("src/core/c.rs")).unwrap();
446        let node_d = graph.upsert_node(&Node::file("src/tools/d.rs")).unwrap();
447        let node_e = graph.upsert_node(&Node::file("src/tools/e.rs")).unwrap();
448
449        graph
450            .upsert_edge(&Edge::new(node_a, node_b, EdgeKind::Imports))
451            .unwrap();
452        graph
453            .upsert_edge(&Edge::new(node_b, node_c, EdgeKind::Imports))
454            .unwrap();
455        graph
456            .upsert_edge(&Edge::new(node_a, node_c, EdgeKind::Calls))
457            .unwrap();
458
459        graph
460            .upsert_edge(&Edge::new(node_d, node_e, EdgeKind::Imports))
461            .unwrap();
462        graph
463            .upsert_edge(&Edge::new(node_e, node_d, EdgeKind::Calls))
464            .unwrap();
465
466        graph
467            .upsert_edge(&Edge::new(node_c, node_d, EdgeKind::Imports))
468            .unwrap();
469
470        graph
471    }
472
473    #[test]
474    fn detects_communities() {
475        let g = build_test_graph();
476        let result = detect_communities(g.connection());
477
478        assert!(
479            !result.communities.is_empty(),
480            "Should detect at least one community"
481        );
482        assert!(result.node_count == 5);
483        assert!(result.edge_count > 0);
484    }
485
486    #[test]
487    fn modularity_positive() {
488        let g = build_test_graph();
489        let result = detect_communities(g.connection());
490
491        assert!(
492            result.modularity >= 0.0,
493            "Modularity should be non-negative for clustered graph"
494        );
495    }
496
497    #[test]
498    fn community_files_cover_all_nodes() {
499        let g = build_test_graph();
500        let result = detect_communities(g.connection());
501
502        let total_files: usize = result.communities.iter().map(|c| c.files.len()).sum();
503        assert_eq!(total_files, 5, "All 5 files should be assigned");
504    }
505
506    #[test]
507    fn empty_graph() {
508        let g = CodeGraph::open_in_memory().unwrap();
509        let result = detect_communities(g.connection());
510        assert!(result.communities.is_empty());
511        assert_eq!(result.modularity, 0.0);
512    }
513
514    #[test]
515    fn communities_are_connected() {
516        let g = build_test_graph();
517        let graph = AdjGraph::from_property_graph(g.connection());
518        let result = detect_communities(g.connection());
519
520        for comm in &result.communities {
521            if comm.files.len() <= 1 {
522                continue;
523            }
524            let indices: Vec<usize> = comm
525                .files
526                .iter()
527                .filter_map(|f| graph.node_to_idx.get(f).copied())
528                .collect();
529            let components = find_connected_components(&graph, &indices);
530            assert_eq!(
531                components.len(),
532                1,
533                "Community {} with {} files should be connected, found {} components",
534                comm.id,
535                comm.files.len(),
536                components.len()
537            );
538        }
539    }
540}