Skip to main content

imp_core/
context_prefill.rs

1//! Context prefill assembly for mana dispatch.
2//!
3//! When `imp run <unit_id>` dispatches an agent, the unit description often
4//! references files the agent will need. Instead of making the agent spend
5//! turns reading those files, we assemble them into a cached prefix message
6//! that precedes the task prompt.
7//!
8//! The assembled context gets `cache_control` breakpoints so every subsequent
9//! turn in the agent's session gets `cache_read` on the file contents — no
10//! re-transmission cost.
11
12use std::collections::HashSet;
13use std::path::{Path, PathBuf};
14
15use imp_llm::message::{ContentBlock, Message, UserMessage};
16
17use crate::trust::{Provenance, TrustedContext};
18
19// ---------------------------------------------------------------------------
20// Types
21// ---------------------------------------------------------------------------
22
23/// How to extract content from a file.
24#[derive(Debug, Clone, PartialEq, Eq)]
25pub enum FileMode {
26    /// Include the entire file (up to per-file budget).
27    Full,
28    /// Include only the last N lines.
29    Tail(usize),
30    /// Include a specific line range (1-indexed, inclusive).
31    Range(usize, usize),
32}
33
34/// A file to include in the prefill context.
35#[derive(Debug, Clone)]
36pub struct FileSpec {
37    pub path: PathBuf,
38    pub mode: FileMode,
39}
40
41/// Configuration for context assembly.
42#[derive(Debug, Clone)]
43pub struct PrefillConfig {
44    /// Max total estimated tokens for all assembled context. Default: 50_000.
45    pub budget_tokens: usize,
46    /// Max estimated tokens per individual file. Default: 10_000.
47    pub per_file_tokens: usize,
48    /// Render compact trust/provenance labels in prompt context blocks.
49    pub annotate_trust: bool,
50}
51
52impl Default for PrefillConfig {
53    fn default() -> Self {
54        Self {
55            budget_tokens: 50_000,
56            per_file_tokens: 10_000,
57            annotate_trust: false,
58        }
59    }
60}
61
62/// Result of context assembly.
63#[derive(Debug)]
64pub struct AssembledContext {
65    /// Messages to inject before the first prompt.
66    pub messages: Vec<Message>,
67    /// Files that were successfully included.
68    pub included_files: Vec<PathBuf>,
69    /// Warnings (missing files, truncations, budget exceeded).
70    pub warnings: Vec<String>,
71    /// Provenance for successfully included files. Prompt text is unchanged; metadata is internal.
72    pub provenance: Vec<TrustedContext<PathBuf>>,
73    /// Estimated token count of assembled context.
74    pub estimated_tokens: usize,
75}
76
77impl AssembledContext {
78    /// An empty context (no files, no messages).
79    pub fn empty() -> Self {
80        Self {
81            messages: Vec::new(),
82            included_files: Vec::new(),
83            warnings: Vec::new(),
84            provenance: Vec::new(),
85            estimated_tokens: 0,
86        }
87    }
88}
89
90// ---------------------------------------------------------------------------
91// Token estimation
92// ---------------------------------------------------------------------------
93
94/// Rough token estimate: 1 token ≈ 4 characters.
95fn estimate_tokens(text: &str) -> usize {
96    text.len() / 4
97}
98
99/// Character budget from a token budget.
100fn chars_from_tokens(tokens: usize) -> usize {
101    tokens * 4
102}
103
104// ---------------------------------------------------------------------------
105// File reading with mode application
106// ---------------------------------------------------------------------------
107
108/// Read a file and apply the extraction mode.
109fn read_file_with_mode(path: &Path, mode: &FileMode) -> Result<String, std::io::Error> {
110    let content = std::fs::read_to_string(path)?;
111    Ok(match mode {
112        FileMode::Full => content,
113        FileMode::Tail(n) => {
114            let lines: Vec<&str> = content.lines().collect();
115            let start = lines.len().saturating_sub(*n);
116            lines[start..].join("\n")
117        }
118        FileMode::Range(start, end) => {
119            let lines: Vec<&str> = content.lines().collect();
120            let s = start.saturating_sub(1); // 1-indexed → 0-indexed
121            let e = (*end).min(lines.len());
122            if s >= lines.len() {
123                String::new()
124            } else {
125                lines[s..e].join("\n")
126            }
127        }
128    })
129}
130
131/// Truncate content to fit within a character budget, appending a note.
132fn truncate_to_budget(content: &str, max_chars: usize) -> (String, bool) {
133    if content.len() <= max_chars {
134        return (content.to_string(), false);
135    }
136    let total_lines = content.lines().count();
137    // Find a line boundary near the budget
138    let mut end = 0;
139    for (i, _) in content.char_indices() {
140        if i > max_chars {
141            break;
142        }
143        end = i;
144    }
145    // Back up to the last newline
146    if let Some(nl) = content[..end].rfind('\n') {
147        end = nl;
148    }
149    let truncated_lines = content[..end].lines().count();
150    let mut result = content[..end].to_string();
151    result.push_str(&format!(
152        "\n[... truncated: showing {truncated_lines} of {total_lines} lines]"
153    ));
154    (result, true)
155}
156
157// ---------------------------------------------------------------------------
158// Assembly
159// ---------------------------------------------------------------------------
160
161/// Assemble context from file specs, reading from disk and respecting budgets.
162///
163/// Produces a single `Message::User` containing all file contents in an XML
164/// structure. Returns an empty context if no files were successfully read.
165pub fn assemble_context(
166    specs: &[FileSpec],
167    cwd: &Path,
168    config: &PrefillConfig,
169) -> AssembledContext {
170    if specs.is_empty() {
171        return AssembledContext::empty();
172    }
173
174    let mut included_files = Vec::new();
175    let mut warnings = Vec::new();
176    let mut file_sections = Vec::new();
177    let mut total_chars: usize = 0;
178    let char_budget = chars_from_tokens(config.budget_tokens);
179    let per_file_char_budget = chars_from_tokens(config.per_file_tokens);
180
181    // Overhead for XML wrapper: <context>\n...\n</context>
182    let wrapper_overhead = "<context>\n</context>".len();
183    total_chars += wrapper_overhead;
184
185    for spec in specs {
186        let resolved = if spec.path.is_absolute() {
187            spec.path.clone()
188        } else {
189            cwd.join(&spec.path)
190        };
191
192        // Read the file
193        let content = match read_file_with_mode(&resolved, &spec.mode) {
194            Ok(c) => c,
195            Err(e) => {
196                warnings.push(format!("{}: {e}", spec.path.display()));
197                continue;
198            }
199        };
200
201        if content.is_empty() {
202            continue;
203        }
204
205        // Build the section XML
206        let mode_note = match &spec.mode {
207            FileMode::Full => String::new(),
208            FileMode::Tail(n) => format!(r#" note="last {n} lines""#),
209            FileMode::Range(s, e) => format!(r#" note="lines {s}-{e}""#),
210        };
211        let trust_attr = if config.annotate_trust {
212            r#" trust="workspace:file""#
213        } else {
214            ""
215        };
216        let header = format!(
217            r#"<file path="{}"{}{}>"#,
218            spec.path.display(),
219            trust_attr,
220            mode_note
221        );
222        let footer = "</file>";
223        let section_overhead = header.len() + footer.len() + 2; // newlines
224
225        // Check per-file budget
226        let (file_content, was_truncated) = truncate_to_budget(
227            &content,
228            per_file_char_budget.saturating_sub(section_overhead),
229        );
230        if was_truncated {
231            warnings.push(format!(
232                "{}: truncated to ~{} tokens (per-file budget)",
233                spec.path.display(),
234                config.per_file_tokens,
235            ));
236        }
237
238        let section = format!("{header}\n{file_content}\n{footer}");
239        let section_chars = section.len();
240
241        // Check total budget
242        if total_chars + section_chars > char_budget {
243            warnings.push(format!(
244                "{}: skipped (total budget of ~{} tokens exceeded)",
245                spec.path.display(),
246                config.budget_tokens,
247            ));
248            // Skip remaining files too
249            for remaining in specs.iter().skip(included_files.len() + warnings.len()) {
250                // Only warn for specs we haven't processed yet
251                if !included_files.contains(&remaining.path) {
252                    warnings.push(format!(
253                        "{}: skipped (total budget exceeded)",
254                        remaining.path.display(),
255                    ));
256                }
257            }
258            break;
259        }
260
261        total_chars += section_chars;
262        file_sections.push(section);
263        included_files.push(spec.path.clone());
264    }
265
266    if file_sections.is_empty() {
267        return AssembledContext {
268            messages: Vec::new(),
269            included_files,
270            warnings,
271            provenance: Vec::new(),
272            estimated_tokens: 0,
273        };
274    }
275
276    let xml = format!("<context>\n{}\n</context>", file_sections.join("\n"));
277    let estimated_tokens = estimate_tokens(&xml);
278
279    let message = Message::User(UserMessage {
280        content: vec![ContentBlock::Text { text: xml }],
281        timestamp: imp_llm::now(),
282    });
283
284    let provenance = included_files
285        .iter()
286        .cloned()
287        .map(|path| TrustedContext::new(path.clone(), Provenance::workspace_file(path)))
288        .collect();
289
290    AssembledContext {
291        messages: vec![message],
292        included_files,
293        warnings,
294        provenance,
295        estimated_tokens,
296    }
297}
298
299// ---------------------------------------------------------------------------
300// File path detection
301// ---------------------------------------------------------------------------
302
303/// Auto-detect file paths from a unit description string.
304///
305/// Scans for patterns that look like source file paths (e.g., `src/foo.rs`,
306/// `crates/bar/baz.ts`). Supports optional mode suffixes:
307/// - `path.rs:tail:50` → `Tail(50)`
308/// - `path.rs:10-50` → `Range(10, 50)`
309///
310/// Deduplicates by path (first occurrence wins).
311pub fn detect_file_paths(text: &str) -> Vec<FileSpec> {
312    // Match sequences that look like file paths with known extensions.
313    // The negative lookbehind-like logic is handled by checking the char
314    // before the match.
315    let extensions = [
316        "rs", "ts", "tsx", "py", "go", "js", "jsx", "toml", "yaml", "yml", "json", "md", "sh",
317        "sql", "zig", "c", "cpp", "h",
318    ];
319    let ext_pattern = extensions.join("|");
320    let pattern = format!(
321        r#"(?:^|[\s(`"'(])((?:[a-zA-Z_./])[a-zA-Z0-9_./-]*\.(?:{ext_pattern}))(?::([^\s)}}"'`]*))?"#,
322    );
323    let re = regex::Regex::new(&pattern).expect("valid regex");
324
325    let mut seen = HashSet::new();
326    let mut specs = Vec::new();
327
328    for cap in re.captures_iter(text) {
329        let path_str = cap.get(1).map(|m| m.as_str()).unwrap_or("");
330        if path_str.is_empty() {
331            continue;
332        }
333
334        let path = PathBuf::from(path_str);
335        if seen.contains(&path) {
336            continue;
337        }
338        seen.insert(path.clone());
339
340        let mode = cap
341            .get(2)
342            .map(|m| parse_mode_suffix(m.as_str()))
343            .unwrap_or(FileMode::Full);
344
345        specs.push(FileSpec { path, mode });
346    }
347
348    specs
349}
350
351/// Parse a mode suffix string into a FileMode.
352fn parse_mode_suffix(suffix: &str) -> FileMode {
353    // tail:N
354    if let Some(n_str) = suffix.strip_prefix("tail:") {
355        if let Ok(n) = n_str.parse::<usize>() {
356            return FileMode::Tail(n);
357        }
358    }
359    // N-M (line range)
360    if let Some(dash_pos) = suffix.find('-') {
361        let start_str = &suffix[..dash_pos];
362        let end_str = &suffix[dash_pos + 1..];
363        if let (Ok(start), Ok(end)) = (start_str.parse::<usize>(), end_str.parse::<usize>()) {
364            return FileMode::Range(start, end);
365        }
366    }
367    FileMode::Full
368}
369
370// ---------------------------------------------------------------------------
371// Tests
372// ---------------------------------------------------------------------------
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377    use std::fs;
378
379    fn temp_dir_with_files(files: &[(&str, &str)]) -> tempfile::TempDir {
380        let dir = tempfile::tempdir().unwrap();
381        for (name, content) in files {
382            let path = dir.path().join(name);
383            if let Some(parent) = path.parent() {
384                fs::create_dir_all(parent).unwrap();
385            }
386            fs::write(path, content).unwrap();
387        }
388        dir
389    }
390
391    // -- Assembly tests --
392
393    #[test]
394    fn prompt_context_trust_annotations_are_configurable() {
395        let dir = temp_dir_with_files(&[("README.md", "hello")]);
396        let specs = vec![FileSpec {
397            path: PathBuf::from("README.md"),
398            mode: FileMode::Full,
399        }];
400
401        let unannotated = assemble_context(&specs, dir.path(), &PrefillConfig::default());
402        let text = message_text(&unannotated.messages[0]);
403        assert!(text.contains(r#"<file path="README.md">"#));
404        assert!(!text.contains("trust="));
405
406        let annotated = assemble_context(
407            &specs,
408            dir.path(),
409            &PrefillConfig {
410                annotate_trust: true,
411                ..PrefillConfig::default()
412            },
413        );
414        let text = message_text(&annotated.messages[0]);
415        assert!(text.contains(r#"<file path="README.md" trust="workspace:file">"#));
416    }
417
418    #[test]
419    fn test_context_prefill_records_workspace_file_provenance() {
420        let dir = temp_dir_with_files(&[("README.md", "hello")]);
421        let specs = vec![FileSpec {
422            path: PathBuf::from("README.md"),
423            mode: FileMode::Full,
424        }];
425
426        let assembled = assemble_context(&specs, dir.path(), &PrefillConfig::default());
427
428        assert_eq!(assembled.provenance.len(), 1);
429        assert_eq!(assembled.provenance[0].value, PathBuf::from("README.md"));
430        assert_eq!(
431            assembled.provenance[0].provenance.trust,
432            crate::trust::TrustLabel::ProjectTrusted
433        );
434        assert!(matches!(
435            assembled.provenance[0].provenance.source,
436            crate::trust::ProvenanceSource::WorkspaceFile { .. }
437        ));
438    }
439
440    #[test]
441    fn test_context_prefill_assembles_single_file() {
442        let dir =
443            temp_dir_with_files(&[("src/main.rs", "fn main() {\n    println!(\"hello\");\n}")]);
444        let specs = vec![FileSpec {
445            path: PathBuf::from("src/main.rs"),
446            mode: FileMode::Full,
447        }];
448        let ctx = assemble_context(&specs, dir.path(), &PrefillConfig::default());
449        assert_eq!(ctx.included_files.len(), 1);
450        assert!(ctx.warnings.is_empty());
451        assert!(!ctx.messages.is_empty());
452
453        let text = message_text(&ctx.messages[0]);
454        assert!(text.contains("<context>"));
455        assert!(text.contains(r#"<file path="src/main.rs">"#));
456        assert!(text.contains("fn main()"));
457        assert!(text.contains("</file>"));
458        assert!(text.contains("</context>"));
459    }
460
461    #[test]
462    fn test_context_prefill_multiple_files() {
463        let dir = temp_dir_with_files(&[("src/a.rs", "struct A;"), ("src/b.rs", "struct B;")]);
464        let specs = vec![
465            FileSpec {
466                path: PathBuf::from("src/a.rs"),
467                mode: FileMode::Full,
468            },
469            FileSpec {
470                path: PathBuf::from("src/b.rs"),
471                mode: FileMode::Full,
472            },
473        ];
474        let ctx = assemble_context(&specs, dir.path(), &PrefillConfig::default());
475        assert_eq!(ctx.included_files.len(), 2);
476        let text = message_text(&ctx.messages[0]);
477        assert!(text.contains("struct A"));
478        assert!(text.contains("struct B"));
479    }
480
481    #[test]
482    fn test_context_prefill_missing_file_warning() {
483        let dir = temp_dir_with_files(&[("src/exists.rs", "exists")]);
484        let specs = vec![
485            FileSpec {
486                path: PathBuf::from("src/missing.rs"),
487                mode: FileMode::Full,
488            },
489            FileSpec {
490                path: PathBuf::from("src/exists.rs"),
491                mode: FileMode::Full,
492            },
493        ];
494        let ctx = assemble_context(&specs, dir.path(), &PrefillConfig::default());
495        assert_eq!(ctx.included_files.len(), 1);
496        assert_eq!(ctx.included_files[0], PathBuf::from("src/exists.rs"));
497        assert!(ctx.warnings.iter().any(|w| w.contains("missing.rs")));
498    }
499
500    #[test]
501    fn test_context_prefill_per_file_budget() {
502        // Create a file that's larger than 100 tokens (~400 chars)
503        let big_content: String = (0..200)
504            .map(|i| format!("line {i}: some content here\n"))
505            .collect();
506        let dir = temp_dir_with_files(&[("big.rs", &big_content)]);
507        let specs = vec![FileSpec {
508            path: PathBuf::from("big.rs"),
509            mode: FileMode::Full,
510        }];
511        let config = PrefillConfig {
512            budget_tokens: 100_000,
513            per_file_tokens: 100, // ~400 chars — file will be truncated
514            ..PrefillConfig::default()
515        };
516        let ctx = assemble_context(&specs, dir.path(), &config);
517        assert_eq!(ctx.included_files.len(), 1);
518        assert!(ctx.warnings.iter().any(|w| w.contains("truncated")));
519        let text = message_text(&ctx.messages[0]);
520        assert!(text.contains("[... truncated:"));
521    }
522
523    #[test]
524    fn test_context_prefill_total_budget() {
525        // Each file is ~4000 chars (~1000 tokens). Set budget to fit only one.
526        let content_a: String = (0..200)
527            .map(|i| format!("line_a_{i}: some padding content here\n"))
528            .collect();
529        let content_b: String = (0..200)
530            .map(|i| format!("line_b_{i}: some padding content here\n"))
531            .collect();
532        let dir = temp_dir_with_files(&[("a.rs", &content_a), ("b.rs", &content_b)]);
533        let specs = vec![
534            FileSpec {
535                path: PathBuf::from("a.rs"),
536                mode: FileMode::Full,
537            },
538            FileSpec {
539                path: PathBuf::from("b.rs"),
540                mode: FileMode::Full,
541            },
542        ];
543        let config = PrefillConfig {
544            budget_tokens: 2500, // ~10000 chars — first file + XML wrapper fits, second doesn't
545            per_file_tokens: 50_000,
546            ..PrefillConfig::default()
547        };
548        let ctx = assemble_context(&specs, dir.path(), &config);
549        // First file should be included, second skipped
550        assert_eq!(
551            ctx.included_files.len(),
552            1,
553            "included: {:?}, warnings: {:?}",
554            ctx.included_files,
555            ctx.warnings
556        );
557        assert!(ctx
558            .warnings
559            .iter()
560            .any(|w| w.contains("b.rs") && w.contains("budget")));
561    }
562
563    #[test]
564    fn test_context_prefill_tail_mode() {
565        let content = "line 1\nline 2\nline 3\nline 4\nline 5\n";
566        let dir = temp_dir_with_files(&[("f.rs", content)]);
567        let specs = vec![FileSpec {
568            path: PathBuf::from("f.rs"),
569            mode: FileMode::Tail(3),
570        }];
571        let ctx = assemble_context(&specs, dir.path(), &PrefillConfig::default());
572        let text = message_text(&ctx.messages[0]);
573        assert!(!text.contains("line 1"));
574        assert!(!text.contains("line 2"));
575        assert!(text.contains("line 3"));
576        assert!(text.contains("line 4"));
577        assert!(text.contains("line 5"));
578    }
579
580    #[test]
581    fn test_context_prefill_range_mode() {
582        let content = "line 1\nline 2\nline 3\nline 4\nline 5\n";
583        let dir = temp_dir_with_files(&[("f.rs", content)]);
584        let specs = vec![FileSpec {
585            path: PathBuf::from("f.rs"),
586            mode: FileMode::Range(2, 4),
587        }];
588        let ctx = assemble_context(&specs, dir.path(), &PrefillConfig::default());
589        let text = message_text(&ctx.messages[0]);
590        assert!(!text.contains("line 1"));
591        assert!(text.contains("line 2"));
592        assert!(text.contains("line 3"));
593        assert!(text.contains("line 4"));
594        assert!(!text.contains("line 5"));
595    }
596
597    #[test]
598    fn test_context_prefill_empty_specs() {
599        let dir = tempfile::tempdir().unwrap();
600        let ctx = assemble_context(&[], dir.path(), &PrefillConfig::default());
601        assert!(ctx.messages.is_empty());
602        assert!(ctx.included_files.is_empty());
603        assert_eq!(ctx.estimated_tokens, 0);
604    }
605
606    // -- Detection tests --
607
608    #[test]
609    fn test_context_prefill_detect_paths() {
610        let text = "Modify src/auth.rs and read crates/imp-llm/src/provider.rs for context.";
611        let specs = detect_file_paths(text);
612        let paths: Vec<_> = specs.iter().map(|s| s.path.to_str().unwrap()).collect();
613        assert!(paths.contains(&"src/auth.rs"));
614        assert!(paths.contains(&"crates/imp-llm/src/provider.rs"));
615    }
616
617    #[test]
618    fn test_context_prefill_detect_deduplicates() {
619        let text = "Read src/foo.rs first, then modify src/foo.rs to add the function.";
620        let specs = detect_file_paths(text);
621        let foo_count = specs
622            .iter()
623            .filter(|s| s.path == std::path::Path::new("src/foo.rs"))
624            .count();
625        assert_eq!(foo_count, 1);
626    }
627
628    #[test]
629    fn test_context_prefill_detect_ignores_non_paths() {
630        let text = "Handle errors gracefully. The users table needs updating.";
631        let specs = detect_file_paths(text);
632        // "errors" and "users" shouldn't match — they don't have path-like structure + extension
633        assert!(specs.is_empty(), "got: {:?}", specs);
634    }
635
636    #[test]
637    fn test_context_prefill_detect_tail_suffix() {
638        let text = "Check patterns in tests/auth_test.rs:tail:50 for reference.";
639        let specs = detect_file_paths(text);
640        assert_eq!(specs.len(), 1);
641        assert_eq!(specs[0].path, PathBuf::from("tests/auth_test.rs"));
642        assert_eq!(specs[0].mode, FileMode::Tail(50));
643    }
644
645    #[test]
646    fn test_context_prefill_detect_range_suffix() {
647        let text = "See src/lib.rs:10-50 for the relevant types.";
648        let specs = detect_file_paths(text);
649        assert_eq!(specs.len(), 1);
650        assert_eq!(specs[0].path, PathBuf::from("src/lib.rs"));
651        assert_eq!(specs[0].mode, FileMode::Range(10, 50));
652    }
653
654    // -- Helpers --
655
656    fn message_text(msg: &Message) -> String {
657        match msg {
658            Message::User(u) => u
659                .content
660                .iter()
661                .filter_map(|b| match b {
662                    ContentBlock::Text { text } => Some(text.as_str()),
663                    _ => None,
664                })
665                .collect::<Vec<_>>()
666                .join(""),
667            _ => String::new(),
668        }
669    }
670}