Skip to main content

bn/
ctx_assembler.rs

1use regex::Regex;
2use std::collections::HashSet;
3use std::fs;
4use std::io;
5use std::path::{Component, Path};
6use std::sync::LazyLock;
7
8// Compiled once, reused across all calls
9static PATH_REGEX: LazyLock<Regex> = LazyLock::new(|| {
10    // Match file paths with supported extensions (tsx and yml added)
11    Regex::new(r"([a-zA-Z0-9_.][a-zA-Z0-9_./\-]*\.(rs|tsx?|py|md|json|toml|ya?ml|sh|go|java))\b")
12        .expect("Invalid regex pattern")
13});
14
15/// Extracts file paths from a bean description using regex pattern matching.
16///
17/// Matches relative file paths with the following extensions:
18/// .rs, .ts, .py, .md, .json, .toml, .yaml, .sh, .go, .java
19///
20/// Examples:
21/// - "Modify src/main.rs" → ["src/main.rs"]
22/// - "See src/foo.rs and tests/bar.rs" → ["src/foo.rs", "tests/bar.rs"]
23/// - "File: src/main.rs." → ["src/main.rs"]
24///
25/// # Arguments
26/// * `description` - The description text to search for file paths
27///
28/// # Returns
29/// A Vec of deduplicated file paths in order of appearance
30pub fn extract_paths(description: &str) -> Vec<String> {
31    let mut result = Vec::new();
32    let mut seen = HashSet::new();
33
34    for cap in PATH_REGEX.captures_iter(description) {
35        if let Some(path) = cap.get(1) {
36            let path_str = path.as_str();
37            let path_start = path.start();
38
39            // Filter out absolute paths: if preceded directly by /
40            // Use byte access (O(1)) since '/' is ASCII
41            if path_start > 0 && description.as_bytes()[path_start - 1] == b'/' {
42                continue;
43            }
44
45            // Filter out URLs (check if preceded by :// in the description)
46            let before = &description[path_start.saturating_sub(3)..path_start];
47            if before.ends_with("://") {
48                continue;
49            }
50
51            // Reject path traversal: any path containing ".." components
52            // could escape the project directory
53            if Path::new(path_str)
54                .components()
55                .any(|c| matches!(c, Component::ParentDir))
56            {
57                continue;
58            }
59
60            // Deduplicate and add to result
61            if seen.insert(path_str.to_string()) {
62                result.push(path_str.to_string());
63            }
64        }
65    }
66
67    result
68}
69
70/// Maximum file size to read (1 MB). Files referenced in bean descriptions
71/// are embedded into LLM prompts, so reading very large files is wasteful
72/// and risks unbounded memory usage.
73const MAX_FILE_SIZE: u64 = 1_024 * 1_024;
74
75/// Reads a file from disk and returns its contents as a string.
76///
77/// # Arguments
78/// * `path` - The file path to read
79///
80/// # Returns
81/// * `Ok(String)` - The file contents
82/// * `Err` - If the file doesn't exist, is too large, is binary, or is not valid UTF-8
83///
84/// # Behavior
85/// - Rejects files larger than 1 MB
86/// - Reads raw bytes first, then checks for binary content (null bytes)
87/// - Converts to UTF-8 only after binary check passes
88pub fn read_file(path: &Path) -> io::Result<String> {
89    let metadata = fs::metadata(path)?;
90    if metadata.len() > MAX_FILE_SIZE {
91        return Err(io::Error::new(
92            io::ErrorKind::InvalidData,
93            format!(
94                "File too large ({} bytes, max {})",
95                metadata.len(),
96                MAX_FILE_SIZE
97            ),
98        ));
99    }
100
101    // Read raw bytes first so we can detect binary files that aren't valid UTF-8
102    let bytes = fs::read(path)?;
103
104    if bytes.contains(&0) {
105        eprintln!("Warning: Skipping binary file: {}", path.display());
106        return Err(io::Error::new(
107            io::ErrorKind::InvalidData,
108            "File appears to be binary (contains null bytes)",
109        ));
110    }
111
112    String::from_utf8(bytes)
113        .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "File is not valid UTF-8"))
114}
115
116/// Detects the programming language from a file extension.
117///
118/// Supports: rs, ts, tsx, py, go, java, json, yaml, toml, sh, md
119fn detect_language(path: &str) -> &str {
120    match path.split('.').next_back() {
121        Some("rs") => "rust",
122        Some("ts") => "typescript",
123        Some("tsx") => "typescript",
124        Some("py") => "python",
125        Some("go") => "go",
126        Some("java") => "java",
127        Some("json") => "json",
128        Some("yaml") | Some("yml") => "yaml",
129        Some("toml") => "toml",
130        Some("sh") => "sh",
131        Some("md") => "markdown",
132        _ => "text",
133    }
134}
135
136/// Formats a file's content as a markdown code block.
137///
138/// # Arguments
139/// * `path` - The file path (used for display and language detection)
140/// * `content` - The file contents
141///
142/// # Returns
143/// A markdown-formatted string with the file header and code fence
144///
145/// # Format
146/// ````text
147/// ## File: {path}
148/// ```{lang}
149/// {content}
150/// ```
151/// ````
152pub fn format_file_block(path: &str, content: &str) -> String {
153    let language = detect_language(path);
154    format!("## File: {}\n```{}\n{}\n```\n", path, language, content)
155}
156
157/// Assembles context from multiple files into a single markdown document.
158///
159/// # Arguments
160/// * `paths` - File paths to include
161/// * `base_dir` - The base directory to resolve relative paths against
162///
163/// # Returns
164/// * `Ok(String)` - Markdown containing all readable files (empty if none succeed)
165/// * `Err` - If `base_dir` cannot be canonicalized
166///
167/// # Behavior
168/// - Validates each resolved path stays within `base_dir` (prevents directory traversal)
169/// - Skips files that escape the project directory, can't be read, or are binary/too large
170/// - Continues even if some files fail
171pub fn assemble_context(paths: Vec<String>, base_dir: &Path) -> io::Result<String> {
172    let canonical_base = base_dir.canonicalize().map_err(|e| {
173        io::Error::new(
174            e.kind(),
175            format!(
176                "Cannot canonicalize base directory {}: {}",
177                base_dir.display(),
178                e
179            ),
180        )
181    })?;
182
183    let mut output = String::new();
184
185    for path_str in paths {
186        let full_path = base_dir.join(&path_str);
187
188        // Canonicalize the resolved path and verify it stays within the project.
189        // This catches symlinks and any traversal that survived extract_paths filtering.
190        let canonical = match full_path.canonicalize() {
191            Ok(p) => p,
192            Err(_) => {
193                // File doesn't exist or can't be resolved — skip silently
194                eprintln!("Warning: Could not read file {}: not found", path_str);
195                continue;
196            }
197        };
198
199        if !canonical.starts_with(&canonical_base) {
200            eprintln!(
201                "Warning: Skipping file outside project directory: {}",
202                path_str
203            );
204            continue;
205        }
206
207        match read_file(&canonical) {
208            Ok(content) => {
209                output.push_str(&format_file_block(&path_str, &content));
210                output.push('\n');
211            }
212            Err(e) => {
213                eprintln!("Warning: Could not read file {}: {}", path_str, e);
214            }
215        }
216    }
217
218    Ok(output)
219}
220
221#[cfg(test)]
222mod tests {
223    use super::{assemble_context, detect_language, extract_paths, format_file_block, read_file};
224    use std::fs;
225    use tempfile::TempDir;
226
227    #[test]
228    fn test_single_path() {
229        let result = extract_paths("Modify src/main.rs");
230        assert_eq!(result, vec!["src/main.rs"]);
231    }
232
233    #[test]
234    fn test_multiple_paths() {
235        let result = extract_paths("See src/foo.rs and tests/bar.rs");
236        assert_eq!(result, vec!["src/foo.rs", "tests/bar.rs"]);
237    }
238
239    #[test]
240    fn test_deduplicate_paths() {
241        let result = extract_paths("Update src/main.rs to fix src/main.rs");
242        assert_eq!(result, vec!["src/main.rs"]);
243    }
244
245    #[test]
246    fn test_with_punctuation() {
247        let result = extract_paths("File: src/main.rs.");
248        assert_eq!(result, vec!["src/main.rs"]);
249    }
250
251    #[test]
252    fn test_no_paths() {
253        let result = extract_paths("No files mentioned here");
254        assert_eq!(result.len(), 0);
255    }
256
257    #[test]
258    fn test_various_extensions() {
259        let description =
260            "Check src/config.rs, tests/test.ts, docs/guide.md, package.json, and Cargo.toml";
261        let result = extract_paths(description);
262        assert_eq!(
263            result,
264            vec![
265                "src/config.rs",
266                "tests/test.ts",
267                "docs/guide.md",
268                "package.json",
269                "Cargo.toml"
270            ]
271        );
272    }
273
274    #[test]
275    fn test_paths_with_hyphens() {
276        let result = extract_paths("See src/my-module.rs and tests/integration-test.rs");
277        assert_eq!(
278            result,
279            vec!["src/my-module.rs", "tests/integration-test.rs"]
280        );
281    }
282
283    #[test]
284    fn test_paths_with_underscores() {
285        let result = extract_paths("Update src/my_module.rs in tests/my_test.rs");
286        assert_eq!(result, vec!["src/my_module.rs", "tests/my_test.rs"]);
287    }
288
289    #[test]
290    fn test_deeply_nested_paths() {
291        let result = extract_paths("Modify deeply/nested/path/to/src/main.rs");
292        assert_eq!(result, vec!["deeply/nested/path/to/src/main.rs"]);
293    }
294
295    #[test]
296    fn test_ignores_absolute_paths() {
297        // Absolute paths starting with / should not match
298        let result = extract_paths("Do not match /absolute/path/file.rs");
299        assert_eq!(result.len(), 0);
300    }
301
302    #[test]
303    fn test_ignores_urls() {
304        // URLs should not match due to :// and domain patterns
305        let result = extract_paths("See https://example.com/file.rs for details");
306        assert_eq!(result.len(), 0);
307    }
308
309    #[test]
310    fn test_mixed_valid_and_invalid() {
311        let description = "Check src/main.rs at https://example.com/file.rs and tests/test.ts";
312        let result = extract_paths(description);
313        assert_eq!(result, vec!["src/main.rs", "tests/test.ts"]);
314    }
315
316    #[test]
317    fn test_order_of_appearance() {
318        let description = "Start with z/file.rs, then a/file.rs, then m/file.rs";
319        let result = extract_paths(description);
320        assert_eq!(result, vec!["z/file.rs", "a/file.rs", "m/file.rs"]);
321    }
322
323    #[test]
324    fn test_yaml_and_json_extensions() {
325        let result = extract_paths("Update config.yaml and settings.json");
326        assert_eq!(result, vec!["config.yaml", "settings.json"]);
327    }
328
329    #[test]
330    fn test_go_and_java_extensions() {
331        let result = extract_paths("Implement src/main.go and src/Main.java");
332        assert_eq!(result, vec!["src/main.go", "src/Main.java"]);
333    }
334
335    #[test]
336    fn test_tsx_extension() {
337        let result = extract_paths("Update components/Button.tsx and pages/Home.tsx");
338        assert_eq!(result, vec!["components/Button.tsx", "pages/Home.tsx"]);
339    }
340
341    #[test]
342    fn test_yml_extension() {
343        let result = extract_paths("Edit .github/workflows/ci.yml and docker-compose.yml");
344        assert_eq!(
345            result,
346            vec![".github/workflows/ci.yml", "docker-compose.yml"]
347        );
348    }
349
350    #[test]
351    fn test_shell_script_extension() {
352        let result = extract_paths("Run scripts/deploy.sh for deployment");
353        assert_eq!(result, vec!["scripts/deploy.sh"]);
354    }
355
356    #[test]
357    fn test_empty_string() {
358        let result = extract_paths("");
359        assert_eq!(result.len(), 0);
360    }
361
362    #[test]
363    fn test_path_in_middle_of_sentence() {
364        let result = extract_paths("The file src/config.rs needs updating because reasons");
365        assert_eq!(result, vec!["src/config.rs"]);
366    }
367
368    #[test]
369    fn test_path_at_start_of_string() {
370        let result = extract_paths("src/main.rs is the entry point");
371        assert_eq!(result, vec!["src/main.rs"]);
372    }
373
374    #[test]
375    fn test_path_at_end_of_string() {
376        let result = extract_paths("Please modify src/main.rs");
377        assert_eq!(result, vec!["src/main.rs"]);
378    }
379
380    #[test]
381    fn test_adjacent_paths() {
382        let result = extract_paths("src/foo.rs src/bar.rs");
383        assert_eq!(result, vec!["src/foo.rs", "src/bar.rs"]);
384    }
385
386    #[test]
387    fn test_paths_with_numbers() {
388        let result = extract_paths("Update src/v2/main.rs and test_1.rs");
389        assert_eq!(result, vec!["src/v2/main.rs", "test_1.rs"]);
390    }
391
392    // Tests for path traversal rejection
393    #[test]
394    fn test_rejects_parent_traversal() {
395        let result = extract_paths("Read ../../etc/shadow.md for secrets");
396        assert!(result.is_empty());
397    }
398
399    #[test]
400    fn test_rejects_mid_path_traversal() {
401        let result = extract_paths("Check src/../../../.ssh/config.json");
402        assert!(result.is_empty());
403    }
404
405    #[test]
406    fn test_rejects_traversal_keeps_valid() {
407        let result = extract_paths("Check src/main.rs and ../../etc/passwd.yaml");
408        assert_eq!(result, vec!["src/main.rs"]);
409    }
410
411    #[test]
412    fn test_allows_dots_in_filenames() {
413        // ".." as a path component is rejected, but dots in filenames are fine
414        let result = extract_paths("Check src/my.module.rs");
415        assert_eq!(result, vec!["src/my.module.rs"]);
416    }
417
418    // Tests for read_file function
419    #[test]
420    fn test_read_file_success() {
421        let temp_dir = TempDir::new().unwrap();
422        let test_file = temp_dir.path().join("test.rs");
423        let content = "fn main() {\n    println!(\"Hello\");\n}\n";
424        fs::write(&test_file, content).unwrap();
425
426        let result = read_file(&test_file).unwrap();
427        assert_eq!(result, content);
428    }
429
430    #[test]
431    fn test_read_file_missing() {
432        let temp_dir = TempDir::new().unwrap();
433        let missing_file = temp_dir.path().join("nonexistent.rs");
434
435        let result = read_file(&missing_file);
436        assert!(result.is_err());
437    }
438
439    #[test]
440    fn test_read_file_binary() {
441        let temp_dir = TempDir::new().unwrap();
442        let binary_file = temp_dir.path().join("binary.bin");
443        let binary_content = vec![0, 1, 2, 3, 0, 255];
444        fs::write(&binary_file, binary_content).unwrap();
445
446        let result = read_file(&binary_file);
447        assert!(result.is_err());
448    }
449
450    #[test]
451    fn test_read_file_rejects_oversized() {
452        let temp_dir = TempDir::new().unwrap();
453        let big_file = temp_dir.path().join("huge.rs");
454        let content = "x".repeat(1_024 * 1_024 + 1);
455        fs::write(&big_file, &content).unwrap();
456
457        let result = read_file(&big_file);
458        assert!(result.is_err());
459        assert!(
460            result.unwrap_err().to_string().contains("too large"),
461            "Error message should mention size"
462        );
463    }
464
465    #[test]
466    fn test_read_file_rejects_non_utf8() {
467        let temp_dir = TempDir::new().unwrap();
468        let bad_file = temp_dir.path().join("bad.rs");
469        // Invalid UTF-8 sequence without null bytes
470        fs::write(&bad_file, [0xFF, 0xFE, 0x41, 0x42]).unwrap();
471
472        let result = read_file(&bad_file);
473        assert!(result.is_err());
474    }
475
476    // Tests for detect_language function
477    #[test]
478    fn test_detect_language_rust() {
479        assert_eq!(detect_language("src/main.rs"), "rust");
480    }
481
482    #[test]
483    fn test_detect_language_python() {
484        assert_eq!(detect_language("script.py"), "python");
485    }
486
487    #[test]
488    fn test_detect_language_json() {
489        assert_eq!(detect_language("config.json"), "json");
490    }
491
492    #[test]
493    fn test_detect_language_yaml() {
494        assert_eq!(detect_language("config.yaml"), "yaml");
495    }
496
497    #[test]
498    fn test_detect_language_yml() {
499        assert_eq!(detect_language("config.yml"), "yaml");
500    }
501
502    #[test]
503    fn test_detect_language_typescript() {
504        assert_eq!(detect_language("index.ts"), "typescript");
505    }
506
507    #[test]
508    fn test_detect_language_tsx() {
509        assert_eq!(detect_language("component.tsx"), "typescript");
510    }
511
512    #[test]
513    fn test_detect_language_go() {
514        assert_eq!(detect_language("main.go"), "go");
515    }
516
517    #[test]
518    fn test_detect_language_java() {
519        assert_eq!(detect_language("Main.java"), "java");
520    }
521
522    #[test]
523    fn test_detect_language_shell() {
524        assert_eq!(detect_language("deploy.sh"), "sh");
525    }
526
527    #[test]
528    fn test_detect_language_markdown() {
529        assert_eq!(detect_language("README.md"), "markdown");
530    }
531
532    #[test]
533    fn test_detect_language_toml() {
534        assert_eq!(detect_language("Cargo.toml"), "toml");
535    }
536
537    #[test]
538    fn test_detect_language_unknown() {
539        assert_eq!(detect_language("file.unknown"), "text");
540    }
541
542    // Tests for format_file_block function
543    #[test]
544    fn test_format_file_block_rust() {
545        let path = "src/main.rs";
546        let content = "fn main() {}";
547        let result = format_file_block(path, content);
548
549        assert!(result.contains("## File: src/main.rs"));
550        assert!(result.contains("```rust"));
551        assert!(result.contains("fn main() {}"));
552        assert!(result.contains("```"));
553    }
554
555    #[test]
556    fn test_format_file_block_python() {
557        let path = "script.py";
558        let content = "print('hello')";
559        let result = format_file_block(path, content);
560
561        assert!(result.contains("## File: script.py"));
562        assert!(result.contains("```python"));
563        assert!(result.contains("print('hello')"));
564    }
565
566    #[test]
567    fn test_format_file_block_json() {
568        let path = "config.json";
569        let content = r#"{"key": "value"}"#;
570        let result = format_file_block(path, content);
571
572        assert!(result.contains("## File: config.json"));
573        assert!(result.contains("```json"));
574        assert!(result.contains(r#"{"key": "value"}"#));
575    }
576
577    #[test]
578    fn test_format_file_block_multiline() {
579        let path = "src/lib.rs";
580        let content = "pub fn foo() {\n    // comment\n    return 42;\n}";
581        let result = format_file_block(path, content);
582
583        assert!(result.contains("## File: src/lib.rs"));
584        assert!(result.contains("```rust"));
585        assert!(result.contains("pub fn foo()"));
586        assert!(result.contains("// comment"));
587        assert!(result.contains("return 42;"));
588    }
589
590    // Tests for assemble_context function
591    #[test]
592    fn test_assemble_context_single_file() {
593        let temp_dir = TempDir::new().unwrap();
594        let test_file = temp_dir.path().join("test.rs");
595        fs::write(&test_file, "fn main() {}").unwrap();
596
597        let result = assemble_context(vec!["test.rs".to_string()], temp_dir.path()).unwrap();
598
599        assert!(result.contains("## File: test.rs"));
600        assert!(result.contains("```rust"));
601        assert!(result.contains("fn main() {}"));
602    }
603
604    #[test]
605    fn test_assemble_context_multiple_files() {
606        let temp_dir = TempDir::new().unwrap();
607
608        let file1 = temp_dir.path().join("file1.rs");
609        fs::write(&file1, "// file 1").unwrap();
610
611        let file2 = temp_dir.path().join("file2.py");
612        fs::write(&file2, "# file 2").unwrap();
613
614        let result = assemble_context(
615            vec!["file1.rs".to_string(), "file2.py".to_string()],
616            temp_dir.path(),
617        )
618        .unwrap();
619
620        assert!(result.contains("## File: file1.rs"));
621        assert!(result.contains("```rust"));
622        assert!(result.contains("// file 1"));
623
624        assert!(result.contains("## File: file2.py"));
625        assert!(result.contains("```python"));
626        assert!(result.contains("# file 2"));
627    }
628
629    #[test]
630    fn test_assemble_context_skips_missing_files() {
631        let temp_dir = TempDir::new().unwrap();
632
633        let existing = temp_dir.path().join("exists.rs");
634        fs::write(&existing, "fn hello() {}").unwrap();
635
636        let result = assemble_context(
637            vec!["exists.rs".to_string(), "missing.rs".to_string()],
638            temp_dir.path(),
639        )
640        .unwrap();
641
642        // Should contain existing file
643        assert!(result.contains("## File: exists.rs"));
644        assert!(result.contains("fn hello() {}"));
645
646        // Should not contain missing file
647        assert!(!result.contains("missing.rs"));
648    }
649
650    #[test]
651    fn test_assemble_context_empty_paths() {
652        let temp_dir = TempDir::new().unwrap();
653
654        let result = assemble_context(vec![], temp_dir.path()).unwrap();
655
656        assert_eq!(result.trim(), "");
657    }
658
659    #[test]
660    fn test_assemble_context_rejects_symlink_escape() {
661        let temp_dir = TempDir::new().unwrap();
662        let project = temp_dir.path().join("project");
663        fs::create_dir(&project).unwrap();
664
665        // Create a secret file outside the project
666        let secret = temp_dir.path().join("secret.json");
667        fs::write(&secret, r#"{"api_key": "leaked"}"#).unwrap();
668
669        // Create a symlink inside the project pointing outside
670        #[cfg(unix)]
671        {
672            std::os::unix::fs::symlink(&secret, project.join("secret.json")).unwrap();
673            let result = assemble_context(vec!["secret.json".to_string()], &project).unwrap();
674            assert!(
675                !result.contains("leaked"),
676                "Symlink escape should be blocked"
677            );
678        }
679    }
680
681    #[test]
682    fn test_assemble_context_preserves_content() {
683        let temp_dir = TempDir::new().unwrap();
684
685        let test_file = temp_dir.path().join("test.json");
686        let content = r#"{
687  "key": "value",
688  "nested": {
689    "inner": 42
690  }
691}"#;
692        fs::write(&test_file, content).unwrap();
693
694        let result = assemble_context(vec!["test.json".to_string()], temp_dir.path()).unwrap();
695
696        assert!(result.contains(content));
697    }
698}