1use std::collections::{HashMap, HashSet, VecDeque};
2
3use super::graph_index::ProjectIndex;
4
5use super::neural::attention_learned::LearnedAttention;
6
7#[derive(Debug, Clone)]
8pub struct RelevanceScore {
9 pub path: String,
10 pub score: f64,
11 pub recommended_mode: &'static str,
12}
13
14pub fn compute_relevance(
15 index: &ProjectIndex,
16 task_files: &[String],
17 task_keywords: &[String],
18) -> Vec<RelevanceScore> {
19 let mut scores: HashMap<String, f64> = HashMap::new();
20
21 for f in task_files {
23 scores.insert(f.clone(), 1.0);
24 }
25
26 let adj = build_adjacency(index);
28 for seed in task_files {
29 let mut visited: HashSet<String> = HashSet::new();
30 let mut queue: VecDeque<(String, usize)> = VecDeque::new();
31 queue.push_back((seed.clone(), 0));
32 visited.insert(seed.clone());
33
34 while let Some((node, depth)) = queue.pop_front() {
35 if depth > 4 {
36 continue;
37 }
38 let decay = 1.0 / (1.0 + depth as f64).powi(2); let entry = scores.entry(node.clone()).or_insert(0.0);
40 *entry = entry.max(decay);
41
42 if let Some(neighbors) = adj.get(&node) {
43 for neighbor in neighbors {
44 if !visited.contains(neighbor) {
45 visited.insert(neighbor.clone());
46 queue.push_back((neighbor.clone(), depth + 1));
47 }
48 }
49 }
50 }
51 }
52
53 if !task_keywords.is_empty() {
55 let kw_lower: Vec<String> = task_keywords.iter().map(|k| k.to_lowercase()).collect();
56 for (file_path, file_entry) in &index.files {
57 let path_lower = file_path.to_lowercase();
58 let mut keyword_hits = 0;
59 for kw in &kw_lower {
60 if path_lower.contains(kw) {
61 keyword_hits += 1;
62 }
63 for export in &file_entry.exports {
64 if export.to_lowercase().contains(kw) {
65 keyword_hits += 1;
66 }
67 }
68 }
69 if keyword_hits > 0 {
70 let boost = (keyword_hits as f64 * 0.15).min(0.6);
71 let entry = scores.entry(file_path.clone()).or_insert(0.0);
72 *entry = (*entry + boost).min(1.0);
73 }
74 }
75 }
76
77 let mut result: Vec<RelevanceScore> = scores
78 .into_iter()
79 .map(|(path, score)| {
80 let mode = recommend_mode(score);
81 RelevanceScore {
82 path,
83 score,
84 recommended_mode: mode,
85 }
86 })
87 .collect();
88
89 result.sort_by(|a, b| {
90 b.score
91 .partial_cmp(&a.score)
92 .unwrap_or(std::cmp::Ordering::Equal)
93 });
94 result
95}
96
97fn recommend_mode(score: f64) -> &'static str {
98 if score >= 0.8 {
99 "full"
100 } else if score >= 0.5 {
101 "signatures"
102 } else if score >= 0.2 {
103 "map"
104 } else {
105 "reference"
106 }
107}
108
109fn build_adjacency(index: &ProjectIndex) -> HashMap<String, Vec<String>> {
110 let mut adj: HashMap<String, Vec<String>> = HashMap::new();
111 for edge in &index.edges {
112 adj.entry(edge.from.clone())
113 .or_default()
114 .push(edge.to.clone());
115 adj.entry(edge.to.clone())
116 .or_default()
117 .push(edge.from.clone());
118 }
119 adj
120}
121
122pub fn parse_task_hints(task_description: &str) -> (Vec<String>, Vec<String>) {
124 let mut files = Vec::new();
125 let mut keywords = Vec::new();
126
127 for word in task_description.split_whitespace() {
128 let clean = word.trim_matches(|c: char| {
129 !c.is_alphanumeric() && c != '.' && c != '/' && c != '_' && c != '-'
130 });
131 if clean.contains('.')
132 && (clean.contains('/')
133 || clean.ends_with(".rs")
134 || clean.ends_with(".ts")
135 || clean.ends_with(".py")
136 || clean.ends_with(".go")
137 || clean.ends_with(".js"))
138 {
139 files.push(clean.to_string());
140 } else if clean.len() >= 3 && !STOP_WORDS.contains(&clean.to_lowercase().as_str()) {
141 keywords.push(clean.to_string());
142 }
143 }
144
145 (files, keywords)
146}
147
148const STOP_WORDS: &[&str] = &[
149 "the", "and", "for", "that", "this", "with", "from", "have", "has", "was", "are", "been",
150 "not", "but", "all", "can", "had", "her", "one", "our", "out", "you", "its", "will", "each",
151 "make", "like", "fix", "add", "use", "get", "set", "run", "new", "old", "should", "would",
152 "could", "into", "also", "than", "them", "then", "when", "just", "only", "very", "some",
153 "more", "other", "nach", "und", "die", "der", "das", "ist", "ein", "eine", "nicht", "auf",
154 "mit",
155];
156
157pub fn information_bottleneck_filter(
168 content: &str,
169 task_keywords: &[String],
170 budget_ratio: f64,
171) -> String {
172 let lines: Vec<&str> = content.lines().collect();
173 if lines.is_empty() {
174 return String::new();
175 }
176
177 let n = lines.len();
178 let kw_lower: Vec<String> = task_keywords.iter().map(|k| k.to_lowercase()).collect();
179 let attention = LearnedAttention::with_defaults();
180
181 let mut global_token_freq: HashMap<&str, usize> = HashMap::new();
182 for line in &lines {
183 for token in line.split_whitespace() {
184 *global_token_freq.entry(token).or_insert(0) += 1;
185 }
186 }
187 let total_unique = global_token_freq.len().max(1) as f64;
188
189 let mut scored_lines: Vec<(usize, &str, f64)> = lines
190 .iter()
191 .enumerate()
192 .map(|(i, line)| {
193 let trimmed = line.trim();
194 if trimmed.is_empty() {
195 return (i, *line, 0.05);
196 }
197
198 let line_lower = trimmed.to_lowercase();
199 let keyword_hits: f64 = kw_lower
200 .iter()
201 .filter(|kw| line_lower.contains(kw.as_str()))
202 .count() as f64;
203
204 let structural = if is_error_handling(trimmed) {
205 1.5
206 } else if is_definition_line(trimmed) {
207 1.0
208 } else if is_control_flow(trimmed) {
209 0.5
210 } else if is_closing_brace(trimmed) {
211 0.15
212 } else {
213 0.3
214 };
215 let relevance = keyword_hits * 0.5 + structural;
216
217 let line_tokens: Vec<&str> = trimmed.split_whitespace().collect();
218 let unique_in_line = line_tokens.iter().collect::<HashSet<_>>().len() as f64;
219 let line_token_count = line_tokens.len().max(1) as f64;
220 let token_diversity = unique_in_line / line_token_count;
221
222 let avg_idf: f64 = if line_tokens.is_empty() {
223 0.0
224 } else {
225 line_tokens
226 .iter()
227 .map(|t| {
228 let freq = *global_token_freq.get(t).unwrap_or(&1) as f64;
229 (total_unique / freq).ln().max(0.0)
230 })
231 .sum::<f64>()
232 / line_token_count
233 };
234 let information = (token_diversity * 0.4 + (avg_idf.min(3.0) / 3.0) * 0.6).min(1.0);
235
236 let pos = i as f64 / n.max(1) as f64;
237 let attn_weight = attention.weight(pos);
238
239 let score = (relevance * 0.6 + 0.05)
240 * (information * 0.25 + 0.05)
241 * (attn_weight * 0.15 + 0.05);
242
243 (i, *line, score)
244 })
245 .collect();
246
247 let budget = ((n as f64) * budget_ratio).ceil() as usize;
248
249 scored_lines.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
250
251 scored_lines.truncate(budget);
252
253 let mut output_lines: Vec<&str> = Vec::with_capacity(budget + 1);
254
255 if !kw_lower.is_empty() {
256 output_lines.push(""); }
258
259 for (_, line, _) in &scored_lines {
260 output_lines.push(line);
261 }
262
263 if !kw_lower.is_empty() {
264 let summary = format!("[task: {}]", task_keywords.join(", "));
265 let mut result = summary;
266 result.push('\n');
267 result.push_str(&output_lines[1..].to_vec().join("\n"));
268 return result;
269 }
270
271 output_lines.join("\n")
272}
273
274fn is_error_handling(line: &str) -> bool {
275 line.starts_with("return Err(")
276 || line.starts_with("Err(")
277 || line.starts_with("bail!(")
278 || line.starts_with("anyhow::bail!")
279 || line.contains(".map_err(")
280 || line.contains("unwrap()")
281 || line.contains("expect(\"")
282 || line.starts_with("raise ")
283 || line.starts_with("throw ")
284 || line.starts_with("catch ")
285 || line.starts_with("except ")
286 || line.starts_with("try ")
287 || (line.contains("?;") && !line.starts_with("//"))
288 || line.starts_with("panic!(")
289 || line.contains("Error::")
290 || line.contains("error!")
291}
292
293pub fn adaptive_ib_budget(content: &str, base_ratio: f64) -> f64 {
297 let lines: Vec<&str> = content.lines().collect();
298 if lines.len() < 10 {
299 return 1.0;
300 }
301
302 let mut token_freq: HashMap<&str, usize> = HashMap::new();
303 let mut total_tokens = 0usize;
304 for line in &lines {
305 for token in line.split_whitespace() {
306 *token_freq.entry(token).or_insert(0) += 1;
307 total_tokens += 1;
308 }
309 }
310
311 if total_tokens == 0 {
312 return base_ratio;
313 }
314
315 let unique_ratio = token_freq.len() as f64 / total_tokens as f64;
316 let repetition_factor = 1.0 - unique_ratio;
317
318 (base_ratio * (1.0 - repetition_factor * 0.3)).clamp(0.2, 1.0)
319}
320
321fn is_definition_line(line: &str) -> bool {
322 let prefixes = [
323 "fn ",
324 "pub fn ",
325 "async fn ",
326 "pub async fn ",
327 "struct ",
328 "pub struct ",
329 "enum ",
330 "pub enum ",
331 "trait ",
332 "pub trait ",
333 "impl ",
334 "type ",
335 "pub type ",
336 "const ",
337 "pub const ",
338 "static ",
339 "pub static ",
340 "class ",
341 "export class ",
342 "interface ",
343 "export interface ",
344 "function ",
345 "export function ",
346 "async function ",
347 "def ",
348 "async def ",
349 "func ",
350 ];
351 prefixes
352 .iter()
353 .any(|p| line.starts_with(p) || line.trim_start().starts_with(p))
354}
355
356fn is_control_flow(line: &str) -> bool {
357 let trimmed = line.trim();
358 trimmed.starts_with("if ")
359 || trimmed.starts_with("else ")
360 || trimmed.starts_with("match ")
361 || trimmed.starts_with("for ")
362 || trimmed.starts_with("while ")
363 || trimmed.starts_with("return ")
364 || trimmed.starts_with("break")
365 || trimmed.starts_with("continue")
366 || trimmed.starts_with("yield")
367 || trimmed.starts_with("await ")
368}
369
370fn is_closing_brace(line: &str) -> bool {
371 let trimmed = line.trim();
372 trimmed == "}" || trimmed == "};" || trimmed == "})" || trimmed == "});"
373}
374
375#[cfg(test)]
376mod tests {
377 use super::*;
378
379 #[test]
380 fn parse_task_finds_files_and_keywords() {
381 let (files, keywords) =
382 parse_task_hints("Fix the authentication bug in src/auth.rs and update tests");
383 assert!(files.iter().any(|f| f.contains("auth.rs")));
384 assert!(keywords
385 .iter()
386 .any(|k| k.to_lowercase().contains("authentication")));
387 }
388
389 #[test]
390 fn recommend_mode_by_score() {
391 assert_eq!(recommend_mode(1.0), "full");
392 assert_eq!(recommend_mode(0.6), "signatures");
393 assert_eq!(recommend_mode(0.3), "map");
394 assert_eq!(recommend_mode(0.1), "reference");
395 }
396
397 #[test]
398 fn info_bottleneck_preserves_definitions() {
399 let content = "fn main() {\n let x = 42;\n // boring comment\n println!(x);\n}\n";
400 let result = information_bottleneck_filter(content, &["main".to_string()], 0.6);
401 assert!(result.contains("fn main"), "definitions must be preserved");
402 assert!(result.contains("[task: main]"), "should have task summary");
403 }
404
405 #[test]
406 fn info_bottleneck_error_handling_priority() {
407 let content = "fn validate() {\n let data = parse()?;\n return Err(\"invalid\");\n let x = 1;\n let y = 2;\n}\n";
408 let result = information_bottleneck_filter(content, &["validate".to_string()], 0.5);
409 assert!(
410 result.contains("return Err"),
411 "error handling should survive filtering"
412 );
413 }
414
415 #[test]
416 fn info_bottleneck_score_sorted() {
417 let content = "fn important() {\n let x = 1;\n let y = 2;\n let z = 3;\n}\n}\n";
418 let result = information_bottleneck_filter(content, &[], 0.6);
419 let lines: Vec<&str> = result.lines().collect();
420 let def_pos = lines.iter().position(|l| l.contains("fn important"));
421 let brace_pos = lines.iter().position(|l| l.trim() == "}");
422 if let (Some(d), Some(b)) = (def_pos, brace_pos) {
423 assert!(
424 d < b,
425 "definitions should appear before closing braces in score-sorted output"
426 );
427 }
428 }
429
430 #[test]
431 fn adaptive_budget_reduces_for_repetitive() {
432 let repetitive = "let x = 1;\n".repeat(50);
433 let diverse = (0..50)
434 .map(|i| format!("let var_{i} = func_{i}(arg_{i});"))
435 .collect::<Vec<_>>()
436 .join("\n");
437 let budget_rep = super::adaptive_ib_budget(&repetitive, 0.7);
438 let budget_div = super::adaptive_ib_budget(&diverse, 0.7);
439 assert!(
440 budget_rep < budget_div,
441 "repetitive content should get lower budget"
442 );
443 }
444}