lean_ctx/tools/
ctx_preload.rs1use crate::core::cache::SessionCache;
2use crate::core::protocol;
3use crate::core::task_relevance::{compute_relevance, parse_task_hints};
4use crate::core::tokens::count_tokens;
5use crate::tools::CrpMode;
6
7const MAX_PRELOAD_FILES: usize = 8;
8const MAX_CRITICAL_LINES: usize = 15;
9const SIGNATURES_BUDGET: usize = 10;
10const TOTAL_TOKEN_BUDGET: usize = 4000;
11
12pub fn handle(
13 cache: &mut SessionCache,
14 task: &str,
15 path: Option<&str>,
16 crp_mode: CrpMode,
17) -> String {
18 if task.trim().is_empty() {
19 return "ERROR: ctx_preload requires a task description".to_string();
20 }
21
22 let project_root = path
23 .map(|p| p.to_string())
24 .unwrap_or_else(|| ".".to_string());
25
26 let index = crate::core::graph_index::load_or_build(&project_root);
27
28 let (task_files, task_keywords) = parse_task_hints(task);
29 let relevance = compute_relevance(&index, &task_files, &task_keywords);
30
31 let candidates: Vec<_> = relevance
32 .iter()
33 .filter(|r| r.score >= 0.1)
34 .take(MAX_PRELOAD_FILES + 10)
35 .collect();
36
37 if candidates.is_empty() {
38 return format!(
39 "[task: {task}]\nNo directly relevant files found. Use ctx_overview for project map."
40 );
41 }
42
43 let task_specificity =
48 (task_files.len() as f64 * 0.3 + task_keywords.len() as f64 * 0.1).clamp(0.0, 1.0);
49 let temperature = 0.8 - task_specificity * 0.6; let temperature = temperature.max(0.1);
51
52 let allocations = boltzmann_allocate(&candidates, TOTAL_TOKEN_BUDGET, temperature);
53
54 let file_context: Vec<(String, usize)> = candidates
55 .iter()
56 .filter_map(|c| {
57 std::fs::read_to_string(&c.path)
58 .ok()
59 .map(|content| (c.path.clone(), content.lines().count()))
60 })
61 .collect();
62 let briefing = crate::core::task_briefing::build_briefing(task, &file_context);
63 let briefing_block = crate::core::task_briefing::format_briefing(&briefing);
64
65 let mut output = Vec::new();
66 output.push(briefing_block);
67 output.push(format!("[task: {task}]"));
68
69 let mut total_estimated_saved = 0usize;
70 let mut critical_count = 0usize;
71
72 for (rel, token_budget) in candidates.iter().zip(allocations.iter()) {
73 if *token_budget < 20 {
74 continue;
75 }
76 critical_count += 1;
77 if critical_count > MAX_PRELOAD_FILES {
78 break;
79 }
80
81 let content = match std::fs::read_to_string(&rel.path) {
82 Ok(c) => c,
83 Err(_) => continue,
84 };
85
86 let file_ref = cache.get_file_ref(&rel.path);
87 let short = protocol::shorten_path(&rel.path);
88 let line_count = content.lines().count();
89 let file_tokens = count_tokens(&content);
90
91 let (entry, _) = cache.store(&rel.path, content.clone());
92 let _ = entry;
93
94 let mode = budget_to_mode(*token_budget, file_tokens);
95
96 let critical_lines = extract_critical_lines(&content, &task_keywords, MAX_CRITICAL_LINES);
97 let sigs = extract_key_signatures(&content, SIGNATURES_BUDGET);
98 let imports = extract_imports(&content);
99
100 output.push(format!(
101 "\nCRITICAL: {file_ref}={short} {line_count}L score={:.1} budget={token_budget}tok mode={mode}",
102 rel.score
103 ));
104
105 if !critical_lines.is_empty() {
106 for (line_no, line) in &critical_lines {
107 output.push(format!(" :{line_no} {line}"));
108 }
109 }
110
111 if !imports.is_empty() {
112 output.push(format!(" imports: {}", imports.join(", ")));
113 }
114
115 if !sigs.is_empty() {
116 for sig in &sigs {
117 output.push(format!(" {sig}"));
118 }
119 }
120
121 total_estimated_saved += file_tokens;
122 }
123
124 let context_files: Vec<_> = relevance
125 .iter()
126 .filter(|r| r.score >= 0.1 && r.score < 0.3)
127 .take(10)
128 .collect();
129
130 if !context_files.is_empty() {
131 output.push("\nRELATED:".to_string());
132 for rel in &context_files {
133 let short = protocol::shorten_path(&rel.path);
134 output.push(format!(
135 " {} mode={} score={:.1}",
136 short, rel.recommended_mode, rel.score
137 ));
138 }
139 }
140
141 let graph_edges: Vec<_> = index
142 .edges
143 .iter()
144 .filter(|e| {
145 candidates
146 .iter()
147 .any(|c| c.path == e.from || c.path == e.to)
148 })
149 .take(10)
150 .collect();
151
152 if !graph_edges.is_empty() {
153 output.push("\nGRAPH:".to_string());
154 for edge in &graph_edges {
155 let from_short = protocol::shorten_path(&edge.from);
156 let to_short = protocol::shorten_path(&edge.to);
157 output.push(format!(" {from_short} -> {to_short}"));
158 }
159 }
160
161 let preload_result = output.join("\n");
162 let preload_tokens = count_tokens(&preload_result);
163 let savings = protocol::format_savings(total_estimated_saved, preload_tokens);
164
165 if crp_mode.is_tdd() {
166 format!("{preload_result}\n{savings}")
167 } else {
168 format!("{preload_result}\n\nNext: ctx_read(path, mode=\"full\") for any file above.\n{savings}")
169 }
170}
171
172fn boltzmann_allocate(
175 candidates: &[&crate::core::task_relevance::RelevanceScore],
176 total_budget: usize,
177 temperature: f64,
178) -> Vec<usize> {
179 if candidates.is_empty() {
180 return Vec::new();
181 }
182
183 let t = temperature.max(0.01);
184
185 let log_weights: Vec<f64> = candidates.iter().map(|c| c.score / t).collect();
187 let max_log = log_weights
188 .iter()
189 .cloned()
190 .fold(f64::NEG_INFINITY, f64::max);
191 let exp_weights: Vec<f64> = log_weights.iter().map(|&lw| (lw - max_log).exp()).collect();
192 let z: f64 = exp_weights.iter().sum();
193
194 if z <= 0.0 {
195 return vec![total_budget / candidates.len().max(1); candidates.len()];
196 }
197
198 let mut allocations: Vec<usize> = exp_weights
199 .iter()
200 .map(|&w| ((w / z) * total_budget as f64).round() as usize)
201 .collect();
202
203 let sum: usize = allocations.iter().sum();
205 if sum > total_budget {
206 let overflow = sum - total_budget;
207 if let Some(last) = allocations.last_mut() {
208 *last = last.saturating_sub(overflow);
209 }
210 }
211
212 allocations
213}
214
215fn budget_to_mode(budget: usize, file_tokens: usize) -> &'static str {
217 let ratio = budget as f64 / file_tokens.max(1) as f64;
218 if ratio >= 0.8 {
219 "full"
220 } else if ratio >= 0.4 {
221 "signatures"
222 } else if ratio >= 0.15 {
223 "map"
224 } else {
225 "reference"
226 }
227}
228
229fn extract_critical_lines(content: &str, keywords: &[String], max: usize) -> Vec<(usize, String)> {
230 let kw_lower: Vec<String> = keywords.iter().map(|k| k.to_lowercase()).collect();
231
232 let mut hits: Vec<(usize, String, usize)> = content
233 .lines()
234 .enumerate()
235 .filter_map(|(i, line)| {
236 let trimmed = line.trim();
237 if trimmed.is_empty() {
238 return None;
239 }
240 let line_lower = trimmed.to_lowercase();
241 let hit_count = kw_lower
242 .iter()
243 .filter(|kw| line_lower.contains(kw.as_str()))
244 .count();
245
246 let is_error = trimmed.contains("Error")
247 || trimmed.contains("Err(")
248 || trimmed.contains("panic!")
249 || trimmed.contains("unwrap()")
250 || trimmed.starts_with("return Err");
251
252 if hit_count > 0 || is_error {
253 let priority = hit_count + if is_error { 2 } else { 0 };
254 Some((i + 1, trimmed.to_string(), priority))
255 } else {
256 None
257 }
258 })
259 .collect();
260
261 hits.sort_by(|a, b| b.2.cmp(&a.2));
262 hits.truncate(max);
263 hits.iter().map(|(n, l, _)| (*n, l.clone())).collect()
264}
265
266fn extract_key_signatures(content: &str, max: usize) -> Vec<String> {
267 let sig_starters = [
268 "pub fn ",
269 "pub async fn ",
270 "pub struct ",
271 "pub enum ",
272 "pub trait ",
273 "pub type ",
274 "pub const ",
275 ];
276
277 content
278 .lines()
279 .filter(|line| {
280 let trimmed = line.trim();
281 sig_starters.iter().any(|s| trimmed.starts_with(s))
282 })
283 .take(max)
284 .map(|line| {
285 let trimmed = line.trim();
286 if trimmed.len() > 120 {
287 format!("{}...", &trimmed[..117])
288 } else {
289 trimmed.to_string()
290 }
291 })
292 .collect()
293}
294
295fn extract_imports(content: &str) -> Vec<String> {
296 content
297 .lines()
298 .filter(|line| {
299 let t = line.trim();
300 t.starts_with("use ") || t.starts_with("import ") || t.starts_with("from ")
301 })
302 .take(8)
303 .map(|line| {
304 let t = line.trim();
305 if let Some(rest) = t.strip_prefix("use ") {
306 rest.trim_end_matches(';').to_string()
307 } else {
308 t.to_string()
309 }
310 })
311 .collect()
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317
318 #[test]
319 fn extract_critical_lines_finds_keywords() {
320 let content = "fn main() {\n let token = validate();\n return Err(e);\n}\n";
321 let result = extract_critical_lines(content, &["validate".to_string()], 5);
322 assert!(!result.is_empty());
323 assert!(result.iter().any(|(_, l)| l.contains("validate")));
324 }
325
326 #[test]
327 fn extract_critical_lines_prioritizes_errors() {
328 let content = "fn main() {\n let x = 1;\n return Err(\"bad\");\n let token = validate();\n}\n";
329 let result = extract_critical_lines(content, &["validate".to_string()], 5);
330 assert!(result.len() >= 2);
331 assert!(result[0].1.contains("Err"), "errors should be first");
332 }
333
334 #[test]
335 fn extract_key_signatures_finds_pub() {
336 let content = "use std::io;\nfn private() {}\npub fn public_one() {}\npub struct Foo {}\n";
337 let sigs = extract_key_signatures(content, 10);
338 assert_eq!(sigs.len(), 2);
339 assert!(sigs[0].contains("pub fn public_one"));
340 assert!(sigs[1].contains("pub struct Foo"));
341 }
342
343 #[test]
344 fn extract_imports_works() {
345 let content = "use std::io;\nuse crate::core::cache;\nfn main() {}\n";
346 let imports = extract_imports(content);
347 assert_eq!(imports.len(), 2);
348 assert!(imports[0].contains("std::io"));
349 }
350}