Skip to main content

lean_ctx/core/
graph_features.rs

1//! Multi-layer graph descriptors inspired by GNN message passing:
2//! PageRank (global importance), local clustering, HITS hubs/authorities,
3//! and weakly-connected community ids.
4
5use std::collections::{HashMap, HashSet};
6
7use rusqlite::Connection;
8
9use crate::core::pagerank::{compute as pagerank_compute, PageRankInput};
10
11/// Aggregated graph-derived features per file node.
12#[derive(Debug, Clone, PartialEq)]
13pub struct GraphFeatures {
14    pub centrality: f64,
15    pub clustering_coeff: f64,
16    pub hub_score: f64,
17    pub authority_score: f64,
18    pub community_id: Option<usize>,
19}
20
21struct UnionFind {
22    parent: Vec<usize>,
23    rank: Vec<u8>,
24}
25
26impl UnionFind {
27    fn new(n: usize) -> Self {
28        Self {
29            parent: (0..n).collect(),
30            rank: vec![0; n],
31        }
32    }
33
34    fn find(&mut self, x: usize) -> usize {
35        if self.parent[x] != x {
36            self.parent[x] = self.find(self.parent[x]);
37        }
38        self.parent[x]
39    }
40
41    fn union(&mut self, a: usize, b: usize) {
42        let mut ra = self.find(a);
43        let mut rb = self.find(b);
44        if ra == rb {
45            return;
46        }
47        if self.rank[ra] < self.rank[rb] {
48            std::mem::swap(&mut ra, &mut rb);
49        }
50        self.parent[rb] = ra;
51        if self.rank[ra] == self.rank[rb] {
52            self.rank[ra] += 1;
53        }
54    }
55}
56
57fn file_adjacency(input: &PageRankInput) -> HashMap<String, HashSet<String>> {
58    let mut adj: HashMap<String, HashSet<String>> = HashMap::new();
59
60    for f in &input.files {
61        adj.entry(f.clone()).or_default();
62    }
63
64    for (u, outs) in &input.forward {
65        for v in outs {
66            if !input.files.contains(v) || u == v {
67                continue;
68            }
69            adj.entry(u.clone()).or_default().insert(v.clone());
70            adj.entry(v.clone()).or_default().insert(u.clone());
71        }
72    }
73
74    adj
75}
76
77fn clustering_coefficients(adj: &HashMap<String, HashSet<String>>) -> HashMap<String, f64> {
78    let mut out = HashMap::new();
79
80    for (node, neigh) in adj {
81        let k = neigh.len();
82        if k < 2 {
83            out.insert(node.clone(), 0.0);
84            continue;
85        }
86
87        let neigh_vec: Vec<&String> = neigh.iter().collect();
88        let mut edges_between = 0usize;
89        for (i, ni) in neigh_vec.iter().enumerate() {
90            for nj in &neigh_vec[(i + 1)..] {
91                if adj.get(*ni).is_some_and(|s| s.contains(*nj)) {
92                    edges_between += 1;
93                }
94            }
95        }
96
97        let denom = k * (k - 1) / 2;
98        let c = if denom > 0 {
99            edges_between as f64 / denom as f64
100        } else {
101            0.0
102        };
103        out.insert(node.clone(), c);
104    }
105
106    out
107}
108
109fn hits_scores(
110    forward: &HashMap<String, Vec<String>>,
111    files: &HashSet<String>,
112    iterations: usize,
113) -> (HashMap<String, f64>, HashMap<String, f64>) {
114    let n = files.len();
115    if n == 0 {
116        return (HashMap::new(), HashMap::new());
117    }
118
119    let mut hubs: HashMap<String, f64> = files.iter().map(|f| (f.clone(), 1.0)).collect();
120    let mut authorities: HashMap<String, f64> = files.iter().map(|f| (f.clone(), 1.0)).collect();
121
122    for _ in 0..iterations {
123        let mut new_auth: HashMap<String, f64> = files.iter().map(|f| (f.clone(), 0.0)).collect();
124        for (u, outs) in forward {
125            let hu = *hubs.get(u).unwrap_or(&0.0);
126            for v in outs {
127                if files.contains(v) {
128                    *new_auth.entry(v.clone()).or_insert(0.0) += hu;
129                }
130            }
131        }
132
133        let a_sum: f64 = new_auth.values().sum::<f64>().max(1e-12);
134        for v in new_auth.values_mut() {
135            *v /= a_sum;
136        }
137
138        let mut new_hub: HashMap<String, f64> = files.iter().map(|f| (f.clone(), 0.0)).collect();
139        for (u, outs) in forward {
140            let mut s = 0.0_f64;
141            for v in outs {
142                if files.contains(v) {
143                    s += new_auth.get(v).copied().unwrap_or(0.0);
144                }
145            }
146            *new_hub.entry(u.clone()).or_insert(0.0) += s;
147        }
148
149        let h_sum: f64 = new_hub.values().sum::<f64>().max(1e-12);
150        for v in new_hub.values_mut() {
151            *v /= h_sum;
152        }
153
154        hubs = new_hub;
155        authorities = new_auth;
156    }
157
158    (hubs, authorities)
159}
160
161fn community_labels(
162    input: &PageRankInput,
163    index_of: &HashMap<String, usize>,
164) -> HashMap<String, usize> {
165    let mut uf = UnionFind::new(index_of.len());
166
167    for (u, outs) in &input.forward {
168        let Some(&iu) = index_of.get(u) else {
169            continue;
170        };
171        for v in outs {
172            let Some(&iv) = index_of.get(v) else {
173                continue;
174            };
175            if iu != iv {
176                uf.union(iu, iv);
177            }
178        }
179    }
180
181    let mut root_map: HashMap<usize, usize> = HashMap::new();
182    let mut next_id = 0usize;
183    let mut labels: HashMap<String, usize> = HashMap::new();
184
185    let mut file_vec: Vec<&String> = input.files.iter().collect();
186    file_vec.sort();
187
188    for f in file_vec {
189        let i = index_of[f];
190        let r = uf.find(i);
191        let cid = *root_map.entry(r).or_insert_with(|| {
192            let id = next_id;
193            next_id += 1;
194            id
195        });
196        labels.insert(f.clone(), cid);
197    }
198
199    labels
200}
201
202/// Computes per-file graph features using the same file-level projection as PageRank.
203pub fn compute_graph_features(conn: &Connection) -> HashMap<String, GraphFeatures> {
204    let input = PageRankInput::from_connection(conn);
205    let files = &input.files;
206
207    if files.is_empty() {
208        return HashMap::new();
209    }
210
211    let ranks = pagerank_compute(&input, 0.85, 50);
212    let adj = file_adjacency(&input);
213    let clustering = clustering_coefficients(&adj);
214    let (hubs, authorities) = hits_scores(&input.forward, files, 40);
215
216    let mut sorted_files: Vec<String> = files.iter().cloned().collect();
217    sorted_files.sort();
218    let index_of: HashMap<String, usize> = sorted_files
219        .iter()
220        .enumerate()
221        .map(|(i, p)| (p.clone(), i))
222        .collect();
223
224    let communities = community_labels(&input, &index_of);
225
226    let mut result = HashMap::with_capacity(files.len());
227    for f in files {
228        result.insert(
229            f.clone(),
230            GraphFeatures {
231                centrality: ranks.get(f).copied().unwrap_or(0.0),
232                clustering_coeff: clustering.get(f).copied().unwrap_or(0.0),
233                hub_score: hubs.get(f).copied().unwrap_or(0.0),
234                authority_score: authorities.get(f).copied().unwrap_or(0.0),
235                community_id: communities.get(f).copied(),
236            },
237        );
238    }
239
240    result
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use crate::core::property_graph::{CodeGraph, Edge, EdgeKind, Node};
247
248    #[test]
249    fn triangle_boosts_clustering() {
250        let g = CodeGraph::open_in_memory().unwrap();
251        let a = g.upsert_node(&Node::file("a.rs")).unwrap();
252        let b = g.upsert_node(&Node::file("b.rs")).unwrap();
253        let c = g.upsert_node(&Node::file("c.rs")).unwrap();
254
255        g.upsert_edge(&Edge::new(a, b, EdgeKind::Imports)).unwrap();
256        g.upsert_edge(&Edge::new(a, c, EdgeKind::Imports)).unwrap();
257        g.upsert_edge(&Edge::new(b, c, EdgeKind::Imports)).unwrap();
258
259        let feats = compute_graph_features(g.connection());
260        let ca = feats.get("a.rs").expect("a.rs");
261        let cb = feats.get("b.rs").expect("b.rs");
262        assert!(
263            ca.clustering_coeff > 0.9 && cb.clustering_coeff > 0.9,
264            "triangle graph should have ~1 clustering: a={} b={}",
265            ca.clustering_coeff,
266            cb.clustering_coeff
267        );
268    }
269
270    #[test]
271    fn authority_on_star() {
272        let g = CodeGraph::open_in_memory().unwrap();
273        let hub = g.upsert_node(&Node::file("hub.rs")).unwrap();
274        let leaf_a = g.upsert_node(&Node::file("leaf_a.rs")).unwrap();
275        let leaf_b = g.upsert_node(&Node::file("leaf_b.rs")).unwrap();
276
277        g.upsert_edge(&Edge::new(hub, leaf_a, EdgeKind::Imports))
278            .unwrap();
279        g.upsert_edge(&Edge::new(hub, leaf_b, EdgeKind::Imports))
280            .unwrap();
281
282        let feats = compute_graph_features(g.connection());
283        let h = feats.get("hub.rs").unwrap();
284        let la = feats.get("leaf_a.rs").unwrap();
285
286        assert!(
287            la.authority_score > h.authority_score,
288            "leaf should have higher authority than hub in out-star"
289        );
290        assert!(
291            h.hub_score > la.hub_score,
292            "hub node should have larger hub score"
293        );
294    }
295
296    #[test]
297    fn disconnected_components_differ_community() {
298        let g = CodeGraph::open_in_memory().unwrap();
299        let a = g.upsert_node(&Node::file("x.rs")).unwrap();
300        let b = g.upsert_node(&Node::file("y.rs")).unwrap();
301        let _c = g.upsert_node(&Node::file("z.rs")).unwrap();
302        g.upsert_edge(&Edge::new(a, b, EdgeKind::Imports)).unwrap();
303
304        let feats = compute_graph_features(g.connection());
305        assert_ne!(
306            feats["x.rs"].community_id, feats["z.rs"].community_id,
307            "isolated file should not share weak component with x-y pair"
308        );
309    }
310
311    #[test]
312    fn empty_graph() {
313        let g = CodeGraph::open_in_memory().unwrap();
314        assert!(compute_graph_features(g.connection()).is_empty());
315    }
316}