Skip to main content

lean_ctx/tools/
ctx_preload.rs

1use crate::core::cache::SessionCache;
2use crate::core::graph_provider::{self, GraphProvider};
3use crate::core::protocol;
4use crate::core::task_relevance::{compute_relevance, parse_task_hints, RelevanceScore};
5use crate::core::tokens::count_tokens;
6use crate::tools::CrpMode;
7
8const MAX_PRELOAD_FILES: usize = 8;
9const MAX_CRITICAL_LINES: usize = 15;
10const SIGNATURES_BUDGET: usize = 10;
11const TOTAL_TOKEN_BUDGET: usize = 4000;
12
13pub fn handle(
14    cache: &mut SessionCache,
15    task: &str,
16    path: Option<&str>,
17    crp_mode: CrpMode,
18) -> String {
19    if task.trim().is_empty() {
20        return "ERROR: ctx_preload requires a task description".to_string();
21    }
22
23    let project_root = path.map_or_else(|| ".".to_string(), std::string::ToString::to_string);
24    let jail_root = std::path::Path::new(&project_root);
25
26    let Some(open) = graph_provider::open_or_build(&project_root) else {
27        return format!("[task: {task}]\nNo graph available. Use ctx_overview for project map.");
28    };
29    let gp = &open.provider;
30
31    let session_intent =
32        crate::core::session::SessionState::load_latest().and_then(|s| s.active_structured_intent);
33
34    let (task_files, task_keywords) = parse_task_hints(task);
35    let relevance = if let Some(ref intent) = session_intent {
36        crate::core::task_relevance::compute_relevance_from_intent(gp, intent)
37    } else {
38        compute_relevance(gp, &task_files, &task_keywords)
39    };
40
41    let mut scored: Vec<_> = relevance
42        .iter()
43        .filter(|r| r.score >= 0.1)
44        .take(MAX_PRELOAD_FILES + 10)
45        .collect();
46
47    apply_heat_ranking(&mut scored, gp, &project_root);
48
49    let pop = crate::core::pop_pruning::decide_for_candidates(task, &project_root, &scored);
50    let candidates =
51        crate::core::pop_pruning::filter_candidates_by_pop(&project_root, &scored, &pop);
52
53    if candidates.is_empty() {
54        return format!(
55            "[task: {task}]\nNo directly relevant files found. Use ctx_overview for project map."
56        );
57    }
58
59    // Boltzmann allocation: p(file_i) = exp(score_i / T) / Z
60    // Temperature T is derived from task specificity:
61    //   - Many keywords / specific file mentions → low T → concentrate budget
62    //   - Few keywords / broad task → high T → spread budget evenly
63    let task_specificity =
64        (task_files.len() as f64 * 0.3 + task_keywords.len() as f64 * 0.1).clamp(0.0, 1.0);
65    let temperature = 0.8 - task_specificity * 0.6; // range [0.2, 0.8]
66    let temperature = temperature.max(0.1);
67
68    let allocations = boltzmann_allocate(&candidates, TOTAL_TOKEN_BUDGET, temperature);
69
70    let file_context: Vec<(String, usize)> = candidates
71        .iter()
72        .filter_map(|c| {
73            let Ok((jailed, warning)) = crate::core::io_boundary::jail_and_check_path(
74                "ctx_preload",
75                std::path::Path::new(&c.path),
76                jail_root,
77            ) else {
78                return None;
79            };
80            if warning.is_some() {
81                return None;
82            }
83            // Don't hydrate cloud placeholders during automatic preload (#363).
84            if crate::core::cloud_files::is_cloud_placeholder(&jailed) {
85                return None;
86            }
87            std::fs::read_to_string(&jailed)
88                .ok()
89                .map(|content| (c.path.clone(), content.lines().count()))
90        })
91        .collect();
92    let briefing = crate::core::task_briefing::build_briefing(task, &file_context);
93    let briefing_block = crate::core::task_briefing::format_briefing(&briefing);
94
95    let multi_intents = crate::core::intent_engine::detect_multi_intent(task);
96    let primary = &multi_intents[0];
97    let complexity = crate::core::intent_engine::classify_complexity(task, primary);
98
99    let mut output = Vec::new();
100    output.push(briefing_block);
101
102    let complexity_label = complexity.instruction_suffix().lines().next().unwrap_or("");
103    if multi_intents.len() > 1 {
104        output.push(format!(
105            "[task: {task}] | {} | {} sub-intents",
106            complexity_label,
107            multi_intents.len()
108        ));
109        for (i, sub) in multi_intents.iter().enumerate() {
110            output.push(format!(
111                "  {}. {} ({:.0}%)",
112                i + 1,
113                sub.task_type.as_str(),
114                sub.confidence * 100.0
115            ));
116        }
117    } else {
118        output.push(format!("[task: {task}] | {complexity_label}"));
119    }
120
121    for r in crate::core::prospective_memory::reminders_for_task(&project_root, task) {
122        output.push(r);
123    }
124
125    if !pop.excluded_modules.is_empty() {
126        output.push("POP:".to_string());
127        for ex in &pop.excluded_modules {
128            output.push(format!(
129                "  - exclude {}/ ({} candidates) — {}",
130                ex.module, ex.candidate_files, ex.reason
131            ));
132        }
133    }
134
135    let mut total_estimated_saved = 0usize;
136    let mut critical_count = 0usize;
137
138    for (rel, token_budget) in candidates.iter().zip(allocations.iter()) {
139        if *token_budget < 20 {
140            continue;
141        }
142        critical_count += 1;
143        if critical_count > MAX_PRELOAD_FILES {
144            break;
145        }
146
147        let Ok((jailed, warning)) = crate::core::io_boundary::jail_and_check_path(
148            "ctx_preload",
149            std::path::Path::new(&rel.path),
150            jail_root,
151        ) else {
152            continue;
153        };
154        if warning.is_some() {
155            continue;
156        }
157
158        let jailed_s = jailed.to_string_lossy().to_string();
159        let Ok(content) = std::fs::read_to_string(&jailed) else {
160            continue;
161        };
162
163        let file_ref = cache.get_file_ref(&jailed_s);
164        let short = protocol::shorten_path(&jailed_s);
165        let line_count = content.lines().count();
166        let file_tokens = count_tokens(&content);
167
168        let _ = cache.store(&jailed_s, &content);
169
170        let mode = budget_to_mode(*token_budget, file_tokens);
171
172        let critical_lines = extract_critical_lines(&content, &task_keywords, MAX_CRITICAL_LINES);
173        let sigs = extract_key_signatures(&content, SIGNATURES_BUDGET);
174        let imports = extract_imports(&content);
175
176        output.push(format!(
177            "\nCRITICAL: {file_ref}={short} {line_count}L score={:.1} budget={token_budget}tok mode={mode}",
178            rel.score
179        ));
180
181        if !critical_lines.is_empty() {
182            for (line_no, line) in &critical_lines {
183                output.push(format!("  :{line_no} {line}"));
184            }
185        }
186
187        if !imports.is_empty() {
188            output.push(format!("  imports: {}", imports.join(", ")));
189        }
190
191        if !sigs.is_empty() {
192            for sig in &sigs {
193                output.push(format!("  {sig}"));
194            }
195        }
196
197        total_estimated_saved += file_tokens;
198    }
199
200    let context_files: Vec<_> = relevance
201        .iter()
202        .filter(|r| r.score >= 0.1 && r.score < 0.3)
203        .take(10)
204        .collect();
205
206    if !context_files.is_empty() {
207        output.push("\nRELATED:".to_string());
208        for rel in &context_files {
209            let short = protocol::shorten_path(&rel.path);
210            output.push(format!(
211                "  {} mode={} score={:.1}",
212                short, rel.recommended_mode, rel.score
213            ));
214        }
215    }
216
217    let all_edges = gp.edges();
218    let graph_edges: Vec<_> = all_edges
219        .iter()
220        .filter(|e| {
221            candidates
222                .iter()
223                .any(|c| c.path == e.from || c.path == e.to)
224        })
225        .take(10)
226        .collect();
227
228    if !graph_edges.is_empty() {
229        output.push("\nGRAPH:".to_string());
230        for edge in &graph_edges {
231            let from_short = protocol::shorten_path(&edge.from);
232            let to_short = protocol::shorten_path(&edge.to);
233            output.push(format!("  {from_short} -> {to_short}"));
234        }
235    }
236
237    let preload_result = output.join("\n");
238    let preload_tokens = count_tokens(&preload_result);
239    let savings = protocol::format_savings(total_estimated_saved, preload_tokens);
240
241    if crp_mode.is_tdd() {
242        format!("{preload_result}\n{savings}")
243    } else {
244        format!("{preload_result}\n\nNext: ctx_read(path, mode=\"full\") for any file above.\n{savings}")
245    }
246}
247
248/// Boltzmann distribution for token budget allocation across files.
249/// p(file_i) = exp(score_i / T) / Z, then budget_i = total * p(file_i)
250fn boltzmann_allocate(
251    candidates: &[&crate::core::task_relevance::RelevanceScore],
252    total_budget: usize,
253    temperature: f64,
254) -> Vec<usize> {
255    if candidates.is_empty() {
256        return Vec::new();
257    }
258
259    let t = temperature.max(0.01);
260
261    // Compute exp(score / T) for each candidate, using log-sum-exp for numerical stability
262    let log_weights: Vec<f64> = candidates.iter().map(|c| c.score / t).collect();
263    let max_log = log_weights
264        .iter()
265        .copied()
266        .fold(f64::NEG_INFINITY, f64::max);
267    let exp_weights: Vec<f64> = log_weights.iter().map(|&lw| (lw - max_log).exp()).collect();
268    let z: f64 = exp_weights.iter().sum();
269
270    if z <= 0.0 {
271        return vec![total_budget / candidates.len().max(1); candidates.len()];
272    }
273
274    let mut allocations: Vec<usize> = exp_weights
275        .iter()
276        .map(|&w| ((w / z) * total_budget as f64).round() as usize)
277        .collect();
278
279    // Ensure total doesn't exceed budget
280    let sum: usize = allocations.iter().sum();
281    if sum > total_budget {
282        let overflow = sum - total_budget;
283        if let Some(last) = allocations.last_mut() {
284            *last = last.saturating_sub(overflow);
285        }
286    }
287
288    allocations
289}
290
291/// Map a token budget to a recommended compression mode.
292fn budget_to_mode(budget: usize, file_tokens: usize) -> &'static str {
293    let ratio = budget as f64 / file_tokens.max(1) as f64;
294    if ratio >= 0.8 {
295        "full"
296    } else if ratio >= 0.4 {
297        "signatures"
298    } else if ratio >= 0.15 {
299        "map"
300    } else {
301        "reference"
302    }
303}
304
305fn extract_critical_lines(content: &str, keywords: &[String], max: usize) -> Vec<(usize, String)> {
306    let kw_lower: Vec<String> = keywords.iter().map(|k| k.to_lowercase()).collect();
307
308    let mut hits: Vec<(usize, String, usize)> = content
309        .lines()
310        .enumerate()
311        .filter_map(|(i, line)| {
312            let trimmed = line.trim();
313            if trimmed.is_empty() {
314                return None;
315            }
316            let line_lower = trimmed.to_lowercase();
317            let hit_count = kw_lower
318                .iter()
319                .filter(|kw| line_lower.contains(kw.as_str()))
320                .count();
321
322            let is_error = trimmed.contains("Error")
323                || trimmed.contains("Err(")
324                || trimmed.contains("panic!")
325                || trimmed.contains("unwrap()")
326                || trimmed.starts_with("return Err");
327
328            if hit_count > 0 || is_error {
329                let priority = hit_count + if is_error { 2 } else { 0 };
330                Some((i + 1, trimmed.to_string(), priority))
331            } else {
332                None
333            }
334        })
335        .collect();
336
337    hits.sort_by_key(|x| std::cmp::Reverse(x.2));
338    hits.truncate(max);
339    hits.iter().map(|(n, l, _)| (*n, l.clone())).collect()
340}
341
342fn extract_key_signatures(content: &str, max: usize) -> Vec<String> {
343    let sig_starters = [
344        "pub fn ",
345        "pub async fn ",
346        "pub struct ",
347        "pub enum ",
348        "pub trait ",
349        "pub type ",
350        "pub const ",
351    ];
352
353    content
354        .lines()
355        .filter(|line| {
356            let trimmed = line.trim();
357            sig_starters.iter().any(|s| trimmed.starts_with(s))
358        })
359        .take(max)
360        .map(|line| {
361            let trimmed = line.trim();
362            if trimmed.len() > 120 {
363                format!("{}...", &trimmed[..trimmed.floor_char_boundary(117)])
364            } else {
365                trimmed.to_string()
366            }
367        })
368        .collect()
369}
370
371fn extract_imports(content: &str) -> Vec<String> {
372    content
373        .lines()
374        .filter(|line| {
375            let t = line.trim();
376            t.starts_with("use ") || t.starts_with("import ") || t.starts_with("from ")
377        })
378        .take(8)
379        .map(|line| {
380            let t = line.trim();
381            if let Some(rest) = t.strip_prefix("use ") {
382                rest.trim_end_matches(';').to_string()
383            } else {
384                t.to_string()
385            }
386        })
387        .collect()
388}
389
390fn apply_heat_ranking(candidates: &mut [&RelevanceScore], gp: &GraphProvider, root: &str) {
391    if gp.file_count() == 0 {
392        return;
393    }
394
395    let all_edges = gp.edges();
396    let mut connection_counts: std::collections::HashMap<String, usize> =
397        std::collections::HashMap::new();
398    for edge in &all_edges {
399        *connection_counts.entry(edge.from.clone()).or_default() += 1;
400        *connection_counts.entry(edge.to.clone()).or_default() += 1;
401    }
402
403    let mut max_tokens = 1usize;
404    for path in gp.file_paths() {
405        if let Some(entry) = gp.get_file_entry(&path) {
406            max_tokens = max_tokens.max(entry.token_count);
407        }
408    }
409    let max_tokens = max_tokens as f64;
410    let max_conn = connection_counts.values().max().copied().unwrap_or(1) as f64;
411
412    candidates.sort_by(|a, b| {
413        let heat_a = compute_heat(&a.path, root, gp, &connection_counts, max_tokens, max_conn);
414        let heat_b = compute_heat(&b.path, root, gp, &connection_counts, max_tokens, max_conn);
415        let combined_a = a.score * 0.6 + heat_a * 0.4;
416        let combined_b = b.score * 0.6 + heat_b * 0.4;
417        combined_b
418            .partial_cmp(&combined_a)
419            .unwrap_or(std::cmp::Ordering::Equal)
420    });
421}
422
423fn compute_heat(
424    path: &str,
425    root: &str,
426    gp: &GraphProvider,
427    connections: &std::collections::HashMap<String, usize>,
428    max_tokens: f64,
429    max_conn: f64,
430) -> f64 {
431    let rel = path
432        .strip_prefix(root)
433        .unwrap_or(path)
434        .trim_start_matches('/');
435
436    if let Some(entry) = gp.get_file_entry(rel) {
437        let conn = connections.get(rel).copied().unwrap_or(0);
438        let token_norm = entry.token_count as f64 / max_tokens;
439        let conn_norm = conn as f64 / max_conn;
440        token_norm * 0.4 + conn_norm * 0.6
441    } else {
442        0.0
443    }
444}
445
446#[cfg(test)]
447mod tests {
448    use super::*;
449
450    #[test]
451    fn extract_critical_lines_finds_keywords() {
452        let content = "fn main() {\n    let token = validate();\n    return Err(e);\n}\n";
453        let result = extract_critical_lines(content, &["validate".to_string()], 5);
454        assert!(!result.is_empty());
455        assert!(result.iter().any(|(_, l)| l.contains("validate")));
456    }
457
458    #[test]
459    fn extract_critical_lines_prioritizes_errors() {
460        let content = "fn main() {\n    let x = 1;\n    return Err(\"bad\");\n    let token = validate();\n}\n";
461        let result = extract_critical_lines(content, &["validate".to_string()], 5);
462        assert!(result.len() >= 2);
463        assert!(result[0].1.contains("Err"), "errors should be first");
464    }
465
466    #[test]
467    fn extract_key_signatures_finds_pub() {
468        let content = "use std::io;\nfn private() {}\npub fn public_one() {}\npub struct Foo {}\n";
469        let sigs = extract_key_signatures(content, 10);
470        assert_eq!(sigs.len(), 2);
471        assert!(sigs[0].contains("pub fn public_one"));
472        assert!(sigs[1].contains("pub struct Foo"));
473    }
474
475    #[test]
476    fn extract_imports_works() {
477        let content = "use std::io;\nuse crate::core::cache;\nfn main() {}\n";
478        let imports = extract_imports(content);
479        assert_eq!(imports.len(), 2);
480        assert!(imports[0].contains("std::io"));
481    }
482}