1use std::collections::HashMap;
8
9use anyhow::{Context, Result};
10
11use crate::graph::GraphQuery;
12use kuzu::Connection;
13
14#[derive(Debug)]
16pub struct ClusterStats {
17 pub num_clusters: usize,
19 pub cluster_sizes: Vec<usize>,
21 pub modularity: f64,
23}
24
25impl std::fmt::Display for ClusterStats {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 writeln!(f, "Cluster Statistics:")?;
28 writeln!(f, " Clusters: {}", self.num_clusters)?;
29 writeln!(f, " Modularity: {:.4}", self.modularity)?;
30
31 let mut sorted_sizes = self.cluster_sizes.clone();
32 sorted_sizes.sort_unstable_by(|a, b| b.cmp(a));
33 let top: Vec<_> = sorted_sizes.iter().take(10).collect();
34 write!(f, " Top sizes: ")?;
35 for (i, size) in top.iter().enumerate() {
36 if i > 0 {
37 write!(f, ", ")?;
38 }
39 write!(f, "{}", size)?;
40 }
41 if sorted_sizes.len() > 10 {
42 write!(f, " ... ({} more)", sorted_sizes.len() - 10)?;
43 }
44 writeln!(f)
45 }
46}
47
48pub fn detect_clusters(conn: &Connection) -> Result<ClusterStats> {
54 let gq = GraphQuery::new(conn);
55
56 let edge_rows = gq.raw_query("MATCH (a:Symbol)-[:CALLS]->(b:Symbol) RETURN a.id, b.id")?;
58
59 let mut id_to_idx: HashMap<String, usize> = HashMap::new();
61 let mut idx_to_id: Vec<String> = Vec::new();
62
63 for row in &edge_rows {
64 for col in row {
65 if !id_to_idx.contains_key(col) {
66 let idx = idx_to_id.len();
67 id_to_idx.insert(col.clone(), idx);
68 idx_to_id.push(col.clone());
69 }
70 }
71 }
72
73 let all_symbols = gq.raw_query(
75 "MATCH (s:Symbol) WHERE s.kind IN ['Function', 'Method', 'Class'] RETURN s.id",
76 )?;
77 for row in &all_symbols {
78 if let Some(id) = row.first() {
79 if !id_to_idx.contains_key(id) {
80 let idx = idx_to_id.len();
81 id_to_idx.insert(id.clone(), idx);
82 idx_to_id.push(id.clone());
83 }
84 }
85 }
86
87 let n = idx_to_id.len();
88 if n == 0 {
89 return Ok(ClusterStats {
90 num_clusters: 0,
91 cluster_sizes: vec![],
92 modularity: 0.0,
93 });
94 }
95
96 let mut edge_weight: HashMap<(usize, usize), f64> = HashMap::new();
100 for row in &edge_rows {
101 let a = id_to_idx[&row[0]];
102 let b = id_to_idx[&row[1]];
103 if a == b {
104 continue; }
106 let key_ab = (a.min(b), a.max(b));
107 *edge_weight.entry(key_ab).or_insert(0.0) += 1.0;
108 }
109
110 let mut adj: Vec<Vec<(usize, f64)>> = vec![vec![]; n];
112 let mut total_weight = 0.0;
113 for (&(a, b), &w) in &edge_weight {
114 adj[a].push((b, w));
115 adj[b].push((a, w));
116 total_weight += w; }
118
119 if total_weight == 0.0 {
120 let assignments: Vec<usize> = (0..n).collect();
122 let stats = store_clusters(conn, &idx_to_id, &assignments, 0.0)?;
123 return Ok(stats);
124 }
125
126 let m = total_weight;
128 let two_m = 2.0 * m;
129
130 let mut degree: Vec<f64> = vec![0.0; n];
132 for (&(a, b), &w) in &edge_weight {
133 degree[a] += w;
134 degree[b] += w;
135 }
136
137 let mut community: Vec<usize> = (0..n).collect();
140 let mut sigma_tot: Vec<f64> = degree.clone();
142
143 let max_iterations = 20;
144 for _iter in 0..max_iterations {
145 let mut improved = false;
146
147 for node in 0..n {
148 let node_comm = community[node];
149 let k_i = degree[node];
150
151 let mut comm_weights: HashMap<usize, f64> = HashMap::new();
153 for &(neighbor, w) in &adj[node] {
154 let nc = community[neighbor];
155 *comm_weights.entry(nc).or_insert(0.0) += w;
156 }
157
158 let k_i_in = comm_weights.get(&node_comm).copied().unwrap_or(0.0);
160
161 sigma_tot[node_comm] -= k_i;
163
164 let mut best_comm = node_comm;
165 let mut best_delta = 0.0;
166
167 for (&cand_comm, &k_i_cand) in &comm_weights {
168 let gain = k_i_cand / m - sigma_tot[cand_comm] * k_i / (two_m * m);
180 let loss = k_i_in / m - sigma_tot[node_comm] * k_i / (two_m * m);
181 let delta = gain - loss;
182
183 if delta > best_delta {
184 best_delta = delta;
185 best_comm = cand_comm;
186 }
187 }
188
189 community[node] = best_comm;
191 sigma_tot[best_comm] += k_i;
192
193 if best_comm != node_comm {
194 improved = true;
195 }
196 }
197
198 if !improved {
199 break;
200 }
201 }
202
203 let modularity = compute_modularity(&community, &edge_weight, °ree, m);
205
206 let stats = store_clusters(conn, &idx_to_id, &community, modularity)?;
208 Ok(stats)
209}
210
211fn compute_modularity(
213 community: &[usize],
214 edge_weight: &HashMap<(usize, usize), f64>,
215 degree: &[f64],
216 m: f64,
217) -> f64 {
218 if m == 0.0 {
219 return 0.0;
220 }
221 let two_m = 2.0 * m;
222 let mut q = 0.0;
223
224 for (&(a, b), &w) in edge_weight {
225 if community[a] == community[b] {
226 q += w - degree[a] * degree[b] / two_m;
227 }
228 }
229
230 q / m
231}
232
233fn store_clusters(
236 conn: &Connection,
237 idx_to_id: &[String],
238 community: &[usize],
239 modularity: f64,
240) -> Result<ClusterStats> {
241 let _ = conn.query("MATCH (s:Symbol)-[r:MEMBER_OF]->(c:Cluster) DELETE r");
243 let _ = conn.query("MATCH (c:Cluster) DELETE c");
244
245 let mut comm_members: HashMap<usize, Vec<usize>> = HashMap::new();
247 for (node, &comm) in community.iter().enumerate() {
248 comm_members.entry(comm).or_default().push(node);
249 }
250
251 let mut cluster_sizes = Vec::new();
252
253 for (cluster_idx, members) in comm_members.values().enumerate() {
254 let cluster_id = format!("cluster_{}", cluster_idx);
255 let cluster_name = format!("Cluster {}", cluster_idx);
256
257 let mut files: Vec<&str> = Vec::new();
259 for &node in members {
260 let sym_id = &idx_to_id[node];
261 if let Some((file, _)) = sym_id.rsplit_once("::") {
263 if !files.contains(&file) {
264 files.push(file);
265 }
266 }
267 }
268 files.truncate(5);
269 let description = format!(
270 "{} symbols across files: {}",
271 members.len(),
272 files.join(", ")
273 );
274
275 let create_cluster = format!(
277 "CREATE (c:Cluster {{id: '{}', name: '{}', description: '{}'}})",
278 escape(&cluster_id),
279 escape(&cluster_name),
280 escape(&description),
281 );
282 conn.query(&create_cluster)
283 .with_context(|| format!("failed to create cluster node: {}", cluster_id))?;
284
285 for &node in members {
287 let sym_id = &idx_to_id[node];
288 let create_edge = format!(
289 "MATCH (s:Symbol), (c:Cluster) WHERE s.id = '{}' AND c.id = '{}' CREATE (s)-[:MEMBER_OF]->(c)",
290 escape(sym_id),
291 escape(&cluster_id),
292 );
293 let _ = conn.query(&create_edge);
294 }
295
296 cluster_sizes.push(members.len());
297 }
298
299 Ok(ClusterStats {
300 num_clusters: cluster_sizes.len(),
301 cluster_sizes,
302 modularity,
303 })
304}
305
306fn escape(s: &str) -> String {
307 s.replace('\'', "\\'")
308}