1use std::collections::HashSet;
13use std::path::{Path, PathBuf};
14
15use imp_llm::message::{ContentBlock, Message, UserMessage};
16
17use crate::trust::{Provenance, TrustedContext};
18
19#[derive(Debug, Clone, PartialEq, Eq)]
25pub enum FileMode {
26 Full,
28 Tail(usize),
30 Range(usize, usize),
32}
33
34#[derive(Debug, Clone)]
36pub struct FileSpec {
37 pub path: PathBuf,
38 pub mode: FileMode,
39}
40
41#[derive(Debug, Clone)]
43pub struct PrefillConfig {
44 pub budget_tokens: usize,
46 pub per_file_tokens: usize,
48 pub annotate_trust: bool,
50}
51
52impl Default for PrefillConfig {
53 fn default() -> Self {
54 Self {
55 budget_tokens: 50_000,
56 per_file_tokens: 10_000,
57 annotate_trust: false,
58 }
59 }
60}
61
62#[derive(Debug)]
64pub struct AssembledContext {
65 pub messages: Vec<Message>,
67 pub included_files: Vec<PathBuf>,
69 pub warnings: Vec<String>,
71 pub provenance: Vec<TrustedContext<PathBuf>>,
73 pub estimated_tokens: usize,
75}
76
77impl AssembledContext {
78 pub fn empty() -> Self {
80 Self {
81 messages: Vec::new(),
82 included_files: Vec::new(),
83 warnings: Vec::new(),
84 provenance: Vec::new(),
85 estimated_tokens: 0,
86 }
87 }
88}
89
90fn estimate_tokens(text: &str) -> usize {
96 text.len() / 4
97}
98
99fn chars_from_tokens(tokens: usize) -> usize {
101 tokens * 4
102}
103
104fn read_file_with_mode(path: &Path, mode: &FileMode) -> Result<String, std::io::Error> {
110 let content = std::fs::read_to_string(path)?;
111 Ok(match mode {
112 FileMode::Full => content,
113 FileMode::Tail(n) => {
114 let lines: Vec<&str> = content.lines().collect();
115 let start = lines.len().saturating_sub(*n);
116 lines[start..].join("\n")
117 }
118 FileMode::Range(start, end) => {
119 let lines: Vec<&str> = content.lines().collect();
120 let s = start.saturating_sub(1); let e = (*end).min(lines.len());
122 if s >= lines.len() {
123 String::new()
124 } else {
125 lines[s..e].join("\n")
126 }
127 }
128 })
129}
130
131fn truncate_to_budget(content: &str, max_chars: usize) -> (String, bool) {
133 if content.len() <= max_chars {
134 return (content.to_string(), false);
135 }
136 let total_lines = content.lines().count();
137 let mut end = 0;
139 for (i, _) in content.char_indices() {
140 if i > max_chars {
141 break;
142 }
143 end = i;
144 }
145 if let Some(nl) = content[..end].rfind('\n') {
147 end = nl;
148 }
149 let truncated_lines = content[..end].lines().count();
150 let mut result = content[..end].to_string();
151 result.push_str(&format!(
152 "\n[... truncated: showing {truncated_lines} of {total_lines} lines]"
153 ));
154 (result, true)
155}
156
157pub fn assemble_context(
166 specs: &[FileSpec],
167 cwd: &Path,
168 config: &PrefillConfig,
169) -> AssembledContext {
170 if specs.is_empty() {
171 return AssembledContext::empty();
172 }
173
174 let mut included_files = Vec::new();
175 let mut warnings = Vec::new();
176 let mut file_sections = Vec::new();
177 let mut total_chars: usize = 0;
178 let char_budget = chars_from_tokens(config.budget_tokens);
179 let per_file_char_budget = chars_from_tokens(config.per_file_tokens);
180
181 let wrapper_overhead = "<context>\n</context>".len();
183 total_chars += wrapper_overhead;
184
185 for spec in specs {
186 let resolved = if spec.path.is_absolute() {
187 spec.path.clone()
188 } else {
189 cwd.join(&spec.path)
190 };
191
192 let content = match read_file_with_mode(&resolved, &spec.mode) {
194 Ok(c) => c,
195 Err(e) => {
196 warnings.push(format!("{}: {e}", spec.path.display()));
197 continue;
198 }
199 };
200
201 if content.is_empty() {
202 continue;
203 }
204
205 let mode_note = match &spec.mode {
207 FileMode::Full => String::new(),
208 FileMode::Tail(n) => format!(r#" note="last {n} lines""#),
209 FileMode::Range(s, e) => format!(r#" note="lines {s}-{e}""#),
210 };
211 let trust_attr = if config.annotate_trust {
212 r#" trust="workspace:file""#
213 } else {
214 ""
215 };
216 let header = format!(
217 r#"<file path="{}"{}{}>"#,
218 spec.path.display(),
219 trust_attr,
220 mode_note
221 );
222 let footer = "</file>";
223 let section_overhead = header.len() + footer.len() + 2; let (file_content, was_truncated) = truncate_to_budget(
227 &content,
228 per_file_char_budget.saturating_sub(section_overhead),
229 );
230 if was_truncated {
231 warnings.push(format!(
232 "{}: truncated to ~{} tokens (per-file budget)",
233 spec.path.display(),
234 config.per_file_tokens,
235 ));
236 }
237
238 let section = format!("{header}\n{file_content}\n{footer}");
239 let section_chars = section.len();
240
241 if total_chars + section_chars > char_budget {
243 warnings.push(format!(
244 "{}: skipped (total budget of ~{} tokens exceeded)",
245 spec.path.display(),
246 config.budget_tokens,
247 ));
248 for remaining in specs.iter().skip(included_files.len() + warnings.len()) {
250 if !included_files.contains(&remaining.path) {
252 warnings.push(format!(
253 "{}: skipped (total budget exceeded)",
254 remaining.path.display(),
255 ));
256 }
257 }
258 break;
259 }
260
261 total_chars += section_chars;
262 file_sections.push(section);
263 included_files.push(spec.path.clone());
264 }
265
266 if file_sections.is_empty() {
267 return AssembledContext {
268 messages: Vec::new(),
269 included_files,
270 warnings,
271 provenance: Vec::new(),
272 estimated_tokens: 0,
273 };
274 }
275
276 let xml = format!("<context>\n{}\n</context>", file_sections.join("\n"));
277 let estimated_tokens = estimate_tokens(&xml);
278
279 let message = Message::User(UserMessage {
280 content: vec![ContentBlock::Text { text: xml }],
281 timestamp: imp_llm::now(),
282 });
283
284 let provenance = included_files
285 .iter()
286 .cloned()
287 .map(|path| TrustedContext::new(path.clone(), Provenance::workspace_file(path)))
288 .collect();
289
290 AssembledContext {
291 messages: vec![message],
292 included_files,
293 warnings,
294 provenance,
295 estimated_tokens,
296 }
297}
298
299pub fn detect_file_paths(text: &str) -> Vec<FileSpec> {
312 let extensions = [
316 "rs", "ts", "tsx", "py", "go", "js", "jsx", "toml", "yaml", "yml", "json", "md", "sh",
317 "sql", "zig", "c", "cpp", "h",
318 ];
319 let ext_pattern = extensions.join("|");
320 let pattern = format!(
321 r#"(?:^|[\s(`"'(])((?:[a-zA-Z_./])[a-zA-Z0-9_./-]*\.(?:{ext_pattern}))(?::([^\s)}}"'`]*))?"#,
322 );
323 let re = regex::Regex::new(&pattern).expect("valid regex");
324
325 let mut seen = HashSet::new();
326 let mut specs = Vec::new();
327
328 for cap in re.captures_iter(text) {
329 let path_str = cap.get(1).map(|m| m.as_str()).unwrap_or("");
330 if path_str.is_empty() {
331 continue;
332 }
333
334 let path = PathBuf::from(path_str);
335 if seen.contains(&path) {
336 continue;
337 }
338 seen.insert(path.clone());
339
340 let mode = cap
341 .get(2)
342 .map(|m| parse_mode_suffix(m.as_str()))
343 .unwrap_or(FileMode::Full);
344
345 specs.push(FileSpec { path, mode });
346 }
347
348 specs
349}
350
351fn parse_mode_suffix(suffix: &str) -> FileMode {
353 if let Some(n_str) = suffix.strip_prefix("tail:") {
355 if let Ok(n) = n_str.parse::<usize>() {
356 return FileMode::Tail(n);
357 }
358 }
359 if let Some(dash_pos) = suffix.find('-') {
361 let start_str = &suffix[..dash_pos];
362 let end_str = &suffix[dash_pos + 1..];
363 if let (Ok(start), Ok(end)) = (start_str.parse::<usize>(), end_str.parse::<usize>()) {
364 return FileMode::Range(start, end);
365 }
366 }
367 FileMode::Full
368}
369
370#[cfg(test)]
375mod tests {
376 use super::*;
377 use std::fs;
378
379 fn temp_dir_with_files(files: &[(&str, &str)]) -> tempfile::TempDir {
380 let dir = tempfile::tempdir().unwrap();
381 for (name, content) in files {
382 let path = dir.path().join(name);
383 if let Some(parent) = path.parent() {
384 fs::create_dir_all(parent).unwrap();
385 }
386 fs::write(path, content).unwrap();
387 }
388 dir
389 }
390
391 #[test]
394 fn prompt_context_trust_annotations_are_configurable() {
395 let dir = temp_dir_with_files(&[("README.md", "hello")]);
396 let specs = vec![FileSpec {
397 path: PathBuf::from("README.md"),
398 mode: FileMode::Full,
399 }];
400
401 let unannotated = assemble_context(&specs, dir.path(), &PrefillConfig::default());
402 let text = message_text(&unannotated.messages[0]);
403 assert!(text.contains(r#"<file path="README.md">"#));
404 assert!(!text.contains("trust="));
405
406 let annotated = assemble_context(
407 &specs,
408 dir.path(),
409 &PrefillConfig {
410 annotate_trust: true,
411 ..PrefillConfig::default()
412 },
413 );
414 let text = message_text(&annotated.messages[0]);
415 assert!(text.contains(r#"<file path="README.md" trust="workspace:file">"#));
416 }
417
418 #[test]
419 fn test_context_prefill_records_workspace_file_provenance() {
420 let dir = temp_dir_with_files(&[("README.md", "hello")]);
421 let specs = vec![FileSpec {
422 path: PathBuf::from("README.md"),
423 mode: FileMode::Full,
424 }];
425
426 let assembled = assemble_context(&specs, dir.path(), &PrefillConfig::default());
427
428 assert_eq!(assembled.provenance.len(), 1);
429 assert_eq!(assembled.provenance[0].value, PathBuf::from("README.md"));
430 assert_eq!(
431 assembled.provenance[0].provenance.trust,
432 crate::trust::TrustLabel::ProjectTrusted
433 );
434 assert!(matches!(
435 assembled.provenance[0].provenance.source,
436 crate::trust::ProvenanceSource::WorkspaceFile { .. }
437 ));
438 }
439
440 #[test]
441 fn test_context_prefill_assembles_single_file() {
442 let dir =
443 temp_dir_with_files(&[("src/main.rs", "fn main() {\n println!(\"hello\");\n}")]);
444 let specs = vec![FileSpec {
445 path: PathBuf::from("src/main.rs"),
446 mode: FileMode::Full,
447 }];
448 let ctx = assemble_context(&specs, dir.path(), &PrefillConfig::default());
449 assert_eq!(ctx.included_files.len(), 1);
450 assert!(ctx.warnings.is_empty());
451 assert!(!ctx.messages.is_empty());
452
453 let text = message_text(&ctx.messages[0]);
454 assert!(text.contains("<context>"));
455 assert!(text.contains(r#"<file path="src/main.rs">"#));
456 assert!(text.contains("fn main()"));
457 assert!(text.contains("</file>"));
458 assert!(text.contains("</context>"));
459 }
460
461 #[test]
462 fn test_context_prefill_multiple_files() {
463 let dir = temp_dir_with_files(&[("src/a.rs", "struct A;"), ("src/b.rs", "struct B;")]);
464 let specs = vec![
465 FileSpec {
466 path: PathBuf::from("src/a.rs"),
467 mode: FileMode::Full,
468 },
469 FileSpec {
470 path: PathBuf::from("src/b.rs"),
471 mode: FileMode::Full,
472 },
473 ];
474 let ctx = assemble_context(&specs, dir.path(), &PrefillConfig::default());
475 assert_eq!(ctx.included_files.len(), 2);
476 let text = message_text(&ctx.messages[0]);
477 assert!(text.contains("struct A"));
478 assert!(text.contains("struct B"));
479 }
480
481 #[test]
482 fn test_context_prefill_missing_file_warning() {
483 let dir = temp_dir_with_files(&[("src/exists.rs", "exists")]);
484 let specs = vec![
485 FileSpec {
486 path: PathBuf::from("src/missing.rs"),
487 mode: FileMode::Full,
488 },
489 FileSpec {
490 path: PathBuf::from("src/exists.rs"),
491 mode: FileMode::Full,
492 },
493 ];
494 let ctx = assemble_context(&specs, dir.path(), &PrefillConfig::default());
495 assert_eq!(ctx.included_files.len(), 1);
496 assert_eq!(ctx.included_files[0], PathBuf::from("src/exists.rs"));
497 assert!(ctx.warnings.iter().any(|w| w.contains("missing.rs")));
498 }
499
500 #[test]
501 fn test_context_prefill_per_file_budget() {
502 let big_content: String = (0..200)
504 .map(|i| format!("line {i}: some content here\n"))
505 .collect();
506 let dir = temp_dir_with_files(&[("big.rs", &big_content)]);
507 let specs = vec![FileSpec {
508 path: PathBuf::from("big.rs"),
509 mode: FileMode::Full,
510 }];
511 let config = PrefillConfig {
512 budget_tokens: 100_000,
513 per_file_tokens: 100, ..PrefillConfig::default()
515 };
516 let ctx = assemble_context(&specs, dir.path(), &config);
517 assert_eq!(ctx.included_files.len(), 1);
518 assert!(ctx.warnings.iter().any(|w| w.contains("truncated")));
519 let text = message_text(&ctx.messages[0]);
520 assert!(text.contains("[... truncated:"));
521 }
522
523 #[test]
524 fn test_context_prefill_total_budget() {
525 let content_a: String = (0..200)
527 .map(|i| format!("line_a_{i}: some padding content here\n"))
528 .collect();
529 let content_b: String = (0..200)
530 .map(|i| format!("line_b_{i}: some padding content here\n"))
531 .collect();
532 let dir = temp_dir_with_files(&[("a.rs", &content_a), ("b.rs", &content_b)]);
533 let specs = vec![
534 FileSpec {
535 path: PathBuf::from("a.rs"),
536 mode: FileMode::Full,
537 },
538 FileSpec {
539 path: PathBuf::from("b.rs"),
540 mode: FileMode::Full,
541 },
542 ];
543 let config = PrefillConfig {
544 budget_tokens: 2500, per_file_tokens: 50_000,
546 ..PrefillConfig::default()
547 };
548 let ctx = assemble_context(&specs, dir.path(), &config);
549 assert_eq!(
551 ctx.included_files.len(),
552 1,
553 "included: {:?}, warnings: {:?}",
554 ctx.included_files,
555 ctx.warnings
556 );
557 assert!(ctx
558 .warnings
559 .iter()
560 .any(|w| w.contains("b.rs") && w.contains("budget")));
561 }
562
563 #[test]
564 fn test_context_prefill_tail_mode() {
565 let content = "line 1\nline 2\nline 3\nline 4\nline 5\n";
566 let dir = temp_dir_with_files(&[("f.rs", content)]);
567 let specs = vec![FileSpec {
568 path: PathBuf::from("f.rs"),
569 mode: FileMode::Tail(3),
570 }];
571 let ctx = assemble_context(&specs, dir.path(), &PrefillConfig::default());
572 let text = message_text(&ctx.messages[0]);
573 assert!(!text.contains("line 1"));
574 assert!(!text.contains("line 2"));
575 assert!(text.contains("line 3"));
576 assert!(text.contains("line 4"));
577 assert!(text.contains("line 5"));
578 }
579
580 #[test]
581 fn test_context_prefill_range_mode() {
582 let content = "line 1\nline 2\nline 3\nline 4\nline 5\n";
583 let dir = temp_dir_with_files(&[("f.rs", content)]);
584 let specs = vec![FileSpec {
585 path: PathBuf::from("f.rs"),
586 mode: FileMode::Range(2, 4),
587 }];
588 let ctx = assemble_context(&specs, dir.path(), &PrefillConfig::default());
589 let text = message_text(&ctx.messages[0]);
590 assert!(!text.contains("line 1"));
591 assert!(text.contains("line 2"));
592 assert!(text.contains("line 3"));
593 assert!(text.contains("line 4"));
594 assert!(!text.contains("line 5"));
595 }
596
597 #[test]
598 fn test_context_prefill_empty_specs() {
599 let dir = tempfile::tempdir().unwrap();
600 let ctx = assemble_context(&[], dir.path(), &PrefillConfig::default());
601 assert!(ctx.messages.is_empty());
602 assert!(ctx.included_files.is_empty());
603 assert_eq!(ctx.estimated_tokens, 0);
604 }
605
606 #[test]
609 fn test_context_prefill_detect_paths() {
610 let text = "Modify src/auth.rs and read crates/imp-llm/src/provider.rs for context.";
611 let specs = detect_file_paths(text);
612 let paths: Vec<_> = specs.iter().map(|s| s.path.to_str().unwrap()).collect();
613 assert!(paths.contains(&"src/auth.rs"));
614 assert!(paths.contains(&"crates/imp-llm/src/provider.rs"));
615 }
616
617 #[test]
618 fn test_context_prefill_detect_deduplicates() {
619 let text = "Read src/foo.rs first, then modify src/foo.rs to add the function.";
620 let specs = detect_file_paths(text);
621 let foo_count = specs
622 .iter()
623 .filter(|s| s.path == std::path::Path::new("src/foo.rs"))
624 .count();
625 assert_eq!(foo_count, 1);
626 }
627
628 #[test]
629 fn test_context_prefill_detect_ignores_non_paths() {
630 let text = "Handle errors gracefully. The users table needs updating.";
631 let specs = detect_file_paths(text);
632 assert!(specs.is_empty(), "got: {:?}", specs);
634 }
635
636 #[test]
637 fn test_context_prefill_detect_tail_suffix() {
638 let text = "Check patterns in tests/auth_test.rs:tail:50 for reference.";
639 let specs = detect_file_paths(text);
640 assert_eq!(specs.len(), 1);
641 assert_eq!(specs[0].path, PathBuf::from("tests/auth_test.rs"));
642 assert_eq!(specs[0].mode, FileMode::Tail(50));
643 }
644
645 #[test]
646 fn test_context_prefill_detect_range_suffix() {
647 let text = "See src/lib.rs:10-50 for the relevant types.";
648 let specs = detect_file_paths(text);
649 assert_eq!(specs.len(), 1);
650 assert_eq!(specs[0].path, PathBuf::from("src/lib.rs"));
651 assert_eq!(specs[0].mode, FileMode::Range(10, 50));
652 }
653
654 fn message_text(msg: &Message) -> String {
657 match msg {
658 Message::User(u) => u
659 .content
660 .iter()
661 .filter_map(|b| match b {
662 ContentBlock::Text { text } => Some(text.as_str()),
663 _ => None,
664 })
665 .collect::<Vec<_>>()
666 .join(""),
667 _ => String::new(),
668 }
669 }
670}