Skip to main content

infigraph_core/cluster/
mod.rs

1//! Louvain community detection for discovering functional modules in the code graph.
2//!
3//! Builds an undirected weighted graph from CALLS edges, then runs single-level
4//! Louvain modularity optimization. Results are stored as Cluster nodes and
5//! MEMBER_OF edges in the graph.
6
7use std::collections::HashMap;
8
9use anyhow::{Context, Result};
10
11use crate::graph::GraphQuery;
12use kuzu::Connection;
13
14/// Statistics returned after clustering.
15#[derive(Debug)]
16pub struct ClusterStats {
17    /// Total number of clusters discovered.
18    pub num_clusters: usize,
19    /// Size of each cluster (number of symbols).
20    pub cluster_sizes: Vec<usize>,
21    /// Final modularity score.
22    pub modularity: f64,
23}
24
25impl std::fmt::Display for ClusterStats {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        writeln!(f, "Cluster Statistics:")?;
28        writeln!(f, "  Clusters:    {}", self.num_clusters)?;
29        writeln!(f, "  Modularity:  {:.4}", self.modularity)?;
30
31        let mut sorted_sizes = self.cluster_sizes.clone();
32        sorted_sizes.sort_unstable_by(|a, b| b.cmp(a));
33        let top: Vec<_> = sorted_sizes.iter().take(10).collect();
34        write!(f, "  Top sizes:   ")?;
35        for (i, size) in top.iter().enumerate() {
36            if i > 0 {
37                write!(f, ", ")?;
38            }
39            write!(f, "{}", size)?;
40        }
41        if sorted_sizes.len() > 10 {
42            write!(f, " ... ({} more)", sorted_sizes.len() - 10)?;
43        }
44        writeln!(f)
45    }
46}
47
48/// Run Louvain community detection on the code graph and store results.
49///
50/// 1. Queries all CALLS edges to build an undirected adjacency list.
51/// 2. Runs single-level Louvain (iterative modularity optimization).
52/// 3. Creates Cluster nodes and MEMBER_OF edges in the graph.
53pub fn detect_clusters(conn: &Connection) -> Result<ClusterStats> {
54    let gq = GraphQuery::new(conn);
55
56    // Step 1: Fetch all CALLS edges as (source_id, target_id) pairs.
57    let edge_rows = gq.raw_query("MATCH (a:Symbol)-[:CALLS]->(b:Symbol) RETURN a.id, b.id")?;
58
59    // Build node index: map symbol ID -> dense integer index.
60    let mut id_to_idx: HashMap<String, usize> = HashMap::new();
61    let mut idx_to_id: Vec<String> = Vec::new();
62
63    for row in &edge_rows {
64        for col in row {
65            if !id_to_idx.contains_key(col) {
66                let idx = idx_to_id.len();
67                id_to_idx.insert(col.clone(), idx);
68                idx_to_id.push(col.clone());
69            }
70        }
71    }
72
73    // Also include isolated symbols (no CALLS edges) so they appear in their own clusters.
74    let all_symbols = gq.raw_query(
75        "MATCH (s:Symbol) WHERE s.kind IN ['Function', 'Method', 'Class'] RETURN s.id",
76    )?;
77    for row in &all_symbols {
78        if let Some(id) = row.first() {
79            if !id_to_idx.contains_key(id) {
80                let idx = idx_to_id.len();
81                id_to_idx.insert(id.clone(), idx);
82                idx_to_id.push(id.clone());
83            }
84        }
85    }
86
87    let n = idx_to_id.len();
88    if n == 0 {
89        return Ok(ClusterStats {
90            num_clusters: 0,
91            cluster_sizes: vec![],
92            modularity: 0.0,
93        });
94    }
95
96    // Build undirected weighted adjacency: adj[node] = Vec<(neighbor, weight)>.
97    // For an undirected graph, each CALLS edge contributes weight 1 in both directions.
98    // Multiple edges between the same pair accumulate weight.
99    let mut edge_weight: HashMap<(usize, usize), f64> = HashMap::new();
100    for row in &edge_rows {
101        let a = id_to_idx[&row[0]];
102        let b = id_to_idx[&row[1]];
103        if a == b {
104            continue; // skip self-loops for community detection
105        }
106        let key_ab = (a.min(b), a.max(b));
107        *edge_weight.entry(key_ab).or_insert(0.0) += 1.0;
108    }
109
110    // Build adjacency list from the edge weights.
111    let mut adj: Vec<Vec<(usize, f64)>> = vec![vec![]; n];
112    let mut total_weight = 0.0;
113    for (&(a, b), &w) in &edge_weight {
114        adj[a].push((b, w));
115        adj[b].push((a, w));
116        total_weight += w; // each undirected edge counted once
117    }
118
119    if total_weight == 0.0 {
120        // No edges: each node is its own cluster.
121        let assignments: Vec<usize> = (0..n).collect();
122        let stats = store_clusters(conn, &idx_to_id, &assignments, 0.0)?;
123        return Ok(stats);
124    }
125
126    // m = sum of all edge weights (undirected). In modularity formula, 2m is the denominator.
127    let m = total_weight;
128    let two_m = 2.0 * m;
129
130    // Degree of each node (sum of weights of incident edges).
131    let mut degree: Vec<f64> = vec![0.0; n];
132    for (&(a, b), &w) in &edge_weight {
133        degree[a] += w;
134        degree[b] += w;
135    }
136
137    // Step 2: Louvain single-level optimization.
138    // community[i] = community label for node i
139    let mut community: Vec<usize> = (0..n).collect();
140    // Sum of degrees in each community.
141    let mut sigma_tot: Vec<f64> = degree.clone();
142
143    let max_iterations = 20;
144    for _iter in 0..max_iterations {
145        let mut improved = false;
146
147        for node in 0..n {
148            let node_comm = community[node];
149            let k_i = degree[node];
150
151            // Compute sum of weights from node to each neighboring community.
152            let mut comm_weights: HashMap<usize, f64> = HashMap::new();
153            for &(neighbor, w) in &adj[node] {
154                let nc = community[neighbor];
155                *comm_weights.entry(nc).or_insert(0.0) += w;
156            }
157
158            // Compute weight from node to its own community.
159            let k_i_in = comm_weights.get(&node_comm).copied().unwrap_or(0.0);
160
161            // Remove node from its community.
162            sigma_tot[node_comm] -= k_i;
163
164            let mut best_comm = node_comm;
165            let mut best_delta = 0.0;
166
167            for (&cand_comm, &k_i_cand) in &comm_weights {
168                // Delta modularity for moving node to cand_comm:
169                // delta_Q = [k_i_cand / m - sigma_tot[cand_comm] * k_i / (2 * m^2)]
170                // compared to keeping node in its own singleton:
171                // We use the standard Louvain delta formula:
172                // delta_Q = (k_i_cand - k_i_in) / m
173                //         - k_i * (sigma_tot[cand_comm] - sigma_tot[node_comm]) / (2 * m * m)
174                // But since we already removed node from node_comm, sigma_tot[node_comm] is updated.
175                // Simplified formula (after removing from current):
176                // gain = k_i_cand / m - sigma_tot[cand_comm] * k_i / (two_m * m)
177                // loss = k_i_in / m - sigma_tot[node_comm] * k_i / (two_m * m)
178                // delta = gain - loss
179                let gain = k_i_cand / m - sigma_tot[cand_comm] * k_i / (two_m * m);
180                let loss = k_i_in / m - sigma_tot[node_comm] * k_i / (two_m * m);
181                let delta = gain - loss;
182
183                if delta > best_delta {
184                    best_delta = delta;
185                    best_comm = cand_comm;
186                }
187            }
188
189            // Move node to best community.
190            community[node] = best_comm;
191            sigma_tot[best_comm] += k_i;
192
193            if best_comm != node_comm {
194                improved = true;
195            }
196        }
197
198        if !improved {
199            break;
200        }
201    }
202
203    // Compute final modularity.
204    let modularity = compute_modularity(&community, &edge_weight, &degree, m);
205
206    // Step 3: Store results in the graph.
207    let stats = store_clusters(conn, &idx_to_id, &community, modularity)?;
208    Ok(stats)
209}
210
211/// Compute modularity Q for the given partition.
212fn compute_modularity(
213    community: &[usize],
214    edge_weight: &HashMap<(usize, usize), f64>,
215    degree: &[f64],
216    m: f64,
217) -> f64 {
218    if m == 0.0 {
219        return 0.0;
220    }
221    let two_m = 2.0 * m;
222    let mut q = 0.0;
223
224    for (&(a, b), &w) in edge_weight {
225        if community[a] == community[b] {
226            q += w - degree[a] * degree[b] / two_m;
227        }
228    }
229
230    q / m
231}
232
233/// Store cluster results into the graph: create Cluster nodes and MEMBER_OF edges.
234/// Clears any existing Cluster/MEMBER_OF data first.
235fn store_clusters(
236    conn: &Connection,
237    idx_to_id: &[String],
238    community: &[usize],
239    modularity: f64,
240) -> Result<ClusterStats> {
241    // Clear existing cluster data.
242    let _ = conn.query("MATCH (s:Symbol)-[r:MEMBER_OF]->(c:Cluster) DELETE r");
243    let _ = conn.query("MATCH (c:Cluster) DELETE c");
244
245    // Build community -> members map, renumbering communities to be contiguous.
246    let mut comm_members: HashMap<usize, Vec<usize>> = HashMap::new();
247    for (node, &comm) in community.iter().enumerate() {
248        comm_members.entry(comm).or_default().push(node);
249    }
250
251    let mut cluster_sizes = Vec::new();
252
253    for (cluster_idx, members) in comm_members.values().enumerate() {
254        let cluster_id = format!("cluster_{}", cluster_idx);
255        let cluster_name = format!("Cluster {}", cluster_idx);
256
257        // Gather file names for description.
258        let mut files: Vec<&str> = Vec::new();
259        for &node in members {
260            let sym_id = &idx_to_id[node];
261            // Extract file part from symbol ID (format: "file::name").
262            if let Some((file, _)) = sym_id.rsplit_once("::") {
263                if !files.contains(&file) {
264                    files.push(file);
265                }
266            }
267        }
268        files.truncate(5);
269        let description = format!(
270            "{} symbols across files: {}",
271            members.len(),
272            files.join(", ")
273        );
274
275        // Create cluster node.
276        let create_cluster = format!(
277            "CREATE (c:Cluster {{id: '{}', name: '{}', description: '{}'}})",
278            escape(&cluster_id),
279            escape(&cluster_name),
280            escape(&description),
281        );
282        conn.query(&create_cluster)
283            .with_context(|| format!("failed to create cluster node: {}", cluster_id))?;
284
285        // Create MEMBER_OF edges.
286        for &node in members {
287            let sym_id = &idx_to_id[node];
288            let create_edge = format!(
289                "MATCH (s:Symbol), (c:Cluster) WHERE s.id = '{}' AND c.id = '{}' CREATE (s)-[:MEMBER_OF]->(c)",
290                escape(sym_id),
291                escape(&cluster_id),
292            );
293            let _ = conn.query(&create_edge);
294        }
295
296        cluster_sizes.push(members.len());
297    }
298
299    Ok(ClusterStats {
300        num_clusters: cluster_sizes.len(),
301        cluster_sizes,
302        modularity,
303    })
304}
305
306fn escape(s: &str) -> String {
307    s.replace('\'', "\\'")
308}