1use crate::core::cache::SessionCache;
2use crate::core::graph_provider::{self, GraphProvider};
3use crate::core::protocol;
4use crate::core::task_relevance::{compute_relevance, parse_task_hints, RelevanceScore};
5use crate::core::tokens::count_tokens;
6use crate::tools::CrpMode;
7
8const MAX_PRELOAD_FILES: usize = 8;
9const MAX_CRITICAL_LINES: usize = 15;
10const SIGNATURES_BUDGET: usize = 10;
11const TOTAL_TOKEN_BUDGET: usize = 4000;
12
13pub fn handle(
14 cache: &mut SessionCache,
15 task: &str,
16 path: Option<&str>,
17 crp_mode: CrpMode,
18) -> String {
19 if task.trim().is_empty() {
20 return "ERROR: ctx_preload requires a task description".to_string();
21 }
22
23 let project_root = path.map_or_else(|| ".".to_string(), std::string::ToString::to_string);
24 let jail_root = std::path::Path::new(&project_root);
25
26 let Some(open) = graph_provider::open_or_build(&project_root) else {
27 return format!("[task: {task}]\nNo graph available. Use ctx_overview for project map.");
28 };
29 let gp = &open.provider;
30
31 let session_intent =
32 crate::core::session::SessionState::load_latest().and_then(|s| s.active_structured_intent);
33
34 let (task_files, task_keywords) = parse_task_hints(task);
35 let relevance = if let Some(ref intent) = session_intent {
36 crate::core::task_relevance::compute_relevance_from_intent(gp, intent)
37 } else {
38 compute_relevance(gp, &task_files, &task_keywords)
39 };
40
41 let mut scored: Vec<_> = relevance
42 .iter()
43 .filter(|r| r.score >= 0.1)
44 .take(MAX_PRELOAD_FILES + 10)
45 .collect();
46
47 apply_heat_ranking(&mut scored, gp, &project_root);
48
49 let pop = crate::core::pop_pruning::decide_for_candidates(task, &project_root, &scored);
50 let candidates =
51 crate::core::pop_pruning::filter_candidates_by_pop(&project_root, &scored, &pop);
52
53 if candidates.is_empty() {
54 return format!(
55 "[task: {task}]\nNo directly relevant files found. Use ctx_overview for project map."
56 );
57 }
58
59 let task_specificity =
64 (task_files.len() as f64 * 0.3 + task_keywords.len() as f64 * 0.1).clamp(0.0, 1.0);
65 let temperature = 0.8 - task_specificity * 0.6; let temperature = temperature.max(0.1);
67
68 let allocations = boltzmann_allocate(&candidates, TOTAL_TOKEN_BUDGET, temperature);
69
70 let file_context: Vec<(String, usize)> = candidates
71 .iter()
72 .filter_map(|c| {
73 let Ok((jailed, warning)) = crate::core::io_boundary::jail_and_check_path(
74 "ctx_preload",
75 std::path::Path::new(&c.path),
76 jail_root,
77 ) else {
78 return None;
79 };
80 if warning.is_some() {
81 return None;
82 }
83 std::fs::read_to_string(&jailed)
84 .ok()
85 .map(|content| (c.path.clone(), content.lines().count()))
86 })
87 .collect();
88 let briefing = crate::core::task_briefing::build_briefing(task, &file_context);
89 let briefing_block = crate::core::task_briefing::format_briefing(&briefing);
90
91 let multi_intents = crate::core::intent_engine::detect_multi_intent(task);
92 let primary = &multi_intents[0];
93 let complexity = crate::core::intent_engine::classify_complexity(task, primary);
94
95 let mut output = Vec::new();
96 output.push(briefing_block);
97
98 let complexity_label = complexity.instruction_suffix().lines().next().unwrap_or("");
99 if multi_intents.len() > 1 {
100 output.push(format!(
101 "[task: {task}] | {} | {} sub-intents",
102 complexity_label,
103 multi_intents.len()
104 ));
105 for (i, sub) in multi_intents.iter().enumerate() {
106 output.push(format!(
107 " {}. {} ({:.0}%)",
108 i + 1,
109 sub.task_type.as_str(),
110 sub.confidence * 100.0
111 ));
112 }
113 } else {
114 output.push(format!("[task: {task}] | {complexity_label}"));
115 }
116
117 for r in crate::core::prospective_memory::reminders_for_task(&project_root, task) {
118 output.push(r);
119 }
120
121 if !pop.excluded_modules.is_empty() {
122 output.push("POP:".to_string());
123 for ex in &pop.excluded_modules {
124 output.push(format!(
125 " - exclude {}/ ({} candidates) — {}",
126 ex.module, ex.candidate_files, ex.reason
127 ));
128 }
129 }
130
131 let mut total_estimated_saved = 0usize;
132 let mut critical_count = 0usize;
133
134 for (rel, token_budget) in candidates.iter().zip(allocations.iter()) {
135 if *token_budget < 20 {
136 continue;
137 }
138 critical_count += 1;
139 if critical_count > MAX_PRELOAD_FILES {
140 break;
141 }
142
143 let Ok((jailed, warning)) = crate::core::io_boundary::jail_and_check_path(
144 "ctx_preload",
145 std::path::Path::new(&rel.path),
146 jail_root,
147 ) else {
148 continue;
149 };
150 if warning.is_some() {
151 continue;
152 }
153
154 let jailed_s = jailed.to_string_lossy().to_string();
155 let Ok(content) = std::fs::read_to_string(&jailed) else {
156 continue;
157 };
158
159 let file_ref = cache.get_file_ref(&jailed_s);
160 let short = protocol::shorten_path(&jailed_s);
161 let line_count = content.lines().count();
162 let file_tokens = count_tokens(&content);
163
164 let _ = cache.store(&jailed_s, &content);
165
166 let mode = budget_to_mode(*token_budget, file_tokens);
167
168 let critical_lines = extract_critical_lines(&content, &task_keywords, MAX_CRITICAL_LINES);
169 let sigs = extract_key_signatures(&content, SIGNATURES_BUDGET);
170 let imports = extract_imports(&content);
171
172 output.push(format!(
173 "\nCRITICAL: {file_ref}={short} {line_count}L score={:.1} budget={token_budget}tok mode={mode}",
174 rel.score
175 ));
176
177 if !critical_lines.is_empty() {
178 for (line_no, line) in &critical_lines {
179 output.push(format!(" :{line_no} {line}"));
180 }
181 }
182
183 if !imports.is_empty() {
184 output.push(format!(" imports: {}", imports.join(", ")));
185 }
186
187 if !sigs.is_empty() {
188 for sig in &sigs {
189 output.push(format!(" {sig}"));
190 }
191 }
192
193 total_estimated_saved += file_tokens;
194 }
195
196 let context_files: Vec<_> = relevance
197 .iter()
198 .filter(|r| r.score >= 0.1 && r.score < 0.3)
199 .take(10)
200 .collect();
201
202 if !context_files.is_empty() {
203 output.push("\nRELATED:".to_string());
204 for rel in &context_files {
205 let short = protocol::shorten_path(&rel.path);
206 output.push(format!(
207 " {} mode={} score={:.1}",
208 short, rel.recommended_mode, rel.score
209 ));
210 }
211 }
212
213 let all_edges = gp.edges();
214 let graph_edges: Vec<_> = all_edges
215 .iter()
216 .filter(|e| {
217 candidates
218 .iter()
219 .any(|c| c.path == e.from || c.path == e.to)
220 })
221 .take(10)
222 .collect();
223
224 if !graph_edges.is_empty() {
225 output.push("\nGRAPH:".to_string());
226 for edge in &graph_edges {
227 let from_short = protocol::shorten_path(&edge.from);
228 let to_short = protocol::shorten_path(&edge.to);
229 output.push(format!(" {from_short} -> {to_short}"));
230 }
231 }
232
233 let preload_result = output.join("\n");
234 let preload_tokens = count_tokens(&preload_result);
235 let savings = protocol::format_savings(total_estimated_saved, preload_tokens);
236
237 if crp_mode.is_tdd() {
238 format!("{preload_result}\n{savings}")
239 } else {
240 format!("{preload_result}\n\nNext: ctx_read(path, mode=\"full\") for any file above.\n{savings}")
241 }
242}
243
244fn boltzmann_allocate(
247 candidates: &[&crate::core::task_relevance::RelevanceScore],
248 total_budget: usize,
249 temperature: f64,
250) -> Vec<usize> {
251 if candidates.is_empty() {
252 return Vec::new();
253 }
254
255 let t = temperature.max(0.01);
256
257 let log_weights: Vec<f64> = candidates.iter().map(|c| c.score / t).collect();
259 let max_log = log_weights
260 .iter()
261 .copied()
262 .fold(f64::NEG_INFINITY, f64::max);
263 let exp_weights: Vec<f64> = log_weights.iter().map(|&lw| (lw - max_log).exp()).collect();
264 let z: f64 = exp_weights.iter().sum();
265
266 if z <= 0.0 {
267 return vec![total_budget / candidates.len().max(1); candidates.len()];
268 }
269
270 let mut allocations: Vec<usize> = exp_weights
271 .iter()
272 .map(|&w| ((w / z) * total_budget as f64).round() as usize)
273 .collect();
274
275 let sum: usize = allocations.iter().sum();
277 if sum > total_budget {
278 let overflow = sum - total_budget;
279 if let Some(last) = allocations.last_mut() {
280 *last = last.saturating_sub(overflow);
281 }
282 }
283
284 allocations
285}
286
287fn budget_to_mode(budget: usize, file_tokens: usize) -> &'static str {
289 let ratio = budget as f64 / file_tokens.max(1) as f64;
290 if ratio >= 0.8 {
291 "full"
292 } else if ratio >= 0.4 {
293 "signatures"
294 } else if ratio >= 0.15 {
295 "map"
296 } else {
297 "reference"
298 }
299}
300
301fn extract_critical_lines(content: &str, keywords: &[String], max: usize) -> Vec<(usize, String)> {
302 let kw_lower: Vec<String> = keywords.iter().map(|k| k.to_lowercase()).collect();
303
304 let mut hits: Vec<(usize, String, usize)> = content
305 .lines()
306 .enumerate()
307 .filter_map(|(i, line)| {
308 let trimmed = line.trim();
309 if trimmed.is_empty() {
310 return None;
311 }
312 let line_lower = trimmed.to_lowercase();
313 let hit_count = kw_lower
314 .iter()
315 .filter(|kw| line_lower.contains(kw.as_str()))
316 .count();
317
318 let is_error = trimmed.contains("Error")
319 || trimmed.contains("Err(")
320 || trimmed.contains("panic!")
321 || trimmed.contains("unwrap()")
322 || trimmed.starts_with("return Err");
323
324 if hit_count > 0 || is_error {
325 let priority = hit_count + if is_error { 2 } else { 0 };
326 Some((i + 1, trimmed.to_string(), priority))
327 } else {
328 None
329 }
330 })
331 .collect();
332
333 hits.sort_by_key(|x| std::cmp::Reverse(x.2));
334 hits.truncate(max);
335 hits.iter().map(|(n, l, _)| (*n, l.clone())).collect()
336}
337
338fn extract_key_signatures(content: &str, max: usize) -> Vec<String> {
339 let sig_starters = [
340 "pub fn ",
341 "pub async fn ",
342 "pub struct ",
343 "pub enum ",
344 "pub trait ",
345 "pub type ",
346 "pub const ",
347 ];
348
349 content
350 .lines()
351 .filter(|line| {
352 let trimmed = line.trim();
353 sig_starters.iter().any(|s| trimmed.starts_with(s))
354 })
355 .take(max)
356 .map(|line| {
357 let trimmed = line.trim();
358 if trimmed.len() > 120 {
359 format!("{}...", &trimmed[..trimmed.floor_char_boundary(117)])
360 } else {
361 trimmed.to_string()
362 }
363 })
364 .collect()
365}
366
367fn extract_imports(content: &str) -> Vec<String> {
368 content
369 .lines()
370 .filter(|line| {
371 let t = line.trim();
372 t.starts_with("use ") || t.starts_with("import ") || t.starts_with("from ")
373 })
374 .take(8)
375 .map(|line| {
376 let t = line.trim();
377 if let Some(rest) = t.strip_prefix("use ") {
378 rest.trim_end_matches(';').to_string()
379 } else {
380 t.to_string()
381 }
382 })
383 .collect()
384}
385
386fn apply_heat_ranking(candidates: &mut [&RelevanceScore], gp: &GraphProvider, root: &str) {
387 if gp.file_count() == 0 {
388 return;
389 }
390
391 let all_edges = gp.edges();
392 let mut connection_counts: std::collections::HashMap<String, usize> =
393 std::collections::HashMap::new();
394 for edge in &all_edges {
395 *connection_counts.entry(edge.from.clone()).or_default() += 1;
396 *connection_counts.entry(edge.to.clone()).or_default() += 1;
397 }
398
399 let mut max_tokens = 1usize;
400 for path in gp.file_paths() {
401 if let Some(entry) = gp.get_file_entry(&path) {
402 max_tokens = max_tokens.max(entry.token_count);
403 }
404 }
405 let max_tokens = max_tokens as f64;
406 let max_conn = connection_counts.values().max().copied().unwrap_or(1) as f64;
407
408 candidates.sort_by(|a, b| {
409 let heat_a = compute_heat(&a.path, root, gp, &connection_counts, max_tokens, max_conn);
410 let heat_b = compute_heat(&b.path, root, gp, &connection_counts, max_tokens, max_conn);
411 let combined_a = a.score * 0.6 + heat_a * 0.4;
412 let combined_b = b.score * 0.6 + heat_b * 0.4;
413 combined_b
414 .partial_cmp(&combined_a)
415 .unwrap_or(std::cmp::Ordering::Equal)
416 });
417}
418
419fn compute_heat(
420 path: &str,
421 root: &str,
422 gp: &GraphProvider,
423 connections: &std::collections::HashMap<String, usize>,
424 max_tokens: f64,
425 max_conn: f64,
426) -> f64 {
427 let rel = path
428 .strip_prefix(root)
429 .unwrap_or(path)
430 .trim_start_matches('/');
431
432 if let Some(entry) = gp.get_file_entry(rel) {
433 let conn = connections.get(rel).copied().unwrap_or(0);
434 let token_norm = entry.token_count as f64 / max_tokens;
435 let conn_norm = conn as f64 / max_conn;
436 token_norm * 0.4 + conn_norm * 0.6
437 } else {
438 0.0
439 }
440}
441
442#[cfg(test)]
443mod tests {
444 use super::*;
445
446 #[test]
447 fn extract_critical_lines_finds_keywords() {
448 let content = "fn main() {\n let token = validate();\n return Err(e);\n}\n";
449 let result = extract_critical_lines(content, &["validate".to_string()], 5);
450 assert!(!result.is_empty());
451 assert!(result.iter().any(|(_, l)| l.contains("validate")));
452 }
453
454 #[test]
455 fn extract_critical_lines_prioritizes_errors() {
456 let content = "fn main() {\n let x = 1;\n return Err(\"bad\");\n let token = validate();\n}\n";
457 let result = extract_critical_lines(content, &["validate".to_string()], 5);
458 assert!(result.len() >= 2);
459 assert!(result[0].1.contains("Err"), "errors should be first");
460 }
461
462 #[test]
463 fn extract_key_signatures_finds_pub() {
464 let content = "use std::io;\nfn private() {}\npub fn public_one() {}\npub struct Foo {}\n";
465 let sigs = extract_key_signatures(content, 10);
466 assert_eq!(sigs.len(), 2);
467 assert!(sigs[0].contains("pub fn public_one"));
468 assert!(sigs[1].contains("pub struct Foo"));
469 }
470
471 #[test]
472 fn extract_imports_works() {
473 let content = "use std::io;\nuse crate::core::cache;\nfn main() {}\n";
474 let imports = extract_imports(content);
475 assert_eq!(imports.len(), 2);
476 assert!(imports[0].contains("std::io"));
477 }
478}