Skip to main content

lean_ctx/tools/
ctx_preload.rs

1use crate::core::cache::SessionCache;
2use crate::core::graph_index::ProjectIndex;
3use crate::core::protocol;
4use crate::core::task_relevance::{compute_relevance, parse_task_hints, RelevanceScore};
5use crate::core::tokens::count_tokens;
6use crate::tools::CrpMode;
7
8const MAX_PRELOAD_FILES: usize = 8;
9const MAX_CRITICAL_LINES: usize = 15;
10const SIGNATURES_BUDGET: usize = 10;
11const TOTAL_TOKEN_BUDGET: usize = 4000;
12
13pub fn handle(
14    cache: &mut SessionCache,
15    task: &str,
16    path: Option<&str>,
17    crp_mode: CrpMode,
18) -> String {
19    if task.trim().is_empty() {
20        return "ERROR: ctx_preload requires a task description".to_string();
21    }
22
23    let project_root = path.map_or_else(|| ".".to_string(), std::string::ToString::to_string);
24    let jail_root = std::path::Path::new(&project_root);
25
26    let index = crate::core::graph_index::load_or_build(&project_root);
27
28    let session_intent =
29        crate::core::session::SessionState::load_latest().and_then(|s| s.active_structured_intent);
30
31    let (task_files, task_keywords) = parse_task_hints(task);
32    let relevance = if let Some(ref intent) = session_intent {
33        crate::core::task_relevance::compute_relevance_from_intent(&index, intent)
34    } else {
35        compute_relevance(&index, &task_files, &task_keywords)
36    };
37
38    let mut scored: Vec<_> = relevance
39        .iter()
40        .filter(|r| r.score >= 0.1)
41        .take(MAX_PRELOAD_FILES + 10)
42        .collect();
43
44    apply_heat_ranking(&mut scored, &index, &project_root);
45
46    let pop = crate::core::pop_pruning::decide_for_candidates(task, &project_root, &scored);
47    let candidates =
48        crate::core::pop_pruning::filter_candidates_by_pop(&project_root, &scored, &pop);
49
50    if candidates.is_empty() {
51        return format!(
52            "[task: {task}]\nNo directly relevant files found. Use ctx_overview for project map."
53        );
54    }
55
56    // Boltzmann allocation: p(file_i) = exp(score_i / T) / Z
57    // Temperature T is derived from task specificity:
58    //   - Many keywords / specific file mentions → low T → concentrate budget
59    //   - Few keywords / broad task → high T → spread budget evenly
60    let task_specificity =
61        (task_files.len() as f64 * 0.3 + task_keywords.len() as f64 * 0.1).clamp(0.0, 1.0);
62    let temperature = 0.8 - task_specificity * 0.6; // range [0.2, 0.8]
63    let temperature = temperature.max(0.1);
64
65    let allocations = boltzmann_allocate(&candidates, TOTAL_TOKEN_BUDGET, temperature);
66
67    let file_context: Vec<(String, usize)> = candidates
68        .iter()
69        .filter_map(|c| {
70            let Ok((jailed, warning)) = crate::core::io_boundary::jail_and_check_path(
71                "ctx_preload",
72                std::path::Path::new(&c.path),
73                jail_root,
74            ) else {
75                return None;
76            };
77            if warning.is_some() {
78                return None;
79            }
80            std::fs::read_to_string(&jailed)
81                .ok()
82                .map(|content| (c.path.clone(), content.lines().count()))
83        })
84        .collect();
85    let briefing = crate::core::task_briefing::build_briefing(task, &file_context);
86    let briefing_block = crate::core::task_briefing::format_briefing(&briefing);
87
88    let multi_intents = crate::core::intent_engine::detect_multi_intent(task);
89    let primary = &multi_intents[0];
90    let complexity = crate::core::intent_engine::classify_complexity(task, primary);
91
92    let mut output = Vec::new();
93    output.push(briefing_block);
94
95    let complexity_label = complexity.instruction_suffix().lines().next().unwrap_or("");
96    if multi_intents.len() > 1 {
97        output.push(format!(
98            "[task: {task}] | {} | {} sub-intents",
99            complexity_label,
100            multi_intents.len()
101        ));
102        for (i, sub) in multi_intents.iter().enumerate() {
103            output.push(format!(
104                "  {}. {} ({:.0}%)",
105                i + 1,
106                sub.task_type.as_str(),
107                sub.confidence * 100.0
108            ));
109        }
110    } else {
111        output.push(format!("[task: {task}] | {complexity_label}"));
112    }
113
114    for r in crate::core::prospective_memory::reminders_for_task(&project_root, task) {
115        output.push(r);
116    }
117
118    if !pop.excluded_modules.is_empty() {
119        output.push("POP:".to_string());
120        for ex in &pop.excluded_modules {
121            output.push(format!(
122                "  - exclude {}/ ({} candidates) — {}",
123                ex.module, ex.candidate_files, ex.reason
124            ));
125        }
126    }
127
128    let mut total_estimated_saved = 0usize;
129    let mut critical_count = 0usize;
130
131    for (rel, token_budget) in candidates.iter().zip(allocations.iter()) {
132        if *token_budget < 20 {
133            continue;
134        }
135        critical_count += 1;
136        if critical_count > MAX_PRELOAD_FILES {
137            break;
138        }
139
140        let Ok((jailed, warning)) = crate::core::io_boundary::jail_and_check_path(
141            "ctx_preload",
142            std::path::Path::new(&rel.path),
143            jail_root,
144        ) else {
145            continue;
146        };
147        if warning.is_some() {
148            continue;
149        }
150
151        let jailed_s = jailed.to_string_lossy().to_string();
152        let Ok(content) = std::fs::read_to_string(&jailed) else {
153            continue;
154        };
155
156        let file_ref = cache.get_file_ref(&jailed_s);
157        let short = protocol::shorten_path(&jailed_s);
158        let line_count = content.lines().count();
159        let file_tokens = count_tokens(&content);
160
161        let _ = cache.store(&jailed_s, &content);
162
163        let mode = budget_to_mode(*token_budget, file_tokens);
164
165        let critical_lines = extract_critical_lines(&content, &task_keywords, MAX_CRITICAL_LINES);
166        let sigs = extract_key_signatures(&content, SIGNATURES_BUDGET);
167        let imports = extract_imports(&content);
168
169        output.push(format!(
170            "\nCRITICAL: {file_ref}={short} {line_count}L score={:.1} budget={token_budget}tok mode={mode}",
171            rel.score
172        ));
173
174        if !critical_lines.is_empty() {
175            for (line_no, line) in &critical_lines {
176                output.push(format!("  :{line_no} {line}"));
177            }
178        }
179
180        if !imports.is_empty() {
181            output.push(format!("  imports: {}", imports.join(", ")));
182        }
183
184        if !sigs.is_empty() {
185            for sig in &sigs {
186                output.push(format!("  {sig}"));
187            }
188        }
189
190        total_estimated_saved += file_tokens;
191    }
192
193    let context_files: Vec<_> = relevance
194        .iter()
195        .filter(|r| r.score >= 0.1 && r.score < 0.3)
196        .take(10)
197        .collect();
198
199    if !context_files.is_empty() {
200        output.push("\nRELATED:".to_string());
201        for rel in &context_files {
202            let short = protocol::shorten_path(&rel.path);
203            output.push(format!(
204                "  {} mode={} score={:.1}",
205                short, rel.recommended_mode, rel.score
206            ));
207        }
208    }
209
210    let graph_edges: Vec<_> = index
211        .edges
212        .iter()
213        .filter(|e| {
214            candidates
215                .iter()
216                .any(|c| c.path == e.from || c.path == e.to)
217        })
218        .take(10)
219        .collect();
220
221    if !graph_edges.is_empty() {
222        output.push("\nGRAPH:".to_string());
223        for edge in &graph_edges {
224            let from_short = protocol::shorten_path(&edge.from);
225            let to_short = protocol::shorten_path(&edge.to);
226            output.push(format!("  {from_short} -> {to_short}"));
227        }
228    }
229
230    let preload_result = output.join("\n");
231    let preload_tokens = count_tokens(&preload_result);
232    let savings = protocol::format_savings(total_estimated_saved, preload_tokens);
233
234    if crp_mode.is_tdd() {
235        format!("{preload_result}\n{savings}")
236    } else {
237        format!("{preload_result}\n\nNext: ctx_read(path, mode=\"full\") for any file above.\n{savings}")
238    }
239}
240
241/// Boltzmann distribution for token budget allocation across files.
242/// p(file_i) = exp(score_i / T) / Z, then budget_i = total * p(file_i)
243fn boltzmann_allocate(
244    candidates: &[&crate::core::task_relevance::RelevanceScore],
245    total_budget: usize,
246    temperature: f64,
247) -> Vec<usize> {
248    if candidates.is_empty() {
249        return Vec::new();
250    }
251
252    let t = temperature.max(0.01);
253
254    // Compute exp(score / T) for each candidate, using log-sum-exp for numerical stability
255    let log_weights: Vec<f64> = candidates.iter().map(|c| c.score / t).collect();
256    let max_log = log_weights
257        .iter()
258        .copied()
259        .fold(f64::NEG_INFINITY, f64::max);
260    let exp_weights: Vec<f64> = log_weights.iter().map(|&lw| (lw - max_log).exp()).collect();
261    let z: f64 = exp_weights.iter().sum();
262
263    if z <= 0.0 {
264        return vec![total_budget / candidates.len().max(1); candidates.len()];
265    }
266
267    let mut allocations: Vec<usize> = exp_weights
268        .iter()
269        .map(|&w| ((w / z) * total_budget as f64).round() as usize)
270        .collect();
271
272    // Ensure total doesn't exceed budget
273    let sum: usize = allocations.iter().sum();
274    if sum > total_budget {
275        let overflow = sum - total_budget;
276        if let Some(last) = allocations.last_mut() {
277            *last = last.saturating_sub(overflow);
278        }
279    }
280
281    allocations
282}
283
284/// Map a token budget to a recommended compression mode.
285fn budget_to_mode(budget: usize, file_tokens: usize) -> &'static str {
286    let ratio = budget as f64 / file_tokens.max(1) as f64;
287    if ratio >= 0.8 {
288        "full"
289    } else if ratio >= 0.4 {
290        "signatures"
291    } else if ratio >= 0.15 {
292        "map"
293    } else {
294        "reference"
295    }
296}
297
298fn extract_critical_lines(content: &str, keywords: &[String], max: usize) -> Vec<(usize, String)> {
299    let kw_lower: Vec<String> = keywords.iter().map(|k| k.to_lowercase()).collect();
300
301    let mut hits: Vec<(usize, String, usize)> = content
302        .lines()
303        .enumerate()
304        .filter_map(|(i, line)| {
305            let trimmed = line.trim();
306            if trimmed.is_empty() {
307                return None;
308            }
309            let line_lower = trimmed.to_lowercase();
310            let hit_count = kw_lower
311                .iter()
312                .filter(|kw| line_lower.contains(kw.as_str()))
313                .count();
314
315            let is_error = trimmed.contains("Error")
316                || trimmed.contains("Err(")
317                || trimmed.contains("panic!")
318                || trimmed.contains("unwrap()")
319                || trimmed.starts_with("return Err");
320
321            if hit_count > 0 || is_error {
322                let priority = hit_count + if is_error { 2 } else { 0 };
323                Some((i + 1, trimmed.to_string(), priority))
324            } else {
325                None
326            }
327        })
328        .collect();
329
330    hits.sort_by_key(|x| std::cmp::Reverse(x.2));
331    hits.truncate(max);
332    hits.iter().map(|(n, l, _)| (*n, l.clone())).collect()
333}
334
335fn extract_key_signatures(content: &str, max: usize) -> Vec<String> {
336    let sig_starters = [
337        "pub fn ",
338        "pub async fn ",
339        "pub struct ",
340        "pub enum ",
341        "pub trait ",
342        "pub type ",
343        "pub const ",
344    ];
345
346    content
347        .lines()
348        .filter(|line| {
349            let trimmed = line.trim();
350            sig_starters.iter().any(|s| trimmed.starts_with(s))
351        })
352        .take(max)
353        .map(|line| {
354            let trimmed = line.trim();
355            if trimmed.len() > 120 {
356                format!("{}...", &trimmed[..117])
357            } else {
358                trimmed.to_string()
359            }
360        })
361        .collect()
362}
363
364fn extract_imports(content: &str) -> Vec<String> {
365    content
366        .lines()
367        .filter(|line| {
368            let t = line.trim();
369            t.starts_with("use ") || t.starts_with("import ") || t.starts_with("from ")
370        })
371        .take(8)
372        .map(|line| {
373            let t = line.trim();
374            if let Some(rest) = t.strip_prefix("use ") {
375                rest.trim_end_matches(';').to_string()
376            } else {
377                t.to_string()
378            }
379        })
380        .collect()
381}
382
383fn apply_heat_ranking(candidates: &mut [&RelevanceScore], index: &ProjectIndex, root: &str) {
384    if index.files.is_empty() {
385        return;
386    }
387
388    let mut connection_counts: std::collections::HashMap<String, usize> =
389        std::collections::HashMap::new();
390    for edge in &index.edges {
391        *connection_counts.entry(edge.from.clone()).or_default() += 1;
392        *connection_counts.entry(edge.to.clone()).or_default() += 1;
393    }
394
395    let max_tokens = index
396        .files
397        .values()
398        .map(|f| f.token_count)
399        .max()
400        .unwrap_or(1) as f64;
401    let max_conn = connection_counts.values().max().copied().unwrap_or(1) as f64;
402
403    candidates.sort_by(|a, b| {
404        let heat_a = compute_heat(
405            &a.path,
406            root,
407            index,
408            &connection_counts,
409            max_tokens,
410            max_conn,
411        );
412        let heat_b = compute_heat(
413            &b.path,
414            root,
415            index,
416            &connection_counts,
417            max_tokens,
418            max_conn,
419        );
420        let combined_a = a.score * 0.6 + heat_a * 0.4;
421        let combined_b = b.score * 0.6 + heat_b * 0.4;
422        combined_b
423            .partial_cmp(&combined_a)
424            .unwrap_or(std::cmp::Ordering::Equal)
425    });
426}
427
428fn compute_heat(
429    path: &str,
430    root: &str,
431    index: &ProjectIndex,
432    connections: &std::collections::HashMap<String, usize>,
433    max_tokens: f64,
434    max_conn: f64,
435) -> f64 {
436    let rel = path
437        .strip_prefix(root)
438        .unwrap_or(path)
439        .trim_start_matches('/');
440
441    if let Some(entry) = index.files.get(rel) {
442        let conn = connections.get(rel).copied().unwrap_or(0);
443        let token_norm = entry.token_count as f64 / max_tokens;
444        let conn_norm = conn as f64 / max_conn;
445        token_norm * 0.4 + conn_norm * 0.6
446    } else {
447        0.0
448    }
449}
450
451#[cfg(test)]
452mod tests {
453    use super::*;
454
455    #[test]
456    fn extract_critical_lines_finds_keywords() {
457        let content = "fn main() {\n    let token = validate();\n    return Err(e);\n}\n";
458        let result = extract_critical_lines(content, &["validate".to_string()], 5);
459        assert!(!result.is_empty());
460        assert!(result.iter().any(|(_, l)| l.contains("validate")));
461    }
462
463    #[test]
464    fn extract_critical_lines_prioritizes_errors() {
465        let content = "fn main() {\n    let x = 1;\n    return Err(\"bad\");\n    let token = validate();\n}\n";
466        let result = extract_critical_lines(content, &["validate".to_string()], 5);
467        assert!(result.len() >= 2);
468        assert!(result[0].1.contains("Err"), "errors should be first");
469    }
470
471    #[test]
472    fn extract_key_signatures_finds_pub() {
473        let content = "use std::io;\nfn private() {}\npub fn public_one() {}\npub struct Foo {}\n";
474        let sigs = extract_key_signatures(content, 10);
475        assert_eq!(sigs.len(), 2);
476        assert!(sigs[0].contains("pub fn public_one"));
477        assert!(sigs[1].contains("pub struct Foo"));
478    }
479
480    #[test]
481    fn extract_imports_works() {
482        let content = "use std::io;\nuse crate::core::cache;\nfn main() {}\n";
483        let imports = extract_imports(content);
484        assert_eq!(imports.len(), 2);
485        assert!(imports[0].contains("std::io"));
486    }
487}