lean_ctx/core/
pagerank.rs1use std::collections::{HashMap, HashSet};
8
9use rusqlite::Connection;
10
11pub struct PageRankInput {
12 pub files: HashSet<String>,
13 pub forward: HashMap<String, Vec<String>>,
14}
15
16impl PageRankInput {
17 pub fn from_connection(conn: &Connection) -> Self {
18 let mut files: HashSet<String> = HashSet::new();
19 let mut forward: HashMap<String, Vec<String>> = HashMap::new();
20
21 if let Ok(mut stmt) =
22 conn.prepare("SELECT DISTINCT file_path FROM nodes WHERE kind = 'file'")
23 {
24 if let Ok(rows) = stmt.query_map([], |row| row.get::<_, String>(0)) {
25 for f in rows.flatten() {
26 files.insert(f);
27 }
28 }
29 }
30
31 let edge_sql = "
32 SELECT DISTINCT n1.file_path, n2.file_path
33 FROM edges e
34 JOIN nodes n1 ON e.source_id = n1.id
35 JOIN nodes n2 ON e.target_id = n2.id
36 WHERE n1.kind = 'file' AND n2.kind = 'file'
37 AND n1.file_path != n2.file_path
38 ";
39 if let Ok(mut stmt) = conn.prepare(edge_sql) {
40 if let Ok(rows) = stmt.query_map([], |row| {
41 Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?))
42 }) {
43 for row in rows.flatten() {
44 let (src, tgt) = row;
45 forward.entry(src).or_default().push(tgt);
46 }
47 }
48 }
49
50 for deps in forward.values_mut() {
51 deps.sort();
52 deps.dedup();
53 }
54
55 Self { files, forward }
56 }
57}
58
59pub fn compute(input: &PageRankInput, damping: f64, iterations: usize) -> HashMap<String, f64> {
60 compute_personalized(input, damping, iterations, &[])
61}
62
63pub fn compute_personalized(
67 input: &PageRankInput,
68 damping: f64,
69 iterations: usize,
70 seed_files: &[String],
71) -> HashMap<String, f64> {
72 let n = input.files.len();
73 if n == 0 {
74 return HashMap::new();
75 }
76
77 let personalization: HashMap<String, f64> = if seed_files.is_empty() {
78 let uniform = 1.0 / n as f64;
79 input.files.iter().map(|f| (f.clone(), uniform)).collect()
80 } else {
81 let valid_seeds: Vec<&String> = seed_files
82 .iter()
83 .filter(|f| input.files.contains(*f))
84 .collect();
85 if valid_seeds.is_empty() {
86 let uniform = 1.0 / n as f64;
87 input.files.iter().map(|f| (f.clone(), uniform)).collect()
88 } else {
89 let weight = 1.0 / valid_seeds.len() as f64;
90 let mut p = HashMap::new();
91 for f in &valid_seeds {
92 p.insert((*f).clone(), weight);
93 }
94 p
95 }
96 };
97
98 let dangling: HashSet<&String> = input
99 .files
100 .iter()
101 .filter(|f| !input.forward.contains_key(*f) || input.forward[*f].is_empty())
102 .collect();
103
104 let init = 1.0 / n as f64;
105 let mut rank: HashMap<String, f64> = input.files.iter().map(|f| (f.clone(), init)).collect();
106
107 let eps = 1e-8;
108 for _ in 0..iterations {
109 let dangling_sum: f64 = dangling
110 .iter()
111 .map(|f| rank.get(*f).copied().unwrap_or(0.0))
112 .sum();
113
114 let mut new_rank: HashMap<String, f64> = HashMap::with_capacity(n);
115
116 for f in &input.files {
117 let teleport = personalization.get(f).copied().unwrap_or(0.0);
118 let dangling_contrib = personalization.get(f).copied().unwrap_or(0.0) * dangling_sum;
119 new_rank.insert(
120 f.clone(),
121 (1.0 - damping) * teleport + damping * dangling_contrib,
122 );
123 }
124
125 for (node, neighbors) in &input.forward {
126 if neighbors.is_empty() {
127 continue;
128 }
129 let share = rank.get(node).copied().unwrap_or(0.0) / neighbors.len() as f64;
130 for neighbor in neighbors {
131 if let Some(nr) = new_rank.get_mut(neighbor) {
132 *nr += damping * share;
133 }
134 }
135 }
136
137 let diff: f64 = input
138 .files
139 .iter()
140 .map(|f| {
141 (rank.get(f).copied().unwrap_or(0.0) - new_rank.get(f).copied().unwrap_or(0.0))
142 .abs()
143 })
144 .sum();
145 rank = new_rank;
146
147 if diff < eps {
148 break;
149 }
150 }
151
152 rank
153}
154
155pub fn top_files(conn: &Connection, limit: usize) -> Vec<(String, f64)> {
156 top_files_personalized(conn, limit, &[])
157}
158
159pub fn top_files_personalized(
160 conn: &Connection,
161 limit: usize,
162 seed_files: &[String],
163) -> Vec<(String, f64)> {
164 let input = PageRankInput::from_connection(conn);
165 let ranks = compute_personalized(&input, 0.85, 50, seed_files);
166 let mut sorted: Vec<(String, f64)> = ranks.into_iter().collect();
167 sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
168 sorted.truncate(limit);
169 sorted
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175 use crate::core::property_graph::{CodeGraph, Edge, EdgeKind, Node};
176
177 #[test]
178 fn pagerank_basic() {
179 let g = CodeGraph::open_in_memory().unwrap();
180 let a = g.upsert_node(&Node::file("a.rs")).unwrap();
181 let b = g.upsert_node(&Node::file("b.rs")).unwrap();
182 let c = g.upsert_node(&Node::file("c.rs")).unwrap();
183
184 g.upsert_edge(&Edge::new(a, b, EdgeKind::Imports)).unwrap();
185 g.upsert_edge(&Edge::new(a, c, EdgeKind::Imports)).unwrap();
186 g.upsert_edge(&Edge::new(b, c, EdgeKind::Imports)).unwrap();
187
188 let input = PageRankInput::from_connection(g.connection());
189 let ranks = compute(&input, 0.85, 30);
190
191 assert_eq!(ranks.len(), 3);
192 let rank_c = ranks.get("c.rs").copied().unwrap_or(0.0);
193 let rank_a = ranks.get("a.rs").copied().unwrap_or(0.0);
194 assert!(
195 rank_c > rank_a,
196 "c.rs should rank higher (more incoming): c={rank_c} a={rank_a}"
197 );
198 }
199
200 #[test]
201 fn top_files_limit() {
202 let g = CodeGraph::open_in_memory().unwrap();
203 for i in 0..10 {
204 g.upsert_node(&Node::file(&format!("f{i}.rs"))).unwrap();
205 }
206 let top = top_files(g.connection(), 3);
207 assert!(top.len() <= 3);
208 }
209
210 #[test]
211 fn empty_graph() {
212 let g = CodeGraph::open_in_memory().unwrap();
213 let top = top_files(g.connection(), 10);
214 assert!(top.is_empty());
215 }
216
217 #[test]
218 fn personalized_pagerank_boosts_seed() {
219 let g = CodeGraph::open_in_memory().unwrap();
220 let a = g.upsert_node(&Node::file("a.rs")).unwrap();
221 let b = g.upsert_node(&Node::file("b.rs")).unwrap();
222 let c = g.upsert_node(&Node::file("c.rs")).unwrap();
223
224 g.upsert_edge(&Edge::new(a, b, EdgeKind::Imports)).unwrap();
225 g.upsert_edge(&Edge::new(b, c, EdgeKind::Imports)).unwrap();
226
227 let input = PageRankInput::from_connection(g.connection());
228
229 let uniform = compute_personalized(&input, 0.85, 50, &[]);
230 let seeded = compute_personalized(&input, 0.85, 50, &["a.rs".to_string()]);
231
232 let a_uniform = uniform.get("a.rs").copied().unwrap_or(0.0);
233 let a_seeded = seeded.get("a.rs").copied().unwrap_or(0.0);
234
235 assert!(
236 a_seeded > a_uniform,
237 "seeded a.rs ({a_seeded}) should rank higher than uniform ({a_uniform})"
238 );
239 }
240
241 #[test]
242 fn early_convergence() {
243 let g = CodeGraph::open_in_memory().unwrap();
244 let a = g.upsert_node(&Node::file("a.rs")).unwrap();
245 let b = g.upsert_node(&Node::file("b.rs")).unwrap();
246 g.upsert_edge(&Edge::new(a, b, EdgeKind::Imports)).unwrap();
247 g.upsert_edge(&Edge::new(b, a, EdgeKind::Imports)).unwrap();
248
249 let input = PageRankInput::from_connection(g.connection());
250 let ranks = compute(&input, 0.85, 1000);
251 assert_eq!(ranks.len(), 2);
252 }
253}