1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
use std::collections::{HashMap, HashSet};
use std::path::{Path, PathBuf};
use crate::repo_map::parser::Symbol;
/// PageRank-inspired symbol ranker.
///
/// Ranks symbols by how many other symbols reference them (incoming edges),
/// weighted by file recency and conversation relevance.
pub struct SymbolRanker {
/// Number of references to each symbol name.
reference_counts: HashMap<String, u32>,
/// Files mentioned in the conversation (higher priority). Stored as a set for O(1) lookup.
conversation_files: HashSet<PathBuf>,
}
impl Default for SymbolRanker {
fn default() -> Self {
Self::new()
}
}
impl SymbolRanker {
pub fn new() -> Self {
Self {
reference_counts: HashMap::new(),
conversation_files: HashSet::new(),
}
}
/// Build reference counts by scanning all file contents for symbol name mentions.
///
/// Uses Aho-Corasick-style multi-pattern matching for O(N_total_chars + M_matches)
/// instead of O(N_files * M_symbols * L_file) brute force.
pub fn build_references(&mut self, symbols: &HashMap<PathBuf, Vec<Symbol>>, root: &Path) {
self.reference_counts.clear();
// Collect all symbol names (skip very short ones to avoid false positives)
let all_names: Vec<String> = symbols
.values()
.flat_map(|syms| syms.iter().map(|s| s.name.clone()).filter(|n| n.len() >= 3))
.collect();
if all_names.is_empty() {
return;
}
// Build a single regex that matches any symbol name at word boundaries.
// \b anchors prevent short names like `set` matching inside `reset` or `offset`.
// This gives us O(text_length) matching instead of O(text_length * num_symbols).
let pattern = format!(
r"\b(?:{})\b",
all_names
.iter()
.map(|name| regex::escape(name))
.collect::<Vec<_>>()
.join("|")
);
let re = match regex::Regex::new(&pattern) {
Ok(re) => re,
Err(_) => {
// Fallback: regex too large, use brute force
self.build_references_brute(symbols, root);
return;
}
};
// Scan files for references using single-pass regex
for path in symbols.keys() {
let content = match std::fs::read_to_string(path) {
Ok(c) => c,
Err(_) => continue,
};
for cap in re.find_iter(&content) {
*self
.reference_counts
.entry(cap.as_str().to_string())
.or_default() += 1;
}
}
// Subtract 1 for each symbol's own definition.
// Deduplicate names first: a symbol defined in N files must only be
// decremented once, not N times (which would under-count references).
let unique_names: HashSet<&String> = all_names.iter().collect();
for name in &unique_names {
if let Some(count) = self.reference_counts.get_mut(*name) {
if *count > 1 {
*count -= 1;
} else {
// Only the definition itself — not referenced elsewhere
self.reference_counts.remove(*name);
}
}
}
tracing::debug!(
"SymbolRanker: built reference counts for {} symbols under {}",
self.reference_counts.len(),
root.display(),
);
}
/// Brute-force fallback for when the regex becomes too large.
fn build_references_brute(&mut self, symbols: &HashMap<PathBuf, Vec<Symbol>>, root: &Path) {
let all_names: Vec<String> = symbols
.values()
.flat_map(|syms| syms.iter().map(|s| s.name.clone()))
.collect();
for path in symbols.keys() {
let content = match std::fs::read_to_string(path) {
Ok(c) => c,
Err(_) => continue,
};
for name in &all_names {
if name.len() < 3 {
continue;
}
let count = content.matches(name.as_str()).count() as u32;
if count > 1 {
*self.reference_counts.entry(name.clone()).or_default() += count - 1;
}
}
}
tracing::debug!(
"SymbolRanker: built reference counts (brute force) for {} symbols under {}",
self.reference_counts.len(),
root.display(),
);
}
/// Set files mentioned in the current conversation for priority boost.
pub fn set_conversation_files(&mut self, files: Vec<PathBuf>) {
self.conversation_files = files.into_iter().collect();
}
/// Score a file+symbol combination. Higher = more important.
pub fn score_file(&self, path: &Path, symbols: &[Symbol]) -> f64 {
let mut score: f64 = 0.0;
// Base: sum of reference counts for symbols in this file
for sym in symbols {
let refs = self.reference_counts.get(&sym.name).copied().unwrap_or(0);
score += refs as f64;
}
// Bonus: number of symbols (more symbols = more important file)
score += symbols.len() as f64 * 0.5;
// Conversation relevance boost (3x)
if self.conversation_files.contains(path) {
score *= 3.0;
}
score
}
/// Read-only access to reference counts (for disk cache serialization).
pub fn reference_counts(&self) -> &HashMap<String, u32> {
&self.reference_counts
}
/// Construct a ranker from pre-computed reference counts (for disk cache restore).
pub fn from_reference_counts(counts: HashMap<String, u32>) -> Self {
Self {
reference_counts: counts,
conversation_files: HashSet::new(),
}
}
/// Rank files by importance, return sorted (highest first).
pub fn rank_files<'a>(
&self,
symbols: &'a HashMap<PathBuf, Vec<Symbol>>,
) -> Vec<(&'a PathBuf, &'a Vec<Symbol>, f64)> {
let mut ranked: Vec<_> = symbols
.iter()
.map(|(path, syms)| (path, syms, self.score_file(path, syms)))
.collect();
ranked.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
ranked
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::repo_map::parser::{Symbol, SymbolKind};
fn sym(name: &str) -> Symbol {
Symbol {
name: name.to_string(),
kind: SymbolKind::Function,
line: 1,
signature: None,
}
}
/// Symbols with the same name defined in two different files must only
/// decrement the reference count once, not twice.
#[test]
fn test_dedup_prevents_double_decrement() {
let mut ranker = SymbolRanker::new();
// "foo" appears 3 times across all files (2 refs + 1 definition)
ranker.reference_counts.insert("foo".to_string(), 3);
// Simulate all_names having "foo" twice (defined in two files)
let all_names: Vec<String> = vec!["foo".to_string(), "foo".to_string()];
let unique_names: HashSet<&String> = all_names.iter().collect();
for name in &unique_names {
if let Some(count) = ranker.reference_counts.get_mut(*name) {
if *count > 1 {
*count -= 1;
} else {
ranker.reference_counts.remove(*name);
}
}
}
// Should be 2 (decremented once), not 1 (which double-decrement would produce)
assert_eq!(ranker.reference_counts.get("foo").copied(), Some(2));
}
/// A file whose symbols have higher reference counts should rank above one with fewer.
#[test]
fn test_basic_ranking_by_reference_count() {
let mut ranker = SymbolRanker::new();
ranker.reference_counts.insert("popular".to_string(), 10);
ranker.reference_counts.insert("rare".to_string(), 1);
let path_a = PathBuf::from("/fake/a.rs");
let path_b = PathBuf::from("/fake/b.rs");
let mut symbols: HashMap<PathBuf, Vec<Symbol>> = HashMap::new();
symbols.insert(path_a.clone(), vec![sym("popular")]);
symbols.insert(path_b.clone(), vec![sym("rare")]);
let ranked = ranker.rank_files(&symbols);
assert_eq!(
ranked[0].0, &path_a,
"higher-reference file should rank first"
);
}
/// Files in conversation_files must receive a 3× score multiplier.
#[test]
fn test_conversation_files_boost() {
let mut ranker = SymbolRanker::new();
ranker.reference_counts.insert("sym".to_string(), 4);
let boosted = PathBuf::from("/fake/boosted.rs");
let normal = PathBuf::from("/fake/normal.rs");
ranker.set_conversation_files(vec![boosted.clone()]);
let score_boosted = ranker.score_file(&boosted, &[sym("sym")]);
let score_normal = ranker.score_file(&normal, &[sym("sym")]);
assert!(score_boosted > score_normal);
// symbol bonus (0.5) is equal; only difference is 3× multiplier on boosted
let expected = score_normal * 3.0;
assert!(
(score_boosted - expected).abs() < 1e-9,
"expected {expected}, got {score_boosted}"
);
}
}