Skip to main content

lean_ctx/core/
pagerank.rs

1//! PageRank computation on the Property Graph.
2//!
3//! Provides a reusable `compute` function that can be called by
4//! `ctx_architecture`, `ctx_overview`, and `ctx_fill` for importance-weighted
5//! context selection.
6
7use std::collections::{HashMap, HashSet};
8
9use rusqlite::Connection;
10
11pub struct PageRankInput {
12    pub files: HashSet<String>,
13    pub forward: HashMap<String, Vec<String>>,
14}
15
16impl PageRankInput {
17    pub fn from_connection(conn: &Connection) -> Self {
18        let mut files: HashSet<String> = HashSet::new();
19        let mut forward: HashMap<String, Vec<String>> = HashMap::new();
20
21        if let Ok(mut stmt) =
22            conn.prepare("SELECT DISTINCT file_path FROM nodes WHERE kind = 'file'")
23        {
24            if let Ok(rows) = stmt.query_map([], |row| row.get::<_, String>(0)) {
25                for f in rows.flatten() {
26                    files.insert(f);
27                }
28            }
29        }
30
31        let edge_sql = "
32            SELECT DISTINCT n1.file_path, n2.file_path
33            FROM edges e
34            JOIN nodes n1 ON e.source_id = n1.id
35            JOIN nodes n2 ON e.target_id = n2.id
36            WHERE n1.kind = 'file' AND n2.kind = 'file'
37              AND n1.file_path != n2.file_path
38        ";
39        if let Ok(mut stmt) = conn.prepare(edge_sql) {
40            if let Ok(rows) = stmt.query_map([], |row| {
41                Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?))
42            }) {
43                for row in rows.flatten() {
44                    let (src, tgt) = row;
45                    forward.entry(src).or_default().push(tgt);
46                }
47            }
48        }
49
50        for deps in forward.values_mut() {
51            deps.sort();
52            deps.dedup();
53        }
54
55        Self { files, forward }
56    }
57}
58
59pub fn compute(input: &PageRankInput, damping: f64, iterations: usize) -> HashMap<String, f64> {
60    let n = input.files.len();
61    if n == 0 {
62        return HashMap::new();
63    }
64
65    let init = 1.0 / n as f64;
66    let mut rank: HashMap<String, f64> = input.files.iter().map(|f| (f.clone(), init)).collect();
67
68    for _ in 0..iterations {
69        let mut new_rank: HashMap<String, f64> = input
70            .files
71            .iter()
72            .map(|f| (f.clone(), (1.0 - damping) / n as f64))
73            .collect();
74
75        for (node, neighbors) in &input.forward {
76            if neighbors.is_empty() {
77                continue;
78            }
79            let share = rank.get(node).copied().unwrap_or(0.0) / neighbors.len() as f64;
80            for neighbor in neighbors {
81                if let Some(nr) = new_rank.get_mut(neighbor) {
82                    *nr += damping * share;
83                }
84            }
85        }
86
87        rank = new_rank;
88    }
89
90    rank
91}
92
93pub fn top_files(conn: &Connection, limit: usize) -> Vec<(String, f64)> {
94    let input = PageRankInput::from_connection(conn);
95    let ranks = compute(&input, 0.85, 30);
96    let mut sorted: Vec<(String, f64)> = ranks.into_iter().collect();
97    sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
98    sorted.truncate(limit);
99    sorted
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105    use crate::core::property_graph::{CodeGraph, Edge, EdgeKind, Node};
106
107    #[test]
108    fn pagerank_basic() {
109        let g = CodeGraph::open_in_memory().unwrap();
110        let a = g.upsert_node(&Node::file("a.rs")).unwrap();
111        let b = g.upsert_node(&Node::file("b.rs")).unwrap();
112        let c = g.upsert_node(&Node::file("c.rs")).unwrap();
113
114        g.upsert_edge(&Edge::new(a, b, EdgeKind::Imports)).unwrap();
115        g.upsert_edge(&Edge::new(a, c, EdgeKind::Imports)).unwrap();
116        g.upsert_edge(&Edge::new(b, c, EdgeKind::Imports)).unwrap();
117
118        let input = PageRankInput::from_connection(g.connection());
119        let ranks = compute(&input, 0.85, 30);
120
121        assert_eq!(ranks.len(), 3);
122        let rank_c = ranks.get("c.rs").copied().unwrap_or(0.0);
123        let rank_a = ranks.get("a.rs").copied().unwrap_or(0.0);
124        assert!(
125            rank_c > rank_a,
126            "c.rs should rank higher (more incoming): c={rank_c} a={rank_a}"
127        );
128    }
129
130    #[test]
131    fn top_files_limit() {
132        let g = CodeGraph::open_in_memory().unwrap();
133        for i in 0..10 {
134            g.upsert_node(&Node::file(&format!("f{i}.rs"))).unwrap();
135        }
136        let top = top_files(g.connection(), 3);
137        assert!(top.len() <= 3);
138    }
139
140    #[test]
141    fn empty_graph() {
142        let g = CodeGraph::open_in_memory().unwrap();
143        let top = top_files(g.connection(), 10);
144        assert!(top.is_empty());
145    }
146}