Skip to main content

lean_ctx/core/
task_relevance.rs

1use std::collections::{HashMap, HashSet};
2
3use super::graph_index::ProjectIndex;
4
5use super::neural::attention_learned::LearnedAttention;
6
7#[derive(Debug, Clone)]
8pub struct RelevanceScore {
9    pub path: String,
10    pub score: f64,
11    pub recommended_mode: &'static str,
12}
13
14pub fn compute_relevance(
15    index: &ProjectIndex,
16    task_files: &[String],
17    task_keywords: &[String],
18) -> Vec<RelevanceScore> {
19    let adj = build_adjacency_resolved(index);
20    let all_nodes: Vec<String> = index.files.keys().cloned().collect();
21    if all_nodes.is_empty() {
22        return Vec::new();
23    }
24
25    let node_idx: HashMap<&str, usize> = all_nodes
26        .iter()
27        .enumerate()
28        .map(|(i, n)| (n.as_str(), i))
29        .collect();
30    let n = all_nodes.len();
31
32    // Build degree-normalized adjacency for heat diffusion
33    let degrees: Vec<f64> = all_nodes
34        .iter()
35        .map(|node| {
36            adj.get(node)
37                .map_or(0.0, |neigh| neigh.len() as f64)
38                .max(1.0)
39        })
40        .collect();
41
42    // Seed vector: task files get 1.0
43    let mut heat: Vec<f64> = vec![0.0; n];
44    for f in task_files {
45        if let Some(&idx) = node_idx.get(f.as_str()) {
46            heat[idx] = 1.0;
47        }
48    }
49
50    // Heat diffusion: h(t+1) = (1-alpha)*h(t) + alpha * A_norm * h(t)
51    // Run for k iterations
52    let alpha = 0.5;
53    let iterations = 4;
54    for _ in 0..iterations {
55        let mut new_heat = vec![0.0; n];
56        for (i, node) in all_nodes.iter().enumerate() {
57            let self_term = (1.0 - alpha) * heat[i];
58            let mut neighbor_sum = 0.0;
59            if let Some(neighbors) = adj.get(node) {
60                for neighbor in neighbors {
61                    if let Some(&j) = node_idx.get(neighbor.as_str()) {
62                        neighbor_sum += heat[j] / degrees[j];
63                    }
64                }
65            }
66            new_heat[i] = self_term + alpha * neighbor_sum;
67        }
68        heat = new_heat;
69    }
70
71    // PageRank centrality for gateway detection
72    let mut pagerank = vec![1.0 / n as f64; n];
73    let damping = 0.85;
74    for _ in 0..8 {
75        let mut new_pr = vec![(1.0 - damping) / n as f64; n];
76        for (i, node) in all_nodes.iter().enumerate() {
77            if let Some(neighbors) = adj.get(node) {
78                let out_deg = neighbors.len().max(1) as f64;
79                for neighbor in neighbors {
80                    if let Some(&j) = node_idx.get(neighbor.as_str()) {
81                        new_pr[j] += damping * pagerank[i] / out_deg;
82                    }
83                }
84            }
85        }
86        pagerank = new_pr;
87    }
88
89    // Combine: heat (primary) + pagerank centrality (gateway bonus)
90    let mut scores: HashMap<String, f64> = HashMap::new();
91    let heat_max = heat.iter().cloned().fold(0.0_f64, f64::max).max(1e-10);
92    let pr_max = pagerank.iter().cloned().fold(0.0_f64, f64::max).max(1e-10);
93
94    for (i, node) in all_nodes.iter().enumerate() {
95        let h = heat[i] / heat_max;
96        let pr = pagerank[i] / pr_max;
97        let combined = h * 0.8 + pr * 0.2;
98        if combined > 0.01 {
99            scores.insert(node.clone(), combined);
100        }
101    }
102
103    // Keyword boost
104    if !task_keywords.is_empty() {
105        let kw_lower: Vec<String> = task_keywords.iter().map(|k| k.to_lowercase()).collect();
106        for (file_path, file_entry) in &index.files {
107            let path_lower = file_path.to_lowercase();
108            let mut keyword_hits = 0;
109            for kw in &kw_lower {
110                if path_lower.contains(kw) {
111                    keyword_hits += 1;
112                }
113                for export in &file_entry.exports {
114                    if export.to_lowercase().contains(kw) {
115                        keyword_hits += 1;
116                    }
117                }
118            }
119            if keyword_hits > 0 {
120                let boost = (keyword_hits as f64 * 0.15).min(0.6);
121                let entry = scores.entry(file_path.clone()).or_insert(0.0);
122                *entry = (*entry + boost).min(1.0);
123            }
124        }
125    }
126
127    let mut result: Vec<RelevanceScore> = scores
128        .into_iter()
129        .map(|(path, score)| {
130            let mode = recommend_mode(score);
131            RelevanceScore {
132                path,
133                score,
134                recommended_mode: mode,
135            }
136        })
137        .collect();
138
139    result.sort_by(|a, b| {
140        b.score
141            .partial_cmp(&a.score)
142            .unwrap_or(std::cmp::Ordering::Equal)
143    });
144    result
145}
146
147fn recommend_mode(score: f64) -> &'static str {
148    if score >= 0.8 {
149        "full"
150    } else if score >= 0.5 {
151        "signatures"
152    } else if score >= 0.2 {
153        "map"
154    } else {
155        "reference"
156    }
157}
158
159/// Build adjacency with module-path → file-path resolution.
160/// Graph edges store file paths as `from` and Rust module paths as `to`
161/// (e.g. `crate::core::tokens::count_tokens`). We resolve `to` back to file
162/// paths so heat diffusion and PageRank can propagate across the graph.
163fn build_adjacency_resolved(index: &ProjectIndex) -> HashMap<String, Vec<String>> {
164    let module_to_file = build_module_map(index);
165    let mut adj: HashMap<String, Vec<String>> = HashMap::new();
166
167    for edge in &index.edges {
168        let from = &edge.from;
169        let to_resolved = module_to_file
170            .get(&edge.to)
171            .cloned()
172            .unwrap_or_else(|| edge.to.clone());
173
174        if index.files.contains_key(from) && index.files.contains_key(&to_resolved) {
175            adj.entry(from.clone())
176                .or_default()
177                .push(to_resolved.clone());
178            adj.entry(to_resolved).or_default().push(from.clone());
179        }
180    }
181    adj
182}
183
184/// Map module/import paths to file paths using heuristics.
185/// e.g. `crate::core::tokens::count_tokens` → `rust/src/core/tokens.rs`
186fn build_module_map(index: &ProjectIndex) -> HashMap<String, String> {
187    let file_paths: Vec<&str> = index.files.keys().map(|s| s.as_str()).collect();
188    let mut mapping: HashMap<String, String> = HashMap::new();
189
190    let edge_targets: HashSet<String> = index.edges.iter().map(|e| e.to.clone()).collect();
191
192    for target in &edge_targets {
193        if index.files.contains_key(target) {
194            mapping.insert(target.clone(), target.clone());
195            continue;
196        }
197
198        if let Some(resolved) = resolve_module_to_file(target, &file_paths) {
199            mapping.insert(target.clone(), resolved);
200        }
201    }
202
203    mapping
204}
205
206fn resolve_module_to_file(module_path: &str, file_paths: &[&str]) -> Option<String> {
207    let cleaned = module_path
208        .trim_start_matches("crate::")
209        .trim_start_matches("super::");
210
211    // Strip trailing symbol (e.g. `core::tokens::count_tokens` → `core::tokens`)
212    let parts: Vec<&str> = cleaned.split("::").collect();
213
214    // Try progressively shorter prefixes to find a matching file
215    for end in (1..=parts.len()).rev() {
216        let candidate = parts[..end].join("/");
217
218        // Try as .rs file
219        for fp in file_paths {
220            let fp_normalized = fp
221                .trim_start_matches("rust/src/")
222                .trim_start_matches("src/");
223
224            if fp_normalized == format!("{candidate}.rs")
225                || fp_normalized == format!("{candidate}/mod.rs")
226                || fp.ends_with(&format!("/{candidate}.rs"))
227                || fp.ends_with(&format!("/{candidate}/mod.rs"))
228            {
229                return Some(fp.to_string());
230            }
231        }
232    }
233
234    // Fallback: match by last segment as filename stem
235    if let Some(last) = parts.last() {
236        let stem = format!("{last}.rs");
237        for fp in file_paths {
238            if fp.ends_with(&stem) {
239                return Some(fp.to_string());
240            }
241        }
242    }
243
244    None
245}
246
247/// Extract likely task-relevant file paths and keywords from a task description.
248pub fn parse_task_hints(task_description: &str) -> (Vec<String>, Vec<String>) {
249    let mut files = Vec::new();
250    let mut keywords = Vec::new();
251
252    for word in task_description.split_whitespace() {
253        let clean = word.trim_matches(|c: char| {
254            !c.is_alphanumeric() && c != '.' && c != '/' && c != '_' && c != '-'
255        });
256        if clean.contains('.')
257            && (clean.contains('/')
258                || clean.ends_with(".rs")
259                || clean.ends_with(".ts")
260                || clean.ends_with(".py")
261                || clean.ends_with(".go")
262                || clean.ends_with(".js"))
263        {
264            files.push(clean.to_string());
265        } else if clean.len() >= 3 && !STOP_WORDS.contains(&clean.to_lowercase().as_str()) {
266            keywords.push(clean.to_string());
267        }
268    }
269
270    (files, keywords)
271}
272
273const STOP_WORDS: &[&str] = &[
274    "the", "and", "for", "that", "this", "with", "from", "have", "has", "was", "are", "been",
275    "not", "but", "all", "can", "had", "her", "one", "our", "out", "you", "its", "will", "each",
276    "make", "like", "fix", "add", "use", "get", "set", "run", "new", "old", "should", "would",
277    "could", "into", "also", "than", "them", "then", "when", "just", "only", "very", "some",
278    "more", "other", "nach", "und", "die", "der", "das", "ist", "ein", "eine", "nicht", "auf",
279    "mit",
280];
281
282/// Information Bottleneck filter v3 — Mutual Information scoring, QUITO-X inspired.
283///
284/// IB principle: maximize I(T;Y) (task relevance) while minimizing I(T;X) (input redundancy).
285/// v3: MI(line, task) approximated via token overlap + IDF weighting + structural importance.
286///
287/// Key changes from v2:
288///   - Mutual Information scoring: MI(line, task) = H(line) - H(line|task)
289///   - Adaptive budget allocation based on task type via TaskClassifier
290///   - Token-level IDF computed over full document for better term weighting
291///   - Maintains L-curve attention, MMR dedup, error-handling priority from v2
292pub fn information_bottleneck_filter(
293    content: &str,
294    task_keywords: &[String],
295    budget_ratio: f64,
296) -> String {
297    let lines: Vec<&str> = content.lines().collect();
298    if lines.is_empty() {
299        return String::new();
300    }
301
302    let n = lines.len();
303    let kw_lower: Vec<String> = task_keywords.iter().map(|k| k.to_lowercase()).collect();
304    let attention = LearnedAttention::with_defaults();
305
306    let mut global_token_freq: HashMap<&str, usize> = HashMap::new();
307    for line in &lines {
308        for token in line.split_whitespace() {
309            *global_token_freq.entry(token).or_insert(0) += 1;
310        }
311    }
312    let total_unique = global_token_freq.len().max(1) as f64;
313    let total_lines = n.max(1) as f64;
314
315    let task_token_set: HashSet<String> = kw_lower
316        .iter()
317        .flat_map(|kw| kw.split(|c: char| !c.is_alphanumeric()).map(String::from))
318        .filter(|t| t.len() >= 2)
319        .collect();
320
321    let effective_ratio = if !task_token_set.is_empty() {
322        adaptive_ib_budget(content, budget_ratio)
323    } else {
324        budget_ratio
325    };
326
327    let mut scored_lines: Vec<(usize, &str, f64)> = lines
328        .iter()
329        .enumerate()
330        .map(|(i, line)| {
331            let trimmed = line.trim();
332            if trimmed.is_empty() {
333                return (i, *line, 0.05);
334            }
335
336            let line_lower = trimmed.to_lowercase();
337            let line_tokens: Vec<&str> = trimmed.split_whitespace().collect();
338            let line_token_count = line_tokens.len().max(1) as f64;
339
340            let mi_score = if task_token_set.is_empty() {
341                0.0
342            } else {
343                let line_token_set: HashSet<String> =
344                    line_tokens.iter().map(|t| t.to_lowercase()).collect();
345                let overlap: f64 = line_token_set
346                    .iter()
347                    .filter(|t| task_token_set.iter().any(|kw| t.contains(kw.as_str())))
348                    .map(|t| {
349                        let freq = *global_token_freq.get(t.as_str()).unwrap_or(&1) as f64;
350                        (total_lines / freq).ln().max(0.1)
351                    })
352                    .sum();
353                overlap / line_token_count
354            };
355
356            let keyword_hits: f64 = kw_lower
357                .iter()
358                .filter(|kw| line_lower.contains(kw.as_str()))
359                .count() as f64;
360
361            let structural = if is_error_handling(trimmed) {
362                1.5
363            } else if is_definition_line(trimmed) {
364                1.0
365            } else if is_control_flow(trimmed) {
366                0.5
367            } else if is_closing_brace(trimmed) {
368                0.15
369            } else {
370                0.3
371            };
372            let relevance = mi_score * 0.4 + keyword_hits * 0.3 + structural;
373
374            let unique_in_line = line_tokens.iter().collect::<HashSet<_>>().len() as f64;
375            let token_diversity = unique_in_line / line_token_count;
376
377            let avg_idf: f64 = if line_tokens.is_empty() {
378                0.0
379            } else {
380                line_tokens
381                    .iter()
382                    .map(|t| {
383                        let freq = *global_token_freq.get(t).unwrap_or(&1) as f64;
384                        (total_unique / freq).ln().max(0.0)
385                    })
386                    .sum::<f64>()
387                    / line_token_count
388            };
389            let information = (token_diversity * 0.4 + (avg_idf.min(3.0) / 3.0) * 0.6).min(1.0);
390
391            let pos = i as f64 / n.max(1) as f64;
392            let attn_weight = attention.weight(pos);
393
394            let score = (relevance * 0.6 + 0.05)
395                * (information * 0.25 + 0.05)
396                * (attn_weight * 0.15 + 0.05);
397
398            (i, *line, score)
399        })
400        .collect();
401
402    let budget = ((n as f64) * effective_ratio).ceil() as usize;
403
404    scored_lines.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
405
406    let selected = mmr_select(&scored_lines, budget, 0.3);
407
408    let mut output_lines: Vec<&str> = Vec::with_capacity(budget + 1);
409
410    if !kw_lower.is_empty() {
411        output_lines.push("");
412    }
413
414    for (_, line, _) in &selected {
415        output_lines.push(line);
416    }
417
418    if !kw_lower.is_empty() {
419        let summary = format!("[task: {}]", task_keywords.join(", "));
420        let mut result = summary;
421        result.push('\n');
422        result.push_str(&output_lines[1..].to_vec().join("\n"));
423        return result;
424    }
425
426    output_lines.join("\n")
427}
428
429/// Maximum Marginal Relevance selection — greedy selection that penalizes
430/// redundancy with already-selected lines using token-set Jaccard similarity.
431///
432/// MMR(i) = relevance(i) - lambda * max_{j in S} jaccard(i, j)
433fn mmr_select<'a>(
434    candidates: &[(usize, &'a str, f64)],
435    budget: usize,
436    lambda: f64,
437) -> Vec<(usize, &'a str, f64)> {
438    if candidates.is_empty() || budget == 0 {
439        return Vec::new();
440    }
441
442    let mut selected: Vec<(usize, &'a str, f64)> = Vec::with_capacity(budget);
443    let mut remaining: Vec<(usize, &'a str, f64)> = candidates.to_vec();
444
445    // Always take the top-scored line first
446    selected.push(remaining.remove(0));
447
448    while selected.len() < budget && !remaining.is_empty() {
449        let mut best_idx = 0;
450        let mut best_mmr = f64::NEG_INFINITY;
451
452        for (i, &(_, cand_line, cand_score)) in remaining.iter().enumerate() {
453            let cand_tokens: HashSet<&str> = cand_line.split_whitespace().collect();
454            if cand_tokens.is_empty() {
455                if cand_score > best_mmr {
456                    best_mmr = cand_score;
457                    best_idx = i;
458                }
459                continue;
460            }
461
462            let max_sim = selected
463                .iter()
464                .map(|&(_, sel_line, _)| {
465                    let sel_tokens: HashSet<&str> = sel_line.split_whitespace().collect();
466                    if sel_tokens.is_empty() {
467                        return 0.0;
468                    }
469                    let inter = cand_tokens.intersection(&sel_tokens).count();
470                    let union = cand_tokens.union(&sel_tokens).count();
471                    if union == 0 {
472                        0.0
473                    } else {
474                        inter as f64 / union as f64
475                    }
476                })
477                .fold(0.0_f64, f64::max);
478
479            let mmr = cand_score - lambda * max_sim;
480            if mmr > best_mmr {
481                best_mmr = mmr;
482                best_idx = i;
483            }
484        }
485
486        selected.push(remaining.remove(best_idx));
487    }
488
489    selected
490}
491
492fn is_error_handling(line: &str) -> bool {
493    line.starts_with("return Err(")
494        || line.starts_with("Err(")
495        || line.starts_with("bail!(")
496        || line.starts_with("anyhow::bail!")
497        || line.contains(".map_err(")
498        || line.contains("unwrap()")
499        || line.contains("expect(\"")
500        || line.starts_with("raise ")
501        || line.starts_with("throw ")
502        || line.starts_with("catch ")
503        || line.starts_with("except ")
504        || line.starts_with("try ")
505        || (line.contains("?;") && !line.starts_with("//"))
506        || line.starts_with("panic!(")
507        || line.contains("Error::")
508        || line.contains("error!")
509}
510
511/// Compute an adaptive IB budget ratio based on content characteristics.
512/// Highly repetitive content → more aggressive filtering (lower ratio).
513/// High-entropy diverse content → more conservative (higher ratio).
514pub fn adaptive_ib_budget(content: &str, base_ratio: f64) -> f64 {
515    let lines: Vec<&str> = content.lines().collect();
516    if lines.len() < 10 {
517        return 1.0;
518    }
519
520    let mut token_freq: HashMap<&str, usize> = HashMap::new();
521    let mut total_tokens = 0usize;
522    for line in &lines {
523        for token in line.split_whitespace() {
524            *token_freq.entry(token).or_insert(0) += 1;
525            total_tokens += 1;
526        }
527    }
528
529    if total_tokens == 0 {
530        return base_ratio;
531    }
532
533    let unique_ratio = token_freq.len() as f64 / total_tokens as f64;
534    let repetition_factor = 1.0 - unique_ratio;
535
536    (base_ratio * (1.0 - repetition_factor * 0.3)).clamp(0.2, 1.0)
537}
538
539fn is_definition_line(line: &str) -> bool {
540    let prefixes = [
541        "fn ",
542        "pub fn ",
543        "async fn ",
544        "pub async fn ",
545        "struct ",
546        "pub struct ",
547        "enum ",
548        "pub enum ",
549        "trait ",
550        "pub trait ",
551        "impl ",
552        "type ",
553        "pub type ",
554        "const ",
555        "pub const ",
556        "static ",
557        "pub static ",
558        "class ",
559        "export class ",
560        "interface ",
561        "export interface ",
562        "function ",
563        "export function ",
564        "async function ",
565        "def ",
566        "async def ",
567        "func ",
568    ];
569    prefixes
570        .iter()
571        .any(|p| line.starts_with(p) || line.trim_start().starts_with(p))
572}
573
574fn is_control_flow(line: &str) -> bool {
575    let trimmed = line.trim();
576    trimmed.starts_with("if ")
577        || trimmed.starts_with("else ")
578        || trimmed.starts_with("match ")
579        || trimmed.starts_with("for ")
580        || trimmed.starts_with("while ")
581        || trimmed.starts_with("return ")
582        || trimmed.starts_with("break")
583        || trimmed.starts_with("continue")
584        || trimmed.starts_with("yield")
585        || trimmed.starts_with("await ")
586}
587
588fn is_closing_brace(line: &str) -> bool {
589    let trimmed = line.trim();
590    trimmed == "}" || trimmed == "};" || trimmed == "})" || trimmed == "});"
591}
592
593#[cfg(test)]
594mod tests {
595    use super::*;
596
597    #[test]
598    fn parse_task_finds_files_and_keywords() {
599        let (files, keywords) =
600            parse_task_hints("Fix the authentication bug in src/auth.rs and update tests");
601        assert!(files.iter().any(|f| f.contains("auth.rs")));
602        assert!(keywords
603            .iter()
604            .any(|k| k.to_lowercase().contains("authentication")));
605    }
606
607    #[test]
608    fn recommend_mode_by_score() {
609        assert_eq!(recommend_mode(1.0), "full");
610        assert_eq!(recommend_mode(0.6), "signatures");
611        assert_eq!(recommend_mode(0.3), "map");
612        assert_eq!(recommend_mode(0.1), "reference");
613    }
614
615    #[test]
616    fn info_bottleneck_preserves_definitions() {
617        let content = "fn main() {\n    let x = 42;\n    // boring comment\n    println!(x);\n}\n";
618        let result = information_bottleneck_filter(content, &["main".to_string()], 0.6);
619        assert!(result.contains("fn main"), "definitions must be preserved");
620        assert!(result.contains("[task: main]"), "should have task summary");
621    }
622
623    #[test]
624    fn info_bottleneck_error_handling_priority() {
625        let content = "fn validate() {\n    let data = parse()?;\n    return Err(\"invalid\");\n    let x = 1;\n    let y = 2;\n}\n";
626        let result = information_bottleneck_filter(content, &["validate".to_string()], 0.5);
627        assert!(
628            result.contains("return Err"),
629            "error handling should survive filtering"
630        );
631    }
632
633    #[test]
634    fn info_bottleneck_score_sorted() {
635        let content = "fn important() {\n    let x = 1;\n    let y = 2;\n    let z = 3;\n}\n}\n";
636        let result = information_bottleneck_filter(content, &[], 0.6);
637        let lines: Vec<&str> = result.lines().collect();
638        let def_pos = lines.iter().position(|l| l.contains("fn important"));
639        let brace_pos = lines.iter().position(|l| l.trim() == "}");
640        if let (Some(d), Some(b)) = (def_pos, brace_pos) {
641            assert!(
642                d < b,
643                "definitions should appear before closing braces in score-sorted output"
644            );
645        }
646    }
647
648    #[test]
649    fn adaptive_budget_reduces_for_repetitive() {
650        let repetitive = "let x = 1;\n".repeat(50);
651        let diverse = (0..50)
652            .map(|i| format!("let var_{i} = func_{i}(arg_{i});"))
653            .collect::<Vec<_>>()
654            .join("\n");
655        let budget_rep = super::adaptive_ib_budget(&repetitive, 0.7);
656        let budget_div = super::adaptive_ib_budget(&diverse, 0.7);
657        assert!(
658            budget_rep < budget_div,
659            "repetitive content should get lower budget"
660        );
661    }
662}