Skip to main content

lean_ctx/tools/
ctx_preload.rs

1use crate::core::cache::SessionCache;
2use crate::core::graph_index::ProjectIndex;
3use crate::core::protocol;
4use crate::core::task_relevance::{compute_relevance, parse_task_hints, RelevanceScore};
5use crate::core::tokens::count_tokens;
6use crate::tools::CrpMode;
7
8const MAX_PRELOAD_FILES: usize = 8;
9const MAX_CRITICAL_LINES: usize = 15;
10const SIGNATURES_BUDGET: usize = 10;
11const TOTAL_TOKEN_BUDGET: usize = 4000;
12
13pub fn handle(
14    cache: &mut SessionCache,
15    task: &str,
16    path: Option<&str>,
17    crp_mode: CrpMode,
18) -> String {
19    if task.trim().is_empty() {
20        return "ERROR: ctx_preload requires a task description".to_string();
21    }
22
23    let project_root = path
24        .map(|p| p.to_string())
25        .unwrap_or_else(|| ".".to_string());
26
27    let index = crate::core::graph_index::load_or_build(&project_root);
28
29    let (task_files, task_keywords) = parse_task_hints(task);
30    let relevance = compute_relevance(&index, &task_files, &task_keywords);
31
32    let mut scored: Vec<_> = relevance
33        .iter()
34        .filter(|r| r.score >= 0.1)
35        .take(MAX_PRELOAD_FILES + 10)
36        .collect();
37
38    apply_heat_ranking(&mut scored, &index, &project_root);
39
40    let candidates = scored;
41
42    if candidates.is_empty() {
43        return format!(
44            "[task: {task}]\nNo directly relevant files found. Use ctx_overview for project map."
45        );
46    }
47
48    // Boltzmann allocation: p(file_i) = exp(score_i / T) / Z
49    // Temperature T is derived from task specificity:
50    //   - Many keywords / specific file mentions → low T → concentrate budget
51    //   - Few keywords / broad task → high T → spread budget evenly
52    let task_specificity =
53        (task_files.len() as f64 * 0.3 + task_keywords.len() as f64 * 0.1).clamp(0.0, 1.0);
54    let temperature = 0.8 - task_specificity * 0.6; // range [0.2, 0.8]
55    let temperature = temperature.max(0.1);
56
57    let allocations = boltzmann_allocate(&candidates, TOTAL_TOKEN_BUDGET, temperature);
58
59    let file_context: Vec<(String, usize)> = candidates
60        .iter()
61        .filter_map(|c| {
62            std::fs::read_to_string(&c.path)
63                .ok()
64                .map(|content| (c.path.clone(), content.lines().count()))
65        })
66        .collect();
67    let briefing = crate::core::task_briefing::build_briefing(task, &file_context);
68    let briefing_block = crate::core::task_briefing::format_briefing(&briefing);
69
70    let multi_intents = crate::core::intent_engine::detect_multi_intent(task);
71    let primary = &multi_intents[0];
72    let complexity = crate::core::intent_engine::classify_complexity(task, primary);
73
74    let mut output = Vec::new();
75    output.push(briefing_block);
76
77    let complexity_label = complexity.instruction_suffix().lines().next().unwrap_or("");
78    if multi_intents.len() > 1 {
79        output.push(format!(
80            "[task: {task}] | {} | {} sub-intents",
81            complexity_label,
82            multi_intents.len()
83        ));
84        for (i, sub) in multi_intents.iter().enumerate() {
85            output.push(format!(
86                "  {}. {} ({:.0}%)",
87                i + 1,
88                sub.task_type.as_str(),
89                sub.confidence * 100.0
90            ));
91        }
92    } else {
93        output.push(format!("[task: {task}] | {complexity_label}"));
94    }
95
96    let mut total_estimated_saved = 0usize;
97    let mut critical_count = 0usize;
98
99    for (rel, token_budget) in candidates.iter().zip(allocations.iter()) {
100        if *token_budget < 20 {
101            continue;
102        }
103        critical_count += 1;
104        if critical_count > MAX_PRELOAD_FILES {
105            break;
106        }
107
108        let content = match std::fs::read_to_string(&rel.path) {
109            Ok(c) => c,
110            Err(_) => continue,
111        };
112
113        let file_ref = cache.get_file_ref(&rel.path);
114        let short = protocol::shorten_path(&rel.path);
115        let line_count = content.lines().count();
116        let file_tokens = count_tokens(&content);
117
118        let (entry, _) = cache.store(&rel.path, content.clone());
119        let _ = entry;
120
121        let mode = budget_to_mode(*token_budget, file_tokens);
122
123        let critical_lines = extract_critical_lines(&content, &task_keywords, MAX_CRITICAL_LINES);
124        let sigs = extract_key_signatures(&content, SIGNATURES_BUDGET);
125        let imports = extract_imports(&content);
126
127        output.push(format!(
128            "\nCRITICAL: {file_ref}={short} {line_count}L score={:.1} budget={token_budget}tok mode={mode}",
129            rel.score
130        ));
131
132        if !critical_lines.is_empty() {
133            for (line_no, line) in &critical_lines {
134                output.push(format!("  :{line_no} {line}"));
135            }
136        }
137
138        if !imports.is_empty() {
139            output.push(format!("  imports: {}", imports.join(", ")));
140        }
141
142        if !sigs.is_empty() {
143            for sig in &sigs {
144                output.push(format!("  {sig}"));
145            }
146        }
147
148        total_estimated_saved += file_tokens;
149    }
150
151    let context_files: Vec<_> = relevance
152        .iter()
153        .filter(|r| r.score >= 0.1 && r.score < 0.3)
154        .take(10)
155        .collect();
156
157    if !context_files.is_empty() {
158        output.push("\nRELATED:".to_string());
159        for rel in &context_files {
160            let short = protocol::shorten_path(&rel.path);
161            output.push(format!(
162                "  {} mode={} score={:.1}",
163                short, rel.recommended_mode, rel.score
164            ));
165        }
166    }
167
168    let graph_edges: Vec<_> = index
169        .edges
170        .iter()
171        .filter(|e| {
172            candidates
173                .iter()
174                .any(|c| c.path == e.from || c.path == e.to)
175        })
176        .take(10)
177        .collect();
178
179    if !graph_edges.is_empty() {
180        output.push("\nGRAPH:".to_string());
181        for edge in &graph_edges {
182            let from_short = protocol::shorten_path(&edge.from);
183            let to_short = protocol::shorten_path(&edge.to);
184            output.push(format!("  {from_short} -> {to_short}"));
185        }
186    }
187
188    let preload_result = output.join("\n");
189    let preload_tokens = count_tokens(&preload_result);
190    let savings = protocol::format_savings(total_estimated_saved, preload_tokens);
191
192    if crp_mode.is_tdd() {
193        format!("{preload_result}\n{savings}")
194    } else {
195        format!("{preload_result}\n\nNext: ctx_read(path, mode=\"full\") for any file above.\n{savings}")
196    }
197}
198
199/// Boltzmann distribution for token budget allocation across files.
200/// p(file_i) = exp(score_i / T) / Z, then budget_i = total * p(file_i)
201fn boltzmann_allocate(
202    candidates: &[&crate::core::task_relevance::RelevanceScore],
203    total_budget: usize,
204    temperature: f64,
205) -> Vec<usize> {
206    if candidates.is_empty() {
207        return Vec::new();
208    }
209
210    let t = temperature.max(0.01);
211
212    // Compute exp(score / T) for each candidate, using log-sum-exp for numerical stability
213    let log_weights: Vec<f64> = candidates.iter().map(|c| c.score / t).collect();
214    let max_log = log_weights
215        .iter()
216        .cloned()
217        .fold(f64::NEG_INFINITY, f64::max);
218    let exp_weights: Vec<f64> = log_weights.iter().map(|&lw| (lw - max_log).exp()).collect();
219    let z: f64 = exp_weights.iter().sum();
220
221    if z <= 0.0 {
222        return vec![total_budget / candidates.len().max(1); candidates.len()];
223    }
224
225    let mut allocations: Vec<usize> = exp_weights
226        .iter()
227        .map(|&w| ((w / z) * total_budget as f64).round() as usize)
228        .collect();
229
230    // Ensure total doesn't exceed budget
231    let sum: usize = allocations.iter().sum();
232    if sum > total_budget {
233        let overflow = sum - total_budget;
234        if let Some(last) = allocations.last_mut() {
235            *last = last.saturating_sub(overflow);
236        }
237    }
238
239    allocations
240}
241
242/// Map a token budget to a recommended compression mode.
243fn budget_to_mode(budget: usize, file_tokens: usize) -> &'static str {
244    let ratio = budget as f64 / file_tokens.max(1) as f64;
245    if ratio >= 0.8 {
246        "full"
247    } else if ratio >= 0.4 {
248        "signatures"
249    } else if ratio >= 0.15 {
250        "map"
251    } else {
252        "reference"
253    }
254}
255
256fn extract_critical_lines(content: &str, keywords: &[String], max: usize) -> Vec<(usize, String)> {
257    let kw_lower: Vec<String> = keywords.iter().map(|k| k.to_lowercase()).collect();
258
259    let mut hits: Vec<(usize, String, usize)> = content
260        .lines()
261        .enumerate()
262        .filter_map(|(i, line)| {
263            let trimmed = line.trim();
264            if trimmed.is_empty() {
265                return None;
266            }
267            let line_lower = trimmed.to_lowercase();
268            let hit_count = kw_lower
269                .iter()
270                .filter(|kw| line_lower.contains(kw.as_str()))
271                .count();
272
273            let is_error = trimmed.contains("Error")
274                || trimmed.contains("Err(")
275                || trimmed.contains("panic!")
276                || trimmed.contains("unwrap()")
277                || trimmed.starts_with("return Err");
278
279            if hit_count > 0 || is_error {
280                let priority = hit_count + if is_error { 2 } else { 0 };
281                Some((i + 1, trimmed.to_string(), priority))
282            } else {
283                None
284            }
285        })
286        .collect();
287
288    hits.sort_by(|a, b| b.2.cmp(&a.2));
289    hits.truncate(max);
290    hits.iter().map(|(n, l, _)| (*n, l.clone())).collect()
291}
292
293fn extract_key_signatures(content: &str, max: usize) -> Vec<String> {
294    let sig_starters = [
295        "pub fn ",
296        "pub async fn ",
297        "pub struct ",
298        "pub enum ",
299        "pub trait ",
300        "pub type ",
301        "pub const ",
302    ];
303
304    content
305        .lines()
306        .filter(|line| {
307            let trimmed = line.trim();
308            sig_starters.iter().any(|s| trimmed.starts_with(s))
309        })
310        .take(max)
311        .map(|line| {
312            let trimmed = line.trim();
313            if trimmed.len() > 120 {
314                format!("{}...", &trimmed[..117])
315            } else {
316                trimmed.to_string()
317            }
318        })
319        .collect()
320}
321
322fn extract_imports(content: &str) -> Vec<String> {
323    content
324        .lines()
325        .filter(|line| {
326            let t = line.trim();
327            t.starts_with("use ") || t.starts_with("import ") || t.starts_with("from ")
328        })
329        .take(8)
330        .map(|line| {
331            let t = line.trim();
332            if let Some(rest) = t.strip_prefix("use ") {
333                rest.trim_end_matches(';').to_string()
334            } else {
335                t.to_string()
336            }
337        })
338        .collect()
339}
340
341fn apply_heat_ranking(candidates: &mut [&RelevanceScore], index: &ProjectIndex, root: &str) {
342    if index.files.is_empty() {
343        return;
344    }
345
346    let mut connection_counts: std::collections::HashMap<String, usize> =
347        std::collections::HashMap::new();
348    for edge in &index.edges {
349        *connection_counts.entry(edge.from.clone()).or_default() += 1;
350        *connection_counts.entry(edge.to.clone()).or_default() += 1;
351    }
352
353    let max_tokens = index
354        .files
355        .values()
356        .map(|f| f.token_count)
357        .max()
358        .unwrap_or(1) as f64;
359    let max_conn = connection_counts.values().max().copied().unwrap_or(1) as f64;
360
361    candidates.sort_by(|a, b| {
362        let heat_a = compute_heat(
363            &a.path,
364            root,
365            index,
366            &connection_counts,
367            max_tokens,
368            max_conn,
369        );
370        let heat_b = compute_heat(
371            &b.path,
372            root,
373            index,
374            &connection_counts,
375            max_tokens,
376            max_conn,
377        );
378        let combined_a = a.score * 0.6 + heat_a * 0.4;
379        let combined_b = b.score * 0.6 + heat_b * 0.4;
380        combined_b
381            .partial_cmp(&combined_a)
382            .unwrap_or(std::cmp::Ordering::Equal)
383    });
384}
385
386fn compute_heat(
387    path: &str,
388    root: &str,
389    index: &ProjectIndex,
390    connections: &std::collections::HashMap<String, usize>,
391    max_tokens: f64,
392    max_conn: f64,
393) -> f64 {
394    let rel = path
395        .strip_prefix(root)
396        .unwrap_or(path)
397        .trim_start_matches('/');
398
399    if let Some(entry) = index.files.get(rel) {
400        let conn = connections.get(rel).copied().unwrap_or(0);
401        let token_norm = entry.token_count as f64 / max_tokens;
402        let conn_norm = conn as f64 / max_conn;
403        token_norm * 0.4 + conn_norm * 0.6
404    } else {
405        0.0
406    }
407}
408
409#[cfg(test)]
410mod tests {
411    use super::*;
412
413    #[test]
414    fn extract_critical_lines_finds_keywords() {
415        let content = "fn main() {\n    let token = validate();\n    return Err(e);\n}\n";
416        let result = extract_critical_lines(content, &["validate".to_string()], 5);
417        assert!(!result.is_empty());
418        assert!(result.iter().any(|(_, l)| l.contains("validate")));
419    }
420
421    #[test]
422    fn extract_critical_lines_prioritizes_errors() {
423        let content = "fn main() {\n    let x = 1;\n    return Err(\"bad\");\n    let token = validate();\n}\n";
424        let result = extract_critical_lines(content, &["validate".to_string()], 5);
425        assert!(result.len() >= 2);
426        assert!(result[0].1.contains("Err"), "errors should be first");
427    }
428
429    #[test]
430    fn extract_key_signatures_finds_pub() {
431        let content = "use std::io;\nfn private() {}\npub fn public_one() {}\npub struct Foo {}\n";
432        let sigs = extract_key_signatures(content, 10);
433        assert_eq!(sigs.len(), 2);
434        assert!(sigs[0].contains("pub fn public_one"));
435        assert!(sigs[1].contains("pub struct Foo"));
436    }
437
438    #[test]
439    fn extract_imports_works() {
440        let content = "use std::io;\nuse crate::core::cache;\nfn main() {}\n";
441        let imports = extract_imports(content);
442        assert_eq!(imports.len(), 2);
443        assert!(imports[0].contains("std::io"));
444    }
445}