Skip to main content

lean_ctx/core/
community.rs

1//! Louvain community detection on the Property Graph.
2//!
3//! Implements the Louvain algorithm for modularity-based graph clustering:
4//!   1. Start with each node in its own community.
5//!   2. Greedily move nodes to the community that yields the highest
6//!      modularity gain.
7//!   3. Aggregate communities into super-nodes and repeat.
8//!
9//! The result maps each file to a community ID, along with modularity metrics.
10
11use std::collections::HashMap;
12
13use rusqlite::Connection;
14use serde::Serialize;
15
16#[derive(Debug, Clone, Serialize)]
17pub struct Community {
18    pub id: usize,
19    pub files: Vec<String>,
20    pub internal_edges: usize,
21    pub external_edges: usize,
22    pub cohesion: f64,
23}
24
25#[derive(Debug, Clone, Serialize)]
26pub struct CommunityResult {
27    pub communities: Vec<Community>,
28    pub modularity: f64,
29    pub node_count: usize,
30    pub edge_count: usize,
31}
32
33struct AdjGraph {
34    node_ids: Vec<String>,
35    #[allow(dead_code)]
36    node_to_idx: HashMap<String, usize>,
37    adj: Vec<Vec<(usize, f64)>>,
38    total_weight: f64,
39    degree: Vec<f64>,
40}
41
42impl AdjGraph {
43    fn from_property_graph(conn: &Connection) -> Self {
44        let mut node_ids: Vec<String> = Vec::new();
45        let mut node_to_idx: HashMap<String, usize> = HashMap::new();
46
47        let mut file_stmt = conn
48            .prepare("SELECT DISTINCT file_path FROM nodes WHERE kind = 'file'")
49            .unwrap();
50        let files = file_stmt
51            .query_map([], |row| row.get::<_, String>(0))
52            .unwrap()
53            .filter_map(std::result::Result::ok)
54            .collect::<Vec<_>>();
55
56        for f in &files {
57            let idx = node_ids.len();
58            node_ids.push(f.clone());
59            node_to_idx.insert(f.clone(), idx);
60        }
61
62        let n = node_ids.len();
63        let mut adj: Vec<Vec<(usize, f64)>> = vec![Vec::new(); n];
64        let mut total_weight = 0.0;
65        let mut degree = vec![0.0; n];
66
67        let edge_sql = "
68            SELECT DISTINCT n1.file_path, n2.file_path, e.kind
69            FROM edges e
70            JOIN nodes n1 ON e.source_id = n1.id
71            JOIN nodes n2 ON e.target_id = n2.id
72            WHERE n1.kind = 'file' AND n2.kind = 'file'
73              AND n1.file_path != n2.file_path
74        ";
75        let mut edge_stmt = conn.prepare(edge_sql).unwrap();
76        let edges = edge_stmt
77            .query_map([], |row| {
78                Ok((
79                    row.get::<_, String>(0)?,
80                    row.get::<_, String>(1)?,
81                    row.get::<_, String>(2)?,
82                ))
83            })
84            .unwrap()
85            .filter_map(std::result::Result::ok)
86            .collect::<Vec<_>>();
87
88        for (from, to, kind) in &edges {
89            let Some(&i) = node_to_idx.get(from) else {
90                continue;
91            };
92            let Some(&j) = node_to_idx.get(to) else {
93                continue;
94            };
95            let w = edge_weight(kind);
96            adj[i].push((j, w));
97            degree[i] += w;
98            degree[j] += w;
99            total_weight += w;
100        }
101
102        Self {
103            node_ids,
104            node_to_idx,
105            adj,
106            total_weight,
107            degree,
108        }
109    }
110}
111
112fn edge_weight(kind: &str) -> f64 {
113    match kind {
114        "imports" => 1.0,
115        "calls" => 1.5,
116        "type_ref" => 0.8,
117        "defines" | "exports" => 0.3,
118        _ => 0.5,
119    }
120}
121
122pub fn detect_communities(conn: &Connection) -> CommunityResult {
123    let graph = AdjGraph::from_property_graph(conn);
124    let n = graph.node_ids.len();
125
126    if n == 0 {
127        return CommunityResult {
128            communities: Vec::new(),
129            modularity: 0.0,
130            node_count: 0,
131            edge_count: 0,
132        };
133    }
134
135    let mut community: Vec<usize> = (0..n).collect();
136    let mut changed = true;
137    let m2 = graph.total_weight.max(1.0) * 2.0;
138
139    while changed {
140        changed = false;
141        for i in 0..n {
142            let current = community[i];
143            let mut best_delta = 0.0f64;
144            let mut best_community = current;
145
146            let mut neighbor_comm_weight: HashMap<usize, f64> = HashMap::new();
147            for &(j, w) in &graph.adj[i] {
148                *neighbor_comm_weight.entry(community[j]).or_default() += w;
149            }
150
151            let ki = graph.degree[i];
152            let sigma_current = comm_sum_degree(&graph, &community, current);
153            let sigma_in_current = neighbor_comm_weight.get(&current).copied().unwrap_or(0.0);
154
155            for (&c, &ki_in) in &neighbor_comm_weight {
156                if c == current {
157                    continue;
158                }
159                let sigma_c = comm_sum_degree(&graph, &community, c);
160
161                let delta_remove = -2.0 * (sigma_in_current - ki * (sigma_current - ki) / m2) / m2;
162                let delta_add = 2.0 * (ki_in - ki * sigma_c / m2) / m2;
163                let delta = delta_add + delta_remove;
164
165                if delta > best_delta {
166                    best_delta = delta;
167                    best_community = c;
168                }
169            }
170
171            if best_community != current {
172                community[i] = best_community;
173                changed = true;
174            }
175        }
176    }
177
178    let mut comm_map: HashMap<usize, Vec<usize>> = HashMap::new();
179    for (i, &c) in community.iter().enumerate() {
180        comm_map.entry(c).or_default().push(i);
181    }
182
183    let mut communities: Vec<Community> = Vec::new();
184    for members in comm_map.values() {
185        let files: Vec<String> = members.iter().map(|&i| graph.node_ids[i].clone()).collect();
186        let member_set: std::collections::HashSet<usize> = members.iter().copied().collect();
187
188        let mut internal = 0usize;
189        let mut external = 0usize;
190        for &i in members {
191            for &(j, _) in &graph.adj[i] {
192                if member_set.contains(&j) {
193                    internal += 1;
194                } else {
195                    external += 1;
196                }
197            }
198        }
199
200        let total = (internal + external).max(1) as f64;
201        let cohesion = internal as f64 / total;
202
203        communities.push(Community {
204            id: 0,
205            files,
206            internal_edges: internal,
207            external_edges: external,
208            cohesion,
209        });
210    }
211
212    communities.sort_by(|a, b| {
213        b.files.len().cmp(&a.files.len()).then_with(|| {
214            b.cohesion
215                .partial_cmp(&a.cohesion)
216                .unwrap_or(std::cmp::Ordering::Equal)
217        })
218    });
219
220    for (new_id, c) in communities.iter_mut().enumerate() {
221        c.id = new_id;
222    }
223
224    let modularity = compute_modularity(&graph, &community);
225    let edge_count = graph.adj.iter().map(Vec::len).sum::<usize>();
226
227    CommunityResult {
228        communities,
229        modularity,
230        node_count: n,
231        edge_count,
232    }
233}
234
235fn comm_sum_degree(graph: &AdjGraph, community: &[usize], c: usize) -> f64 {
236    let mut sum = 0.0;
237    for (i, &ci) in community.iter().enumerate() {
238        if ci == c {
239            sum += graph.degree[i];
240        }
241    }
242    sum
243}
244
245fn compute_modularity(graph: &AdjGraph, community: &[usize]) -> f64 {
246    let m2 = graph.total_weight.max(1.0) * 2.0;
247    let mut q = 0.0;
248
249    for (i, neighbors) in graph.adj.iter().enumerate() {
250        for &(j, w) in neighbors {
251            if community[i] == community[j] {
252                let ki = graph.degree[i];
253                let kj = graph.degree[j];
254                q += w - (ki * kj) / m2;
255            }
256        }
257    }
258
259    q / m2
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265    use crate::core::property_graph::{CodeGraph, Edge, EdgeKind, Node};
266
267    fn build_test_graph() -> CodeGraph {
268        let graph = CodeGraph::open_in_memory().unwrap();
269
270        let node_a = graph.upsert_node(&Node::file("src/core/a.rs")).unwrap();
271        let node_b = graph.upsert_node(&Node::file("src/core/b.rs")).unwrap();
272        let node_c = graph.upsert_node(&Node::file("src/core/c.rs")).unwrap();
273        let node_d = graph.upsert_node(&Node::file("src/tools/d.rs")).unwrap();
274        let node_e = graph.upsert_node(&Node::file("src/tools/e.rs")).unwrap();
275
276        graph
277            .upsert_edge(&Edge::new(node_a, node_b, EdgeKind::Imports))
278            .unwrap();
279        graph
280            .upsert_edge(&Edge::new(node_b, node_c, EdgeKind::Imports))
281            .unwrap();
282        graph
283            .upsert_edge(&Edge::new(node_a, node_c, EdgeKind::Calls))
284            .unwrap();
285
286        graph
287            .upsert_edge(&Edge::new(node_d, node_e, EdgeKind::Imports))
288            .unwrap();
289        graph
290            .upsert_edge(&Edge::new(node_e, node_d, EdgeKind::Calls))
291            .unwrap();
292
293        graph
294            .upsert_edge(&Edge::new(node_c, node_d, EdgeKind::Imports))
295            .unwrap();
296
297        graph
298    }
299
300    #[test]
301    fn detects_communities() {
302        let g = build_test_graph();
303        let result = detect_communities(g.connection());
304
305        assert!(
306            !result.communities.is_empty(),
307            "Should detect at least one community"
308        );
309        assert!(result.node_count == 5);
310        assert!(result.edge_count > 0);
311    }
312
313    #[test]
314    fn modularity_positive() {
315        let g = build_test_graph();
316        let result = detect_communities(g.connection());
317
318        assert!(
319            result.modularity >= 0.0,
320            "Modularity should be non-negative for clustered graph"
321        );
322    }
323
324    #[test]
325    fn community_files_cover_all_nodes() {
326        let g = build_test_graph();
327        let result = detect_communities(g.connection());
328
329        let total_files: usize = result.communities.iter().map(|c| c.files.len()).sum();
330        assert_eq!(total_files, 5, "All 5 files should be assigned");
331    }
332
333    #[test]
334    fn empty_graph() {
335        let g = CodeGraph::open_in_memory().unwrap();
336        let result = detect_communities(g.connection());
337        assert!(result.communities.is_empty());
338        assert_eq!(result.modularity, 0.0);
339    }
340}