Skip to main content

lean_ctx/core/
pagerank.rs

1//! PageRank computation on the Property Graph.
2//!
3//! Provides a reusable `compute` function that can be called by
4//! `ctx_architecture`, `ctx_overview`, and `ctx_fill` for importance-weighted
5//! context selection.
6
7use std::collections::{HashMap, HashSet};
8
9use rusqlite::Connection;
10
11pub struct PageRankInput {
12    pub files: HashSet<String>,
13    pub forward: HashMap<String, Vec<String>>,
14}
15
16impl PageRankInput {
17    pub fn from_connection(conn: &Connection) -> Self {
18        let mut files: HashSet<String> = HashSet::new();
19        let mut forward: HashMap<String, Vec<String>> = HashMap::new();
20
21        if let Ok(mut stmt) =
22            conn.prepare("SELECT DISTINCT file_path FROM nodes WHERE kind = 'file'")
23        {
24            if let Ok(rows) = stmt.query_map([], |row| row.get::<_, String>(0)) {
25                for f in rows.flatten() {
26                    files.insert(f);
27                }
28            }
29        }
30
31        let edge_sql = "
32            SELECT DISTINCT n1.file_path, n2.file_path
33            FROM edges e
34            JOIN nodes n1 ON e.source_id = n1.id
35            JOIN nodes n2 ON e.target_id = n2.id
36            WHERE n1.kind = 'file' AND n2.kind = 'file'
37              AND n1.file_path != n2.file_path
38        ";
39        if let Ok(mut stmt) = conn.prepare(edge_sql) {
40            if let Ok(rows) = stmt.query_map([], |row| {
41                Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?))
42            }) {
43                for row in rows.flatten() {
44                    let (src, tgt) = row;
45                    forward.entry(src).or_default().push(tgt);
46                }
47            }
48        }
49
50        for deps in forward.values_mut() {
51            deps.sort();
52            deps.dedup();
53        }
54
55        Self { files, forward }
56    }
57}
58
59pub fn compute(input: &PageRankInput, damping: f64, iterations: usize) -> HashMap<String, f64> {
60    compute_personalized(input, damping, iterations, &[])
61}
62
63/// Personalized PageRank: if `seed_files` is non-empty, teleportation bias goes
64/// to those files instead of uniform distribution. Handles dangling nodes
65/// (nodes with no outgoing edges) by redistributing their rank.
66pub fn compute_personalized(
67    input: &PageRankInput,
68    damping: f64,
69    iterations: usize,
70    seed_files: &[String],
71) -> HashMap<String, f64> {
72    let n = input.files.len();
73    if n == 0 {
74        return HashMap::new();
75    }
76
77    let personalization: HashMap<String, f64> = if seed_files.is_empty() {
78        let uniform = 1.0 / n as f64;
79        input.files.iter().map(|f| (f.clone(), uniform)).collect()
80    } else {
81        let valid_seeds: Vec<&String> = seed_files
82            .iter()
83            .filter(|f| input.files.contains(*f))
84            .collect();
85        if valid_seeds.is_empty() {
86            let uniform = 1.0 / n as f64;
87            input.files.iter().map(|f| (f.clone(), uniform)).collect()
88        } else {
89            let weight = 1.0 / valid_seeds.len() as f64;
90            let mut p = HashMap::new();
91            for f in &valid_seeds {
92                p.insert((*f).clone(), weight);
93            }
94            p
95        }
96    };
97
98    let dangling: HashSet<&String> = input
99        .files
100        .iter()
101        .filter(|f| !input.forward.contains_key(*f) || input.forward[*f].is_empty())
102        .collect();
103
104    let init = 1.0 / n as f64;
105    let mut rank: HashMap<String, f64> = input.files.iter().map(|f| (f.clone(), init)).collect();
106
107    let eps = 1e-8;
108    for _ in 0..iterations {
109        let dangling_sum: f64 = dangling
110            .iter()
111            .map(|f| rank.get(*f).copied().unwrap_or(0.0))
112            .sum();
113
114        let mut new_rank: HashMap<String, f64> = HashMap::with_capacity(n);
115
116        for f in &input.files {
117            let teleport = personalization.get(f).copied().unwrap_or(0.0);
118            let dangling_contrib = personalization.get(f).copied().unwrap_or(0.0) * dangling_sum;
119            new_rank.insert(
120                f.clone(),
121                (1.0 - damping) * teleport + damping * dangling_contrib,
122            );
123        }
124
125        for (node, neighbors) in &input.forward {
126            if neighbors.is_empty() {
127                continue;
128            }
129            let share = rank.get(node).copied().unwrap_or(0.0) / neighbors.len() as f64;
130            for neighbor in neighbors {
131                if let Some(nr) = new_rank.get_mut(neighbor) {
132                    *nr += damping * share;
133                }
134            }
135        }
136
137        let diff: f64 = input
138            .files
139            .iter()
140            .map(|f| {
141                (rank.get(f).copied().unwrap_or(0.0) - new_rank.get(f).copied().unwrap_or(0.0))
142                    .abs()
143            })
144            .sum();
145        rank = new_rank;
146
147        if diff < eps {
148            break;
149        }
150    }
151
152    rank
153}
154
155pub fn top_files(conn: &Connection, limit: usize) -> Vec<(String, f64)> {
156    top_files_personalized(conn, limit, &[])
157}
158
159pub fn top_files_personalized(
160    conn: &Connection,
161    limit: usize,
162    seed_files: &[String],
163) -> Vec<(String, f64)> {
164    let input = PageRankInput::from_connection(conn);
165    let ranks = compute_personalized(&input, 0.85, 50, seed_files);
166    let mut sorted: Vec<(String, f64)> = ranks.into_iter().collect();
167    sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
168    sorted.truncate(limit);
169    sorted
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175    use crate::core::property_graph::{CodeGraph, Edge, EdgeKind, Node};
176
177    #[test]
178    fn pagerank_basic() {
179        let g = CodeGraph::open_in_memory().unwrap();
180        let a = g.upsert_node(&Node::file("a.rs")).unwrap();
181        let b = g.upsert_node(&Node::file("b.rs")).unwrap();
182        let c = g.upsert_node(&Node::file("c.rs")).unwrap();
183
184        g.upsert_edge(&Edge::new(a, b, EdgeKind::Imports)).unwrap();
185        g.upsert_edge(&Edge::new(a, c, EdgeKind::Imports)).unwrap();
186        g.upsert_edge(&Edge::new(b, c, EdgeKind::Imports)).unwrap();
187
188        let input = PageRankInput::from_connection(g.connection());
189        let ranks = compute(&input, 0.85, 30);
190
191        assert_eq!(ranks.len(), 3);
192        let rank_c = ranks.get("c.rs").copied().unwrap_or(0.0);
193        let rank_a = ranks.get("a.rs").copied().unwrap_or(0.0);
194        assert!(
195            rank_c > rank_a,
196            "c.rs should rank higher (more incoming): c={rank_c} a={rank_a}"
197        );
198    }
199
200    #[test]
201    fn top_files_limit() {
202        let g = CodeGraph::open_in_memory().unwrap();
203        for i in 0..10 {
204            g.upsert_node(&Node::file(&format!("f{i}.rs"))).unwrap();
205        }
206        let top = top_files(g.connection(), 3);
207        assert!(top.len() <= 3);
208    }
209
210    #[test]
211    fn empty_graph() {
212        let g = CodeGraph::open_in_memory().unwrap();
213        let top = top_files(g.connection(), 10);
214        assert!(top.is_empty());
215    }
216
217    #[test]
218    fn personalized_pagerank_boosts_seed() {
219        let g = CodeGraph::open_in_memory().unwrap();
220        let a = g.upsert_node(&Node::file("a.rs")).unwrap();
221        let b = g.upsert_node(&Node::file("b.rs")).unwrap();
222        let c = g.upsert_node(&Node::file("c.rs")).unwrap();
223
224        g.upsert_edge(&Edge::new(a, b, EdgeKind::Imports)).unwrap();
225        g.upsert_edge(&Edge::new(b, c, EdgeKind::Imports)).unwrap();
226
227        let input = PageRankInput::from_connection(g.connection());
228
229        let uniform = compute_personalized(&input, 0.85, 50, &[]);
230        let seeded = compute_personalized(&input, 0.85, 50, &["a.rs".to_string()]);
231
232        let a_uniform = uniform.get("a.rs").copied().unwrap_or(0.0);
233        let a_seeded = seeded.get("a.rs").copied().unwrap_or(0.0);
234
235        assert!(
236            a_seeded > a_uniform,
237            "seeded a.rs ({a_seeded}) should rank higher than uniform ({a_uniform})"
238        );
239    }
240
241    #[test]
242    fn early_convergence() {
243        let g = CodeGraph::open_in_memory().unwrap();
244        let a = g.upsert_node(&Node::file("a.rs")).unwrap();
245        let b = g.upsert_node(&Node::file("b.rs")).unwrap();
246        g.upsert_edge(&Edge::new(a, b, EdgeKind::Imports)).unwrap();
247        g.upsert_edge(&Edge::new(b, a, EdgeKind::Imports)).unwrap();
248
249        let input = PageRankInput::from_connection(g.connection());
250        let ranks = compute(&input, 0.85, 1000);
251        assert_eq!(ranks.len(), 2);
252    }
253}