Skip to main content

lean_ctx/core/
task_relevance.rs

1use std::collections::{HashMap, HashSet};
2
3use super::graph_provider::{EdgeInfo, GraphProvider};
4use super::neural::attention_learned::LearnedAttention;
5
6#[derive(Debug, Clone)]
7pub struct RelevanceScore {
8    pub path: String,
9    pub score: f64,
10    pub recommended_mode: &'static str,
11}
12
13pub fn compute_relevance(
14    gp: &GraphProvider,
15    task_files: &[String],
16    task_keywords: &[String],
17) -> Vec<RelevanceScore> {
18    let all_edges = gp.edges();
19    let file_set: HashSet<String> = gp.file_paths().into_iter().collect();
20    let adj = build_adjacency_resolved(&all_edges, &file_set);
21    let all_nodes: Vec<String> = file_set.into_iter().collect();
22    if all_nodes.is_empty() {
23        return Vec::new();
24    }
25
26    let node_idx: HashMap<&str, usize> = all_nodes
27        .iter()
28        .enumerate()
29        .map(|(i, n)| (n.as_str(), i))
30        .collect();
31    let n = all_nodes.len();
32
33    // Build degree-normalized adjacency for heat diffusion
34    let degrees: Vec<f64> = all_nodes
35        .iter()
36        .map(|node| {
37            adj.get(node)
38                .map_or(0.0, |neigh| neigh.len() as f64)
39                .max(1.0)
40        })
41        .collect();
42
43    // Seed vector: task files get 1.0
44    let mut heat: Vec<f64> = vec![0.0; n];
45    for f in task_files {
46        if let Some(&idx) = node_idx.get(f.as_str()) {
47            heat[idx] = 1.0;
48        }
49    }
50
51    // Heat diffusion: h(t+1) = (1-alpha)*h(t) + alpha * A_norm * h(t)
52    // Run for k iterations
53    let alpha = 0.5;
54    let iterations = 4;
55    for _ in 0..iterations {
56        let mut new_heat = vec![0.0; n];
57        for (i, node) in all_nodes.iter().enumerate() {
58            let self_term = (1.0 - alpha) * heat[i];
59            let mut neighbor_sum = 0.0;
60            if let Some(neighbors) = adj.get(node) {
61                for neighbor in neighbors {
62                    if let Some(&j) = node_idx.get(neighbor.as_str()) {
63                        neighbor_sum += heat[j] / degrees[j];
64                    }
65                }
66            }
67            new_heat[i] = self_term + alpha * neighbor_sum;
68        }
69        heat = new_heat;
70    }
71
72    // PageRank centrality for gateway detection
73    let mut pagerank = vec![1.0 / n as f64; n];
74    let damping = 0.85;
75    for _ in 0..8 {
76        let mut new_pr = vec![(1.0 - damping) / n as f64; n];
77        for (i, node) in all_nodes.iter().enumerate() {
78            if let Some(neighbors) = adj.get(node) {
79                let out_deg = neighbors.len().max(1) as f64;
80                for neighbor in neighbors {
81                    if let Some(&j) = node_idx.get(neighbor.as_str()) {
82                        new_pr[j] += damping * pagerank[i] / out_deg;
83                    }
84                }
85            }
86        }
87        pagerank = new_pr;
88    }
89
90    // Combine: heat (primary) + pagerank centrality (gateway bonus)
91    let mut scores: HashMap<String, f64> = HashMap::new();
92    let heat_max = heat.iter().copied().fold(0.0_f64, f64::max).max(1e-10);
93    let pr_max = pagerank.iter().copied().fold(0.0_f64, f64::max).max(1e-10);
94
95    for (i, node) in all_nodes.iter().enumerate() {
96        let h = heat[i] / heat_max;
97        let pr = pagerank[i] / pr_max;
98        let combined = h * 0.8 + pr * 0.2;
99        if combined > 0.01 {
100            scores.insert(node.clone(), combined);
101        }
102    }
103
104    if !task_keywords.is_empty() {
105        let kw_lower: Vec<String> = task_keywords.iter().map(|k| k.to_lowercase()).collect();
106        for file_path in &all_nodes {
107            let path_lower = file_path.to_lowercase();
108            let mut keyword_hits = 0;
109            for kw in &kw_lower {
110                if path_lower.contains(kw) {
111                    keyword_hits += 1;
112                }
113                if let Some(entry) = gp.get_file_entry(file_path) {
114                    for export in &entry.exports {
115                        if export.to_lowercase().contains(kw) {
116                            keyword_hits += 1;
117                        }
118                    }
119                }
120            }
121            if keyword_hits > 0 {
122                let boost = (keyword_hits as f64 * 0.15).min(0.6);
123                let entry = scores.entry(file_path.clone()).or_insert(0.0);
124                *entry = (*entry + boost).min(1.0);
125            }
126        }
127    }
128
129    let mut result: Vec<RelevanceScore> = scores
130        .into_iter()
131        .map(|(path, score)| {
132            let mode = recommend_mode(score);
133            RelevanceScore {
134                path,
135                score,
136                recommended_mode: mode,
137            }
138        })
139        .collect();
140
141    result.sort_by(|a, b| {
142        b.score
143            .partial_cmp(&a.score)
144            .unwrap_or(std::cmp::Ordering::Equal)
145    });
146    result
147}
148
149pub fn compute_relevance_from_intent(
150    gp: &GraphProvider,
151    intent: &super::intent_engine::StructuredIntent,
152) -> Vec<RelevanceScore> {
153    use super::intent_engine::IntentScope;
154
155    let mut file_seeds: Vec<String> = Vec::new();
156    let mut extra_keywords: Vec<String> = intent.keywords.clone();
157
158    let file_paths = gp.file_paths();
159    for target in &intent.targets {
160        if target.contains('.') || target.contains('/') {
161            let matched = resolve_target_to_files(&file_paths, target);
162            if matched.is_empty() {
163                extra_keywords.push(target.clone());
164            } else {
165                file_seeds.extend(matched);
166            }
167        } else {
168            let from_symbol = resolve_symbol_to_files(gp, target);
169            if from_symbol.is_empty() {
170                extra_keywords.push(target.clone());
171            } else {
172                file_seeds.extend(from_symbol);
173            }
174        }
175    }
176
177    if let Some(lang) = &intent.language_hint {
178        let lang_ext = match lang.as_str() {
179            "rust" => Some("rs"),
180            "typescript" => Some("ts"),
181            "javascript" => Some("js"),
182            "python" => Some("py"),
183            "go" => Some("go"),
184            "ruby" => Some("rb"),
185            "java" => Some("java"),
186            _ => None,
187        };
188        if let Some(ext) = lang_ext {
189            if file_seeds.is_empty() {
190                for path in &file_paths {
191                    if path.ends_with(&format!(".{ext}")) {
192                        extra_keywords.push(
193                            std::path::Path::new(path)
194                                .file_stem()
195                                .and_then(|s| s.to_str())
196                                .unwrap_or("")
197                                .to_string(),
198                        );
199                        break;
200                    }
201                }
202            }
203        }
204    }
205
206    let mut result = compute_relevance(gp, &file_seeds, &extra_keywords);
207
208    match intent.scope {
209        IntentScope::SingleFile => {
210            result.truncate(5);
211        }
212        IntentScope::MultiFile => {
213            result.truncate(15);
214        }
215        IntentScope::CrossModule | IntentScope::ProjectWide => {}
216    }
217
218    result
219}
220
221fn resolve_target_to_files(file_paths: &[String], target: &str) -> Vec<String> {
222    file_paths
223        .iter()
224        .filter(|path| path.ends_with(target) || path.contains(target))
225        .cloned()
226        .collect()
227}
228
229fn resolve_symbol_to_files(gp: &GraphProvider, symbol: &str) -> Vec<String> {
230    let found = gp.find_symbols(symbol, None, None);
231    let mut matches: Vec<String> = found
232        .into_iter()
233        .map(|s| s.file)
234        .collect::<HashSet<_>>()
235        .into_iter()
236        .collect();
237    if matches.is_empty() {
238        let sym_lower = symbol.to_lowercase();
239        for path in gp.file_paths() {
240            if let Some(entry) = gp.get_file_entry(&path) {
241                if entry
242                    .exports
243                    .iter()
244                    .any(|e| e.to_lowercase().contains(&sym_lower))
245                    && !matches.contains(&path)
246                {
247                    matches.push(path);
248                }
249            }
250        }
251    }
252    matches
253}
254
255fn recommend_mode(score: f64) -> &'static str {
256    if score >= 0.8 {
257        "full"
258    } else if score >= 0.5 {
259        "signatures"
260    } else if score >= 0.2 {
261        "map"
262    } else {
263        "reference"
264    }
265}
266
267fn build_adjacency_resolved(
268    edges: &[EdgeInfo],
269    file_set: &HashSet<String>,
270) -> HashMap<String, Vec<String>> {
271    let file_paths_vec: Vec<&str> = file_set.iter().map(String::as_str).collect();
272    let module_to_file = build_module_map(edges, file_set, &file_paths_vec);
273    let mut adj: HashMap<String, Vec<String>> = HashMap::new();
274
275    for edge in edges {
276        let from = &edge.from;
277        let to_resolved = module_to_file
278            .get(&edge.to)
279            .cloned()
280            .unwrap_or_else(|| edge.to.clone());
281
282        if file_set.contains(from) && file_set.contains(&to_resolved) {
283            adj.entry(from.clone())
284                .or_default()
285                .push(to_resolved.clone());
286            adj.entry(to_resolved).or_default().push(from.clone());
287        }
288    }
289    adj
290}
291
292fn build_module_map(
293    edges: &[EdgeInfo],
294    file_set: &HashSet<String>,
295    file_paths: &[&str],
296) -> HashMap<String, String> {
297    let mut mapping: HashMap<String, String> = HashMap::new();
298
299    let edge_targets: HashSet<String> = edges.iter().map(|e| e.to.clone()).collect();
300
301    for target in &edge_targets {
302        if file_set.contains(target) {
303            mapping.insert(target.clone(), target.clone());
304            continue;
305        }
306
307        if let Some(resolved) = resolve_module_to_file(target, file_paths) {
308            mapping.insert(target.clone(), resolved);
309        }
310    }
311
312    mapping
313}
314
315fn resolve_module_to_file(module_path: &str, file_paths: &[&str]) -> Option<String> {
316    let cleaned = module_path
317        .trim_start_matches("crate::")
318        .trim_start_matches("super::");
319
320    // Strip trailing symbol (e.g. `core::tokens::count_tokens` → `core::tokens`)
321    let parts: Vec<&str> = cleaned.split("::").collect();
322
323    // Try progressively shorter prefixes to find a matching file
324    for end in (1..=parts.len()).rev() {
325        let candidate = parts[..end].join("/");
326
327        // Try as .rs file
328        for fp in file_paths {
329            let fp_normalized = fp
330                .trim_start_matches("rust/src/")
331                .trim_start_matches("src/");
332
333            if fp_normalized == format!("{candidate}.rs")
334                || fp_normalized == format!("{candidate}/mod.rs")
335                || fp.ends_with(&format!("/{candidate}.rs"))
336                || fp.ends_with(&format!("/{candidate}/mod.rs"))
337            {
338                return Some(fp.to_string());
339            }
340        }
341    }
342
343    // Fallback: match by last segment as filename stem
344    if let Some(last) = parts.last() {
345        let stem = format!("{last}.rs");
346        for fp in file_paths {
347            if fp.ends_with(&stem) {
348                return Some(fp.to_string());
349            }
350        }
351    }
352
353    None
354}
355
356/// Extract likely task-relevant file paths and keywords from a task description.
357pub fn parse_task_hints(task_description: &str) -> (Vec<String>, Vec<String>) {
358    let mut files = Vec::new();
359    let mut keywords = Vec::new();
360
361    for word in task_description.split_whitespace() {
362        let clean = word.trim_matches(|c: char| {
363            !c.is_alphanumeric() && c != '.' && c != '/' && c != '_' && c != '-'
364        });
365        if clean.contains('.') && {
366            let p = std::path::Path::new(clean);
367            clean.contains('/')
368                || p.extension().is_some_and(|e| {
369                    e.eq_ignore_ascii_case("rs")
370                        || e.eq_ignore_ascii_case("ts")
371                        || e.eq_ignore_ascii_case("py")
372                        || e.eq_ignore_ascii_case("go")
373                        || e.eq_ignore_ascii_case("js")
374                })
375        } {
376            files.push(clean.to_string());
377        } else if clean.len() >= 3 && !STOP_WORDS.contains(&clean.to_lowercase().as_str()) {
378            keywords.push(clean.to_string());
379        }
380    }
381
382    (files, keywords)
383}
384
385const STOP_WORDS: &[&str] = &[
386    "the", "and", "for", "that", "this", "with", "from", "have", "has", "was", "are", "been",
387    "not", "but", "all", "can", "had", "her", "one", "our", "out", "you", "its", "will", "each",
388    "make", "like", "fix", "add", "use", "get", "set", "run", "new", "old", "should", "would",
389    "could", "into", "also", "than", "them", "then", "when", "just", "only", "very", "some",
390    "more", "other", "nach", "und", "die", "der", "das", "ist", "ein", "eine", "nicht", "auf",
391    "mit",
392];
393
394struct StructuralWeights {
395    error_handling: f64,
396    definition: f64,
397    control_flow: f64,
398    closing_brace: f64,
399    other: f64,
400}
401
402impl StructuralWeights {
403    const DEFAULT: Self = Self {
404        error_handling: 1.5,
405        definition: 1.0,
406        control_flow: 0.5,
407        closing_brace: 0.15,
408        other: 0.3,
409    };
410
411    fn for_task_type(task_type: Option<super::intent_engine::TaskType>) -> Self {
412        use super::intent_engine::TaskType;
413        match task_type {
414            Some(TaskType::FixBug) => Self {
415                error_handling: 2.0,
416                definition: 0.8,
417                control_flow: 0.8,
418                closing_brace: 0.1,
419                other: 0.2,
420            },
421            Some(TaskType::Debug) => Self {
422                error_handling: 2.0,
423                definition: 0.6,
424                control_flow: 1.0,
425                closing_brace: 0.1,
426                other: 0.2,
427            },
428            Some(TaskType::Generate) => Self {
429                error_handling: 0.8,
430                definition: 1.5,
431                control_flow: 0.3,
432                closing_brace: 0.15,
433                other: 0.4,
434            },
435            Some(TaskType::Refactor) => Self {
436                error_handling: 1.0,
437                definition: 1.5,
438                control_flow: 0.6,
439                closing_brace: 0.2,
440                other: 0.3,
441            },
442            Some(TaskType::Test) => Self {
443                error_handling: 1.2,
444                definition: 1.3,
445                control_flow: 0.4,
446                closing_brace: 0.15,
447                other: 0.3,
448            },
449            Some(TaskType::Review) => Self {
450                error_handling: 1.3,
451                definition: 1.2,
452                control_flow: 0.6,
453                closing_brace: 0.15,
454                other: 0.3,
455            },
456            None | Some(TaskType::Explore | _) => Self::DEFAULT,
457        }
458    }
459}
460
461/// Information Bottleneck filter v3 — Mutual Information scoring, QUITO-X inspired.
462///
463/// IB principle: maximize I(T;Y) (task relevance) while minimizing I(T;X) (input redundancy).
464/// v3: MI(line, task) approximated via token overlap + IDF weighting + structural importance.
465///
466/// Key changes from v2:
467///   - Mutual Information scoring: MI(line, task) = H(line) - H(line|task)
468///   - Adaptive budget allocation based on task type via TaskClassifier
469///   - Token-level IDF computed over full document for better term weighting
470///   - Maintains L-curve attention, MMR dedup, error-handling priority from v2
471pub fn information_bottleneck_filter(
472    content: &str,
473    task_keywords: &[String],
474    budget_ratio: f64,
475) -> String {
476    information_bottleneck_filter_typed(content, task_keywords, budget_ratio, None)
477}
478
479/// Task-type-aware IB filter. Uses `TaskType` to adjust structural weights.
480pub fn information_bottleneck_filter_typed(
481    content: &str,
482    task_keywords: &[String],
483    budget_ratio: f64,
484    task_type: Option<super::intent_engine::TaskType>,
485) -> String {
486    let lines: Vec<&str> = content.lines().collect();
487    if lines.is_empty() {
488        return String::new();
489    }
490
491    let n = lines.len();
492    let kw_lower: Vec<String> = task_keywords.iter().map(|k| k.to_lowercase()).collect();
493    let attention = LearnedAttention::with_defaults();
494
495    let mut global_token_freq: HashMap<&str, usize> = HashMap::new();
496    for line in &lines {
497        for token in line.split_whitespace() {
498            *global_token_freq.entry(token).or_insert(0) += 1;
499        }
500    }
501    let total_unique = global_token_freq.len().max(1) as f64;
502    let total_lines = n.max(1) as f64;
503
504    let task_token_set: HashSet<String> = kw_lower
505        .iter()
506        .flat_map(|kw| kw.split(|c: char| !c.is_alphanumeric()).map(String::from))
507        .filter(|t| t.len() >= 2)
508        .collect();
509
510    let effective_ratio = if task_token_set.is_empty() {
511        budget_ratio
512    } else {
513        adaptive_ib_budget(content, budget_ratio)
514    };
515
516    let weights = StructuralWeights::for_task_type(task_type);
517
518    let mut scored_lines: Vec<(usize, &str, f64)> = lines
519        .iter()
520        .enumerate()
521        .map(|(i, line)| {
522            let trimmed = line.trim();
523            if trimmed.is_empty() {
524                return (i, *line, 0.05);
525            }
526
527            let line_lower = trimmed.to_lowercase();
528            let line_tokens: Vec<&str> = trimmed.split_whitespace().collect();
529            let line_token_count = line_tokens.len().max(1) as f64;
530
531            let mi_score = if task_token_set.is_empty() {
532                0.0
533            } else {
534                let line_token_set: HashSet<String> =
535                    line_tokens.iter().map(|t| t.to_lowercase()).collect();
536                let overlap: f64 = line_token_set
537                    .iter()
538                    .filter(|t| task_token_set.iter().any(|kw| t.contains(kw.as_str())))
539                    .map(|t| {
540                        let freq = *global_token_freq.get(t.as_str()).unwrap_or(&1) as f64;
541                        (total_lines / freq).ln().max(0.1)
542                    })
543                    .sum();
544                overlap / line_token_count
545            };
546
547            let keyword_hits: f64 = kw_lower
548                .iter()
549                .filter(|kw| line_lower.contains(kw.as_str()))
550                .count() as f64;
551
552            let structural = if is_error_handling(trimmed) {
553                weights.error_handling
554            } else if is_definition_line(trimmed) {
555                weights.definition
556            } else if is_control_flow(trimmed) {
557                weights.control_flow
558            } else if is_closing_brace(trimmed) {
559                weights.closing_brace
560            } else {
561                weights.other
562            };
563            let relevance = mi_score * 0.4 + keyword_hits * 0.3 + structural;
564
565            let unique_in_line = line_tokens.iter().collect::<HashSet<_>>().len() as f64;
566            let token_diversity = unique_in_line / line_token_count;
567
568            let avg_idf: f64 = if line_tokens.is_empty() {
569                0.0
570            } else {
571                line_tokens
572                    .iter()
573                    .map(|t| {
574                        let freq = *global_token_freq.get(t).unwrap_or(&1) as f64;
575                        (total_unique / freq).ln().max(0.0)
576                    })
577                    .sum::<f64>()
578                    / line_token_count
579            };
580            let information = (token_diversity * 0.4 + (avg_idf.min(3.0) / 3.0) * 0.6).min(1.0);
581
582            let pos = i as f64 / n.max(1) as f64;
583            let attn_weight = attention.weight(pos);
584
585            let score = (relevance * 0.6 + 0.05)
586                * (information * 0.25 + 0.05)
587                * (attn_weight * 0.15 + 0.05);
588
589            (i, *line, score)
590        })
591        .collect();
592
593    let budget = ((n as f64) * effective_ratio).ceil() as usize;
594
595    scored_lines.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
596
597    let selected = mmr_select(&scored_lines, budget, 0.3);
598
599    let mut output_lines: Vec<&str> = Vec::with_capacity(budget + 1);
600
601    if !kw_lower.is_empty() {
602        output_lines.push("");
603    }
604
605    for (_, line, _) in &selected {
606        output_lines.push(line);
607    }
608
609    if !kw_lower.is_empty() {
610        let summary = format!("[task: {}]", task_keywords.join(", "));
611        let mut result = summary;
612        result.push('\n');
613        result.push_str(&output_lines[1..].to_vec().join("\n"));
614        return result;
615    }
616
617    output_lines.join("\n")
618}
619
620/// Maximum Marginal Relevance selection — greedy selection that penalizes
621/// redundancy with already-selected lines using token-set Jaccard similarity.
622///
623/// MMR(i) = relevance(i) - lambda * max_{j in S} jaccard(i, j)
624fn mmr_select<'a>(
625    candidates: &[(usize, &'a str, f64)],
626    budget: usize,
627    lambda: f64,
628) -> Vec<(usize, &'a str, f64)> {
629    if candidates.is_empty() || budget == 0 {
630        return Vec::new();
631    }
632
633    let mut selected: Vec<(usize, &'a str, f64)> = Vec::with_capacity(budget);
634    let mut remaining: Vec<(usize, &'a str, f64)> = candidates.to_vec();
635
636    // Always take the top-scored line first
637    selected.push(remaining.remove(0));
638
639    while selected.len() < budget && !remaining.is_empty() {
640        let mut best_idx = 0;
641        let mut best_mmr = f64::NEG_INFINITY;
642
643        for (i, &(_, cand_line, cand_score)) in remaining.iter().enumerate() {
644            let cand_tokens: HashSet<&str> = cand_line.split_whitespace().collect();
645            if cand_tokens.is_empty() {
646                if cand_score > best_mmr {
647                    best_mmr = cand_score;
648                    best_idx = i;
649                }
650                continue;
651            }
652
653            let max_sim = selected
654                .iter()
655                .map(|&(_, sel_line, _)| {
656                    let sel_tokens: HashSet<&str> = sel_line.split_whitespace().collect();
657                    if sel_tokens.is_empty() {
658                        return 0.0;
659                    }
660                    let inter = cand_tokens.intersection(&sel_tokens).count();
661                    let union = cand_tokens.union(&sel_tokens).count();
662                    if union == 0 {
663                        0.0
664                    } else {
665                        inter as f64 / union as f64
666                    }
667                })
668                .fold(0.0_f64, f64::max);
669
670            let mmr = cand_score - lambda * max_sim;
671            if mmr > best_mmr {
672                best_mmr = mmr;
673                best_idx = i;
674            }
675        }
676
677        selected.push(remaining.remove(best_idx));
678    }
679
680    selected
681}
682
683fn is_error_handling(line: &str) -> bool {
684    line.starts_with("return Err(")
685        || line.starts_with("Err(")
686        || line.starts_with("bail!(")
687        || line.starts_with("anyhow::bail!")
688        || line.contains(".map_err(")
689        || line.contains("unwrap()")
690        || line.contains("expect(\"")
691        || line.starts_with("raise ")
692        || line.starts_with("throw ")
693        || line.starts_with("catch ")
694        || line.starts_with("except ")
695        || line.starts_with("try ")
696        || (line.contains("?;") && !line.starts_with("//"))
697        || line.starts_with("panic!(")
698        || line.contains("Error::")
699        || line.contains("error!")
700}
701
702/// Compute an adaptive IB budget ratio based on content characteristics.
703/// Highly repetitive content → more aggressive filtering (lower ratio).
704/// High-entropy diverse content → more conservative (higher ratio).
705pub fn adaptive_ib_budget(content: &str, base_ratio: f64) -> f64 {
706    let lines: Vec<&str> = content.lines().collect();
707    if lines.len() < 10 {
708        return 1.0;
709    }
710
711    let mut token_freq: HashMap<&str, usize> = HashMap::new();
712    let mut total_tokens = 0usize;
713    for line in &lines {
714        for token in line.split_whitespace() {
715            *token_freq.entry(token).or_insert(0) += 1;
716            total_tokens += 1;
717        }
718    }
719
720    if total_tokens == 0 {
721        return base_ratio;
722    }
723
724    let unique_ratio = token_freq.len() as f64 / total_tokens as f64;
725    let repetition_factor = 1.0 - unique_ratio;
726
727    (base_ratio * (1.0 - repetition_factor * 0.3)).clamp(0.2, 1.0)
728}
729
730fn is_definition_line(line: &str) -> bool {
731    let prefixes = [
732        "fn ",
733        "pub fn ",
734        "async fn ",
735        "pub async fn ",
736        "struct ",
737        "pub struct ",
738        "enum ",
739        "pub enum ",
740        "trait ",
741        "pub trait ",
742        "impl ",
743        "type ",
744        "pub type ",
745        "const ",
746        "pub const ",
747        "static ",
748        "pub static ",
749        "class ",
750        "export class ",
751        "interface ",
752        "export interface ",
753        "function ",
754        "export function ",
755        "async function ",
756        "def ",
757        "async def ",
758        "func ",
759    ];
760    prefixes
761        .iter()
762        .any(|p| line.starts_with(p) || line.trim_start().starts_with(p))
763}
764
765fn is_control_flow(line: &str) -> bool {
766    let trimmed = line.trim();
767    trimmed.starts_with("if ")
768        || trimmed.starts_with("else ")
769        || trimmed.starts_with("match ")
770        || trimmed.starts_with("for ")
771        || trimmed.starts_with("while ")
772        || trimmed.starts_with("return ")
773        || trimmed.starts_with("break")
774        || trimmed.starts_with("continue")
775        || trimmed.starts_with("yield")
776        || trimmed.starts_with("await ")
777}
778
779fn is_closing_brace(line: &str) -> bool {
780    let trimmed = line.trim();
781    trimmed == "}" || trimmed == "};" || trimmed == "})" || trimmed == "});"
782}
783
784#[cfg(test)]
785mod tests {
786    use super::*;
787
788    #[test]
789    fn parse_task_finds_files_and_keywords() {
790        let (files, keywords) =
791            parse_task_hints("Fix the authentication bug in src/auth.rs and update tests");
792        assert!(files.iter().any(|f| f.contains("auth.rs")));
793        assert!(keywords
794            .iter()
795            .any(|k| k.to_lowercase().contains("authentication")));
796    }
797
798    #[test]
799    fn recommend_mode_by_score() {
800        assert_eq!(recommend_mode(1.0), "full");
801        assert_eq!(recommend_mode(0.6), "signatures");
802        assert_eq!(recommend_mode(0.3), "map");
803        assert_eq!(recommend_mode(0.1), "reference");
804    }
805
806    #[test]
807    fn info_bottleneck_preserves_definitions() {
808        let content = "fn main() {\n    let x = 42;\n    // boring comment\n    println!(x);\n}\n";
809        let result = information_bottleneck_filter(content, &["main".to_string()], 0.6);
810        assert!(result.contains("fn main"), "definitions must be preserved");
811        assert!(result.contains("[task: main]"), "should have task summary");
812    }
813
814    #[test]
815    fn info_bottleneck_error_handling_priority() {
816        let content = "fn validate() {\n    let data = parse()?;\n    return Err(\"invalid\");\n    let x = 1;\n    let y = 2;\n}\n";
817        let result = information_bottleneck_filter(content, &["validate".to_string()], 0.5);
818        assert!(
819            result.contains("return Err"),
820            "error handling should survive filtering"
821        );
822    }
823
824    #[test]
825    fn info_bottleneck_score_sorted() {
826        let content = "fn important() {\n    let x = 1;\n    let y = 2;\n    let z = 3;\n}\n}\n";
827        let result = information_bottleneck_filter(content, &[], 0.6);
828        let lines: Vec<&str> = result.lines().collect();
829        let def_pos = lines.iter().position(|l| l.contains("fn important"));
830        let brace_pos = lines.iter().position(|l| l.trim() == "}");
831        if let (Some(d), Some(b)) = (def_pos, brace_pos) {
832            assert!(
833                d < b,
834                "definitions should appear before closing braces in score-sorted output"
835            );
836        }
837    }
838
839    #[test]
840    fn adaptive_budget_reduces_for_repetitive() {
841        let repetitive = "let x = 1;\n".repeat(50);
842        let diverse = (0..50)
843            .map(|i| format!("let var_{i} = func_{i}(arg_{i});"))
844            .collect::<Vec<_>>()
845            .join("\n");
846        let budget_rep = super::adaptive_ib_budget(&repetitive, 0.7);
847        let budget_div = super::adaptive_ib_budget(&diverse, 0.7);
848        assert!(
849            budget_rep < budget_div,
850            "repetitive content should get lower budget"
851        );
852    }
853}