Skip to main content

argus_repomap/
graph.rs

1use std::collections::HashMap;
2use std::path::PathBuf;
3
4use petgraph::graph::{DiGraph, NodeIndex};
5
6use crate::parser::{Reference, Symbol};
7
8/// A node in the symbol graph: a symbol annotated with its PageRank score.
9///
10/// # Examples
11///
12/// ```
13/// use std::path::PathBuf;
14/// use argus_repomap::parser::{Symbol, SymbolKind};
15/// use argus_repomap::graph::SymbolNode;
16///
17/// let node = SymbolNode {
18///     symbol: Symbol {
19///         name: "main".into(),
20///         kind: SymbolKind::Function,
21///         file: PathBuf::from("src/main.rs"),
22///         line: 1,
23///         signature: "fn main()".into(),
24///         token_cost: 2,
25///     },
26///     rank: 0.0,
27/// };
28/// assert_eq!(node.rank, 0.0);
29/// ```
30#[derive(Debug, Clone)]
31pub struct SymbolNode {
32    /// The original parsed symbol.
33    pub symbol: Symbol,
34    /// PageRank score (higher = more important).
35    pub rank: f64,
36}
37
38/// Directed graph of symbols linked by cross-references, with PageRank ranking.
39///
40/// # Examples
41///
42/// ```
43/// use std::path::PathBuf;
44/// use argus_repomap::parser::{Symbol, SymbolKind, Reference};
45/// use argus_repomap::graph::SymbolGraph;
46///
47/// let symbols = vec![
48///     Symbol {
49///         name: "caller".into(),
50///         kind: SymbolKind::Function,
51///         file: PathBuf::from("a.rs"),
52///         line: 1,
53///         signature: "fn caller()".into(),
54///         token_cost: 3,
55///     },
56///     Symbol {
57///         name: "callee".into(),
58///         kind: SymbolKind::Function,
59///         file: PathBuf::from("b.rs"),
60///         line: 1,
61///         signature: "fn callee()".into(),
62///         token_cost: 3,
63///     },
64/// ];
65/// let refs = vec![
66///     Reference {
67///         from_file: PathBuf::from("a.rs"),
68///         from_symbol: Some("caller".into()),
69///         to_name: "callee".into(),
70///         line: 2,
71///     },
72/// ];
73/// let mut graph = SymbolGraph::build(symbols, refs);
74/// graph.compute_pagerank();
75/// let ranked = graph.ranked_symbols();
76/// assert!(!ranked.is_empty());
77/// ```
78pub struct SymbolGraph {
79    graph: DiGraph<SymbolNode, ()>,
80    #[allow(dead_code)]
81    name_to_index: HashMap<String, NodeIndex>,
82}
83
84impl SymbolGraph {
85    /// Build a graph from extracted symbols and references.
86    ///
87    /// Each symbol becomes a node. Each reference that resolves to a known
88    /// symbol name creates a directed edge from the referencing context to
89    /// the referenced symbol.
90    pub fn build(symbols: Vec<Symbol>, references: Vec<Reference>) -> Self {
91        let mut graph = DiGraph::new();
92        let mut name_to_index: HashMap<String, NodeIndex> = HashMap::new();
93
94        for symbol in symbols {
95            let name = symbol.name.clone();
96            let idx = graph.add_node(SymbolNode { symbol, rank: 0.0 });
97            // First symbol with a given name wins
98            name_to_index.entry(name).or_insert(idx);
99        }
100
101        for reference in &references {
102            let Some(&to_idx) = name_to_index.get(&reference.to_name) else {
103                continue; // Unresolved reference
104            };
105
106            // Find the "from" node: prefer the enclosing symbol, fall back to file-level
107            let from_idx = reference
108                .from_symbol
109                .as_ref()
110                .and_then(|name| name_to_index.get(name).copied());
111
112            let Some(from_idx) = from_idx else {
113                continue;
114            };
115
116            // Don't add self-loops
117            if from_idx == to_idx {
118                continue;
119            }
120
121            graph.add_edge(from_idx, to_idx, ());
122        }
123
124        Self {
125            graph,
126            name_to_index,
127        }
128    }
129
130    /// Run PageRank (damping=0.85, 20 iterations) and store scores on nodes.
131    pub fn compute_pagerank(&mut self) {
132        let n = self.graph.node_count();
133        if n == 0 {
134            return;
135        }
136
137        let d: f64 = 0.85;
138        let n_f64 = n as f64;
139        let base = (1.0 - d) / n_f64;
140
141        // Initialize all ranks to 1/N
142        let mut ranks = vec![1.0 / n_f64; n];
143
144        for _ in 0..20 {
145            let mut new_ranks = vec![base; n];
146
147            for node_idx in self.graph.node_indices() {
148                let i = node_idx.index();
149                let out_degree = self
150                    .graph
151                    .neighbors_directed(node_idx, petgraph::Direction::Outgoing)
152                    .count();
153
154                if out_degree == 0 {
155                    continue;
156                }
157
158                let contribution = d * ranks[i] / out_degree as f64;
159                for neighbor in self
160                    .graph
161                    .neighbors_directed(node_idx, petgraph::Direction::Outgoing)
162                {
163                    new_ranks[neighbor.index()] += contribution;
164                }
165            }
166
167            ranks = new_ranks;
168        }
169
170        // Write ranks back to nodes
171        for node_idx in self.graph.node_indices() {
172            self.graph[node_idx].rank = ranks[node_idx.index()];
173        }
174    }
175
176    /// Get all symbols sorted by rank (highest first).
177    pub fn ranked_symbols(&self) -> Vec<&SymbolNode> {
178        let mut nodes: Vec<&SymbolNode> = self.graph.node_weights().collect();
179        nodes.sort_by(|a, b| {
180            b.rank
181                .partial_cmp(&a.rank)
182                .unwrap_or(std::cmp::Ordering::Equal)
183        });
184        nodes
185    }
186
187    /// Get symbols ranked with a 2x boost for those in the given focus files.
188    pub fn ranked_symbols_for_files(&self, focus_files: &[PathBuf]) -> Vec<&SymbolNode> {
189        let mut scored: Vec<(&SymbolNode, f64)> = self
190            .graph
191            .node_weights()
192            .map(|node| {
193                let multiplier = if focus_files.contains(&node.symbol.file) {
194                    2.0
195                } else {
196                    1.0
197                };
198                (node, node.rank * multiplier)
199            })
200            .collect();
201
202        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
203        scored.into_iter().map(|(node, _)| node).collect()
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210    use crate::parser::SymbolKind;
211
212    fn make_symbol(name: &str, file: &str) -> Symbol {
213        Symbol {
214            name: name.to_string(),
215            kind: SymbolKind::Function,
216            file: PathBuf::from(file),
217            line: 1,
218            signature: format!("fn {name}()"),
219            token_cost: 5,
220        }
221    }
222
223    fn make_ref(from: &str, to: &str) -> Reference {
224        Reference {
225            from_file: PathBuf::from("test.rs"),
226            from_symbol: Some(from.to_string()),
227            to_name: to.to_string(),
228            line: 1,
229        }
230    }
231
232    #[test]
233    fn pagerank_linked_chain() {
234        // A -> B -> C: C should have highest rank (most "votes" flow to it)
235        let symbols = vec![
236            make_symbol("A", "a.rs"),
237            make_symbol("B", "b.rs"),
238            make_symbol("C", "c.rs"),
239        ];
240        let refs = vec![make_ref("A", "B"), make_ref("B", "C")];
241
242        let mut graph = SymbolGraph::build(symbols, refs);
243        graph.compute_pagerank();
244        let ranked = graph.ranked_symbols();
245
246        assert_eq!(ranked.len(), 3);
247        // C (end of chain, receives most links) should rank highest
248        assert_eq!(ranked[0].symbol.name, "C");
249        assert!(ranked[0].rank > ranked[2].rank);
250    }
251
252    #[test]
253    fn disconnected_nodes_get_base_rank() {
254        let symbols = vec![make_symbol("X", "x.rs"), make_symbol("Y", "y.rs")];
255        let refs: Vec<Reference> = vec![];
256
257        let mut graph = SymbolGraph::build(symbols, refs);
258        graph.compute_pagerank();
259        let ranked = graph.ranked_symbols();
260
261        // Both should have equal base rank = (1-d)/N
262        assert_eq!(ranked.len(), 2);
263        let diff = (ranked[0].rank - ranked[1].rank).abs();
264        assert!(diff < 1e-10, "disconnected nodes should have equal rank");
265    }
266
267    #[test]
268    fn focus_files_boost_ranking() {
269        let symbols = vec![make_symbol("A", "a.rs"), make_symbol("B", "b.rs")];
270        let refs: Vec<Reference> = vec![];
271
272        let mut graph = SymbolGraph::build(symbols, refs);
273        graph.compute_pagerank();
274
275        // Without focus, both should be equal
276        let ranked = graph.ranked_symbols();
277        assert!((ranked[0].rank - ranked[1].rank).abs() < 1e-10);
278
279        // With focus on b.rs, B should rank higher
280        let focus = vec![PathBuf::from("b.rs")];
281        let boosted = graph.ranked_symbols_for_files(&focus);
282        assert_eq!(boosted[0].symbol.name, "B");
283    }
284
285    #[test]
286    fn empty_graph() {
287        let mut graph = SymbolGraph::build(vec![], vec![]);
288        graph.compute_pagerank();
289        let ranked = graph.ranked_symbols();
290        assert!(ranked.is_empty());
291    }
292}