Skip to main content

imp_core/
context_prefill.rs

1//! Context prefill assembly for mana dispatch.
2//!
3//! When `imp run <unit_id>` dispatches an agent, the unit description often
4//! references files the agent will need. Instead of making the agent spend
5//! turns reading those files, we assemble them into a cached prefix message
6//! that precedes the task prompt.
7//!
8//! The assembled context gets `cache_control` breakpoints so every subsequent
9//! turn in the agent's session gets `cache_read` on the file contents — no
10//! re-transmission cost.
11
12use std::collections::HashSet;
13use std::path::{Path, PathBuf};
14
15use imp_llm::message::{ContentBlock, Message, UserMessage};
16
17// ---------------------------------------------------------------------------
18// Types
19// ---------------------------------------------------------------------------
20
21/// How to extract content from a file.
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub enum FileMode {
24    /// Include the entire file (up to per-file budget).
25    Full,
26    /// Include only the last N lines.
27    Tail(usize),
28    /// Include a specific line range (1-indexed, inclusive).
29    Range(usize, usize),
30}
31
32/// A file to include in the prefill context.
33#[derive(Debug, Clone)]
34pub struct FileSpec {
35    pub path: PathBuf,
36    pub mode: FileMode,
37}
38
39/// Configuration for context assembly.
40#[derive(Debug, Clone)]
41pub struct PrefillConfig {
42    /// Max total estimated tokens for all assembled context. Default: 50_000.
43    pub budget_tokens: usize,
44    /// Max estimated tokens per individual file. Default: 10_000.
45    pub per_file_tokens: usize,
46}
47
48impl Default for PrefillConfig {
49    fn default() -> Self {
50        Self {
51            budget_tokens: 50_000,
52            per_file_tokens: 10_000,
53        }
54    }
55}
56
57/// Result of context assembly.
58#[derive(Debug)]
59pub struct AssembledContext {
60    /// Messages to inject before the first prompt.
61    pub messages: Vec<Message>,
62    /// Files that were successfully included.
63    pub included_files: Vec<PathBuf>,
64    /// Warnings (missing files, truncations, budget exceeded).
65    pub warnings: Vec<String>,
66    /// Estimated token count of assembled context.
67    pub estimated_tokens: usize,
68}
69
70impl AssembledContext {
71    /// An empty context (no files, no messages).
72    pub fn empty() -> Self {
73        Self {
74            messages: Vec::new(),
75            included_files: Vec::new(),
76            warnings: Vec::new(),
77            estimated_tokens: 0,
78        }
79    }
80}
81
82// ---------------------------------------------------------------------------
83// Token estimation
84// ---------------------------------------------------------------------------
85
86/// Rough token estimate: 1 token ≈ 4 characters.
87fn estimate_tokens(text: &str) -> usize {
88    text.len() / 4
89}
90
91/// Character budget from a token budget.
92fn chars_from_tokens(tokens: usize) -> usize {
93    tokens * 4
94}
95
96// ---------------------------------------------------------------------------
97// File reading with mode application
98// ---------------------------------------------------------------------------
99
100/// Read a file and apply the extraction mode.
101fn read_file_with_mode(path: &Path, mode: &FileMode) -> Result<String, std::io::Error> {
102    let content = std::fs::read_to_string(path)?;
103    Ok(match mode {
104        FileMode::Full => content,
105        FileMode::Tail(n) => {
106            let lines: Vec<&str> = content.lines().collect();
107            let start = lines.len().saturating_sub(*n);
108            lines[start..].join("\n")
109        }
110        FileMode::Range(start, end) => {
111            let lines: Vec<&str> = content.lines().collect();
112            let s = start.saturating_sub(1); // 1-indexed → 0-indexed
113            let e = (*end).min(lines.len());
114            if s >= lines.len() {
115                String::new()
116            } else {
117                lines[s..e].join("\n")
118            }
119        }
120    })
121}
122
123/// Truncate content to fit within a character budget, appending a note.
124fn truncate_to_budget(content: &str, max_chars: usize) -> (String, bool) {
125    if content.len() <= max_chars {
126        return (content.to_string(), false);
127    }
128    let total_lines = content.lines().count();
129    // Find a line boundary near the budget
130    let mut end = 0;
131    for (i, _) in content.char_indices() {
132        if i > max_chars {
133            break;
134        }
135        end = i;
136    }
137    // Back up to the last newline
138    if let Some(nl) = content[..end].rfind('\n') {
139        end = nl;
140    }
141    let truncated_lines = content[..end].lines().count();
142    let mut result = content[..end].to_string();
143    result.push_str(&format!(
144        "\n[... truncated: showing {truncated_lines} of {total_lines} lines]"
145    ));
146    (result, true)
147}
148
149// ---------------------------------------------------------------------------
150// Assembly
151// ---------------------------------------------------------------------------
152
153/// Assemble context from file specs, reading from disk and respecting budgets.
154///
155/// Produces a single `Message::User` containing all file contents in an XML
156/// structure. Returns an empty context if no files were successfully read.
157pub fn assemble_context(
158    specs: &[FileSpec],
159    cwd: &Path,
160    config: &PrefillConfig,
161) -> AssembledContext {
162    if specs.is_empty() {
163        return AssembledContext::empty();
164    }
165
166    let mut included_files = Vec::new();
167    let mut warnings = Vec::new();
168    let mut file_sections = Vec::new();
169    let mut total_chars: usize = 0;
170    let char_budget = chars_from_tokens(config.budget_tokens);
171    let per_file_char_budget = chars_from_tokens(config.per_file_tokens);
172
173    // Overhead for XML wrapper: <context>\n...\n</context>
174    let wrapper_overhead = "<context>\n</context>".len();
175    total_chars += wrapper_overhead;
176
177    for spec in specs {
178        let resolved = if spec.path.is_absolute() {
179            spec.path.clone()
180        } else {
181            cwd.join(&spec.path)
182        };
183
184        // Read the file
185        let content = match read_file_with_mode(&resolved, &spec.mode) {
186            Ok(c) => c,
187            Err(e) => {
188                warnings.push(format!("{}: {e}", spec.path.display()));
189                continue;
190            }
191        };
192
193        if content.is_empty() {
194            continue;
195        }
196
197        // Build the section XML
198        let mode_note = match &spec.mode {
199            FileMode::Full => String::new(),
200            FileMode::Tail(n) => format!(r#" note="last {n} lines""#),
201            FileMode::Range(s, e) => format!(r#" note="lines {s}-{e}""#),
202        };
203        let header = format!(r#"<file path="{}"{}>"#, spec.path.display(), mode_note);
204        let footer = "</file>";
205        let section_overhead = header.len() + footer.len() + 2; // newlines
206
207        // Check per-file budget
208        let (file_content, was_truncated) = truncate_to_budget(
209            &content,
210            per_file_char_budget.saturating_sub(section_overhead),
211        );
212        if was_truncated {
213            warnings.push(format!(
214                "{}: truncated to ~{} tokens (per-file budget)",
215                spec.path.display(),
216                config.per_file_tokens,
217            ));
218        }
219
220        let section = format!("{header}\n{file_content}\n{footer}");
221        let section_chars = section.len();
222
223        // Check total budget
224        if total_chars + section_chars > char_budget {
225            warnings.push(format!(
226                "{}: skipped (total budget of ~{} tokens exceeded)",
227                spec.path.display(),
228                config.budget_tokens,
229            ));
230            // Skip remaining files too
231            for remaining in specs.iter().skip(included_files.len() + warnings.len()) {
232                // Only warn for specs we haven't processed yet
233                if !included_files.contains(&remaining.path) {
234                    warnings.push(format!(
235                        "{}: skipped (total budget exceeded)",
236                        remaining.path.display(),
237                    ));
238                }
239            }
240            break;
241        }
242
243        total_chars += section_chars;
244        file_sections.push(section);
245        included_files.push(spec.path.clone());
246    }
247
248    if file_sections.is_empty() {
249        return AssembledContext {
250            messages: Vec::new(),
251            included_files,
252            warnings,
253            estimated_tokens: 0,
254        };
255    }
256
257    let xml = format!("<context>\n{}\n</context>", file_sections.join("\n"));
258    let estimated_tokens = estimate_tokens(&xml);
259
260    let message = Message::User(UserMessage {
261        content: vec![ContentBlock::Text { text: xml }],
262        timestamp: imp_llm::now(),
263    });
264
265    AssembledContext {
266        messages: vec![message],
267        included_files,
268        warnings,
269        estimated_tokens,
270    }
271}
272
273// ---------------------------------------------------------------------------
274// File path detection
275// ---------------------------------------------------------------------------
276
277/// Auto-detect file paths from a unit description string.
278///
279/// Scans for patterns that look like source file paths (e.g., `src/foo.rs`,
280/// `crates/bar/baz.ts`). Supports optional mode suffixes:
281/// - `path.rs:tail:50` → `Tail(50)`
282/// - `path.rs:10-50` → `Range(10, 50)`
283///
284/// Deduplicates by path (first occurrence wins).
285pub fn detect_file_paths(text: &str) -> Vec<FileSpec> {
286    // Match sequences that look like file paths with known extensions.
287    // The negative lookbehind-like logic is handled by checking the char
288    // before the match.
289    let extensions = [
290        "rs", "ts", "tsx", "py", "go", "js", "jsx", "toml", "yaml", "yml", "json", "md", "sh",
291        "sql", "zig", "c", "cpp", "h",
292    ];
293    let ext_pattern = extensions.join("|");
294    let pattern = format!(
295        r#"(?:^|[\s(`"'(])((?:[a-zA-Z_./])[a-zA-Z0-9_./-]*\.(?:{ext_pattern}))(?::([^\s)}}"'`]*))?"#,
296    );
297    let re = regex::Regex::new(&pattern).expect("valid regex");
298
299    let mut seen = HashSet::new();
300    let mut specs = Vec::new();
301
302    for cap in re.captures_iter(text) {
303        let path_str = cap.get(1).map(|m| m.as_str()).unwrap_or("");
304        if path_str.is_empty() {
305            continue;
306        }
307
308        let path = PathBuf::from(path_str);
309        if seen.contains(&path) {
310            continue;
311        }
312        seen.insert(path.clone());
313
314        let mode = cap
315            .get(2)
316            .map(|m| parse_mode_suffix(m.as_str()))
317            .unwrap_or(FileMode::Full);
318
319        specs.push(FileSpec { path, mode });
320    }
321
322    specs
323}
324
325/// Parse a mode suffix string into a FileMode.
326fn parse_mode_suffix(suffix: &str) -> FileMode {
327    // tail:N
328    if let Some(n_str) = suffix.strip_prefix("tail:") {
329        if let Ok(n) = n_str.parse::<usize>() {
330            return FileMode::Tail(n);
331        }
332    }
333    // N-M (line range)
334    if let Some(dash_pos) = suffix.find('-') {
335        let start_str = &suffix[..dash_pos];
336        let end_str = &suffix[dash_pos + 1..];
337        if let (Ok(start), Ok(end)) = (start_str.parse::<usize>(), end_str.parse::<usize>()) {
338            return FileMode::Range(start, end);
339        }
340    }
341    FileMode::Full
342}
343
344// ---------------------------------------------------------------------------
345// Tests
346// ---------------------------------------------------------------------------
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351    use std::fs;
352
353    fn temp_dir_with_files(files: &[(&str, &str)]) -> tempfile::TempDir {
354        let dir = tempfile::tempdir().unwrap();
355        for (name, content) in files {
356            let path = dir.path().join(name);
357            if let Some(parent) = path.parent() {
358                fs::create_dir_all(parent).unwrap();
359            }
360            fs::write(path, content).unwrap();
361        }
362        dir
363    }
364
365    // -- Assembly tests --
366
367    #[test]
368    fn test_context_prefill_assembles_single_file() {
369        let dir =
370            temp_dir_with_files(&[("src/main.rs", "fn main() {\n    println!(\"hello\");\n}")]);
371        let specs = vec![FileSpec {
372            path: PathBuf::from("src/main.rs"),
373            mode: FileMode::Full,
374        }];
375        let ctx = assemble_context(&specs, dir.path(), &PrefillConfig::default());
376        assert_eq!(ctx.included_files.len(), 1);
377        assert!(ctx.warnings.is_empty());
378        assert!(!ctx.messages.is_empty());
379
380        let text = message_text(&ctx.messages[0]);
381        assert!(text.contains("<context>"));
382        assert!(text.contains(r#"<file path="src/main.rs">"#));
383        assert!(text.contains("fn main()"));
384        assert!(text.contains("</file>"));
385        assert!(text.contains("</context>"));
386    }
387
388    #[test]
389    fn test_context_prefill_multiple_files() {
390        let dir = temp_dir_with_files(&[("src/a.rs", "struct A;"), ("src/b.rs", "struct B;")]);
391        let specs = vec![
392            FileSpec {
393                path: PathBuf::from("src/a.rs"),
394                mode: FileMode::Full,
395            },
396            FileSpec {
397                path: PathBuf::from("src/b.rs"),
398                mode: FileMode::Full,
399            },
400        ];
401        let ctx = assemble_context(&specs, dir.path(), &PrefillConfig::default());
402        assert_eq!(ctx.included_files.len(), 2);
403        let text = message_text(&ctx.messages[0]);
404        assert!(text.contains("struct A"));
405        assert!(text.contains("struct B"));
406    }
407
408    #[test]
409    fn test_context_prefill_missing_file_warning() {
410        let dir = temp_dir_with_files(&[("src/exists.rs", "exists")]);
411        let specs = vec![
412            FileSpec {
413                path: PathBuf::from("src/missing.rs"),
414                mode: FileMode::Full,
415            },
416            FileSpec {
417                path: PathBuf::from("src/exists.rs"),
418                mode: FileMode::Full,
419            },
420        ];
421        let ctx = assemble_context(&specs, dir.path(), &PrefillConfig::default());
422        assert_eq!(ctx.included_files.len(), 1);
423        assert_eq!(ctx.included_files[0], PathBuf::from("src/exists.rs"));
424        assert!(ctx.warnings.iter().any(|w| w.contains("missing.rs")));
425    }
426
427    #[test]
428    fn test_context_prefill_per_file_budget() {
429        // Create a file that's larger than 100 tokens (~400 chars)
430        let big_content: String = (0..200)
431            .map(|i| format!("line {i}: some content here\n"))
432            .collect();
433        let dir = temp_dir_with_files(&[("big.rs", &big_content)]);
434        let specs = vec![FileSpec {
435            path: PathBuf::from("big.rs"),
436            mode: FileMode::Full,
437        }];
438        let config = PrefillConfig {
439            budget_tokens: 100_000,
440            per_file_tokens: 100, // ~400 chars — file will be truncated
441        };
442        let ctx = assemble_context(&specs, dir.path(), &config);
443        assert_eq!(ctx.included_files.len(), 1);
444        assert!(ctx.warnings.iter().any(|w| w.contains("truncated")));
445        let text = message_text(&ctx.messages[0]);
446        assert!(text.contains("[... truncated:"));
447    }
448
449    #[test]
450    fn test_context_prefill_total_budget() {
451        // Each file is ~4000 chars (~1000 tokens). Set budget to fit only one.
452        let content_a: String = (0..200)
453            .map(|i| format!("line_a_{i}: some padding content here\n"))
454            .collect();
455        let content_b: String = (0..200)
456            .map(|i| format!("line_b_{i}: some padding content here\n"))
457            .collect();
458        let dir = temp_dir_with_files(&[("a.rs", &content_a), ("b.rs", &content_b)]);
459        let specs = vec![
460            FileSpec {
461                path: PathBuf::from("a.rs"),
462                mode: FileMode::Full,
463            },
464            FileSpec {
465                path: PathBuf::from("b.rs"),
466                mode: FileMode::Full,
467            },
468        ];
469        let config = PrefillConfig {
470            budget_tokens: 2500, // ~10000 chars — first file + XML wrapper fits, second doesn't
471            per_file_tokens: 50_000,
472        };
473        let ctx = assemble_context(&specs, dir.path(), &config);
474        // First file should be included, second skipped
475        assert_eq!(
476            ctx.included_files.len(),
477            1,
478            "included: {:?}, warnings: {:?}",
479            ctx.included_files,
480            ctx.warnings
481        );
482        assert!(ctx
483            .warnings
484            .iter()
485            .any(|w| w.contains("b.rs") && w.contains("budget")));
486    }
487
488    #[test]
489    fn test_context_prefill_tail_mode() {
490        let content = "line 1\nline 2\nline 3\nline 4\nline 5\n";
491        let dir = temp_dir_with_files(&[("f.rs", content)]);
492        let specs = vec![FileSpec {
493            path: PathBuf::from("f.rs"),
494            mode: FileMode::Tail(3),
495        }];
496        let ctx = assemble_context(&specs, dir.path(), &PrefillConfig::default());
497        let text = message_text(&ctx.messages[0]);
498        assert!(!text.contains("line 1"));
499        assert!(!text.contains("line 2"));
500        assert!(text.contains("line 3"));
501        assert!(text.contains("line 4"));
502        assert!(text.contains("line 5"));
503    }
504
505    #[test]
506    fn test_context_prefill_range_mode() {
507        let content = "line 1\nline 2\nline 3\nline 4\nline 5\n";
508        let dir = temp_dir_with_files(&[("f.rs", content)]);
509        let specs = vec![FileSpec {
510            path: PathBuf::from("f.rs"),
511            mode: FileMode::Range(2, 4),
512        }];
513        let ctx = assemble_context(&specs, dir.path(), &PrefillConfig::default());
514        let text = message_text(&ctx.messages[0]);
515        assert!(!text.contains("line 1"));
516        assert!(text.contains("line 2"));
517        assert!(text.contains("line 3"));
518        assert!(text.contains("line 4"));
519        assert!(!text.contains("line 5"));
520    }
521
522    #[test]
523    fn test_context_prefill_empty_specs() {
524        let dir = tempfile::tempdir().unwrap();
525        let ctx = assemble_context(&[], dir.path(), &PrefillConfig::default());
526        assert!(ctx.messages.is_empty());
527        assert!(ctx.included_files.is_empty());
528        assert_eq!(ctx.estimated_tokens, 0);
529    }
530
531    // -- Detection tests --
532
533    #[test]
534    fn test_context_prefill_detect_paths() {
535        let text = "Modify src/auth.rs and read crates/imp-llm/src/provider.rs for context.";
536        let specs = detect_file_paths(text);
537        let paths: Vec<_> = specs.iter().map(|s| s.path.to_str().unwrap()).collect();
538        assert!(paths.contains(&"src/auth.rs"));
539        assert!(paths.contains(&"crates/imp-llm/src/provider.rs"));
540    }
541
542    #[test]
543    fn test_context_prefill_detect_deduplicates() {
544        let text = "Read src/foo.rs first, then modify src/foo.rs to add the function.";
545        let specs = detect_file_paths(text);
546        let foo_count = specs
547            .iter()
548            .filter(|s| s.path == std::path::Path::new("src/foo.rs"))
549            .count();
550        assert_eq!(foo_count, 1);
551    }
552
553    #[test]
554    fn test_context_prefill_detect_ignores_non_paths() {
555        let text = "Handle errors gracefully. The users table needs updating.";
556        let specs = detect_file_paths(text);
557        // "errors" and "users" shouldn't match — they don't have path-like structure + extension
558        assert!(specs.is_empty(), "got: {:?}", specs);
559    }
560
561    #[test]
562    fn test_context_prefill_detect_tail_suffix() {
563        let text = "Check patterns in tests/auth_test.rs:tail:50 for reference.";
564        let specs = detect_file_paths(text);
565        assert_eq!(specs.len(), 1);
566        assert_eq!(specs[0].path, PathBuf::from("tests/auth_test.rs"));
567        assert_eq!(specs[0].mode, FileMode::Tail(50));
568    }
569
570    #[test]
571    fn test_context_prefill_detect_range_suffix() {
572        let text = "See src/lib.rs:10-50 for the relevant types.";
573        let specs = detect_file_paths(text);
574        assert_eq!(specs.len(), 1);
575        assert_eq!(specs[0].path, PathBuf::from("src/lib.rs"));
576        assert_eq!(specs[0].mode, FileMode::Range(10, 50));
577    }
578
579    // -- Helpers --
580
581    fn message_text(msg: &Message) -> String {
582        match msg {
583            Message::User(u) => u
584                .content
585                .iter()
586                .filter_map(|b| match b {
587                    ContentBlock::Text { text } => Some(text.as_str()),
588                    _ => None,
589                })
590                .collect::<Vec<_>>()
591                .join(""),
592            _ => String::new(),
593        }
594    }
595}