1use std::collections::HashSet;
13use std::path::{Path, PathBuf};
14
15use imp_llm::message::{ContentBlock, Message, UserMessage};
16
17#[derive(Debug, Clone, PartialEq, Eq)]
23pub enum FileMode {
24 Full,
26 Tail(usize),
28 Range(usize, usize),
30}
31
32#[derive(Debug, Clone)]
34pub struct FileSpec {
35 pub path: PathBuf,
36 pub mode: FileMode,
37}
38
39#[derive(Debug, Clone)]
41pub struct PrefillConfig {
42 pub budget_tokens: usize,
44 pub per_file_tokens: usize,
46}
47
48impl Default for PrefillConfig {
49 fn default() -> Self {
50 Self {
51 budget_tokens: 50_000,
52 per_file_tokens: 10_000,
53 }
54 }
55}
56
57#[derive(Debug)]
59pub struct AssembledContext {
60 pub messages: Vec<Message>,
62 pub included_files: Vec<PathBuf>,
64 pub warnings: Vec<String>,
66 pub estimated_tokens: usize,
68}
69
70impl AssembledContext {
71 pub fn empty() -> Self {
73 Self {
74 messages: Vec::new(),
75 included_files: Vec::new(),
76 warnings: Vec::new(),
77 estimated_tokens: 0,
78 }
79 }
80}
81
82fn estimate_tokens(text: &str) -> usize {
88 text.len() / 4
89}
90
91fn chars_from_tokens(tokens: usize) -> usize {
93 tokens * 4
94}
95
96fn read_file_with_mode(path: &Path, mode: &FileMode) -> Result<String, std::io::Error> {
102 let content = std::fs::read_to_string(path)?;
103 Ok(match mode {
104 FileMode::Full => content,
105 FileMode::Tail(n) => {
106 let lines: Vec<&str> = content.lines().collect();
107 let start = lines.len().saturating_sub(*n);
108 lines[start..].join("\n")
109 }
110 FileMode::Range(start, end) => {
111 let lines: Vec<&str> = content.lines().collect();
112 let s = start.saturating_sub(1); let e = (*end).min(lines.len());
114 if s >= lines.len() {
115 String::new()
116 } else {
117 lines[s..e].join("\n")
118 }
119 }
120 })
121}
122
123fn truncate_to_budget(content: &str, max_chars: usize) -> (String, bool) {
125 if content.len() <= max_chars {
126 return (content.to_string(), false);
127 }
128 let total_lines = content.lines().count();
129 let mut end = 0;
131 for (i, _) in content.char_indices() {
132 if i > max_chars {
133 break;
134 }
135 end = i;
136 }
137 if let Some(nl) = content[..end].rfind('\n') {
139 end = nl;
140 }
141 let truncated_lines = content[..end].lines().count();
142 let mut result = content[..end].to_string();
143 result.push_str(&format!(
144 "\n[... truncated: showing {truncated_lines} of {total_lines} lines]"
145 ));
146 (result, true)
147}
148
149pub fn assemble_context(
158 specs: &[FileSpec],
159 cwd: &Path,
160 config: &PrefillConfig,
161) -> AssembledContext {
162 if specs.is_empty() {
163 return AssembledContext::empty();
164 }
165
166 let mut included_files = Vec::new();
167 let mut warnings = Vec::new();
168 let mut file_sections = Vec::new();
169 let mut total_chars: usize = 0;
170 let char_budget = chars_from_tokens(config.budget_tokens);
171 let per_file_char_budget = chars_from_tokens(config.per_file_tokens);
172
173 let wrapper_overhead = "<context>\n</context>".len();
175 total_chars += wrapper_overhead;
176
177 for spec in specs {
178 let resolved = if spec.path.is_absolute() {
179 spec.path.clone()
180 } else {
181 cwd.join(&spec.path)
182 };
183
184 let content = match read_file_with_mode(&resolved, &spec.mode) {
186 Ok(c) => c,
187 Err(e) => {
188 warnings.push(format!("{}: {e}", spec.path.display()));
189 continue;
190 }
191 };
192
193 if content.is_empty() {
194 continue;
195 }
196
197 let mode_note = match &spec.mode {
199 FileMode::Full => String::new(),
200 FileMode::Tail(n) => format!(r#" note="last {n} lines""#),
201 FileMode::Range(s, e) => format!(r#" note="lines {s}-{e}""#),
202 };
203 let header = format!(r#"<file path="{}"{}>"#, spec.path.display(), mode_note);
204 let footer = "</file>";
205 let section_overhead = header.len() + footer.len() + 2; let (file_content, was_truncated) = truncate_to_budget(
209 &content,
210 per_file_char_budget.saturating_sub(section_overhead),
211 );
212 if was_truncated {
213 warnings.push(format!(
214 "{}: truncated to ~{} tokens (per-file budget)",
215 spec.path.display(),
216 config.per_file_tokens,
217 ));
218 }
219
220 let section = format!("{header}\n{file_content}\n{footer}");
221 let section_chars = section.len();
222
223 if total_chars + section_chars > char_budget {
225 warnings.push(format!(
226 "{}: skipped (total budget of ~{} tokens exceeded)",
227 spec.path.display(),
228 config.budget_tokens,
229 ));
230 for remaining in specs.iter().skip(included_files.len() + warnings.len()) {
232 if !included_files.contains(&remaining.path) {
234 warnings.push(format!(
235 "{}: skipped (total budget exceeded)",
236 remaining.path.display(),
237 ));
238 }
239 }
240 break;
241 }
242
243 total_chars += section_chars;
244 file_sections.push(section);
245 included_files.push(spec.path.clone());
246 }
247
248 if file_sections.is_empty() {
249 return AssembledContext {
250 messages: Vec::new(),
251 included_files,
252 warnings,
253 estimated_tokens: 0,
254 };
255 }
256
257 let xml = format!("<context>\n{}\n</context>", file_sections.join("\n"));
258 let estimated_tokens = estimate_tokens(&xml);
259
260 let message = Message::User(UserMessage {
261 content: vec![ContentBlock::Text { text: xml }],
262 timestamp: imp_llm::now(),
263 });
264
265 AssembledContext {
266 messages: vec![message],
267 included_files,
268 warnings,
269 estimated_tokens,
270 }
271}
272
273pub fn detect_file_paths(text: &str) -> Vec<FileSpec> {
286 let extensions = [
290 "rs", "ts", "tsx", "py", "go", "js", "jsx", "toml", "yaml", "yml", "json", "md", "sh",
291 "sql", "zig", "c", "cpp", "h",
292 ];
293 let ext_pattern = extensions.join("|");
294 let pattern = format!(
295 r#"(?:^|[\s(`"'(])((?:[a-zA-Z_./])[a-zA-Z0-9_./-]*\.(?:{ext_pattern}))(?::([^\s)}}"'`]*))?"#,
296 );
297 let re = regex::Regex::new(&pattern).expect("valid regex");
298
299 let mut seen = HashSet::new();
300 let mut specs = Vec::new();
301
302 for cap in re.captures_iter(text) {
303 let path_str = cap.get(1).map(|m| m.as_str()).unwrap_or("");
304 if path_str.is_empty() {
305 continue;
306 }
307
308 let path = PathBuf::from(path_str);
309 if seen.contains(&path) {
310 continue;
311 }
312 seen.insert(path.clone());
313
314 let mode = cap
315 .get(2)
316 .map(|m| parse_mode_suffix(m.as_str()))
317 .unwrap_or(FileMode::Full);
318
319 specs.push(FileSpec { path, mode });
320 }
321
322 specs
323}
324
325fn parse_mode_suffix(suffix: &str) -> FileMode {
327 if let Some(n_str) = suffix.strip_prefix("tail:") {
329 if let Ok(n) = n_str.parse::<usize>() {
330 return FileMode::Tail(n);
331 }
332 }
333 if let Some(dash_pos) = suffix.find('-') {
335 let start_str = &suffix[..dash_pos];
336 let end_str = &suffix[dash_pos + 1..];
337 if let (Ok(start), Ok(end)) = (start_str.parse::<usize>(), end_str.parse::<usize>()) {
338 return FileMode::Range(start, end);
339 }
340 }
341 FileMode::Full
342}
343
344#[cfg(test)]
349mod tests {
350 use super::*;
351 use std::fs;
352
353 fn temp_dir_with_files(files: &[(&str, &str)]) -> tempfile::TempDir {
354 let dir = tempfile::tempdir().unwrap();
355 for (name, content) in files {
356 let path = dir.path().join(name);
357 if let Some(parent) = path.parent() {
358 fs::create_dir_all(parent).unwrap();
359 }
360 fs::write(path, content).unwrap();
361 }
362 dir
363 }
364
365 #[test]
368 fn test_context_prefill_assembles_single_file() {
369 let dir =
370 temp_dir_with_files(&[("src/main.rs", "fn main() {\n println!(\"hello\");\n}")]);
371 let specs = vec![FileSpec {
372 path: PathBuf::from("src/main.rs"),
373 mode: FileMode::Full,
374 }];
375 let ctx = assemble_context(&specs, dir.path(), &PrefillConfig::default());
376 assert_eq!(ctx.included_files.len(), 1);
377 assert!(ctx.warnings.is_empty());
378 assert!(!ctx.messages.is_empty());
379
380 let text = message_text(&ctx.messages[0]);
381 assert!(text.contains("<context>"));
382 assert!(text.contains(r#"<file path="src/main.rs">"#));
383 assert!(text.contains("fn main()"));
384 assert!(text.contains("</file>"));
385 assert!(text.contains("</context>"));
386 }
387
388 #[test]
389 fn test_context_prefill_multiple_files() {
390 let dir = temp_dir_with_files(&[("src/a.rs", "struct A;"), ("src/b.rs", "struct B;")]);
391 let specs = vec![
392 FileSpec {
393 path: PathBuf::from("src/a.rs"),
394 mode: FileMode::Full,
395 },
396 FileSpec {
397 path: PathBuf::from("src/b.rs"),
398 mode: FileMode::Full,
399 },
400 ];
401 let ctx = assemble_context(&specs, dir.path(), &PrefillConfig::default());
402 assert_eq!(ctx.included_files.len(), 2);
403 let text = message_text(&ctx.messages[0]);
404 assert!(text.contains("struct A"));
405 assert!(text.contains("struct B"));
406 }
407
408 #[test]
409 fn test_context_prefill_missing_file_warning() {
410 let dir = temp_dir_with_files(&[("src/exists.rs", "exists")]);
411 let specs = vec![
412 FileSpec {
413 path: PathBuf::from("src/missing.rs"),
414 mode: FileMode::Full,
415 },
416 FileSpec {
417 path: PathBuf::from("src/exists.rs"),
418 mode: FileMode::Full,
419 },
420 ];
421 let ctx = assemble_context(&specs, dir.path(), &PrefillConfig::default());
422 assert_eq!(ctx.included_files.len(), 1);
423 assert_eq!(ctx.included_files[0], PathBuf::from("src/exists.rs"));
424 assert!(ctx.warnings.iter().any(|w| w.contains("missing.rs")));
425 }
426
427 #[test]
428 fn test_context_prefill_per_file_budget() {
429 let big_content: String = (0..200)
431 .map(|i| format!("line {i}: some content here\n"))
432 .collect();
433 let dir = temp_dir_with_files(&[("big.rs", &big_content)]);
434 let specs = vec![FileSpec {
435 path: PathBuf::from("big.rs"),
436 mode: FileMode::Full,
437 }];
438 let config = PrefillConfig {
439 budget_tokens: 100_000,
440 per_file_tokens: 100, };
442 let ctx = assemble_context(&specs, dir.path(), &config);
443 assert_eq!(ctx.included_files.len(), 1);
444 assert!(ctx.warnings.iter().any(|w| w.contains("truncated")));
445 let text = message_text(&ctx.messages[0]);
446 assert!(text.contains("[... truncated:"));
447 }
448
449 #[test]
450 fn test_context_prefill_total_budget() {
451 let content_a: String = (0..200)
453 .map(|i| format!("line_a_{i}: some padding content here\n"))
454 .collect();
455 let content_b: String = (0..200)
456 .map(|i| format!("line_b_{i}: some padding content here\n"))
457 .collect();
458 let dir = temp_dir_with_files(&[("a.rs", &content_a), ("b.rs", &content_b)]);
459 let specs = vec![
460 FileSpec {
461 path: PathBuf::from("a.rs"),
462 mode: FileMode::Full,
463 },
464 FileSpec {
465 path: PathBuf::from("b.rs"),
466 mode: FileMode::Full,
467 },
468 ];
469 let config = PrefillConfig {
470 budget_tokens: 2500, per_file_tokens: 50_000,
472 };
473 let ctx = assemble_context(&specs, dir.path(), &config);
474 assert_eq!(
476 ctx.included_files.len(),
477 1,
478 "included: {:?}, warnings: {:?}",
479 ctx.included_files,
480 ctx.warnings
481 );
482 assert!(ctx
483 .warnings
484 .iter()
485 .any(|w| w.contains("b.rs") && w.contains("budget")));
486 }
487
488 #[test]
489 fn test_context_prefill_tail_mode() {
490 let content = "line 1\nline 2\nline 3\nline 4\nline 5\n";
491 let dir = temp_dir_with_files(&[("f.rs", content)]);
492 let specs = vec![FileSpec {
493 path: PathBuf::from("f.rs"),
494 mode: FileMode::Tail(3),
495 }];
496 let ctx = assemble_context(&specs, dir.path(), &PrefillConfig::default());
497 let text = message_text(&ctx.messages[0]);
498 assert!(!text.contains("line 1"));
499 assert!(!text.contains("line 2"));
500 assert!(text.contains("line 3"));
501 assert!(text.contains("line 4"));
502 assert!(text.contains("line 5"));
503 }
504
505 #[test]
506 fn test_context_prefill_range_mode() {
507 let content = "line 1\nline 2\nline 3\nline 4\nline 5\n";
508 let dir = temp_dir_with_files(&[("f.rs", content)]);
509 let specs = vec![FileSpec {
510 path: PathBuf::from("f.rs"),
511 mode: FileMode::Range(2, 4),
512 }];
513 let ctx = assemble_context(&specs, dir.path(), &PrefillConfig::default());
514 let text = message_text(&ctx.messages[0]);
515 assert!(!text.contains("line 1"));
516 assert!(text.contains("line 2"));
517 assert!(text.contains("line 3"));
518 assert!(text.contains("line 4"));
519 assert!(!text.contains("line 5"));
520 }
521
522 #[test]
523 fn test_context_prefill_empty_specs() {
524 let dir = tempfile::tempdir().unwrap();
525 let ctx = assemble_context(&[], dir.path(), &PrefillConfig::default());
526 assert!(ctx.messages.is_empty());
527 assert!(ctx.included_files.is_empty());
528 assert_eq!(ctx.estimated_tokens, 0);
529 }
530
531 #[test]
534 fn test_context_prefill_detect_paths() {
535 let text = "Modify src/auth.rs and read crates/imp-llm/src/provider.rs for context.";
536 let specs = detect_file_paths(text);
537 let paths: Vec<_> = specs.iter().map(|s| s.path.to_str().unwrap()).collect();
538 assert!(paths.contains(&"src/auth.rs"));
539 assert!(paths.contains(&"crates/imp-llm/src/provider.rs"));
540 }
541
542 #[test]
543 fn test_context_prefill_detect_deduplicates() {
544 let text = "Read src/foo.rs first, then modify src/foo.rs to add the function.";
545 let specs = detect_file_paths(text);
546 let foo_count = specs
547 .iter()
548 .filter(|s| s.path == std::path::Path::new("src/foo.rs"))
549 .count();
550 assert_eq!(foo_count, 1);
551 }
552
553 #[test]
554 fn test_context_prefill_detect_ignores_non_paths() {
555 let text = "Handle errors gracefully. The users table needs updating.";
556 let specs = detect_file_paths(text);
557 assert!(specs.is_empty(), "got: {:?}", specs);
559 }
560
561 #[test]
562 fn test_context_prefill_detect_tail_suffix() {
563 let text = "Check patterns in tests/auth_test.rs:tail:50 for reference.";
564 let specs = detect_file_paths(text);
565 assert_eq!(specs.len(), 1);
566 assert_eq!(specs[0].path, PathBuf::from("tests/auth_test.rs"));
567 assert_eq!(specs[0].mode, FileMode::Tail(50));
568 }
569
570 #[test]
571 fn test_context_prefill_detect_range_suffix() {
572 let text = "See src/lib.rs:10-50 for the relevant types.";
573 let specs = detect_file_paths(text);
574 assert_eq!(specs.len(), 1);
575 assert_eq!(specs[0].path, PathBuf::from("src/lib.rs"));
576 assert_eq!(specs[0].mode, FileMode::Range(10, 50));
577 }
578
579 fn message_text(msg: &Message) -> String {
582 match msg {
583 Message::User(u) => u
584 .content
585 .iter()
586 .filter_map(|b| match b {
587 ContentBlock::Text { text } => Some(text.as_str()),
588 _ => None,
589 })
590 .collect::<Vec<_>>()
591 .join(""),
592 _ => String::new(),
593 }
594 }
595}