Skip to main content

lean_ctx/tools/
ctx_preload.rs

1use crate::core::cache::SessionCache;
2use crate::core::graph_index::ProjectIndex;
3use crate::core::protocol;
4use crate::core::task_relevance::{compute_relevance, parse_task_hints, RelevanceScore};
5use crate::core::tokens::count_tokens;
6use crate::tools::CrpMode;
7
8const MAX_PRELOAD_FILES: usize = 8;
9const MAX_CRITICAL_LINES: usize = 15;
10const SIGNATURES_BUDGET: usize = 10;
11const TOTAL_TOKEN_BUDGET: usize = 4000;
12
13pub fn handle(
14    cache: &mut SessionCache,
15    task: &str,
16    path: Option<&str>,
17    crp_mode: CrpMode,
18) -> String {
19    if task.trim().is_empty() {
20        return "ERROR: ctx_preload requires a task description".to_string();
21    }
22
23    let project_root = path.map_or_else(|| ".".to_string(), std::string::ToString::to_string);
24
25    let index = crate::core::graph_index::load_or_build(&project_root);
26
27    let session_intent =
28        crate::core::session::SessionState::load_latest().and_then(|s| s.active_structured_intent);
29
30    let (task_files, task_keywords) = parse_task_hints(task);
31    let relevance = if let Some(ref intent) = session_intent {
32        crate::core::task_relevance::compute_relevance_from_intent(&index, intent)
33    } else {
34        compute_relevance(&index, &task_files, &task_keywords)
35    };
36
37    let mut scored: Vec<_> = relevance
38        .iter()
39        .filter(|r| r.score >= 0.1)
40        .take(MAX_PRELOAD_FILES + 10)
41        .collect();
42
43    apply_heat_ranking(&mut scored, &index, &project_root);
44
45    let pop = crate::core::pop_pruning::decide_for_candidates(task, &project_root, &scored);
46    let candidates =
47        crate::core::pop_pruning::filter_candidates_by_pop(&project_root, &scored, &pop);
48
49    if candidates.is_empty() {
50        return format!(
51            "[task: {task}]\nNo directly relevant files found. Use ctx_overview for project map."
52        );
53    }
54
55    // Boltzmann allocation: p(file_i) = exp(score_i / T) / Z
56    // Temperature T is derived from task specificity:
57    //   - Many keywords / specific file mentions → low T → concentrate budget
58    //   - Few keywords / broad task → high T → spread budget evenly
59    let task_specificity =
60        (task_files.len() as f64 * 0.3 + task_keywords.len() as f64 * 0.1).clamp(0.0, 1.0);
61    let temperature = 0.8 - task_specificity * 0.6; // range [0.2, 0.8]
62    let temperature = temperature.max(0.1);
63
64    let allocations = boltzmann_allocate(&candidates, TOTAL_TOKEN_BUDGET, temperature);
65
66    let file_context: Vec<(String, usize)> = candidates
67        .iter()
68        .filter_map(|c| {
69            std::fs::read_to_string(&c.path)
70                .ok()
71                .map(|content| (c.path.clone(), content.lines().count()))
72        })
73        .collect();
74    let briefing = crate::core::task_briefing::build_briefing(task, &file_context);
75    let briefing_block = crate::core::task_briefing::format_briefing(&briefing);
76
77    let multi_intents = crate::core::intent_engine::detect_multi_intent(task);
78    let primary = &multi_intents[0];
79    let complexity = crate::core::intent_engine::classify_complexity(task, primary);
80
81    let mut output = Vec::new();
82    output.push(briefing_block);
83
84    let complexity_label = complexity.instruction_suffix().lines().next().unwrap_or("");
85    if multi_intents.len() > 1 {
86        output.push(format!(
87            "[task: {task}] | {} | {} sub-intents",
88            complexity_label,
89            multi_intents.len()
90        ));
91        for (i, sub) in multi_intents.iter().enumerate() {
92            output.push(format!(
93                "  {}. {} ({:.0}%)",
94                i + 1,
95                sub.task_type.as_str(),
96                sub.confidence * 100.0
97            ));
98        }
99    } else {
100        output.push(format!("[task: {task}] | {complexity_label}"));
101    }
102
103    for r in crate::core::prospective_memory::reminders_for_task(&project_root, task) {
104        output.push(r);
105    }
106
107    if !pop.excluded_modules.is_empty() {
108        output.push("POP:".to_string());
109        for ex in &pop.excluded_modules {
110            output.push(format!(
111                "  - exclude {}/ ({} candidates) — {}",
112                ex.module, ex.candidate_files, ex.reason
113            ));
114        }
115    }
116
117    let mut total_estimated_saved = 0usize;
118    let mut critical_count = 0usize;
119
120    for (rel, token_budget) in candidates.iter().zip(allocations.iter()) {
121        if *token_budget < 20 {
122            continue;
123        }
124        critical_count += 1;
125        if critical_count > MAX_PRELOAD_FILES {
126            break;
127        }
128
129        let Ok(content) = std::fs::read_to_string(&rel.path) else {
130            continue;
131        };
132
133        let file_ref = cache.get_file_ref(&rel.path);
134        let short = protocol::shorten_path(&rel.path);
135        let line_count = content.lines().count();
136        let file_tokens = count_tokens(&content);
137
138        let _ = cache.store(&rel.path, content.clone());
139
140        let mode = budget_to_mode(*token_budget, file_tokens);
141
142        let critical_lines = extract_critical_lines(&content, &task_keywords, MAX_CRITICAL_LINES);
143        let sigs = extract_key_signatures(&content, SIGNATURES_BUDGET);
144        let imports = extract_imports(&content);
145
146        output.push(format!(
147            "\nCRITICAL: {file_ref}={short} {line_count}L score={:.1} budget={token_budget}tok mode={mode}",
148            rel.score
149        ));
150
151        if !critical_lines.is_empty() {
152            for (line_no, line) in &critical_lines {
153                output.push(format!("  :{line_no} {line}"));
154            }
155        }
156
157        if !imports.is_empty() {
158            output.push(format!("  imports: {}", imports.join(", ")));
159        }
160
161        if !sigs.is_empty() {
162            for sig in &sigs {
163                output.push(format!("  {sig}"));
164            }
165        }
166
167        total_estimated_saved += file_tokens;
168    }
169
170    let context_files: Vec<_> = relevance
171        .iter()
172        .filter(|r| r.score >= 0.1 && r.score < 0.3)
173        .take(10)
174        .collect();
175
176    if !context_files.is_empty() {
177        output.push("\nRELATED:".to_string());
178        for rel in &context_files {
179            let short = protocol::shorten_path(&rel.path);
180            output.push(format!(
181                "  {} mode={} score={:.1}",
182                short, rel.recommended_mode, rel.score
183            ));
184        }
185    }
186
187    let graph_edges: Vec<_> = index
188        .edges
189        .iter()
190        .filter(|e| {
191            candidates
192                .iter()
193                .any(|c| c.path == e.from || c.path == e.to)
194        })
195        .take(10)
196        .collect();
197
198    if !graph_edges.is_empty() {
199        output.push("\nGRAPH:".to_string());
200        for edge in &graph_edges {
201            let from_short = protocol::shorten_path(&edge.from);
202            let to_short = protocol::shorten_path(&edge.to);
203            output.push(format!("  {from_short} -> {to_short}"));
204        }
205    }
206
207    let preload_result = output.join("\n");
208    let preload_tokens = count_tokens(&preload_result);
209    let savings = protocol::format_savings(total_estimated_saved, preload_tokens);
210
211    if crp_mode.is_tdd() {
212        format!("{preload_result}\n{savings}")
213    } else {
214        format!("{preload_result}\n\nNext: ctx_read(path, mode=\"full\") for any file above.\n{savings}")
215    }
216}
217
218/// Boltzmann distribution for token budget allocation across files.
219/// p(file_i) = exp(score_i / T) / Z, then budget_i = total * p(file_i)
220fn boltzmann_allocate(
221    candidates: &[&crate::core::task_relevance::RelevanceScore],
222    total_budget: usize,
223    temperature: f64,
224) -> Vec<usize> {
225    if candidates.is_empty() {
226        return Vec::new();
227    }
228
229    let t = temperature.max(0.01);
230
231    // Compute exp(score / T) for each candidate, using log-sum-exp for numerical stability
232    let log_weights: Vec<f64> = candidates.iter().map(|c| c.score / t).collect();
233    let max_log = log_weights
234        .iter()
235        .copied()
236        .fold(f64::NEG_INFINITY, f64::max);
237    let exp_weights: Vec<f64> = log_weights.iter().map(|&lw| (lw - max_log).exp()).collect();
238    let z: f64 = exp_weights.iter().sum();
239
240    if z <= 0.0 {
241        return vec![total_budget / candidates.len().max(1); candidates.len()];
242    }
243
244    let mut allocations: Vec<usize> = exp_weights
245        .iter()
246        .map(|&w| ((w / z) * total_budget as f64).round() as usize)
247        .collect();
248
249    // Ensure total doesn't exceed budget
250    let sum: usize = allocations.iter().sum();
251    if sum > total_budget {
252        let overflow = sum - total_budget;
253        if let Some(last) = allocations.last_mut() {
254            *last = last.saturating_sub(overflow);
255        }
256    }
257
258    allocations
259}
260
261/// Map a token budget to a recommended compression mode.
262fn budget_to_mode(budget: usize, file_tokens: usize) -> &'static str {
263    let ratio = budget as f64 / file_tokens.max(1) as f64;
264    if ratio >= 0.8 {
265        "full"
266    } else if ratio >= 0.4 {
267        "signatures"
268    } else if ratio >= 0.15 {
269        "map"
270    } else {
271        "reference"
272    }
273}
274
275fn extract_critical_lines(content: &str, keywords: &[String], max: usize) -> Vec<(usize, String)> {
276    let kw_lower: Vec<String> = keywords.iter().map(|k| k.to_lowercase()).collect();
277
278    let mut hits: Vec<(usize, String, usize)> = content
279        .lines()
280        .enumerate()
281        .filter_map(|(i, line)| {
282            let trimmed = line.trim();
283            if trimmed.is_empty() {
284                return None;
285            }
286            let line_lower = trimmed.to_lowercase();
287            let hit_count = kw_lower
288                .iter()
289                .filter(|kw| line_lower.contains(kw.as_str()))
290                .count();
291
292            let is_error = trimmed.contains("Error")
293                || trimmed.contains("Err(")
294                || trimmed.contains("panic!")
295                || trimmed.contains("unwrap()")
296                || trimmed.starts_with("return Err");
297
298            if hit_count > 0 || is_error {
299                let priority = hit_count + if is_error { 2 } else { 0 };
300                Some((i + 1, trimmed.to_string(), priority))
301            } else {
302                None
303            }
304        })
305        .collect();
306
307    hits.sort_by_key(|x| std::cmp::Reverse(x.2));
308    hits.truncate(max);
309    hits.iter().map(|(n, l, _)| (*n, l.clone())).collect()
310}
311
312fn extract_key_signatures(content: &str, max: usize) -> Vec<String> {
313    let sig_starters = [
314        "pub fn ",
315        "pub async fn ",
316        "pub struct ",
317        "pub enum ",
318        "pub trait ",
319        "pub type ",
320        "pub const ",
321    ];
322
323    content
324        .lines()
325        .filter(|line| {
326            let trimmed = line.trim();
327            sig_starters.iter().any(|s| trimmed.starts_with(s))
328        })
329        .take(max)
330        .map(|line| {
331            let trimmed = line.trim();
332            if trimmed.len() > 120 {
333                format!("{}...", &trimmed[..117])
334            } else {
335                trimmed.to_string()
336            }
337        })
338        .collect()
339}
340
341fn extract_imports(content: &str) -> Vec<String> {
342    content
343        .lines()
344        .filter(|line| {
345            let t = line.trim();
346            t.starts_with("use ") || t.starts_with("import ") || t.starts_with("from ")
347        })
348        .take(8)
349        .map(|line| {
350            let t = line.trim();
351            if let Some(rest) = t.strip_prefix("use ") {
352                rest.trim_end_matches(';').to_string()
353            } else {
354                t.to_string()
355            }
356        })
357        .collect()
358}
359
360fn apply_heat_ranking(candidates: &mut [&RelevanceScore], index: &ProjectIndex, root: &str) {
361    if index.files.is_empty() {
362        return;
363    }
364
365    let mut connection_counts: std::collections::HashMap<String, usize> =
366        std::collections::HashMap::new();
367    for edge in &index.edges {
368        *connection_counts.entry(edge.from.clone()).or_default() += 1;
369        *connection_counts.entry(edge.to.clone()).or_default() += 1;
370    }
371
372    let max_tokens = index
373        .files
374        .values()
375        .map(|f| f.token_count)
376        .max()
377        .unwrap_or(1) as f64;
378    let max_conn = connection_counts.values().max().copied().unwrap_or(1) as f64;
379
380    candidates.sort_by(|a, b| {
381        let heat_a = compute_heat(
382            &a.path,
383            root,
384            index,
385            &connection_counts,
386            max_tokens,
387            max_conn,
388        );
389        let heat_b = compute_heat(
390            &b.path,
391            root,
392            index,
393            &connection_counts,
394            max_tokens,
395            max_conn,
396        );
397        let combined_a = a.score * 0.6 + heat_a * 0.4;
398        let combined_b = b.score * 0.6 + heat_b * 0.4;
399        combined_b
400            .partial_cmp(&combined_a)
401            .unwrap_or(std::cmp::Ordering::Equal)
402    });
403}
404
405fn compute_heat(
406    path: &str,
407    root: &str,
408    index: &ProjectIndex,
409    connections: &std::collections::HashMap<String, usize>,
410    max_tokens: f64,
411    max_conn: f64,
412) -> f64 {
413    let rel = path
414        .strip_prefix(root)
415        .unwrap_or(path)
416        .trim_start_matches('/');
417
418    if let Some(entry) = index.files.get(rel) {
419        let conn = connections.get(rel).copied().unwrap_or(0);
420        let token_norm = entry.token_count as f64 / max_tokens;
421        let conn_norm = conn as f64 / max_conn;
422        token_norm * 0.4 + conn_norm * 0.6
423    } else {
424        0.0
425    }
426}
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431
432    #[test]
433    fn extract_critical_lines_finds_keywords() {
434        let content = "fn main() {\n    let token = validate();\n    return Err(e);\n}\n";
435        let result = extract_critical_lines(content, &["validate".to_string()], 5);
436        assert!(!result.is_empty());
437        assert!(result.iter().any(|(_, l)| l.contains("validate")));
438    }
439
440    #[test]
441    fn extract_critical_lines_prioritizes_errors() {
442        let content = "fn main() {\n    let x = 1;\n    return Err(\"bad\");\n    let token = validate();\n}\n";
443        let result = extract_critical_lines(content, &["validate".to_string()], 5);
444        assert!(result.len() >= 2);
445        assert!(result[0].1.contains("Err"), "errors should be first");
446    }
447
448    #[test]
449    fn extract_key_signatures_finds_pub() {
450        let content = "use std::io;\nfn private() {}\npub fn public_one() {}\npub struct Foo {}\n";
451        let sigs = extract_key_signatures(content, 10);
452        assert_eq!(sigs.len(), 2);
453        assert!(sigs[0].contains("pub fn public_one"));
454        assert!(sigs[1].contains("pub struct Foo"));
455    }
456
457    #[test]
458    fn extract_imports_works() {
459        let content = "use std::io;\nuse crate::core::cache;\nfn main() {}\n";
460        let imports = extract_imports(content);
461        assert_eq!(imports.len(), 2);
462        assert!(imports[0].contains("std::io"));
463    }
464}