1use std::collections::HashMap;
2use std::path::PathBuf;
3
4use petgraph::graph::{DiGraph, NodeIndex};
5
6use crate::parser::{Reference, Symbol};
7
8#[derive(Debug, Clone)]
31pub struct SymbolNode {
32 pub symbol: Symbol,
34 pub rank: f64,
36}
37
38pub struct SymbolGraph {
79 graph: DiGraph<SymbolNode, ()>,
80 #[allow(dead_code)]
81 name_to_index: HashMap<String, NodeIndex>,
82}
83
84impl SymbolGraph {
85 pub fn build(symbols: Vec<Symbol>, references: Vec<Reference>) -> Self {
91 let mut graph = DiGraph::new();
92 let mut name_to_index: HashMap<String, NodeIndex> = HashMap::new();
93
94 for symbol in symbols {
95 let name = symbol.name.clone();
96 let idx = graph.add_node(SymbolNode { symbol, rank: 0.0 });
97 name_to_index.entry(name).or_insert(idx);
99 }
100
101 for reference in &references {
102 let Some(&to_idx) = name_to_index.get(&reference.to_name) else {
103 continue; };
105
106 let from_idx = reference
108 .from_symbol
109 .as_ref()
110 .and_then(|name| name_to_index.get(name).copied());
111
112 let Some(from_idx) = from_idx else {
113 continue;
114 };
115
116 if from_idx == to_idx {
118 continue;
119 }
120
121 graph.add_edge(from_idx, to_idx, ());
122 }
123
124 Self {
125 graph,
126 name_to_index,
127 }
128 }
129
130 pub fn compute_pagerank(&mut self) {
132 let n = self.graph.node_count();
133 if n == 0 {
134 return;
135 }
136
137 let d: f64 = 0.85;
138 let n_f64 = n as f64;
139 let base = (1.0 - d) / n_f64;
140
141 let mut ranks = vec![1.0 / n_f64; n];
143
144 for _ in 0..20 {
145 let mut new_ranks = vec![base; n];
146
147 for node_idx in self.graph.node_indices() {
148 let i = node_idx.index();
149 let out_degree = self
150 .graph
151 .neighbors_directed(node_idx, petgraph::Direction::Outgoing)
152 .count();
153
154 if out_degree == 0 {
155 continue;
156 }
157
158 let contribution = d * ranks[i] / out_degree as f64;
159 for neighbor in self
160 .graph
161 .neighbors_directed(node_idx, petgraph::Direction::Outgoing)
162 {
163 new_ranks[neighbor.index()] += contribution;
164 }
165 }
166
167 ranks = new_ranks;
168 }
169
170 for node_idx in self.graph.node_indices() {
172 self.graph[node_idx].rank = ranks[node_idx.index()];
173 }
174 }
175
176 pub fn ranked_symbols(&self) -> Vec<&SymbolNode> {
178 let mut nodes: Vec<&SymbolNode> = self.graph.node_weights().collect();
179 nodes.sort_by(|a, b| {
180 b.rank
181 .partial_cmp(&a.rank)
182 .unwrap_or(std::cmp::Ordering::Equal)
183 });
184 nodes
185 }
186
187 pub fn ranked_symbols_for_files(&self, focus_files: &[PathBuf]) -> Vec<&SymbolNode> {
189 let mut scored: Vec<(&SymbolNode, f64)> = self
190 .graph
191 .node_weights()
192 .map(|node| {
193 let multiplier = if focus_files.contains(&node.symbol.file) {
194 2.0
195 } else {
196 1.0
197 };
198 (node, node.rank * multiplier)
199 })
200 .collect();
201
202 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
203 scored.into_iter().map(|(node, _)| node).collect()
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210 use crate::parser::SymbolKind;
211
212 fn make_symbol(name: &str, file: &str) -> Symbol {
213 Symbol {
214 name: name.to_string(),
215 kind: SymbolKind::Function,
216 file: PathBuf::from(file),
217 line: 1,
218 signature: format!("fn {name}()"),
219 token_cost: 5,
220 }
221 }
222
223 fn make_ref(from: &str, to: &str) -> Reference {
224 Reference {
225 from_file: PathBuf::from("test.rs"),
226 from_symbol: Some(from.to_string()),
227 to_name: to.to_string(),
228 line: 1,
229 }
230 }
231
232 #[test]
233 fn pagerank_linked_chain() {
234 let symbols = vec![
236 make_symbol("A", "a.rs"),
237 make_symbol("B", "b.rs"),
238 make_symbol("C", "c.rs"),
239 ];
240 let refs = vec![make_ref("A", "B"), make_ref("B", "C")];
241
242 let mut graph = SymbolGraph::build(symbols, refs);
243 graph.compute_pagerank();
244 let ranked = graph.ranked_symbols();
245
246 assert_eq!(ranked.len(), 3);
247 assert_eq!(ranked[0].symbol.name, "C");
249 assert!(ranked[0].rank > ranked[2].rank);
250 }
251
252 #[test]
253 fn disconnected_nodes_get_base_rank() {
254 let symbols = vec![make_symbol("X", "x.rs"), make_symbol("Y", "y.rs")];
255 let refs: Vec<Reference> = vec![];
256
257 let mut graph = SymbolGraph::build(symbols, refs);
258 graph.compute_pagerank();
259 let ranked = graph.ranked_symbols();
260
261 assert_eq!(ranked.len(), 2);
263 let diff = (ranked[0].rank - ranked[1].rank).abs();
264 assert!(diff < 1e-10, "disconnected nodes should have equal rank");
265 }
266
267 #[test]
268 fn focus_files_boost_ranking() {
269 let symbols = vec![make_symbol("A", "a.rs"), make_symbol("B", "b.rs")];
270 let refs: Vec<Reference> = vec![];
271
272 let mut graph = SymbolGraph::build(symbols, refs);
273 graph.compute_pagerank();
274
275 let ranked = graph.ranked_symbols();
277 assert!((ranked[0].rank - ranked[1].rank).abs() < 1e-10);
278
279 let focus = vec![PathBuf::from("b.rs")];
281 let boosted = graph.ranked_symbols_for_files(&focus);
282 assert_eq!(boosted[0].symbol.name, "B");
283 }
284
285 #[test]
286 fn empty_graph() {
287 let mut graph = SymbolGraph::build(vec![], vec![]);
288 graph.compute_pagerank();
289 let ranked = graph.ranked_symbols();
290 assert!(ranked.is_empty());
291 }
292}