1use std::collections::{HashMap, HashSet};
2
3use super::graph_index::ProjectIndex;
4
5use super::neural::attention_learned::LearnedAttention;
6
7#[derive(Debug, Clone)]
8pub struct RelevanceScore {
9 pub path: String,
10 pub score: f64,
11 pub recommended_mode: &'static str,
12}
13
14pub fn compute_relevance(
15 index: &ProjectIndex,
16 task_files: &[String],
17 task_keywords: &[String],
18) -> Vec<RelevanceScore> {
19 let adj = build_adjacency_resolved(index);
20 let all_nodes: Vec<String> = index.files.keys().cloned().collect();
21 if all_nodes.is_empty() {
22 return Vec::new();
23 }
24
25 let node_idx: HashMap<&str, usize> = all_nodes
26 .iter()
27 .enumerate()
28 .map(|(i, n)| (n.as_str(), i))
29 .collect();
30 let n = all_nodes.len();
31
32 let degrees: Vec<f64> = all_nodes
34 .iter()
35 .map(|node| {
36 adj.get(node)
37 .map_or(0.0, |neigh| neigh.len() as f64)
38 .max(1.0)
39 })
40 .collect();
41
42 let mut heat: Vec<f64> = vec![0.0; n];
44 for f in task_files {
45 if let Some(&idx) = node_idx.get(f.as_str()) {
46 heat[idx] = 1.0;
47 }
48 }
49
50 let alpha = 0.5;
53 let iterations = 4;
54 for _ in 0..iterations {
55 let mut new_heat = vec![0.0; n];
56 for (i, node) in all_nodes.iter().enumerate() {
57 let self_term = (1.0 - alpha) * heat[i];
58 let mut neighbor_sum = 0.0;
59 if let Some(neighbors) = adj.get(node) {
60 for neighbor in neighbors {
61 if let Some(&j) = node_idx.get(neighbor.as_str()) {
62 neighbor_sum += heat[j] / degrees[j];
63 }
64 }
65 }
66 new_heat[i] = self_term + alpha * neighbor_sum;
67 }
68 heat = new_heat;
69 }
70
71 let mut pagerank = vec![1.0 / n as f64; n];
73 let damping = 0.85;
74 for _ in 0..8 {
75 let mut new_pr = vec![(1.0 - damping) / n as f64; n];
76 for (i, node) in all_nodes.iter().enumerate() {
77 if let Some(neighbors) = adj.get(node) {
78 let out_deg = neighbors.len().max(1) as f64;
79 for neighbor in neighbors {
80 if let Some(&j) = node_idx.get(neighbor.as_str()) {
81 new_pr[j] += damping * pagerank[i] / out_deg;
82 }
83 }
84 }
85 }
86 pagerank = new_pr;
87 }
88
89 let mut scores: HashMap<String, f64> = HashMap::new();
91 let heat_max = heat.iter().cloned().fold(0.0_f64, f64::max).max(1e-10);
92 let pr_max = pagerank.iter().cloned().fold(0.0_f64, f64::max).max(1e-10);
93
94 for (i, node) in all_nodes.iter().enumerate() {
95 let h = heat[i] / heat_max;
96 let pr = pagerank[i] / pr_max;
97 let combined = h * 0.8 + pr * 0.2;
98 if combined > 0.01 {
99 scores.insert(node.clone(), combined);
100 }
101 }
102
103 if !task_keywords.is_empty() {
105 let kw_lower: Vec<String> = task_keywords.iter().map(|k| k.to_lowercase()).collect();
106 for (file_path, file_entry) in &index.files {
107 let path_lower = file_path.to_lowercase();
108 let mut keyword_hits = 0;
109 for kw in &kw_lower {
110 if path_lower.contains(kw) {
111 keyword_hits += 1;
112 }
113 for export in &file_entry.exports {
114 if export.to_lowercase().contains(kw) {
115 keyword_hits += 1;
116 }
117 }
118 }
119 if keyword_hits > 0 {
120 let boost = (keyword_hits as f64 * 0.15).min(0.6);
121 let entry = scores.entry(file_path.clone()).or_insert(0.0);
122 *entry = (*entry + boost).min(1.0);
123 }
124 }
125 }
126
127 let mut result: Vec<RelevanceScore> = scores
128 .into_iter()
129 .map(|(path, score)| {
130 let mode = recommend_mode(score);
131 RelevanceScore {
132 path,
133 score,
134 recommended_mode: mode,
135 }
136 })
137 .collect();
138
139 result.sort_by(|a, b| {
140 b.score
141 .partial_cmp(&a.score)
142 .unwrap_or(std::cmp::Ordering::Equal)
143 });
144 result
145}
146
147fn recommend_mode(score: f64) -> &'static str {
148 if score >= 0.8 {
149 "full"
150 } else if score >= 0.5 {
151 "signatures"
152 } else if score >= 0.2 {
153 "map"
154 } else {
155 "reference"
156 }
157}
158
159fn build_adjacency_resolved(index: &ProjectIndex) -> HashMap<String, Vec<String>> {
164 let module_to_file = build_module_map(index);
165 let mut adj: HashMap<String, Vec<String>> = HashMap::new();
166
167 for edge in &index.edges {
168 let from = &edge.from;
169 let to_resolved = module_to_file
170 .get(&edge.to)
171 .cloned()
172 .unwrap_or_else(|| edge.to.clone());
173
174 if index.files.contains_key(from) && index.files.contains_key(&to_resolved) {
175 adj.entry(from.clone())
176 .or_default()
177 .push(to_resolved.clone());
178 adj.entry(to_resolved).or_default().push(from.clone());
179 }
180 }
181 adj
182}
183
184fn build_module_map(index: &ProjectIndex) -> HashMap<String, String> {
187 let file_paths: Vec<&str> = index.files.keys().map(|s| s.as_str()).collect();
188 let mut mapping: HashMap<String, String> = HashMap::new();
189
190 let edge_targets: HashSet<String> = index.edges.iter().map(|e| e.to.clone()).collect();
191
192 for target in &edge_targets {
193 if index.files.contains_key(target) {
194 mapping.insert(target.clone(), target.clone());
195 continue;
196 }
197
198 if let Some(resolved) = resolve_module_to_file(target, &file_paths) {
199 mapping.insert(target.clone(), resolved);
200 }
201 }
202
203 mapping
204}
205
206fn resolve_module_to_file(module_path: &str, file_paths: &[&str]) -> Option<String> {
207 let cleaned = module_path
208 .trim_start_matches("crate::")
209 .trim_start_matches("super::");
210
211 let parts: Vec<&str> = cleaned.split("::").collect();
213
214 for end in (1..=parts.len()).rev() {
216 let candidate = parts[..end].join("/");
217
218 for fp in file_paths {
220 let fp_normalized = fp
221 .trim_start_matches("rust/src/")
222 .trim_start_matches("src/");
223
224 if fp_normalized == format!("{candidate}.rs")
225 || fp_normalized == format!("{candidate}/mod.rs")
226 || fp.ends_with(&format!("/{candidate}.rs"))
227 || fp.ends_with(&format!("/{candidate}/mod.rs"))
228 {
229 return Some(fp.to_string());
230 }
231 }
232 }
233
234 if let Some(last) = parts.last() {
236 let stem = format!("{last}.rs");
237 for fp in file_paths {
238 if fp.ends_with(&stem) {
239 return Some(fp.to_string());
240 }
241 }
242 }
243
244 None
245}
246
247pub fn parse_task_hints(task_description: &str) -> (Vec<String>, Vec<String>) {
249 let mut files = Vec::new();
250 let mut keywords = Vec::new();
251
252 for word in task_description.split_whitespace() {
253 let clean = word.trim_matches(|c: char| {
254 !c.is_alphanumeric() && c != '.' && c != '/' && c != '_' && c != '-'
255 });
256 if clean.contains('.')
257 && (clean.contains('/')
258 || clean.ends_with(".rs")
259 || clean.ends_with(".ts")
260 || clean.ends_with(".py")
261 || clean.ends_with(".go")
262 || clean.ends_with(".js"))
263 {
264 files.push(clean.to_string());
265 } else if clean.len() >= 3 && !STOP_WORDS.contains(&clean.to_lowercase().as_str()) {
266 keywords.push(clean.to_string());
267 }
268 }
269
270 (files, keywords)
271}
272
273const STOP_WORDS: &[&str] = &[
274 "the", "and", "for", "that", "this", "with", "from", "have", "has", "was", "are", "been",
275 "not", "but", "all", "can", "had", "her", "one", "our", "out", "you", "its", "will", "each",
276 "make", "like", "fix", "add", "use", "get", "set", "run", "new", "old", "should", "would",
277 "could", "into", "also", "than", "them", "then", "when", "just", "only", "very", "some",
278 "more", "other", "nach", "und", "die", "der", "das", "ist", "ein", "eine", "nicht", "auf",
279 "mit",
280];
281
282pub fn information_bottleneck_filter(
293 content: &str,
294 task_keywords: &[String],
295 budget_ratio: f64,
296) -> String {
297 let lines: Vec<&str> = content.lines().collect();
298 if lines.is_empty() {
299 return String::new();
300 }
301
302 let n = lines.len();
303 let kw_lower: Vec<String> = task_keywords.iter().map(|k| k.to_lowercase()).collect();
304 let attention = LearnedAttention::with_defaults();
305
306 let mut global_token_freq: HashMap<&str, usize> = HashMap::new();
307 for line in &lines {
308 for token in line.split_whitespace() {
309 *global_token_freq.entry(token).or_insert(0) += 1;
310 }
311 }
312 let total_unique = global_token_freq.len().max(1) as f64;
313 let total_lines = n.max(1) as f64;
314
315 let task_token_set: HashSet<String> = kw_lower
316 .iter()
317 .flat_map(|kw| kw.split(|c: char| !c.is_alphanumeric()).map(String::from))
318 .filter(|t| t.len() >= 2)
319 .collect();
320
321 let effective_ratio = if !task_token_set.is_empty() {
322 adaptive_ib_budget(content, budget_ratio)
323 } else {
324 budget_ratio
325 };
326
327 let mut scored_lines: Vec<(usize, &str, f64)> = lines
328 .iter()
329 .enumerate()
330 .map(|(i, line)| {
331 let trimmed = line.trim();
332 if trimmed.is_empty() {
333 return (i, *line, 0.05);
334 }
335
336 let line_lower = trimmed.to_lowercase();
337 let line_tokens: Vec<&str> = trimmed.split_whitespace().collect();
338 let line_token_count = line_tokens.len().max(1) as f64;
339
340 let mi_score = if task_token_set.is_empty() {
341 0.0
342 } else {
343 let line_token_set: HashSet<String> =
344 line_tokens.iter().map(|t| t.to_lowercase()).collect();
345 let overlap: f64 = line_token_set
346 .iter()
347 .filter(|t| task_token_set.iter().any(|kw| t.contains(kw.as_str())))
348 .map(|t| {
349 let freq = *global_token_freq.get(t.as_str()).unwrap_or(&1) as f64;
350 (total_lines / freq).ln().max(0.1)
351 })
352 .sum();
353 overlap / line_token_count
354 };
355
356 let keyword_hits: f64 = kw_lower
357 .iter()
358 .filter(|kw| line_lower.contains(kw.as_str()))
359 .count() as f64;
360
361 let structural = if is_error_handling(trimmed) {
362 1.5
363 } else if is_definition_line(trimmed) {
364 1.0
365 } else if is_control_flow(trimmed) {
366 0.5
367 } else if is_closing_brace(trimmed) {
368 0.15
369 } else {
370 0.3
371 };
372 let relevance = mi_score * 0.4 + keyword_hits * 0.3 + structural;
373
374 let unique_in_line = line_tokens.iter().collect::<HashSet<_>>().len() as f64;
375 let token_diversity = unique_in_line / line_token_count;
376
377 let avg_idf: f64 = if line_tokens.is_empty() {
378 0.0
379 } else {
380 line_tokens
381 .iter()
382 .map(|t| {
383 let freq = *global_token_freq.get(t).unwrap_or(&1) as f64;
384 (total_unique / freq).ln().max(0.0)
385 })
386 .sum::<f64>()
387 / line_token_count
388 };
389 let information = (token_diversity * 0.4 + (avg_idf.min(3.0) / 3.0) * 0.6).min(1.0);
390
391 let pos = i as f64 / n.max(1) as f64;
392 let attn_weight = attention.weight(pos);
393
394 let score = (relevance * 0.6 + 0.05)
395 * (information * 0.25 + 0.05)
396 * (attn_weight * 0.15 + 0.05);
397
398 (i, *line, score)
399 })
400 .collect();
401
402 let budget = ((n as f64) * effective_ratio).ceil() as usize;
403
404 scored_lines.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
405
406 let selected = mmr_select(&scored_lines, budget, 0.3);
407
408 let mut output_lines: Vec<&str> = Vec::with_capacity(budget + 1);
409
410 if !kw_lower.is_empty() {
411 output_lines.push("");
412 }
413
414 for (_, line, _) in &selected {
415 output_lines.push(line);
416 }
417
418 if !kw_lower.is_empty() {
419 let summary = format!("[task: {}]", task_keywords.join(", "));
420 let mut result = summary;
421 result.push('\n');
422 result.push_str(&output_lines[1..].to_vec().join("\n"));
423 return result;
424 }
425
426 output_lines.join("\n")
427}
428
429fn mmr_select<'a>(
434 candidates: &[(usize, &'a str, f64)],
435 budget: usize,
436 lambda: f64,
437) -> Vec<(usize, &'a str, f64)> {
438 if candidates.is_empty() || budget == 0 {
439 return Vec::new();
440 }
441
442 let mut selected: Vec<(usize, &'a str, f64)> = Vec::with_capacity(budget);
443 let mut remaining: Vec<(usize, &'a str, f64)> = candidates.to_vec();
444
445 selected.push(remaining.remove(0));
447
448 while selected.len() < budget && !remaining.is_empty() {
449 let mut best_idx = 0;
450 let mut best_mmr = f64::NEG_INFINITY;
451
452 for (i, &(_, cand_line, cand_score)) in remaining.iter().enumerate() {
453 let cand_tokens: HashSet<&str> = cand_line.split_whitespace().collect();
454 if cand_tokens.is_empty() {
455 if cand_score > best_mmr {
456 best_mmr = cand_score;
457 best_idx = i;
458 }
459 continue;
460 }
461
462 let max_sim = selected
463 .iter()
464 .map(|&(_, sel_line, _)| {
465 let sel_tokens: HashSet<&str> = sel_line.split_whitespace().collect();
466 if sel_tokens.is_empty() {
467 return 0.0;
468 }
469 let inter = cand_tokens.intersection(&sel_tokens).count();
470 let union = cand_tokens.union(&sel_tokens).count();
471 if union == 0 {
472 0.0
473 } else {
474 inter as f64 / union as f64
475 }
476 })
477 .fold(0.0_f64, f64::max);
478
479 let mmr = cand_score - lambda * max_sim;
480 if mmr > best_mmr {
481 best_mmr = mmr;
482 best_idx = i;
483 }
484 }
485
486 selected.push(remaining.remove(best_idx));
487 }
488
489 selected
490}
491
492fn is_error_handling(line: &str) -> bool {
493 line.starts_with("return Err(")
494 || line.starts_with("Err(")
495 || line.starts_with("bail!(")
496 || line.starts_with("anyhow::bail!")
497 || line.contains(".map_err(")
498 || line.contains("unwrap()")
499 || line.contains("expect(\"")
500 || line.starts_with("raise ")
501 || line.starts_with("throw ")
502 || line.starts_with("catch ")
503 || line.starts_with("except ")
504 || line.starts_with("try ")
505 || (line.contains("?;") && !line.starts_with("//"))
506 || line.starts_with("panic!(")
507 || line.contains("Error::")
508 || line.contains("error!")
509}
510
511pub fn adaptive_ib_budget(content: &str, base_ratio: f64) -> f64 {
515 let lines: Vec<&str> = content.lines().collect();
516 if lines.len() < 10 {
517 return 1.0;
518 }
519
520 let mut token_freq: HashMap<&str, usize> = HashMap::new();
521 let mut total_tokens = 0usize;
522 for line in &lines {
523 for token in line.split_whitespace() {
524 *token_freq.entry(token).or_insert(0) += 1;
525 total_tokens += 1;
526 }
527 }
528
529 if total_tokens == 0 {
530 return base_ratio;
531 }
532
533 let unique_ratio = token_freq.len() as f64 / total_tokens as f64;
534 let repetition_factor = 1.0 - unique_ratio;
535
536 (base_ratio * (1.0 - repetition_factor * 0.3)).clamp(0.2, 1.0)
537}
538
539fn is_definition_line(line: &str) -> bool {
540 let prefixes = [
541 "fn ",
542 "pub fn ",
543 "async fn ",
544 "pub async fn ",
545 "struct ",
546 "pub struct ",
547 "enum ",
548 "pub enum ",
549 "trait ",
550 "pub trait ",
551 "impl ",
552 "type ",
553 "pub type ",
554 "const ",
555 "pub const ",
556 "static ",
557 "pub static ",
558 "class ",
559 "export class ",
560 "interface ",
561 "export interface ",
562 "function ",
563 "export function ",
564 "async function ",
565 "def ",
566 "async def ",
567 "func ",
568 ];
569 prefixes
570 .iter()
571 .any(|p| line.starts_with(p) || line.trim_start().starts_with(p))
572}
573
574fn is_control_flow(line: &str) -> bool {
575 let trimmed = line.trim();
576 trimmed.starts_with("if ")
577 || trimmed.starts_with("else ")
578 || trimmed.starts_with("match ")
579 || trimmed.starts_with("for ")
580 || trimmed.starts_with("while ")
581 || trimmed.starts_with("return ")
582 || trimmed.starts_with("break")
583 || trimmed.starts_with("continue")
584 || trimmed.starts_with("yield")
585 || trimmed.starts_with("await ")
586}
587
588fn is_closing_brace(line: &str) -> bool {
589 let trimmed = line.trim();
590 trimmed == "}" || trimmed == "};" || trimmed == "})" || trimmed == "});"
591}
592
593#[cfg(test)]
594mod tests {
595 use super::*;
596
597 #[test]
598 fn parse_task_finds_files_and_keywords() {
599 let (files, keywords) =
600 parse_task_hints("Fix the authentication bug in src/auth.rs and update tests");
601 assert!(files.iter().any(|f| f.contains("auth.rs")));
602 assert!(keywords
603 .iter()
604 .any(|k| k.to_lowercase().contains("authentication")));
605 }
606
607 #[test]
608 fn recommend_mode_by_score() {
609 assert_eq!(recommend_mode(1.0), "full");
610 assert_eq!(recommend_mode(0.6), "signatures");
611 assert_eq!(recommend_mode(0.3), "map");
612 assert_eq!(recommend_mode(0.1), "reference");
613 }
614
615 #[test]
616 fn info_bottleneck_preserves_definitions() {
617 let content = "fn main() {\n let x = 42;\n // boring comment\n println!(x);\n}\n";
618 let result = information_bottleneck_filter(content, &["main".to_string()], 0.6);
619 assert!(result.contains("fn main"), "definitions must be preserved");
620 assert!(result.contains("[task: main]"), "should have task summary");
621 }
622
623 #[test]
624 fn info_bottleneck_error_handling_priority() {
625 let content = "fn validate() {\n let data = parse()?;\n return Err(\"invalid\");\n let x = 1;\n let y = 2;\n}\n";
626 let result = information_bottleneck_filter(content, &["validate".to_string()], 0.5);
627 assert!(
628 result.contains("return Err"),
629 "error handling should survive filtering"
630 );
631 }
632
633 #[test]
634 fn info_bottleneck_score_sorted() {
635 let content = "fn important() {\n let x = 1;\n let y = 2;\n let z = 3;\n}\n}\n";
636 let result = information_bottleneck_filter(content, &[], 0.6);
637 let lines: Vec<&str> = result.lines().collect();
638 let def_pos = lines.iter().position(|l| l.contains("fn important"));
639 let brace_pos = lines.iter().position(|l| l.trim() == "}");
640 if let (Some(d), Some(b)) = (def_pos, brace_pos) {
641 assert!(
642 d < b,
643 "definitions should appear before closing braces in score-sorted output"
644 );
645 }
646 }
647
648 #[test]
649 fn adaptive_budget_reduces_for_repetitive() {
650 let repetitive = "let x = 1;\n".repeat(50);
651 let diverse = (0..50)
652 .map(|i| format!("let var_{i} = func_{i}(arg_{i});"))
653 .collect::<Vec<_>>()
654 .join("\n");
655 let budget_rep = super::adaptive_ib_budget(&repetitive, 0.7);
656 let budget_div = super::adaptive_ib_budget(&diverse, 0.7);
657 assert!(
658 budget_rep < budget_div,
659 "repetitive content should get lower budget"
660 );
661 }
662}