lean_ctx/core/repomap/
ranking.rs1use crate::core::pagerank::{self, PageRankInput};
7use crate::core::repomap::graph::{RepoGraph, SymbolDef};
8
9#[derive(Debug, Clone)]
11pub struct RankedSymbol {
12 pub def: SymbolDef,
13 pub score: f64,
14}
15
16pub fn rank_symbols(
23 graph: &RepoGraph,
24 session_files: &[String],
25 focus_files: &[String],
26) -> Vec<RankedSymbol> {
27 let input = PageRankInput {
28 files: graph.files.clone(),
29 forward: graph.forward.clone(),
30 };
31
32 let seed_files = build_seed_files(session_files, focus_files, &graph.files);
33 let file_ranks = pagerank::compute_personalized(&input, 0.85, 30, &seed_files);
34
35 let mut ranked: Vec<RankedSymbol> = Vec::new();
36
37 for (file, symbols) in &graph.symbols_by_file {
38 let file_score = file_ranks.get(file).copied().unwrap_or(0.0);
39
40 for sym in symbols {
41 let export_boost = if sym.is_exported { 2.0 } else { 1.0 };
43 let score = file_score * export_boost;
44
45 ranked.push(RankedSymbol {
46 def: sym.clone(),
47 score,
48 });
49 }
50 }
51
52 ranked.sort_by(|a, b| {
53 b.score
54 .partial_cmp(&a.score)
55 .unwrap_or(std::cmp::Ordering::Equal)
56 });
57
58 ranked
59}
60
61fn build_seed_files(
67 session_files: &[String],
68 focus_files: &[String],
69 valid_files: &std::collections::HashSet<String>,
70) -> Vec<String> {
71 let mut seen = std::collections::HashSet::new();
72 let mut seeds: Vec<String> = Vec::new();
73
74 for f in session_files.iter().chain(focus_files.iter()) {
75 if valid_files.contains(f) && seen.insert(f.clone()) {
76 seeds.push(f.clone());
77 }
78 }
79
80 seeds
81}
82
83#[cfg(test)]
84mod tests {
85 use super::*;
86 use crate::core::repomap::graph::SymbolDef;
87 use std::collections::{HashMap, HashSet};
88
89 fn test_graph() -> RepoGraph {
90 let mut files = HashSet::new();
91 files.insert("a.rs".into());
92 files.insert("b.rs".into());
93 files.insert("c.rs".into());
94
95 let mut forward = HashMap::new();
96 forward.insert("a.rs".into(), vec!["b.rs".into()]);
98 forward.insert("b.rs".into(), vec!["c.rs".into()]);
99
100 let mut symbols_by_file = HashMap::new();
101 symbols_by_file.insert("a.rs".into(), vec![test_sym("main", "fn", "a.rs", false)]);
102 symbols_by_file.insert("b.rs".into(), vec![test_sym("process", "fn", "b.rs", true)]);
103 symbols_by_file.insert(
104 "c.rs".into(),
105 vec![
106 test_sym("Config", "struct", "c.rs", true),
107 test_sym("helper", "fn", "c.rs", false),
108 ],
109 );
110
111 RepoGraph {
112 files,
113 forward,
114 symbols_by_file,
115 }
116 }
117
118 fn test_sym(name: &str, kind: &str, file: &str, exported: bool) -> SymbolDef {
119 SymbolDef {
120 name: name.into(),
121 kind: kind.into(),
122 file: file.into(),
123 line: 1,
124 end_line: 10,
125 is_exported: exported,
126 signature: format!("{kind} {name}"),
127 }
128 }
129
130 #[test]
131 fn most_depended_file_ranks_highest() {
132 let graph = test_graph();
133 let ranked = rank_symbols(&graph, &[], &[]);
134
135 let c_scores: Vec<f64> = ranked
136 .iter()
137 .filter(|r| r.def.file == "c.rs")
138 .map(|r| r.score)
139 .collect();
140 let a_scores: Vec<f64> = ranked
141 .iter()
142 .filter(|r| r.def.file == "a.rs")
143 .map(|r| r.score)
144 .collect();
145
146 let max_c = c_scores.iter().copied().fold(0.0_f64, f64::max);
147 let max_a = a_scores.iter().copied().fold(0.0_f64, f64::max);
148
149 assert!(
150 max_c > max_a,
151 "c.rs (most deps) should rank higher: c={max_c} a={max_a}"
152 );
153 }
154
155 #[test]
156 fn exported_symbols_get_boost() {
157 let graph = test_graph();
158 let ranked = rank_symbols(&graph, &[], &[]);
159
160 let config = ranked.iter().find(|r| r.def.name == "Config").unwrap();
161 let helper = ranked.iter().find(|r| r.def.name == "helper").unwrap();
162
163 assert!(
164 config.score > helper.score,
165 "exported Config should rank higher than non-exported helper in same file"
166 );
167 }
168
169 #[test]
170 fn session_files_get_boosted() {
171 let graph = test_graph();
172
173 let no_seed = rank_symbols(&graph, &[], &[]);
174 let with_seed = rank_symbols(&graph, &["a.rs".into()], &[]);
175
176 let a_no_seed = no_seed.iter().find(|r| r.def.name == "main").unwrap().score;
177 let a_with_seed = with_seed
178 .iter()
179 .find(|r| r.def.name == "main")
180 .unwrap()
181 .score;
182
183 assert!(
184 a_with_seed > a_no_seed,
185 "session-seeded a.rs should rank higher: {a_with_seed} vs {a_no_seed}"
186 );
187 }
188
189 #[test]
190 fn empty_graph_returns_empty() {
191 let graph = RepoGraph {
192 files: HashSet::new(),
193 forward: HashMap::new(),
194 symbols_by_file: HashMap::new(),
195 };
196 let ranked = rank_symbols(&graph, &[], &[]);
197 assert!(ranked.is_empty());
198 }
199
200 #[test]
201 fn build_seed_filters_invalid_files() {
202 let mut valid = HashSet::new();
203 valid.insert("a.rs".into());
204
205 let seeds = build_seed_files(
206 &["a.rs".into(), "nonexistent.rs".into()],
207 &["also_missing.rs".into()],
208 &valid,
209 );
210
211 assert_eq!(seeds.len(), 1, "only valid a.rs should remain");
212 assert_eq!(seeds[0], "a.rs");
213 }
214}