1use crate::core::cache::SessionCache;
2use crate::core::graph_index::ProjectIndex;
3use crate::core::protocol;
4use crate::core::task_relevance::{compute_relevance, parse_task_hints, RelevanceScore};
5use crate::core::tokens::count_tokens;
6use crate::tools::CrpMode;
7
8const MAX_PRELOAD_FILES: usize = 8;
9const MAX_CRITICAL_LINES: usize = 15;
10const SIGNATURES_BUDGET: usize = 10;
11const TOTAL_TOKEN_BUDGET: usize = 4000;
12
13pub fn handle(
14 cache: &mut SessionCache,
15 task: &str,
16 path: Option<&str>,
17 crp_mode: CrpMode,
18) -> String {
19 if task.trim().is_empty() {
20 return "ERROR: ctx_preload requires a task description".to_string();
21 }
22
23 let project_root = path
24 .map(|p| p.to_string())
25 .unwrap_or_else(|| ".".to_string());
26
27 let index = crate::core::graph_index::load_or_build(&project_root);
28
29 let (task_files, task_keywords) = parse_task_hints(task);
30 let relevance = compute_relevance(&index, &task_files, &task_keywords);
31
32 let mut scored: Vec<_> = relevance
33 .iter()
34 .filter(|r| r.score >= 0.1)
35 .take(MAX_PRELOAD_FILES + 10)
36 .collect();
37
38 apply_heat_ranking(&mut scored, &index, &project_root);
39
40 let candidates = scored;
41
42 if candidates.is_empty() {
43 return format!(
44 "[task: {task}]\nNo directly relevant files found. Use ctx_overview for project map."
45 );
46 }
47
48 let task_specificity =
53 (task_files.len() as f64 * 0.3 + task_keywords.len() as f64 * 0.1).clamp(0.0, 1.0);
54 let temperature = 0.8 - task_specificity * 0.6; let temperature = temperature.max(0.1);
56
57 let allocations = boltzmann_allocate(&candidates, TOTAL_TOKEN_BUDGET, temperature);
58
59 let file_context: Vec<(String, usize)> = candidates
60 .iter()
61 .filter_map(|c| {
62 std::fs::read_to_string(&c.path)
63 .ok()
64 .map(|content| (c.path.clone(), content.lines().count()))
65 })
66 .collect();
67 let briefing = crate::core::task_briefing::build_briefing(task, &file_context);
68 let briefing_block = crate::core::task_briefing::format_briefing(&briefing);
69
70 let multi_intents = crate::core::intent_engine::detect_multi_intent(task);
71 let primary = &multi_intents[0];
72 let complexity = crate::core::intent_engine::classify_complexity(task, primary);
73
74 let mut output = Vec::new();
75 output.push(briefing_block);
76
77 let complexity_label = complexity.instruction_suffix().lines().next().unwrap_or("");
78 if multi_intents.len() > 1 {
79 output.push(format!(
80 "[task: {task}] | {} | {} sub-intents",
81 complexity_label,
82 multi_intents.len()
83 ));
84 for (i, sub) in multi_intents.iter().enumerate() {
85 output.push(format!(
86 " {}. {} ({:.0}%)",
87 i + 1,
88 sub.task_type.as_str(),
89 sub.confidence * 100.0
90 ));
91 }
92 } else {
93 output.push(format!("[task: {task}] | {complexity_label}"));
94 }
95
96 let mut total_estimated_saved = 0usize;
97 let mut critical_count = 0usize;
98
99 for (rel, token_budget) in candidates.iter().zip(allocations.iter()) {
100 if *token_budget < 20 {
101 continue;
102 }
103 critical_count += 1;
104 if critical_count > MAX_PRELOAD_FILES {
105 break;
106 }
107
108 let content = match std::fs::read_to_string(&rel.path) {
109 Ok(c) => c,
110 Err(_) => continue,
111 };
112
113 let file_ref = cache.get_file_ref(&rel.path);
114 let short = protocol::shorten_path(&rel.path);
115 let line_count = content.lines().count();
116 let file_tokens = count_tokens(&content);
117
118 let (entry, _) = cache.store(&rel.path, content.clone());
119 let _ = entry;
120
121 let mode = budget_to_mode(*token_budget, file_tokens);
122
123 let critical_lines = extract_critical_lines(&content, &task_keywords, MAX_CRITICAL_LINES);
124 let sigs = extract_key_signatures(&content, SIGNATURES_BUDGET);
125 let imports = extract_imports(&content);
126
127 output.push(format!(
128 "\nCRITICAL: {file_ref}={short} {line_count}L score={:.1} budget={token_budget}tok mode={mode}",
129 rel.score
130 ));
131
132 if !critical_lines.is_empty() {
133 for (line_no, line) in &critical_lines {
134 output.push(format!(" :{line_no} {line}"));
135 }
136 }
137
138 if !imports.is_empty() {
139 output.push(format!(" imports: {}", imports.join(", ")));
140 }
141
142 if !sigs.is_empty() {
143 for sig in &sigs {
144 output.push(format!(" {sig}"));
145 }
146 }
147
148 total_estimated_saved += file_tokens;
149 }
150
151 let context_files: Vec<_> = relevance
152 .iter()
153 .filter(|r| r.score >= 0.1 && r.score < 0.3)
154 .take(10)
155 .collect();
156
157 if !context_files.is_empty() {
158 output.push("\nRELATED:".to_string());
159 for rel in &context_files {
160 let short = protocol::shorten_path(&rel.path);
161 output.push(format!(
162 " {} mode={} score={:.1}",
163 short, rel.recommended_mode, rel.score
164 ));
165 }
166 }
167
168 let graph_edges: Vec<_> = index
169 .edges
170 .iter()
171 .filter(|e| {
172 candidates
173 .iter()
174 .any(|c| c.path == e.from || c.path == e.to)
175 })
176 .take(10)
177 .collect();
178
179 if !graph_edges.is_empty() {
180 output.push("\nGRAPH:".to_string());
181 for edge in &graph_edges {
182 let from_short = protocol::shorten_path(&edge.from);
183 let to_short = protocol::shorten_path(&edge.to);
184 output.push(format!(" {from_short} -> {to_short}"));
185 }
186 }
187
188 let preload_result = output.join("\n");
189 let preload_tokens = count_tokens(&preload_result);
190 let savings = protocol::format_savings(total_estimated_saved, preload_tokens);
191
192 if crp_mode.is_tdd() {
193 format!("{preload_result}\n{savings}")
194 } else {
195 format!("{preload_result}\n\nNext: ctx_read(path, mode=\"full\") for any file above.\n{savings}")
196 }
197}
198
199fn boltzmann_allocate(
202 candidates: &[&crate::core::task_relevance::RelevanceScore],
203 total_budget: usize,
204 temperature: f64,
205) -> Vec<usize> {
206 if candidates.is_empty() {
207 return Vec::new();
208 }
209
210 let t = temperature.max(0.01);
211
212 let log_weights: Vec<f64> = candidates.iter().map(|c| c.score / t).collect();
214 let max_log = log_weights
215 .iter()
216 .cloned()
217 .fold(f64::NEG_INFINITY, f64::max);
218 let exp_weights: Vec<f64> = log_weights.iter().map(|&lw| (lw - max_log).exp()).collect();
219 let z: f64 = exp_weights.iter().sum();
220
221 if z <= 0.0 {
222 return vec![total_budget / candidates.len().max(1); candidates.len()];
223 }
224
225 let mut allocations: Vec<usize> = exp_weights
226 .iter()
227 .map(|&w| ((w / z) * total_budget as f64).round() as usize)
228 .collect();
229
230 let sum: usize = allocations.iter().sum();
232 if sum > total_budget {
233 let overflow = sum - total_budget;
234 if let Some(last) = allocations.last_mut() {
235 *last = last.saturating_sub(overflow);
236 }
237 }
238
239 allocations
240}
241
242fn budget_to_mode(budget: usize, file_tokens: usize) -> &'static str {
244 let ratio = budget as f64 / file_tokens.max(1) as f64;
245 if ratio >= 0.8 {
246 "full"
247 } else if ratio >= 0.4 {
248 "signatures"
249 } else if ratio >= 0.15 {
250 "map"
251 } else {
252 "reference"
253 }
254}
255
256fn extract_critical_lines(content: &str, keywords: &[String], max: usize) -> Vec<(usize, String)> {
257 let kw_lower: Vec<String> = keywords.iter().map(|k| k.to_lowercase()).collect();
258
259 let mut hits: Vec<(usize, String, usize)> = content
260 .lines()
261 .enumerate()
262 .filter_map(|(i, line)| {
263 let trimmed = line.trim();
264 if trimmed.is_empty() {
265 return None;
266 }
267 let line_lower = trimmed.to_lowercase();
268 let hit_count = kw_lower
269 .iter()
270 .filter(|kw| line_lower.contains(kw.as_str()))
271 .count();
272
273 let is_error = trimmed.contains("Error")
274 || trimmed.contains("Err(")
275 || trimmed.contains("panic!")
276 || trimmed.contains("unwrap()")
277 || trimmed.starts_with("return Err");
278
279 if hit_count > 0 || is_error {
280 let priority = hit_count + if is_error { 2 } else { 0 };
281 Some((i + 1, trimmed.to_string(), priority))
282 } else {
283 None
284 }
285 })
286 .collect();
287
288 hits.sort_by(|a, b| b.2.cmp(&a.2));
289 hits.truncate(max);
290 hits.iter().map(|(n, l, _)| (*n, l.clone())).collect()
291}
292
293fn extract_key_signatures(content: &str, max: usize) -> Vec<String> {
294 let sig_starters = [
295 "pub fn ",
296 "pub async fn ",
297 "pub struct ",
298 "pub enum ",
299 "pub trait ",
300 "pub type ",
301 "pub const ",
302 ];
303
304 content
305 .lines()
306 .filter(|line| {
307 let trimmed = line.trim();
308 sig_starters.iter().any(|s| trimmed.starts_with(s))
309 })
310 .take(max)
311 .map(|line| {
312 let trimmed = line.trim();
313 if trimmed.len() > 120 {
314 format!("{}...", &trimmed[..117])
315 } else {
316 trimmed.to_string()
317 }
318 })
319 .collect()
320}
321
322fn extract_imports(content: &str) -> Vec<String> {
323 content
324 .lines()
325 .filter(|line| {
326 let t = line.trim();
327 t.starts_with("use ") || t.starts_with("import ") || t.starts_with("from ")
328 })
329 .take(8)
330 .map(|line| {
331 let t = line.trim();
332 if let Some(rest) = t.strip_prefix("use ") {
333 rest.trim_end_matches(';').to_string()
334 } else {
335 t.to_string()
336 }
337 })
338 .collect()
339}
340
341fn apply_heat_ranking(candidates: &mut [&RelevanceScore], index: &ProjectIndex, root: &str) {
342 if index.files.is_empty() {
343 return;
344 }
345
346 let mut connection_counts: std::collections::HashMap<String, usize> =
347 std::collections::HashMap::new();
348 for edge in &index.edges {
349 *connection_counts.entry(edge.from.clone()).or_default() += 1;
350 *connection_counts.entry(edge.to.clone()).or_default() += 1;
351 }
352
353 let max_tokens = index
354 .files
355 .values()
356 .map(|f| f.token_count)
357 .max()
358 .unwrap_or(1) as f64;
359 let max_conn = connection_counts.values().max().copied().unwrap_or(1) as f64;
360
361 candidates.sort_by(|a, b| {
362 let heat_a = compute_heat(
363 &a.path,
364 root,
365 index,
366 &connection_counts,
367 max_tokens,
368 max_conn,
369 );
370 let heat_b = compute_heat(
371 &b.path,
372 root,
373 index,
374 &connection_counts,
375 max_tokens,
376 max_conn,
377 );
378 let combined_a = a.score * 0.6 + heat_a * 0.4;
379 let combined_b = b.score * 0.6 + heat_b * 0.4;
380 combined_b
381 .partial_cmp(&combined_a)
382 .unwrap_or(std::cmp::Ordering::Equal)
383 });
384}
385
386fn compute_heat(
387 path: &str,
388 root: &str,
389 index: &ProjectIndex,
390 connections: &std::collections::HashMap<String, usize>,
391 max_tokens: f64,
392 max_conn: f64,
393) -> f64 {
394 let rel = path
395 .strip_prefix(root)
396 .unwrap_or(path)
397 .trim_start_matches('/');
398
399 if let Some(entry) = index.files.get(rel) {
400 let conn = connections.get(rel).copied().unwrap_or(0);
401 let token_norm = entry.token_count as f64 / max_tokens;
402 let conn_norm = conn as f64 / max_conn;
403 token_norm * 0.4 + conn_norm * 0.6
404 } else {
405 0.0
406 }
407}
408
409#[cfg(test)]
410mod tests {
411 use super::*;
412
413 #[test]
414 fn extract_critical_lines_finds_keywords() {
415 let content = "fn main() {\n let token = validate();\n return Err(e);\n}\n";
416 let result = extract_critical_lines(content, &["validate".to_string()], 5);
417 assert!(!result.is_empty());
418 assert!(result.iter().any(|(_, l)| l.contains("validate")));
419 }
420
421 #[test]
422 fn extract_critical_lines_prioritizes_errors() {
423 let content = "fn main() {\n let x = 1;\n return Err(\"bad\");\n let token = validate();\n}\n";
424 let result = extract_critical_lines(content, &["validate".to_string()], 5);
425 assert!(result.len() >= 2);
426 assert!(result[0].1.contains("Err"), "errors should be first");
427 }
428
429 #[test]
430 fn extract_key_signatures_finds_pub() {
431 let content = "use std::io;\nfn private() {}\npub fn public_one() {}\npub struct Foo {}\n";
432 let sigs = extract_key_signatures(content, 10);
433 assert_eq!(sigs.len(), 2);
434 assert!(sigs[0].contains("pub fn public_one"));
435 assert!(sigs[1].contains("pub struct Foo"));
436 }
437
438 #[test]
439 fn extract_imports_works() {
440 let content = "use std::io;\nuse crate::core::cache;\nfn main() {}\n";
441 let imports = extract_imports(content);
442 assert_eq!(imports.len(), 2);
443 assert!(imports[0].contains("std::io"));
444 }
445}