Skip to main content

lean_ctx/core/
task_relevance.rs

1use std::collections::{HashMap, HashSet};
2
3use super::graph_index::ProjectIndex;
4
5use super::neural::attention_learned::LearnedAttention;
6
7#[derive(Debug, Clone)]
8pub struct RelevanceScore {
9    pub path: String,
10    pub score: f64,
11    pub recommended_mode: &'static str,
12}
13
14pub fn compute_relevance(
15    index: &ProjectIndex,
16    task_files: &[String],
17    task_keywords: &[String],
18) -> Vec<RelevanceScore> {
19    let adj = build_adjacency_resolved(index);
20    let all_nodes: Vec<String> = index.files.keys().cloned().collect();
21    if all_nodes.is_empty() {
22        return Vec::new();
23    }
24
25    let node_idx: HashMap<&str, usize> = all_nodes
26        .iter()
27        .enumerate()
28        .map(|(i, n)| (n.as_str(), i))
29        .collect();
30    let n = all_nodes.len();
31
32    // Build degree-normalized adjacency for heat diffusion
33    let degrees: Vec<f64> = all_nodes
34        .iter()
35        .map(|node| {
36            adj.get(node)
37                .map_or(0.0, |neigh| neigh.len() as f64)
38                .max(1.0)
39        })
40        .collect();
41
42    // Seed vector: task files get 1.0
43    let mut heat: Vec<f64> = vec![0.0; n];
44    for f in task_files {
45        if let Some(&idx) = node_idx.get(f.as_str()) {
46            heat[idx] = 1.0;
47        }
48    }
49
50    // Heat diffusion: h(t+1) = (1-alpha)*h(t) + alpha * A_norm * h(t)
51    // Run for k iterations
52    let alpha = 0.5;
53    let iterations = 4;
54    for _ in 0..iterations {
55        let mut new_heat = vec![0.0; n];
56        for (i, node) in all_nodes.iter().enumerate() {
57            let self_term = (1.0 - alpha) * heat[i];
58            let mut neighbor_sum = 0.0;
59            if let Some(neighbors) = adj.get(node) {
60                for neighbor in neighbors {
61                    if let Some(&j) = node_idx.get(neighbor.as_str()) {
62                        neighbor_sum += heat[j] / degrees[j];
63                    }
64                }
65            }
66            new_heat[i] = self_term + alpha * neighbor_sum;
67        }
68        heat = new_heat;
69    }
70
71    // PageRank centrality for gateway detection
72    let mut pagerank = vec![1.0 / n as f64; n];
73    let damping = 0.85;
74    for _ in 0..8 {
75        let mut new_pr = vec![(1.0 - damping) / n as f64; n];
76        for (i, node) in all_nodes.iter().enumerate() {
77            if let Some(neighbors) = adj.get(node) {
78                let out_deg = neighbors.len().max(1) as f64;
79                for neighbor in neighbors {
80                    if let Some(&j) = node_idx.get(neighbor.as_str()) {
81                        new_pr[j] += damping * pagerank[i] / out_deg;
82                    }
83                }
84            }
85        }
86        pagerank = new_pr;
87    }
88
89    // Combine: heat (primary) + pagerank centrality (gateway bonus)
90    let mut scores: HashMap<String, f64> = HashMap::new();
91    let heat_max = heat.iter().cloned().fold(0.0_f64, f64::max).max(1e-10);
92    let pr_max = pagerank.iter().cloned().fold(0.0_f64, f64::max).max(1e-10);
93
94    for (i, node) in all_nodes.iter().enumerate() {
95        let h = heat[i] / heat_max;
96        let pr = pagerank[i] / pr_max;
97        let combined = h * 0.8 + pr * 0.2;
98        if combined > 0.01 {
99            scores.insert(node.clone(), combined);
100        }
101    }
102
103    // Keyword boost
104    if !task_keywords.is_empty() {
105        let kw_lower: Vec<String> = task_keywords.iter().map(|k| k.to_lowercase()).collect();
106        for (file_path, file_entry) in &index.files {
107            let path_lower = file_path.to_lowercase();
108            let mut keyword_hits = 0;
109            for kw in &kw_lower {
110                if path_lower.contains(kw) {
111                    keyword_hits += 1;
112                }
113                for export in &file_entry.exports {
114                    if export.to_lowercase().contains(kw) {
115                        keyword_hits += 1;
116                    }
117                }
118            }
119            if keyword_hits > 0 {
120                let boost = (keyword_hits as f64 * 0.15).min(0.6);
121                let entry = scores.entry(file_path.clone()).or_insert(0.0);
122                *entry = (*entry + boost).min(1.0);
123            }
124        }
125    }
126
127    let mut result: Vec<RelevanceScore> = scores
128        .into_iter()
129        .map(|(path, score)| {
130            let mode = recommend_mode(score);
131            RelevanceScore {
132                path,
133                score,
134                recommended_mode: mode,
135            }
136        })
137        .collect();
138
139    result.sort_by(|a, b| {
140        b.score
141            .partial_cmp(&a.score)
142            .unwrap_or(std::cmp::Ordering::Equal)
143    });
144    result
145}
146
147fn recommend_mode(score: f64) -> &'static str {
148    if score >= 0.8 {
149        "full"
150    } else if score >= 0.5 {
151        "signatures"
152    } else if score >= 0.2 {
153        "map"
154    } else {
155        "reference"
156    }
157}
158
159/// Build adjacency with module-path → file-path resolution.
160/// Graph edges store file paths as `from` and Rust module paths as `to`
161/// (e.g. `crate::core::tokens::count_tokens`). We resolve `to` back to file
162/// paths so heat diffusion and PageRank can propagate across the graph.
163fn build_adjacency_resolved(index: &ProjectIndex) -> HashMap<String, Vec<String>> {
164    let module_to_file = build_module_map(index);
165    let mut adj: HashMap<String, Vec<String>> = HashMap::new();
166
167    for edge in &index.edges {
168        let from = &edge.from;
169        let to_resolved = module_to_file
170            .get(&edge.to)
171            .cloned()
172            .unwrap_or_else(|| edge.to.clone());
173
174        if index.files.contains_key(from) && index.files.contains_key(&to_resolved) {
175            adj.entry(from.clone())
176                .or_default()
177                .push(to_resolved.clone());
178            adj.entry(to_resolved).or_default().push(from.clone());
179        }
180    }
181    adj
182}
183
184/// Map module/import paths to file paths using heuristics.
185/// e.g. `crate::core::tokens::count_tokens` → `rust/src/core/tokens.rs`
186fn build_module_map(index: &ProjectIndex) -> HashMap<String, String> {
187    let file_paths: Vec<&str> = index.files.keys().map(|s| s.as_str()).collect();
188    let mut mapping: HashMap<String, String> = HashMap::new();
189
190    let edge_targets: HashSet<String> = index.edges.iter().map(|e| e.to.clone()).collect();
191
192    for target in &edge_targets {
193        if index.files.contains_key(target) {
194            mapping.insert(target.clone(), target.clone());
195            continue;
196        }
197
198        if let Some(resolved) = resolve_module_to_file(target, &file_paths) {
199            mapping.insert(target.clone(), resolved);
200        }
201    }
202
203    mapping
204}
205
206fn resolve_module_to_file(module_path: &str, file_paths: &[&str]) -> Option<String> {
207    let cleaned = module_path
208        .trim_start_matches("crate::")
209        .trim_start_matches("super::");
210
211    // Strip trailing symbol (e.g. `core::tokens::count_tokens` → `core::tokens`)
212    let parts: Vec<&str> = cleaned.split("::").collect();
213
214    // Try progressively shorter prefixes to find a matching file
215    for end in (1..=parts.len()).rev() {
216        let candidate = parts[..end].join("/");
217
218        // Try as .rs file
219        for fp in file_paths {
220            let fp_normalized = fp
221                .trim_start_matches("rust/src/")
222                .trim_start_matches("src/");
223
224            if fp_normalized == format!("{candidate}.rs")
225                || fp_normalized == format!("{candidate}/mod.rs")
226                || fp.ends_with(&format!("/{candidate}.rs"))
227                || fp.ends_with(&format!("/{candidate}/mod.rs"))
228            {
229                return Some(fp.to_string());
230            }
231        }
232    }
233
234    // Fallback: match by last segment as filename stem
235    if let Some(last) = parts.last() {
236        let stem = format!("{last}.rs");
237        for fp in file_paths {
238            if fp.ends_with(&stem) {
239                return Some(fp.to_string());
240            }
241        }
242    }
243
244    None
245}
246
247/// Extract likely task-relevant file paths and keywords from a task description.
248pub fn parse_task_hints(task_description: &str) -> (Vec<String>, Vec<String>) {
249    let mut files = Vec::new();
250    let mut keywords = Vec::new();
251
252    for word in task_description.split_whitespace() {
253        let clean = word.trim_matches(|c: char| {
254            !c.is_alphanumeric() && c != '.' && c != '/' && c != '_' && c != '-'
255        });
256        if clean.contains('.')
257            && (clean.contains('/')
258                || clean.ends_with(".rs")
259                || clean.ends_with(".ts")
260                || clean.ends_with(".py")
261                || clean.ends_with(".go")
262                || clean.ends_with(".js"))
263        {
264            files.push(clean.to_string());
265        } else if clean.len() >= 3 && !STOP_WORDS.contains(&clean.to_lowercase().as_str()) {
266            keywords.push(clean.to_string());
267        }
268    }
269
270    (files, keywords)
271}
272
273const STOP_WORDS: &[&str] = &[
274    "the", "and", "for", "that", "this", "with", "from", "have", "has", "was", "are", "been",
275    "not", "but", "all", "can", "had", "her", "one", "our", "out", "you", "its", "will", "each",
276    "make", "like", "fix", "add", "use", "get", "set", "run", "new", "old", "should", "would",
277    "could", "into", "also", "than", "them", "then", "when", "just", "only", "very", "some",
278    "more", "other", "nach", "und", "die", "der", "das", "ist", "ein", "eine", "nicht", "auf",
279    "mit",
280];
281
282/// Information Bottleneck filter v2 — L-Curve aware, score-sorted output.
283///
284/// IB principle: maximize I(T;Y) (task relevance) while minimizing I(T;X) (input redundancy).
285/// Each line is scored by: relevance_to_task * information_density * attention_weight.
286///
287/// v2 changes (based on Lab Experiments A-C):
288///   - Uses empirical L-curve attention from attention_learned.rs instead of heuristic U-curve
289///   - Output is sorted by score DESC (most important first), not by line number
290///   - Error-handling lines get a priority boost (fragile under compression)
291///   - Emits a one-line task summary as the first line when keywords are present
292pub fn information_bottleneck_filter(
293    content: &str,
294    task_keywords: &[String],
295    budget_ratio: f64,
296) -> String {
297    let lines: Vec<&str> = content.lines().collect();
298    if lines.is_empty() {
299        return String::new();
300    }
301
302    let n = lines.len();
303    let kw_lower: Vec<String> = task_keywords.iter().map(|k| k.to_lowercase()).collect();
304    let attention = LearnedAttention::with_defaults();
305
306    let mut global_token_freq: HashMap<&str, usize> = HashMap::new();
307    for line in &lines {
308        for token in line.split_whitespace() {
309            *global_token_freq.entry(token).or_insert(0) += 1;
310        }
311    }
312    let total_unique = global_token_freq.len().max(1) as f64;
313
314    let mut scored_lines: Vec<(usize, &str, f64)> = lines
315        .iter()
316        .enumerate()
317        .map(|(i, line)| {
318            let trimmed = line.trim();
319            if trimmed.is_empty() {
320                return (i, *line, 0.05);
321            }
322
323            let line_lower = trimmed.to_lowercase();
324            let keyword_hits: f64 = kw_lower
325                .iter()
326                .filter(|kw| line_lower.contains(kw.as_str()))
327                .count() as f64;
328
329            let structural = if is_error_handling(trimmed) {
330                1.5
331            } else if is_definition_line(trimmed) {
332                1.0
333            } else if is_control_flow(trimmed) {
334                0.5
335            } else if is_closing_brace(trimmed) {
336                0.15
337            } else {
338                0.3
339            };
340            let relevance = keyword_hits * 0.5 + structural;
341
342            let line_tokens: Vec<&str> = trimmed.split_whitespace().collect();
343            let unique_in_line = line_tokens.iter().collect::<HashSet<_>>().len() as f64;
344            let line_token_count = line_tokens.len().max(1) as f64;
345            let token_diversity = unique_in_line / line_token_count;
346
347            let avg_idf: f64 = if line_tokens.is_empty() {
348                0.0
349            } else {
350                line_tokens
351                    .iter()
352                    .map(|t| {
353                        let freq = *global_token_freq.get(t).unwrap_or(&1) as f64;
354                        (total_unique / freq).ln().max(0.0)
355                    })
356                    .sum::<f64>()
357                    / line_token_count
358            };
359            let information = (token_diversity * 0.4 + (avg_idf.min(3.0) / 3.0) * 0.6).min(1.0);
360
361            let pos = i as f64 / n.max(1) as f64;
362            let attn_weight = attention.weight(pos);
363
364            let score = (relevance * 0.6 + 0.05)
365                * (information * 0.25 + 0.05)
366                * (attn_weight * 0.15 + 0.05);
367
368            (i, *line, score)
369        })
370        .collect();
371
372    let budget = ((n as f64) * budget_ratio).ceil() as usize;
373
374    scored_lines.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
375
376    // MMR deduplication: penalize lines that are redundant with already-selected ones.
377    // score_mmr(i) = relevance(i) - lambda * max_j∈S(similarity(i, j))
378    let selected = mmr_select(&scored_lines, budget, 0.3);
379
380    let mut output_lines: Vec<&str> = Vec::with_capacity(budget + 1);
381
382    if !kw_lower.is_empty() {
383        output_lines.push(""); // placeholder for summary
384    }
385
386    for (_, line, _) in &selected {
387        output_lines.push(line);
388    }
389
390    if !kw_lower.is_empty() {
391        let summary = format!("[task: {}]", task_keywords.join(", "));
392        let mut result = summary;
393        result.push('\n');
394        result.push_str(&output_lines[1..].to_vec().join("\n"));
395        return result;
396    }
397
398    output_lines.join("\n")
399}
400
401/// Maximum Marginal Relevance selection — greedy selection that penalizes
402/// redundancy with already-selected lines using token-set Jaccard similarity.
403///
404/// MMR(i) = relevance(i) - lambda * max_{j in S} jaccard(i, j)
405fn mmr_select<'a>(
406    candidates: &[(usize, &'a str, f64)],
407    budget: usize,
408    lambda: f64,
409) -> Vec<(usize, &'a str, f64)> {
410    if candidates.is_empty() || budget == 0 {
411        return Vec::new();
412    }
413
414    let mut selected: Vec<(usize, &'a str, f64)> = Vec::with_capacity(budget);
415    let mut remaining: Vec<(usize, &'a str, f64)> = candidates.to_vec();
416
417    // Always take the top-scored line first
418    selected.push(remaining.remove(0));
419
420    while selected.len() < budget && !remaining.is_empty() {
421        let mut best_idx = 0;
422        let mut best_mmr = f64::NEG_INFINITY;
423
424        for (i, &(_, cand_line, cand_score)) in remaining.iter().enumerate() {
425            let cand_tokens: HashSet<&str> = cand_line.split_whitespace().collect();
426            if cand_tokens.is_empty() {
427                if cand_score > best_mmr {
428                    best_mmr = cand_score;
429                    best_idx = i;
430                }
431                continue;
432            }
433
434            let max_sim = selected
435                .iter()
436                .map(|&(_, sel_line, _)| {
437                    let sel_tokens: HashSet<&str> = sel_line.split_whitespace().collect();
438                    if sel_tokens.is_empty() {
439                        return 0.0;
440                    }
441                    let inter = cand_tokens.intersection(&sel_tokens).count();
442                    let union = cand_tokens.union(&sel_tokens).count();
443                    if union == 0 {
444                        0.0
445                    } else {
446                        inter as f64 / union as f64
447                    }
448                })
449                .fold(0.0_f64, f64::max);
450
451            let mmr = cand_score - lambda * max_sim;
452            if mmr > best_mmr {
453                best_mmr = mmr;
454                best_idx = i;
455            }
456        }
457
458        selected.push(remaining.remove(best_idx));
459    }
460
461    selected
462}
463
464fn is_error_handling(line: &str) -> bool {
465    line.starts_with("return Err(")
466        || line.starts_with("Err(")
467        || line.starts_with("bail!(")
468        || line.starts_with("anyhow::bail!")
469        || line.contains(".map_err(")
470        || line.contains("unwrap()")
471        || line.contains("expect(\"")
472        || line.starts_with("raise ")
473        || line.starts_with("throw ")
474        || line.starts_with("catch ")
475        || line.starts_with("except ")
476        || line.starts_with("try ")
477        || (line.contains("?;") && !line.starts_with("//"))
478        || line.starts_with("panic!(")
479        || line.contains("Error::")
480        || line.contains("error!")
481}
482
483/// Compute an adaptive IB budget ratio based on content characteristics.
484/// Highly repetitive content → more aggressive filtering (lower ratio).
485/// High-entropy diverse content → more conservative (higher ratio).
486pub fn adaptive_ib_budget(content: &str, base_ratio: f64) -> f64 {
487    let lines: Vec<&str> = content.lines().collect();
488    if lines.len() < 10 {
489        return 1.0;
490    }
491
492    let mut token_freq: HashMap<&str, usize> = HashMap::new();
493    let mut total_tokens = 0usize;
494    for line in &lines {
495        for token in line.split_whitespace() {
496            *token_freq.entry(token).or_insert(0) += 1;
497            total_tokens += 1;
498        }
499    }
500
501    if total_tokens == 0 {
502        return base_ratio;
503    }
504
505    let unique_ratio = token_freq.len() as f64 / total_tokens as f64;
506    let repetition_factor = 1.0 - unique_ratio;
507
508    (base_ratio * (1.0 - repetition_factor * 0.3)).clamp(0.2, 1.0)
509}
510
511fn is_definition_line(line: &str) -> bool {
512    let prefixes = [
513        "fn ",
514        "pub fn ",
515        "async fn ",
516        "pub async fn ",
517        "struct ",
518        "pub struct ",
519        "enum ",
520        "pub enum ",
521        "trait ",
522        "pub trait ",
523        "impl ",
524        "type ",
525        "pub type ",
526        "const ",
527        "pub const ",
528        "static ",
529        "pub static ",
530        "class ",
531        "export class ",
532        "interface ",
533        "export interface ",
534        "function ",
535        "export function ",
536        "async function ",
537        "def ",
538        "async def ",
539        "func ",
540    ];
541    prefixes
542        .iter()
543        .any(|p| line.starts_with(p) || line.trim_start().starts_with(p))
544}
545
546fn is_control_flow(line: &str) -> bool {
547    let trimmed = line.trim();
548    trimmed.starts_with("if ")
549        || trimmed.starts_with("else ")
550        || trimmed.starts_with("match ")
551        || trimmed.starts_with("for ")
552        || trimmed.starts_with("while ")
553        || trimmed.starts_with("return ")
554        || trimmed.starts_with("break")
555        || trimmed.starts_with("continue")
556        || trimmed.starts_with("yield")
557        || trimmed.starts_with("await ")
558}
559
560fn is_closing_brace(line: &str) -> bool {
561    let trimmed = line.trim();
562    trimmed == "}" || trimmed == "};" || trimmed == "})" || trimmed == "});"
563}
564
565#[cfg(test)]
566mod tests {
567    use super::*;
568
569    #[test]
570    fn parse_task_finds_files_and_keywords() {
571        let (files, keywords) =
572            parse_task_hints("Fix the authentication bug in src/auth.rs and update tests");
573        assert!(files.iter().any(|f| f.contains("auth.rs")));
574        assert!(keywords
575            .iter()
576            .any(|k| k.to_lowercase().contains("authentication")));
577    }
578
579    #[test]
580    fn recommend_mode_by_score() {
581        assert_eq!(recommend_mode(1.0), "full");
582        assert_eq!(recommend_mode(0.6), "signatures");
583        assert_eq!(recommend_mode(0.3), "map");
584        assert_eq!(recommend_mode(0.1), "reference");
585    }
586
587    #[test]
588    fn info_bottleneck_preserves_definitions() {
589        let content = "fn main() {\n    let x = 42;\n    // boring comment\n    println!(x);\n}\n";
590        let result = information_bottleneck_filter(content, &["main".to_string()], 0.6);
591        assert!(result.contains("fn main"), "definitions must be preserved");
592        assert!(result.contains("[task: main]"), "should have task summary");
593    }
594
595    #[test]
596    fn info_bottleneck_error_handling_priority() {
597        let content = "fn validate() {\n    let data = parse()?;\n    return Err(\"invalid\");\n    let x = 1;\n    let y = 2;\n}\n";
598        let result = information_bottleneck_filter(content, &["validate".to_string()], 0.5);
599        assert!(
600            result.contains("return Err"),
601            "error handling should survive filtering"
602        );
603    }
604
605    #[test]
606    fn info_bottleneck_score_sorted() {
607        let content = "fn important() {\n    let x = 1;\n    let y = 2;\n    let z = 3;\n}\n}\n";
608        let result = information_bottleneck_filter(content, &[], 0.6);
609        let lines: Vec<&str> = result.lines().collect();
610        let def_pos = lines.iter().position(|l| l.contains("fn important"));
611        let brace_pos = lines.iter().position(|l| l.trim() == "}");
612        if let (Some(d), Some(b)) = (def_pos, brace_pos) {
613            assert!(
614                d < b,
615                "definitions should appear before closing braces in score-sorted output"
616            );
617        }
618    }
619
620    #[test]
621    fn adaptive_budget_reduces_for_repetitive() {
622        let repetitive = "let x = 1;\n".repeat(50);
623        let diverse = (0..50)
624            .map(|i| format!("let var_{i} = func_{i}(arg_{i});"))
625            .collect::<Vec<_>>()
626            .join("\n");
627        let budget_rep = super::adaptive_ib_budget(&repetitive, 0.7);
628        let budget_div = super::adaptive_ib_budget(&diverse, 0.7);
629        assert!(
630            budget_rep < budget_div,
631            "repetitive content should get lower budget"
632        );
633    }
634}