1use std::collections::{HashMap, HashSet};
2
3use super::graph_index::ProjectIndex;
4
5use super::neural::attention_learned::LearnedAttention;
6
7#[derive(Debug, Clone)]
8pub struct RelevanceScore {
9 pub path: String,
10 pub score: f64,
11 pub recommended_mode: &'static str,
12}
13
14pub fn compute_relevance(
15 index: &ProjectIndex,
16 task_files: &[String],
17 task_keywords: &[String],
18) -> Vec<RelevanceScore> {
19 let adj = build_adjacency_resolved(index);
20 let all_nodes: Vec<String> = index.files.keys().cloned().collect();
21 if all_nodes.is_empty() {
22 return Vec::new();
23 }
24
25 let node_idx: HashMap<&str, usize> = all_nodes
26 .iter()
27 .enumerate()
28 .map(|(i, n)| (n.as_str(), i))
29 .collect();
30 let n = all_nodes.len();
31
32 let degrees: Vec<f64> = all_nodes
34 .iter()
35 .map(|node| {
36 adj.get(node)
37 .map_or(0.0, |neigh| neigh.len() as f64)
38 .max(1.0)
39 })
40 .collect();
41
42 let mut heat: Vec<f64> = vec![0.0; n];
44 for f in task_files {
45 if let Some(&idx) = node_idx.get(f.as_str()) {
46 heat[idx] = 1.0;
47 }
48 }
49
50 let alpha = 0.5;
53 let iterations = 4;
54 for _ in 0..iterations {
55 let mut new_heat = vec![0.0; n];
56 for (i, node) in all_nodes.iter().enumerate() {
57 let self_term = (1.0 - alpha) * heat[i];
58 let mut neighbor_sum = 0.0;
59 if let Some(neighbors) = adj.get(node) {
60 for neighbor in neighbors {
61 if let Some(&j) = node_idx.get(neighbor.as_str()) {
62 neighbor_sum += heat[j] / degrees[j];
63 }
64 }
65 }
66 new_heat[i] = self_term + alpha * neighbor_sum;
67 }
68 heat = new_heat;
69 }
70
71 let mut pagerank = vec![1.0 / n as f64; n];
73 let damping = 0.85;
74 for _ in 0..8 {
75 let mut new_pr = vec![(1.0 - damping) / n as f64; n];
76 for (i, node) in all_nodes.iter().enumerate() {
77 if let Some(neighbors) = adj.get(node) {
78 let out_deg = neighbors.len().max(1) as f64;
79 for neighbor in neighbors {
80 if let Some(&j) = node_idx.get(neighbor.as_str()) {
81 new_pr[j] += damping * pagerank[i] / out_deg;
82 }
83 }
84 }
85 }
86 pagerank = new_pr;
87 }
88
89 let mut scores: HashMap<String, f64> = HashMap::new();
91 let heat_max = heat.iter().cloned().fold(0.0_f64, f64::max).max(1e-10);
92 let pr_max = pagerank.iter().cloned().fold(0.0_f64, f64::max).max(1e-10);
93
94 for (i, node) in all_nodes.iter().enumerate() {
95 let h = heat[i] / heat_max;
96 let pr = pagerank[i] / pr_max;
97 let combined = h * 0.8 + pr * 0.2;
98 if combined > 0.01 {
99 scores.insert(node.clone(), combined);
100 }
101 }
102
103 if !task_keywords.is_empty() {
105 let kw_lower: Vec<String> = task_keywords.iter().map(|k| k.to_lowercase()).collect();
106 for (file_path, file_entry) in &index.files {
107 let path_lower = file_path.to_lowercase();
108 let mut keyword_hits = 0;
109 for kw in &kw_lower {
110 if path_lower.contains(kw) {
111 keyword_hits += 1;
112 }
113 for export in &file_entry.exports {
114 if export.to_lowercase().contains(kw) {
115 keyword_hits += 1;
116 }
117 }
118 }
119 if keyword_hits > 0 {
120 let boost = (keyword_hits as f64 * 0.15).min(0.6);
121 let entry = scores.entry(file_path.clone()).or_insert(0.0);
122 *entry = (*entry + boost).min(1.0);
123 }
124 }
125 }
126
127 let mut result: Vec<RelevanceScore> = scores
128 .into_iter()
129 .map(|(path, score)| {
130 let mode = recommend_mode(score);
131 RelevanceScore {
132 path,
133 score,
134 recommended_mode: mode,
135 }
136 })
137 .collect();
138
139 result.sort_by(|a, b| {
140 b.score
141 .partial_cmp(&a.score)
142 .unwrap_or(std::cmp::Ordering::Equal)
143 });
144 result
145}
146
147fn recommend_mode(score: f64) -> &'static str {
148 if score >= 0.8 {
149 "full"
150 } else if score >= 0.5 {
151 "signatures"
152 } else if score >= 0.2 {
153 "map"
154 } else {
155 "reference"
156 }
157}
158
159fn build_adjacency_resolved(index: &ProjectIndex) -> HashMap<String, Vec<String>> {
164 let module_to_file = build_module_map(index);
165 let mut adj: HashMap<String, Vec<String>> = HashMap::new();
166
167 for edge in &index.edges {
168 let from = &edge.from;
169 let to_resolved = module_to_file
170 .get(&edge.to)
171 .cloned()
172 .unwrap_or_else(|| edge.to.clone());
173
174 if index.files.contains_key(from) && index.files.contains_key(&to_resolved) {
175 adj.entry(from.clone())
176 .or_default()
177 .push(to_resolved.clone());
178 adj.entry(to_resolved).or_default().push(from.clone());
179 }
180 }
181 adj
182}
183
184fn build_module_map(index: &ProjectIndex) -> HashMap<String, String> {
187 let file_paths: Vec<&str> = index.files.keys().map(|s| s.as_str()).collect();
188 let mut mapping: HashMap<String, String> = HashMap::new();
189
190 let edge_targets: HashSet<String> = index.edges.iter().map(|e| e.to.clone()).collect();
191
192 for target in &edge_targets {
193 if index.files.contains_key(target) {
194 mapping.insert(target.clone(), target.clone());
195 continue;
196 }
197
198 if let Some(resolved) = resolve_module_to_file(target, &file_paths) {
199 mapping.insert(target.clone(), resolved);
200 }
201 }
202
203 mapping
204}
205
206fn resolve_module_to_file(module_path: &str, file_paths: &[&str]) -> Option<String> {
207 let cleaned = module_path
208 .trim_start_matches("crate::")
209 .trim_start_matches("super::");
210
211 let parts: Vec<&str> = cleaned.split("::").collect();
213
214 for end in (1..=parts.len()).rev() {
216 let candidate = parts[..end].join("/");
217
218 for fp in file_paths {
220 let fp_normalized = fp
221 .trim_start_matches("rust/src/")
222 .trim_start_matches("src/");
223
224 if fp_normalized == format!("{candidate}.rs")
225 || fp_normalized == format!("{candidate}/mod.rs")
226 || fp.ends_with(&format!("/{candidate}.rs"))
227 || fp.ends_with(&format!("/{candidate}/mod.rs"))
228 {
229 return Some(fp.to_string());
230 }
231 }
232 }
233
234 if let Some(last) = parts.last() {
236 let stem = format!("{last}.rs");
237 for fp in file_paths {
238 if fp.ends_with(&stem) {
239 return Some(fp.to_string());
240 }
241 }
242 }
243
244 None
245}
246
247pub fn parse_task_hints(task_description: &str) -> (Vec<String>, Vec<String>) {
249 let mut files = Vec::new();
250 let mut keywords = Vec::new();
251
252 for word in task_description.split_whitespace() {
253 let clean = word.trim_matches(|c: char| {
254 !c.is_alphanumeric() && c != '.' && c != '/' && c != '_' && c != '-'
255 });
256 if clean.contains('.')
257 && (clean.contains('/')
258 || clean.ends_with(".rs")
259 || clean.ends_with(".ts")
260 || clean.ends_with(".py")
261 || clean.ends_with(".go")
262 || clean.ends_with(".js"))
263 {
264 files.push(clean.to_string());
265 } else if clean.len() >= 3 && !STOP_WORDS.contains(&clean.to_lowercase().as_str()) {
266 keywords.push(clean.to_string());
267 }
268 }
269
270 (files, keywords)
271}
272
273const STOP_WORDS: &[&str] = &[
274 "the", "and", "for", "that", "this", "with", "from", "have", "has", "was", "are", "been",
275 "not", "but", "all", "can", "had", "her", "one", "our", "out", "you", "its", "will", "each",
276 "make", "like", "fix", "add", "use", "get", "set", "run", "new", "old", "should", "would",
277 "could", "into", "also", "than", "them", "then", "when", "just", "only", "very", "some",
278 "more", "other", "nach", "und", "die", "der", "das", "ist", "ein", "eine", "nicht", "auf",
279 "mit",
280];
281
282pub fn information_bottleneck_filter(
293 content: &str,
294 task_keywords: &[String],
295 budget_ratio: f64,
296) -> String {
297 let lines: Vec<&str> = content.lines().collect();
298 if lines.is_empty() {
299 return String::new();
300 }
301
302 let n = lines.len();
303 let kw_lower: Vec<String> = task_keywords.iter().map(|k| k.to_lowercase()).collect();
304 let attention = LearnedAttention::with_defaults();
305
306 let mut global_token_freq: HashMap<&str, usize> = HashMap::new();
307 for line in &lines {
308 for token in line.split_whitespace() {
309 *global_token_freq.entry(token).or_insert(0) += 1;
310 }
311 }
312 let total_unique = global_token_freq.len().max(1) as f64;
313
314 let mut scored_lines: Vec<(usize, &str, f64)> = lines
315 .iter()
316 .enumerate()
317 .map(|(i, line)| {
318 let trimmed = line.trim();
319 if trimmed.is_empty() {
320 return (i, *line, 0.05);
321 }
322
323 let line_lower = trimmed.to_lowercase();
324 let keyword_hits: f64 = kw_lower
325 .iter()
326 .filter(|kw| line_lower.contains(kw.as_str()))
327 .count() as f64;
328
329 let structural = if is_error_handling(trimmed) {
330 1.5
331 } else if is_definition_line(trimmed) {
332 1.0
333 } else if is_control_flow(trimmed) {
334 0.5
335 } else if is_closing_brace(trimmed) {
336 0.15
337 } else {
338 0.3
339 };
340 let relevance = keyword_hits * 0.5 + structural;
341
342 let line_tokens: Vec<&str> = trimmed.split_whitespace().collect();
343 let unique_in_line = line_tokens.iter().collect::<HashSet<_>>().len() as f64;
344 let line_token_count = line_tokens.len().max(1) as f64;
345 let token_diversity = unique_in_line / line_token_count;
346
347 let avg_idf: f64 = if line_tokens.is_empty() {
348 0.0
349 } else {
350 line_tokens
351 .iter()
352 .map(|t| {
353 let freq = *global_token_freq.get(t).unwrap_or(&1) as f64;
354 (total_unique / freq).ln().max(0.0)
355 })
356 .sum::<f64>()
357 / line_token_count
358 };
359 let information = (token_diversity * 0.4 + (avg_idf.min(3.0) / 3.0) * 0.6).min(1.0);
360
361 let pos = i as f64 / n.max(1) as f64;
362 let attn_weight = attention.weight(pos);
363
364 let score = (relevance * 0.6 + 0.05)
365 * (information * 0.25 + 0.05)
366 * (attn_weight * 0.15 + 0.05);
367
368 (i, *line, score)
369 })
370 .collect();
371
372 let budget = ((n as f64) * budget_ratio).ceil() as usize;
373
374 scored_lines.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
375
376 let selected = mmr_select(&scored_lines, budget, 0.3);
379
380 let mut output_lines: Vec<&str> = Vec::with_capacity(budget + 1);
381
382 if !kw_lower.is_empty() {
383 output_lines.push(""); }
385
386 for (_, line, _) in &selected {
387 output_lines.push(line);
388 }
389
390 if !kw_lower.is_empty() {
391 let summary = format!("[task: {}]", task_keywords.join(", "));
392 let mut result = summary;
393 result.push('\n');
394 result.push_str(&output_lines[1..].to_vec().join("\n"));
395 return result;
396 }
397
398 output_lines.join("\n")
399}
400
401fn mmr_select<'a>(
406 candidates: &[(usize, &'a str, f64)],
407 budget: usize,
408 lambda: f64,
409) -> Vec<(usize, &'a str, f64)> {
410 if candidates.is_empty() || budget == 0 {
411 return Vec::new();
412 }
413
414 let mut selected: Vec<(usize, &'a str, f64)> = Vec::with_capacity(budget);
415 let mut remaining: Vec<(usize, &'a str, f64)> = candidates.to_vec();
416
417 selected.push(remaining.remove(0));
419
420 while selected.len() < budget && !remaining.is_empty() {
421 let mut best_idx = 0;
422 let mut best_mmr = f64::NEG_INFINITY;
423
424 for (i, &(_, cand_line, cand_score)) in remaining.iter().enumerate() {
425 let cand_tokens: HashSet<&str> = cand_line.split_whitespace().collect();
426 if cand_tokens.is_empty() {
427 if cand_score > best_mmr {
428 best_mmr = cand_score;
429 best_idx = i;
430 }
431 continue;
432 }
433
434 let max_sim = selected
435 .iter()
436 .map(|&(_, sel_line, _)| {
437 let sel_tokens: HashSet<&str> = sel_line.split_whitespace().collect();
438 if sel_tokens.is_empty() {
439 return 0.0;
440 }
441 let inter = cand_tokens.intersection(&sel_tokens).count();
442 let union = cand_tokens.union(&sel_tokens).count();
443 if union == 0 {
444 0.0
445 } else {
446 inter as f64 / union as f64
447 }
448 })
449 .fold(0.0_f64, f64::max);
450
451 let mmr = cand_score - lambda * max_sim;
452 if mmr > best_mmr {
453 best_mmr = mmr;
454 best_idx = i;
455 }
456 }
457
458 selected.push(remaining.remove(best_idx));
459 }
460
461 selected
462}
463
464fn is_error_handling(line: &str) -> bool {
465 line.starts_with("return Err(")
466 || line.starts_with("Err(")
467 || line.starts_with("bail!(")
468 || line.starts_with("anyhow::bail!")
469 || line.contains(".map_err(")
470 || line.contains("unwrap()")
471 || line.contains("expect(\"")
472 || line.starts_with("raise ")
473 || line.starts_with("throw ")
474 || line.starts_with("catch ")
475 || line.starts_with("except ")
476 || line.starts_with("try ")
477 || (line.contains("?;") && !line.starts_with("//"))
478 || line.starts_with("panic!(")
479 || line.contains("Error::")
480 || line.contains("error!")
481}
482
483pub fn adaptive_ib_budget(content: &str, base_ratio: f64) -> f64 {
487 let lines: Vec<&str> = content.lines().collect();
488 if lines.len() < 10 {
489 return 1.0;
490 }
491
492 let mut token_freq: HashMap<&str, usize> = HashMap::new();
493 let mut total_tokens = 0usize;
494 for line in &lines {
495 for token in line.split_whitespace() {
496 *token_freq.entry(token).or_insert(0) += 1;
497 total_tokens += 1;
498 }
499 }
500
501 if total_tokens == 0 {
502 return base_ratio;
503 }
504
505 let unique_ratio = token_freq.len() as f64 / total_tokens as f64;
506 let repetition_factor = 1.0 - unique_ratio;
507
508 (base_ratio * (1.0 - repetition_factor * 0.3)).clamp(0.2, 1.0)
509}
510
511fn is_definition_line(line: &str) -> bool {
512 let prefixes = [
513 "fn ",
514 "pub fn ",
515 "async fn ",
516 "pub async fn ",
517 "struct ",
518 "pub struct ",
519 "enum ",
520 "pub enum ",
521 "trait ",
522 "pub trait ",
523 "impl ",
524 "type ",
525 "pub type ",
526 "const ",
527 "pub const ",
528 "static ",
529 "pub static ",
530 "class ",
531 "export class ",
532 "interface ",
533 "export interface ",
534 "function ",
535 "export function ",
536 "async function ",
537 "def ",
538 "async def ",
539 "func ",
540 ];
541 prefixes
542 .iter()
543 .any(|p| line.starts_with(p) || line.trim_start().starts_with(p))
544}
545
546fn is_control_flow(line: &str) -> bool {
547 let trimmed = line.trim();
548 trimmed.starts_with("if ")
549 || trimmed.starts_with("else ")
550 || trimmed.starts_with("match ")
551 || trimmed.starts_with("for ")
552 || trimmed.starts_with("while ")
553 || trimmed.starts_with("return ")
554 || trimmed.starts_with("break")
555 || trimmed.starts_with("continue")
556 || trimmed.starts_with("yield")
557 || trimmed.starts_with("await ")
558}
559
560fn is_closing_brace(line: &str) -> bool {
561 let trimmed = line.trim();
562 trimmed == "}" || trimmed == "};" || trimmed == "})" || trimmed == "});"
563}
564
565#[cfg(test)]
566mod tests {
567 use super::*;
568
569 #[test]
570 fn parse_task_finds_files_and_keywords() {
571 let (files, keywords) =
572 parse_task_hints("Fix the authentication bug in src/auth.rs and update tests");
573 assert!(files.iter().any(|f| f.contains("auth.rs")));
574 assert!(keywords
575 .iter()
576 .any(|k| k.to_lowercase().contains("authentication")));
577 }
578
579 #[test]
580 fn recommend_mode_by_score() {
581 assert_eq!(recommend_mode(1.0), "full");
582 assert_eq!(recommend_mode(0.6), "signatures");
583 assert_eq!(recommend_mode(0.3), "map");
584 assert_eq!(recommend_mode(0.1), "reference");
585 }
586
587 #[test]
588 fn info_bottleneck_preserves_definitions() {
589 let content = "fn main() {\n let x = 42;\n // boring comment\n println!(x);\n}\n";
590 let result = information_bottleneck_filter(content, &["main".to_string()], 0.6);
591 assert!(result.contains("fn main"), "definitions must be preserved");
592 assert!(result.contains("[task: main]"), "should have task summary");
593 }
594
595 #[test]
596 fn info_bottleneck_error_handling_priority() {
597 let content = "fn validate() {\n let data = parse()?;\n return Err(\"invalid\");\n let x = 1;\n let y = 2;\n}\n";
598 let result = information_bottleneck_filter(content, &["validate".to_string()], 0.5);
599 assert!(
600 result.contains("return Err"),
601 "error handling should survive filtering"
602 );
603 }
604
605 #[test]
606 fn info_bottleneck_score_sorted() {
607 let content = "fn important() {\n let x = 1;\n let y = 2;\n let z = 3;\n}\n}\n";
608 let result = information_bottleneck_filter(content, &[], 0.6);
609 let lines: Vec<&str> = result.lines().collect();
610 let def_pos = lines.iter().position(|l| l.contains("fn important"));
611 let brace_pos = lines.iter().position(|l| l.trim() == "}");
612 if let (Some(d), Some(b)) = (def_pos, brace_pos) {
613 assert!(
614 d < b,
615 "definitions should appear before closing braces in score-sorted output"
616 );
617 }
618 }
619
620 #[test]
621 fn adaptive_budget_reduces_for_repetitive() {
622 let repetitive = "let x = 1;\n".repeat(50);
623 let diverse = (0..50)
624 .map(|i| format!("let var_{i} = func_{i}(arg_{i});"))
625 .collect::<Vec<_>>()
626 .join("\n");
627 let budget_rep = super::adaptive_ib_budget(&repetitive, 0.7);
628 let budget_div = super::adaptive_ib_budget(&diverse, 0.7);
629 assert!(
630 budget_rep < budget_div,
631 "repetitive content should get lower budget"
632 );
633 }
634}