Skip to main content

lean_ctx/tools/
ctx_preload.rs

1use crate::core::cache::SessionCache;
2use crate::core::graph_index::ProjectIndex;
3use crate::core::protocol;
4use crate::core::task_relevance::{compute_relevance, parse_task_hints, RelevanceScore};
5use crate::core::tokens::count_tokens;
6use crate::tools::CrpMode;
7
8const MAX_PRELOAD_FILES: usize = 8;
9const MAX_CRITICAL_LINES: usize = 15;
10const SIGNATURES_BUDGET: usize = 10;
11const TOTAL_TOKEN_BUDGET: usize = 4000;
12
13pub fn handle(
14    cache: &mut SessionCache,
15    task: &str,
16    path: Option<&str>,
17    crp_mode: CrpMode,
18) -> String {
19    if task.trim().is_empty() {
20        return "ERROR: ctx_preload requires a task description".to_string();
21    }
22
23    let project_root = path
24        .map(|p| p.to_string())
25        .unwrap_or_else(|| ".".to_string());
26
27    let index = crate::core::graph_index::load_or_build(&project_root);
28
29    let session_intent =
30        crate::core::session::SessionState::load_latest().and_then(|s| s.active_structured_intent);
31
32    let (task_files, task_keywords) = parse_task_hints(task);
33    let relevance = if let Some(ref intent) = session_intent {
34        crate::core::task_relevance::compute_relevance_from_intent(&index, intent)
35    } else {
36        compute_relevance(&index, &task_files, &task_keywords)
37    };
38
39    let mut scored: Vec<_> = relevance
40        .iter()
41        .filter(|r| r.score >= 0.1)
42        .take(MAX_PRELOAD_FILES + 10)
43        .collect();
44
45    apply_heat_ranking(&mut scored, &index, &project_root);
46
47    let pop = crate::core::pop_pruning::decide_for_candidates(task, &project_root, &scored);
48    let candidates =
49        crate::core::pop_pruning::filter_candidates_by_pop(&project_root, &scored, &pop);
50
51    if candidates.is_empty() {
52        return format!(
53            "[task: {task}]\nNo directly relevant files found. Use ctx_overview for project map."
54        );
55    }
56
57    // Boltzmann allocation: p(file_i) = exp(score_i / T) / Z
58    // Temperature T is derived from task specificity:
59    //   - Many keywords / specific file mentions → low T → concentrate budget
60    //   - Few keywords / broad task → high T → spread budget evenly
61    let task_specificity =
62        (task_files.len() as f64 * 0.3 + task_keywords.len() as f64 * 0.1).clamp(0.0, 1.0);
63    let temperature = 0.8 - task_specificity * 0.6; // range [0.2, 0.8]
64    let temperature = temperature.max(0.1);
65
66    let allocations = boltzmann_allocate(&candidates, TOTAL_TOKEN_BUDGET, temperature);
67
68    let file_context: Vec<(String, usize)> = candidates
69        .iter()
70        .filter_map(|c| {
71            std::fs::read_to_string(&c.path)
72                .ok()
73                .map(|content| (c.path.clone(), content.lines().count()))
74        })
75        .collect();
76    let briefing = crate::core::task_briefing::build_briefing(task, &file_context);
77    let briefing_block = crate::core::task_briefing::format_briefing(&briefing);
78
79    let multi_intents = crate::core::intent_engine::detect_multi_intent(task);
80    let primary = &multi_intents[0];
81    let complexity = crate::core::intent_engine::classify_complexity(task, primary);
82
83    let mut output = Vec::new();
84    output.push(briefing_block);
85
86    let complexity_label = complexity.instruction_suffix().lines().next().unwrap_or("");
87    if multi_intents.len() > 1 {
88        output.push(format!(
89            "[task: {task}] | {} | {} sub-intents",
90            complexity_label,
91            multi_intents.len()
92        ));
93        for (i, sub) in multi_intents.iter().enumerate() {
94            output.push(format!(
95                "  {}. {} ({:.0}%)",
96                i + 1,
97                sub.task_type.as_str(),
98                sub.confidence * 100.0
99            ));
100        }
101    } else {
102        output.push(format!("[task: {task}] | {complexity_label}"));
103    }
104
105    for r in crate::core::prospective_memory::reminders_for_task(&project_root, task) {
106        output.push(r);
107    }
108
109    if !pop.excluded_modules.is_empty() {
110        output.push("POP:".to_string());
111        for ex in &pop.excluded_modules {
112            output.push(format!(
113                "  - exclude {}/ ({} candidates) — {}",
114                ex.module, ex.candidate_files, ex.reason
115            ));
116        }
117    }
118
119    let mut total_estimated_saved = 0usize;
120    let mut critical_count = 0usize;
121
122    for (rel, token_budget) in candidates.iter().zip(allocations.iter()) {
123        if *token_budget < 20 {
124            continue;
125        }
126        critical_count += 1;
127        if critical_count > MAX_PRELOAD_FILES {
128            break;
129        }
130
131        let content = match std::fs::read_to_string(&rel.path) {
132            Ok(c) => c,
133            Err(_) => continue,
134        };
135
136        let file_ref = cache.get_file_ref(&rel.path);
137        let short = protocol::shorten_path(&rel.path);
138        let line_count = content.lines().count();
139        let file_tokens = count_tokens(&content);
140
141        let _ = cache.store(&rel.path, content.clone());
142
143        let mode = budget_to_mode(*token_budget, file_tokens);
144
145        let critical_lines = extract_critical_lines(&content, &task_keywords, MAX_CRITICAL_LINES);
146        let sigs = extract_key_signatures(&content, SIGNATURES_BUDGET);
147        let imports = extract_imports(&content);
148
149        output.push(format!(
150            "\nCRITICAL: {file_ref}={short} {line_count}L score={:.1} budget={token_budget}tok mode={mode}",
151            rel.score
152        ));
153
154        if !critical_lines.is_empty() {
155            for (line_no, line) in &critical_lines {
156                output.push(format!("  :{line_no} {line}"));
157            }
158        }
159
160        if !imports.is_empty() {
161            output.push(format!("  imports: {}", imports.join(", ")));
162        }
163
164        if !sigs.is_empty() {
165            for sig in &sigs {
166                output.push(format!("  {sig}"));
167            }
168        }
169
170        total_estimated_saved += file_tokens;
171    }
172
173    let context_files: Vec<_> = relevance
174        .iter()
175        .filter(|r| r.score >= 0.1 && r.score < 0.3)
176        .take(10)
177        .collect();
178
179    if !context_files.is_empty() {
180        output.push("\nRELATED:".to_string());
181        for rel in &context_files {
182            let short = protocol::shorten_path(&rel.path);
183            output.push(format!(
184                "  {} mode={} score={:.1}",
185                short, rel.recommended_mode, rel.score
186            ));
187        }
188    }
189
190    let graph_edges: Vec<_> = index
191        .edges
192        .iter()
193        .filter(|e| {
194            candidates
195                .iter()
196                .any(|c| c.path == e.from || c.path == e.to)
197        })
198        .take(10)
199        .collect();
200
201    if !graph_edges.is_empty() {
202        output.push("\nGRAPH:".to_string());
203        for edge in &graph_edges {
204            let from_short = protocol::shorten_path(&edge.from);
205            let to_short = protocol::shorten_path(&edge.to);
206            output.push(format!("  {from_short} -> {to_short}"));
207        }
208    }
209
210    let preload_result = output.join("\n");
211    let preload_tokens = count_tokens(&preload_result);
212    let savings = protocol::format_savings(total_estimated_saved, preload_tokens);
213
214    if crp_mode.is_tdd() {
215        format!("{preload_result}\n{savings}")
216    } else {
217        format!("{preload_result}\n\nNext: ctx_read(path, mode=\"full\") for any file above.\n{savings}")
218    }
219}
220
221/// Boltzmann distribution for token budget allocation across files.
222/// p(file_i) = exp(score_i / T) / Z, then budget_i = total * p(file_i)
223fn boltzmann_allocate(
224    candidates: &[&crate::core::task_relevance::RelevanceScore],
225    total_budget: usize,
226    temperature: f64,
227) -> Vec<usize> {
228    if candidates.is_empty() {
229        return Vec::new();
230    }
231
232    let t = temperature.max(0.01);
233
234    // Compute exp(score / T) for each candidate, using log-sum-exp for numerical stability
235    let log_weights: Vec<f64> = candidates.iter().map(|c| c.score / t).collect();
236    let max_log = log_weights
237        .iter()
238        .cloned()
239        .fold(f64::NEG_INFINITY, f64::max);
240    let exp_weights: Vec<f64> = log_weights.iter().map(|&lw| (lw - max_log).exp()).collect();
241    let z: f64 = exp_weights.iter().sum();
242
243    if z <= 0.0 {
244        return vec![total_budget / candidates.len().max(1); candidates.len()];
245    }
246
247    let mut allocations: Vec<usize> = exp_weights
248        .iter()
249        .map(|&w| ((w / z) * total_budget as f64).round() as usize)
250        .collect();
251
252    // Ensure total doesn't exceed budget
253    let sum: usize = allocations.iter().sum();
254    if sum > total_budget {
255        let overflow = sum - total_budget;
256        if let Some(last) = allocations.last_mut() {
257            *last = last.saturating_sub(overflow);
258        }
259    }
260
261    allocations
262}
263
264/// Map a token budget to a recommended compression mode.
265fn budget_to_mode(budget: usize, file_tokens: usize) -> &'static str {
266    let ratio = budget as f64 / file_tokens.max(1) as f64;
267    if ratio >= 0.8 {
268        "full"
269    } else if ratio >= 0.4 {
270        "signatures"
271    } else if ratio >= 0.15 {
272        "map"
273    } else {
274        "reference"
275    }
276}
277
278fn extract_critical_lines(content: &str, keywords: &[String], max: usize) -> Vec<(usize, String)> {
279    let kw_lower: Vec<String> = keywords.iter().map(|k| k.to_lowercase()).collect();
280
281    let mut hits: Vec<(usize, String, usize)> = content
282        .lines()
283        .enumerate()
284        .filter_map(|(i, line)| {
285            let trimmed = line.trim();
286            if trimmed.is_empty() {
287                return None;
288            }
289            let line_lower = trimmed.to_lowercase();
290            let hit_count = kw_lower
291                .iter()
292                .filter(|kw| line_lower.contains(kw.as_str()))
293                .count();
294
295            let is_error = trimmed.contains("Error")
296                || trimmed.contains("Err(")
297                || trimmed.contains("panic!")
298                || trimmed.contains("unwrap()")
299                || trimmed.starts_with("return Err");
300
301            if hit_count > 0 || is_error {
302                let priority = hit_count + if is_error { 2 } else { 0 };
303                Some((i + 1, trimmed.to_string(), priority))
304            } else {
305                None
306            }
307        })
308        .collect();
309
310    hits.sort_by_key(|x| std::cmp::Reverse(x.2));
311    hits.truncate(max);
312    hits.iter().map(|(n, l, _)| (*n, l.clone())).collect()
313}
314
315fn extract_key_signatures(content: &str, max: usize) -> Vec<String> {
316    let sig_starters = [
317        "pub fn ",
318        "pub async fn ",
319        "pub struct ",
320        "pub enum ",
321        "pub trait ",
322        "pub type ",
323        "pub const ",
324    ];
325
326    content
327        .lines()
328        .filter(|line| {
329            let trimmed = line.trim();
330            sig_starters.iter().any(|s| trimmed.starts_with(s))
331        })
332        .take(max)
333        .map(|line| {
334            let trimmed = line.trim();
335            if trimmed.len() > 120 {
336                format!("{}...", &trimmed[..117])
337            } else {
338                trimmed.to_string()
339            }
340        })
341        .collect()
342}
343
344fn extract_imports(content: &str) -> Vec<String> {
345    content
346        .lines()
347        .filter(|line| {
348            let t = line.trim();
349            t.starts_with("use ") || t.starts_with("import ") || t.starts_with("from ")
350        })
351        .take(8)
352        .map(|line| {
353            let t = line.trim();
354            if let Some(rest) = t.strip_prefix("use ") {
355                rest.trim_end_matches(';').to_string()
356            } else {
357                t.to_string()
358            }
359        })
360        .collect()
361}
362
363fn apply_heat_ranking(candidates: &mut [&RelevanceScore], index: &ProjectIndex, root: &str) {
364    if index.files.is_empty() {
365        return;
366    }
367
368    let mut connection_counts: std::collections::HashMap<String, usize> =
369        std::collections::HashMap::new();
370    for edge in &index.edges {
371        *connection_counts.entry(edge.from.clone()).or_default() += 1;
372        *connection_counts.entry(edge.to.clone()).or_default() += 1;
373    }
374
375    let max_tokens = index
376        .files
377        .values()
378        .map(|f| f.token_count)
379        .max()
380        .unwrap_or(1) as f64;
381    let max_conn = connection_counts.values().max().copied().unwrap_or(1) as f64;
382
383    candidates.sort_by(|a, b| {
384        let heat_a = compute_heat(
385            &a.path,
386            root,
387            index,
388            &connection_counts,
389            max_tokens,
390            max_conn,
391        );
392        let heat_b = compute_heat(
393            &b.path,
394            root,
395            index,
396            &connection_counts,
397            max_tokens,
398            max_conn,
399        );
400        let combined_a = a.score * 0.6 + heat_a * 0.4;
401        let combined_b = b.score * 0.6 + heat_b * 0.4;
402        combined_b
403            .partial_cmp(&combined_a)
404            .unwrap_or(std::cmp::Ordering::Equal)
405    });
406}
407
408fn compute_heat(
409    path: &str,
410    root: &str,
411    index: &ProjectIndex,
412    connections: &std::collections::HashMap<String, usize>,
413    max_tokens: f64,
414    max_conn: f64,
415) -> f64 {
416    let rel = path
417        .strip_prefix(root)
418        .unwrap_or(path)
419        .trim_start_matches('/');
420
421    if let Some(entry) = index.files.get(rel) {
422        let conn = connections.get(rel).copied().unwrap_or(0);
423        let token_norm = entry.token_count as f64 / max_tokens;
424        let conn_norm = conn as f64 / max_conn;
425        token_norm * 0.4 + conn_norm * 0.6
426    } else {
427        0.0
428    }
429}
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434
435    #[test]
436    fn extract_critical_lines_finds_keywords() {
437        let content = "fn main() {\n    let token = validate();\n    return Err(e);\n}\n";
438        let result = extract_critical_lines(content, &["validate".to_string()], 5);
439        assert!(!result.is_empty());
440        assert!(result.iter().any(|(_, l)| l.contains("validate")));
441    }
442
443    #[test]
444    fn extract_critical_lines_prioritizes_errors() {
445        let content = "fn main() {\n    let x = 1;\n    return Err(\"bad\");\n    let token = validate();\n}\n";
446        let result = extract_critical_lines(content, &["validate".to_string()], 5);
447        assert!(result.len() >= 2);
448        assert!(result[0].1.contains("Err"), "errors should be first");
449    }
450
451    #[test]
452    fn extract_key_signatures_finds_pub() {
453        let content = "use std::io;\nfn private() {}\npub fn public_one() {}\npub struct Foo {}\n";
454        let sigs = extract_key_signatures(content, 10);
455        assert_eq!(sigs.len(), 2);
456        assert!(sigs[0].contains("pub fn public_one"));
457        assert!(sigs[1].contains("pub struct Foo"));
458    }
459
460    #[test]
461    fn extract_imports_works() {
462        let content = "use std::io;\nuse crate::core::cache;\nfn main() {}\n";
463        let imports = extract_imports(content);
464        assert_eq!(imports.len(), 2);
465        assert!(imports[0].contains("std::io"));
466    }
467}