Skip to main content

lean_ctx/tools/
ctx_preload.rs

1use crate::core::cache::SessionCache;
2use crate::core::graph_index::ProjectIndex;
3use crate::core::protocol;
4use crate::core::task_relevance::{compute_relevance, parse_task_hints, RelevanceScore};
5use crate::core::tokens::count_tokens;
6use crate::tools::CrpMode;
7
8const MAX_PRELOAD_FILES: usize = 8;
9const MAX_CRITICAL_LINES: usize = 15;
10const SIGNATURES_BUDGET: usize = 10;
11const TOTAL_TOKEN_BUDGET: usize = 4000;
12
13pub fn handle(
14    cache: &mut SessionCache,
15    task: &str,
16    path: Option<&str>,
17    crp_mode: CrpMode,
18) -> String {
19    if task.trim().is_empty() {
20        return "ERROR: ctx_preload requires a task description".to_string();
21    }
22
23    let project_root = path
24        .map(|p| p.to_string())
25        .unwrap_or_else(|| ".".to_string());
26
27    let index = crate::core::graph_index::load_or_build(&project_root);
28
29    let (task_files, task_keywords) = parse_task_hints(task);
30    let relevance = compute_relevance(&index, &task_files, &task_keywords);
31
32    let mut scored: Vec<_> = relevance
33        .iter()
34        .filter(|r| r.score >= 0.1)
35        .take(MAX_PRELOAD_FILES + 10)
36        .collect();
37
38    apply_heat_ranking(&mut scored, &index, &project_root);
39
40    let pop = crate::core::pop_pruning::decide_for_candidates(task, &project_root, &scored);
41    let candidates =
42        crate::core::pop_pruning::filter_candidates_by_pop(&project_root, &scored, &pop);
43
44    if candidates.is_empty() {
45        return format!(
46            "[task: {task}]\nNo directly relevant files found. Use ctx_overview for project map."
47        );
48    }
49
50    // Boltzmann allocation: p(file_i) = exp(score_i / T) / Z
51    // Temperature T is derived from task specificity:
52    //   - Many keywords / specific file mentions → low T → concentrate budget
53    //   - Few keywords / broad task → high T → spread budget evenly
54    let task_specificity =
55        (task_files.len() as f64 * 0.3 + task_keywords.len() as f64 * 0.1).clamp(0.0, 1.0);
56    let temperature = 0.8 - task_specificity * 0.6; // range [0.2, 0.8]
57    let temperature = temperature.max(0.1);
58
59    let allocations = boltzmann_allocate(&candidates, TOTAL_TOKEN_BUDGET, temperature);
60
61    let file_context: Vec<(String, usize)> = candidates
62        .iter()
63        .filter_map(|c| {
64            std::fs::read_to_string(&c.path)
65                .ok()
66                .map(|content| (c.path.clone(), content.lines().count()))
67        })
68        .collect();
69    let briefing = crate::core::task_briefing::build_briefing(task, &file_context);
70    let briefing_block = crate::core::task_briefing::format_briefing(&briefing);
71
72    let multi_intents = crate::core::intent_engine::detect_multi_intent(task);
73    let primary = &multi_intents[0];
74    let complexity = crate::core::intent_engine::classify_complexity(task, primary);
75
76    let mut output = Vec::new();
77    output.push(briefing_block);
78
79    let complexity_label = complexity.instruction_suffix().lines().next().unwrap_or("");
80    if multi_intents.len() > 1 {
81        output.push(format!(
82            "[task: {task}] | {} | {} sub-intents",
83            complexity_label,
84            multi_intents.len()
85        ));
86        for (i, sub) in multi_intents.iter().enumerate() {
87            output.push(format!(
88                "  {}. {} ({:.0}%)",
89                i + 1,
90                sub.task_type.as_str(),
91                sub.confidence * 100.0
92            ));
93        }
94    } else {
95        output.push(format!("[task: {task}] | {complexity_label}"));
96    }
97
98    for r in crate::core::prospective_memory::reminders_for_task(&project_root, task) {
99        output.push(r);
100    }
101
102    if !pop.excluded_modules.is_empty() {
103        output.push("POP:".to_string());
104        for ex in &pop.excluded_modules {
105            output.push(format!(
106                "  - exclude {}/ ({} candidates) — {}",
107                ex.module, ex.candidate_files, ex.reason
108            ));
109        }
110    }
111
112    let mut total_estimated_saved = 0usize;
113    let mut critical_count = 0usize;
114
115    for (rel, token_budget) in candidates.iter().zip(allocations.iter()) {
116        if *token_budget < 20 {
117            continue;
118        }
119        critical_count += 1;
120        if critical_count > MAX_PRELOAD_FILES {
121            break;
122        }
123
124        let content = match std::fs::read_to_string(&rel.path) {
125            Ok(c) => c,
126            Err(_) => continue,
127        };
128
129        let file_ref = cache.get_file_ref(&rel.path);
130        let short = protocol::shorten_path(&rel.path);
131        let line_count = content.lines().count();
132        let file_tokens = count_tokens(&content);
133
134        let _ = cache.store(&rel.path, content.clone());
135
136        let mode = budget_to_mode(*token_budget, file_tokens);
137
138        let critical_lines = extract_critical_lines(&content, &task_keywords, MAX_CRITICAL_LINES);
139        let sigs = extract_key_signatures(&content, SIGNATURES_BUDGET);
140        let imports = extract_imports(&content);
141
142        output.push(format!(
143            "\nCRITICAL: {file_ref}={short} {line_count}L score={:.1} budget={token_budget}tok mode={mode}",
144            rel.score
145        ));
146
147        if !critical_lines.is_empty() {
148            for (line_no, line) in &critical_lines {
149                output.push(format!("  :{line_no} {line}"));
150            }
151        }
152
153        if !imports.is_empty() {
154            output.push(format!("  imports: {}", imports.join(", ")));
155        }
156
157        if !sigs.is_empty() {
158            for sig in &sigs {
159                output.push(format!("  {sig}"));
160            }
161        }
162
163        total_estimated_saved += file_tokens;
164    }
165
166    let context_files: Vec<_> = relevance
167        .iter()
168        .filter(|r| r.score >= 0.1 && r.score < 0.3)
169        .take(10)
170        .collect();
171
172    if !context_files.is_empty() {
173        output.push("\nRELATED:".to_string());
174        for rel in &context_files {
175            let short = protocol::shorten_path(&rel.path);
176            output.push(format!(
177                "  {} mode={} score={:.1}",
178                short, rel.recommended_mode, rel.score
179            ));
180        }
181    }
182
183    let graph_edges: Vec<_> = index
184        .edges
185        .iter()
186        .filter(|e| {
187            candidates
188                .iter()
189                .any(|c| c.path == e.from || c.path == e.to)
190        })
191        .take(10)
192        .collect();
193
194    if !graph_edges.is_empty() {
195        output.push("\nGRAPH:".to_string());
196        for edge in &graph_edges {
197            let from_short = protocol::shorten_path(&edge.from);
198            let to_short = protocol::shorten_path(&edge.to);
199            output.push(format!("  {from_short} -> {to_short}"));
200        }
201    }
202
203    let preload_result = output.join("\n");
204    let preload_tokens = count_tokens(&preload_result);
205    let savings = protocol::format_savings(total_estimated_saved, preload_tokens);
206
207    if crp_mode.is_tdd() {
208        format!("{preload_result}\n{savings}")
209    } else {
210        format!("{preload_result}\n\nNext: ctx_read(path, mode=\"full\") for any file above.\n{savings}")
211    }
212}
213
214/// Boltzmann distribution for token budget allocation across files.
215/// p(file_i) = exp(score_i / T) / Z, then budget_i = total * p(file_i)
216fn boltzmann_allocate(
217    candidates: &[&crate::core::task_relevance::RelevanceScore],
218    total_budget: usize,
219    temperature: f64,
220) -> Vec<usize> {
221    if candidates.is_empty() {
222        return Vec::new();
223    }
224
225    let t = temperature.max(0.01);
226
227    // Compute exp(score / T) for each candidate, using log-sum-exp for numerical stability
228    let log_weights: Vec<f64> = candidates.iter().map(|c| c.score / t).collect();
229    let max_log = log_weights
230        .iter()
231        .cloned()
232        .fold(f64::NEG_INFINITY, f64::max);
233    let exp_weights: Vec<f64> = log_weights.iter().map(|&lw| (lw - max_log).exp()).collect();
234    let z: f64 = exp_weights.iter().sum();
235
236    if z <= 0.0 {
237        return vec![total_budget / candidates.len().max(1); candidates.len()];
238    }
239
240    let mut allocations: Vec<usize> = exp_weights
241        .iter()
242        .map(|&w| ((w / z) * total_budget as f64).round() as usize)
243        .collect();
244
245    // Ensure total doesn't exceed budget
246    let sum: usize = allocations.iter().sum();
247    if sum > total_budget {
248        let overflow = sum - total_budget;
249        if let Some(last) = allocations.last_mut() {
250            *last = last.saturating_sub(overflow);
251        }
252    }
253
254    allocations
255}
256
257/// Map a token budget to a recommended compression mode.
258fn budget_to_mode(budget: usize, file_tokens: usize) -> &'static str {
259    let ratio = budget as f64 / file_tokens.max(1) as f64;
260    if ratio >= 0.8 {
261        "full"
262    } else if ratio >= 0.4 {
263        "signatures"
264    } else if ratio >= 0.15 {
265        "map"
266    } else {
267        "reference"
268    }
269}
270
271fn extract_critical_lines(content: &str, keywords: &[String], max: usize) -> Vec<(usize, String)> {
272    let kw_lower: Vec<String> = keywords.iter().map(|k| k.to_lowercase()).collect();
273
274    let mut hits: Vec<(usize, String, usize)> = content
275        .lines()
276        .enumerate()
277        .filter_map(|(i, line)| {
278            let trimmed = line.trim();
279            if trimmed.is_empty() {
280                return None;
281            }
282            let line_lower = trimmed.to_lowercase();
283            let hit_count = kw_lower
284                .iter()
285                .filter(|kw| line_lower.contains(kw.as_str()))
286                .count();
287
288            let is_error = trimmed.contains("Error")
289                || trimmed.contains("Err(")
290                || trimmed.contains("panic!")
291                || trimmed.contains("unwrap()")
292                || trimmed.starts_with("return Err");
293
294            if hit_count > 0 || is_error {
295                let priority = hit_count + if is_error { 2 } else { 0 };
296                Some((i + 1, trimmed.to_string(), priority))
297            } else {
298                None
299            }
300        })
301        .collect();
302
303    hits.sort_by_key(|x| std::cmp::Reverse(x.2));
304    hits.truncate(max);
305    hits.iter().map(|(n, l, _)| (*n, l.clone())).collect()
306}
307
308fn extract_key_signatures(content: &str, max: usize) -> Vec<String> {
309    let sig_starters = [
310        "pub fn ",
311        "pub async fn ",
312        "pub struct ",
313        "pub enum ",
314        "pub trait ",
315        "pub type ",
316        "pub const ",
317    ];
318
319    content
320        .lines()
321        .filter(|line| {
322            let trimmed = line.trim();
323            sig_starters.iter().any(|s| trimmed.starts_with(s))
324        })
325        .take(max)
326        .map(|line| {
327            let trimmed = line.trim();
328            if trimmed.len() > 120 {
329                format!("{}...", &trimmed[..117])
330            } else {
331                trimmed.to_string()
332            }
333        })
334        .collect()
335}
336
337fn extract_imports(content: &str) -> Vec<String> {
338    content
339        .lines()
340        .filter(|line| {
341            let t = line.trim();
342            t.starts_with("use ") || t.starts_with("import ") || t.starts_with("from ")
343        })
344        .take(8)
345        .map(|line| {
346            let t = line.trim();
347            if let Some(rest) = t.strip_prefix("use ") {
348                rest.trim_end_matches(';').to_string()
349            } else {
350                t.to_string()
351            }
352        })
353        .collect()
354}
355
356fn apply_heat_ranking(candidates: &mut [&RelevanceScore], index: &ProjectIndex, root: &str) {
357    if index.files.is_empty() {
358        return;
359    }
360
361    let mut connection_counts: std::collections::HashMap<String, usize> =
362        std::collections::HashMap::new();
363    for edge in &index.edges {
364        *connection_counts.entry(edge.from.clone()).or_default() += 1;
365        *connection_counts.entry(edge.to.clone()).or_default() += 1;
366    }
367
368    let max_tokens = index
369        .files
370        .values()
371        .map(|f| f.token_count)
372        .max()
373        .unwrap_or(1) as f64;
374    let max_conn = connection_counts.values().max().copied().unwrap_or(1) as f64;
375
376    candidates.sort_by(|a, b| {
377        let heat_a = compute_heat(
378            &a.path,
379            root,
380            index,
381            &connection_counts,
382            max_tokens,
383            max_conn,
384        );
385        let heat_b = compute_heat(
386            &b.path,
387            root,
388            index,
389            &connection_counts,
390            max_tokens,
391            max_conn,
392        );
393        let combined_a = a.score * 0.6 + heat_a * 0.4;
394        let combined_b = b.score * 0.6 + heat_b * 0.4;
395        combined_b
396            .partial_cmp(&combined_a)
397            .unwrap_or(std::cmp::Ordering::Equal)
398    });
399}
400
401fn compute_heat(
402    path: &str,
403    root: &str,
404    index: &ProjectIndex,
405    connections: &std::collections::HashMap<String, usize>,
406    max_tokens: f64,
407    max_conn: f64,
408) -> f64 {
409    let rel = path
410        .strip_prefix(root)
411        .unwrap_or(path)
412        .trim_start_matches('/');
413
414    if let Some(entry) = index.files.get(rel) {
415        let conn = connections.get(rel).copied().unwrap_or(0);
416        let token_norm = entry.token_count as f64 / max_tokens;
417        let conn_norm = conn as f64 / max_conn;
418        token_norm * 0.4 + conn_norm * 0.6
419    } else {
420        0.0
421    }
422}
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427
428    #[test]
429    fn extract_critical_lines_finds_keywords() {
430        let content = "fn main() {\n    let token = validate();\n    return Err(e);\n}\n";
431        let result = extract_critical_lines(content, &["validate".to_string()], 5);
432        assert!(!result.is_empty());
433        assert!(result.iter().any(|(_, l)| l.contains("validate")));
434    }
435
436    #[test]
437    fn extract_critical_lines_prioritizes_errors() {
438        let content = "fn main() {\n    let x = 1;\n    return Err(\"bad\");\n    let token = validate();\n}\n";
439        let result = extract_critical_lines(content, &["validate".to_string()], 5);
440        assert!(result.len() >= 2);
441        assert!(result[0].1.contains("Err"), "errors should be first");
442    }
443
444    #[test]
445    fn extract_key_signatures_finds_pub() {
446        let content = "use std::io;\nfn private() {}\npub fn public_one() {}\npub struct Foo {}\n";
447        let sigs = extract_key_signatures(content, 10);
448        assert_eq!(sigs.len(), 2);
449        assert!(sigs[0].contains("pub fn public_one"));
450        assert!(sigs[1].contains("pub struct Foo"));
451    }
452
453    #[test]
454    fn extract_imports_works() {
455        let content = "use std::io;\nuse crate::core::cache;\nfn main() {}\n";
456        let imports = extract_imports(content);
457        assert_eq!(imports.len(), 2);
458        assert!(imports[0].contains("std::io"));
459    }
460}