1use regex::Regex;
2use std::collections::HashSet;
3use std::fs;
4use std::io;
5use std::path::{Component, Path};
6use std::sync::LazyLock;
7
8static PATH_REGEX: LazyLock<Regex> = LazyLock::new(|| {
10 Regex::new(r"([a-zA-Z0-9_.][a-zA-Z0-9_./\-]*\.(rs|tsx?|py|md|json|toml|ya?ml|sh|go|java))\b")
12 .expect("Invalid regex pattern")
13});
14
15pub fn extract_paths(description: &str) -> Vec<String> {
31 let mut result = Vec::new();
32 let mut seen = HashSet::new();
33
34 for cap in PATH_REGEX.captures_iter(description) {
35 if let Some(path) = cap.get(1) {
36 let path_str = path.as_str();
37 let path_start = path.start();
38
39 if path_start > 0 && description.as_bytes()[path_start - 1] == b'/' {
42 continue;
43 }
44
45 let before = &description[path_start.saturating_sub(3)..path_start];
47 if before.ends_with("://") {
48 continue;
49 }
50
51 if Path::new(path_str)
54 .components()
55 .any(|c| matches!(c, Component::ParentDir))
56 {
57 continue;
58 }
59
60 if seen.insert(path_str.to_string()) {
62 result.push(path_str.to_string());
63 }
64 }
65 }
66
67 result
68}
69
70const MAX_FILE_SIZE: u64 = 1_024 * 1_024;
74
75pub fn read_file(path: &Path) -> io::Result<String> {
89 let metadata = fs::metadata(path)?;
90 if metadata.len() > MAX_FILE_SIZE {
91 return Err(io::Error::new(
92 io::ErrorKind::InvalidData,
93 format!(
94 "File too large ({} bytes, max {})",
95 metadata.len(),
96 MAX_FILE_SIZE
97 ),
98 ));
99 }
100
101 let bytes = fs::read(path)?;
103
104 if bytes.contains(&0) {
105 eprintln!("Warning: Skipping binary file: {}", path.display());
106 return Err(io::Error::new(
107 io::ErrorKind::InvalidData,
108 "File appears to be binary (contains null bytes)",
109 ));
110 }
111
112 String::from_utf8(bytes)
113 .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "File is not valid UTF-8"))
114}
115
116fn detect_language(path: &str) -> &str {
120 match path.split('.').next_back() {
121 Some("rs") => "rust",
122 Some("ts") => "typescript",
123 Some("tsx") => "typescript",
124 Some("py") => "python",
125 Some("go") => "go",
126 Some("java") => "java",
127 Some("json") => "json",
128 Some("yaml") | Some("yml") => "yaml",
129 Some("toml") => "toml",
130 Some("sh") => "sh",
131 Some("md") => "markdown",
132 _ => "text",
133 }
134}
135
136pub fn format_file_block(path: &str, content: &str) -> String {
153 let language = detect_language(path);
154 format!("## File: {}\n```{}\n{}\n```\n", path, language, content)
155}
156
157pub fn assemble_context(paths: Vec<String>, base_dir: &Path) -> io::Result<String> {
172 let canonical_base = base_dir.canonicalize().map_err(|e| {
173 io::Error::new(
174 e.kind(),
175 format!(
176 "Cannot canonicalize base directory {}: {}",
177 base_dir.display(),
178 e
179 ),
180 )
181 })?;
182
183 let mut output = String::new();
184
185 for path_str in paths {
186 let full_path = base_dir.join(&path_str);
187
188 let canonical = match full_path.canonicalize() {
191 Ok(p) => p,
192 Err(_) => {
193 eprintln!("Warning: Could not read file {}: not found", path_str);
195 continue;
196 }
197 };
198
199 if !canonical.starts_with(&canonical_base) {
200 eprintln!(
201 "Warning: Skipping file outside project directory: {}",
202 path_str
203 );
204 continue;
205 }
206
207 match read_file(&canonical) {
208 Ok(content) => {
209 output.push_str(&format_file_block(&path_str, &content));
210 output.push('\n');
211 }
212 Err(e) => {
213 eprintln!("Warning: Could not read file {}: {}", path_str, e);
214 }
215 }
216 }
217
218 Ok(output)
219}
220
221#[cfg(test)]
222mod tests {
223 use super::{assemble_context, detect_language, extract_paths, format_file_block, read_file};
224 use std::fs;
225 use tempfile::TempDir;
226
227 #[test]
228 fn test_single_path() {
229 let result = extract_paths("Modify src/main.rs");
230 assert_eq!(result, vec!["src/main.rs"]);
231 }
232
233 #[test]
234 fn test_multiple_paths() {
235 let result = extract_paths("See src/foo.rs and tests/bar.rs");
236 assert_eq!(result, vec!["src/foo.rs", "tests/bar.rs"]);
237 }
238
239 #[test]
240 fn test_deduplicate_paths() {
241 let result = extract_paths("Update src/main.rs to fix src/main.rs");
242 assert_eq!(result, vec!["src/main.rs"]);
243 }
244
245 #[test]
246 fn test_with_punctuation() {
247 let result = extract_paths("File: src/main.rs.");
248 assert_eq!(result, vec!["src/main.rs"]);
249 }
250
251 #[test]
252 fn test_no_paths() {
253 let result = extract_paths("No files mentioned here");
254 assert_eq!(result.len(), 0);
255 }
256
257 #[test]
258 fn test_various_extensions() {
259 let description =
260 "Check src/config.rs, tests/test.ts, docs/guide.md, package.json, and Cargo.toml";
261 let result = extract_paths(description);
262 assert_eq!(
263 result,
264 vec![
265 "src/config.rs",
266 "tests/test.ts",
267 "docs/guide.md",
268 "package.json",
269 "Cargo.toml"
270 ]
271 );
272 }
273
274 #[test]
275 fn test_paths_with_hyphens() {
276 let result = extract_paths("See src/my-module.rs and tests/integration-test.rs");
277 assert_eq!(
278 result,
279 vec!["src/my-module.rs", "tests/integration-test.rs"]
280 );
281 }
282
283 #[test]
284 fn test_paths_with_underscores() {
285 let result = extract_paths("Update src/my_module.rs in tests/my_test.rs");
286 assert_eq!(result, vec!["src/my_module.rs", "tests/my_test.rs"]);
287 }
288
289 #[test]
290 fn test_deeply_nested_paths() {
291 let result = extract_paths("Modify deeply/nested/path/to/src/main.rs");
292 assert_eq!(result, vec!["deeply/nested/path/to/src/main.rs"]);
293 }
294
295 #[test]
296 fn test_ignores_absolute_paths() {
297 let result = extract_paths("Do not match /absolute/path/file.rs");
299 assert_eq!(result.len(), 0);
300 }
301
302 #[test]
303 fn test_ignores_urls() {
304 let result = extract_paths("See https://example.com/file.rs for details");
306 assert_eq!(result.len(), 0);
307 }
308
309 #[test]
310 fn test_mixed_valid_and_invalid() {
311 let description = "Check src/main.rs at https://example.com/file.rs and tests/test.ts";
312 let result = extract_paths(description);
313 assert_eq!(result, vec!["src/main.rs", "tests/test.ts"]);
314 }
315
316 #[test]
317 fn test_order_of_appearance() {
318 let description = "Start with z/file.rs, then a/file.rs, then m/file.rs";
319 let result = extract_paths(description);
320 assert_eq!(result, vec!["z/file.rs", "a/file.rs", "m/file.rs"]);
321 }
322
323 #[test]
324 fn test_yaml_and_json_extensions() {
325 let result = extract_paths("Update config.yaml and settings.json");
326 assert_eq!(result, vec!["config.yaml", "settings.json"]);
327 }
328
329 #[test]
330 fn test_go_and_java_extensions() {
331 let result = extract_paths("Implement src/main.go and src/Main.java");
332 assert_eq!(result, vec!["src/main.go", "src/Main.java"]);
333 }
334
335 #[test]
336 fn test_tsx_extension() {
337 let result = extract_paths("Update components/Button.tsx and pages/Home.tsx");
338 assert_eq!(result, vec!["components/Button.tsx", "pages/Home.tsx"]);
339 }
340
341 #[test]
342 fn test_yml_extension() {
343 let result = extract_paths("Edit .github/workflows/ci.yml and docker-compose.yml");
344 assert_eq!(
345 result,
346 vec![".github/workflows/ci.yml", "docker-compose.yml"]
347 );
348 }
349
350 #[test]
351 fn test_shell_script_extension() {
352 let result = extract_paths("Run scripts/deploy.sh for deployment");
353 assert_eq!(result, vec!["scripts/deploy.sh"]);
354 }
355
356 #[test]
357 fn test_empty_string() {
358 let result = extract_paths("");
359 assert_eq!(result.len(), 0);
360 }
361
362 #[test]
363 fn test_path_in_middle_of_sentence() {
364 let result = extract_paths("The file src/config.rs needs updating because reasons");
365 assert_eq!(result, vec!["src/config.rs"]);
366 }
367
368 #[test]
369 fn test_path_at_start_of_string() {
370 let result = extract_paths("src/main.rs is the entry point");
371 assert_eq!(result, vec!["src/main.rs"]);
372 }
373
374 #[test]
375 fn test_path_at_end_of_string() {
376 let result = extract_paths("Please modify src/main.rs");
377 assert_eq!(result, vec!["src/main.rs"]);
378 }
379
380 #[test]
381 fn test_adjacent_paths() {
382 let result = extract_paths("src/foo.rs src/bar.rs");
383 assert_eq!(result, vec!["src/foo.rs", "src/bar.rs"]);
384 }
385
386 #[test]
387 fn test_paths_with_numbers() {
388 let result = extract_paths("Update src/v2/main.rs and test_1.rs");
389 assert_eq!(result, vec!["src/v2/main.rs", "test_1.rs"]);
390 }
391
392 #[test]
394 fn test_rejects_parent_traversal() {
395 let result = extract_paths("Read ../../etc/shadow.md for secrets");
396 assert!(result.is_empty());
397 }
398
399 #[test]
400 fn test_rejects_mid_path_traversal() {
401 let result = extract_paths("Check src/../../../.ssh/config.json");
402 assert!(result.is_empty());
403 }
404
405 #[test]
406 fn test_rejects_traversal_keeps_valid() {
407 let result = extract_paths("Check src/main.rs and ../../etc/passwd.yaml");
408 assert_eq!(result, vec!["src/main.rs"]);
409 }
410
411 #[test]
412 fn test_allows_dots_in_filenames() {
413 let result = extract_paths("Check src/my.module.rs");
415 assert_eq!(result, vec!["src/my.module.rs"]);
416 }
417
418 #[test]
420 fn test_read_file_success() {
421 let temp_dir = TempDir::new().unwrap();
422 let test_file = temp_dir.path().join("test.rs");
423 let content = "fn main() {\n println!(\"Hello\");\n}\n";
424 fs::write(&test_file, content).unwrap();
425
426 let result = read_file(&test_file).unwrap();
427 assert_eq!(result, content);
428 }
429
430 #[test]
431 fn test_read_file_missing() {
432 let temp_dir = TempDir::new().unwrap();
433 let missing_file = temp_dir.path().join("nonexistent.rs");
434
435 let result = read_file(&missing_file);
436 assert!(result.is_err());
437 }
438
439 #[test]
440 fn test_read_file_binary() {
441 let temp_dir = TempDir::new().unwrap();
442 let binary_file = temp_dir.path().join("binary.bin");
443 let binary_content = vec![0, 1, 2, 3, 0, 255];
444 fs::write(&binary_file, binary_content).unwrap();
445
446 let result = read_file(&binary_file);
447 assert!(result.is_err());
448 }
449
450 #[test]
451 fn test_read_file_rejects_oversized() {
452 let temp_dir = TempDir::new().unwrap();
453 let big_file = temp_dir.path().join("huge.rs");
454 let content = "x".repeat(1_024 * 1_024 + 1);
455 fs::write(&big_file, &content).unwrap();
456
457 let result = read_file(&big_file);
458 assert!(result.is_err());
459 assert!(
460 result.unwrap_err().to_string().contains("too large"),
461 "Error message should mention size"
462 );
463 }
464
465 #[test]
466 fn test_read_file_rejects_non_utf8() {
467 let temp_dir = TempDir::new().unwrap();
468 let bad_file = temp_dir.path().join("bad.rs");
469 fs::write(&bad_file, [0xFF, 0xFE, 0x41, 0x42]).unwrap();
471
472 let result = read_file(&bad_file);
473 assert!(result.is_err());
474 }
475
476 #[test]
478 fn test_detect_language_rust() {
479 assert_eq!(detect_language("src/main.rs"), "rust");
480 }
481
482 #[test]
483 fn test_detect_language_python() {
484 assert_eq!(detect_language("script.py"), "python");
485 }
486
487 #[test]
488 fn test_detect_language_json() {
489 assert_eq!(detect_language("config.json"), "json");
490 }
491
492 #[test]
493 fn test_detect_language_yaml() {
494 assert_eq!(detect_language("config.yaml"), "yaml");
495 }
496
497 #[test]
498 fn test_detect_language_yml() {
499 assert_eq!(detect_language("config.yml"), "yaml");
500 }
501
502 #[test]
503 fn test_detect_language_typescript() {
504 assert_eq!(detect_language("index.ts"), "typescript");
505 }
506
507 #[test]
508 fn test_detect_language_tsx() {
509 assert_eq!(detect_language("component.tsx"), "typescript");
510 }
511
512 #[test]
513 fn test_detect_language_go() {
514 assert_eq!(detect_language("main.go"), "go");
515 }
516
517 #[test]
518 fn test_detect_language_java() {
519 assert_eq!(detect_language("Main.java"), "java");
520 }
521
522 #[test]
523 fn test_detect_language_shell() {
524 assert_eq!(detect_language("deploy.sh"), "sh");
525 }
526
527 #[test]
528 fn test_detect_language_markdown() {
529 assert_eq!(detect_language("README.md"), "markdown");
530 }
531
532 #[test]
533 fn test_detect_language_toml() {
534 assert_eq!(detect_language("Cargo.toml"), "toml");
535 }
536
537 #[test]
538 fn test_detect_language_unknown() {
539 assert_eq!(detect_language("file.unknown"), "text");
540 }
541
542 #[test]
544 fn test_format_file_block_rust() {
545 let path = "src/main.rs";
546 let content = "fn main() {}";
547 let result = format_file_block(path, content);
548
549 assert!(result.contains("## File: src/main.rs"));
550 assert!(result.contains("```rust"));
551 assert!(result.contains("fn main() {}"));
552 assert!(result.contains("```"));
553 }
554
555 #[test]
556 fn test_format_file_block_python() {
557 let path = "script.py";
558 let content = "print('hello')";
559 let result = format_file_block(path, content);
560
561 assert!(result.contains("## File: script.py"));
562 assert!(result.contains("```python"));
563 assert!(result.contains("print('hello')"));
564 }
565
566 #[test]
567 fn test_format_file_block_json() {
568 let path = "config.json";
569 let content = r#"{"key": "value"}"#;
570 let result = format_file_block(path, content);
571
572 assert!(result.contains("## File: config.json"));
573 assert!(result.contains("```json"));
574 assert!(result.contains(r#"{"key": "value"}"#));
575 }
576
577 #[test]
578 fn test_format_file_block_multiline() {
579 let path = "src/lib.rs";
580 let content = "pub fn foo() {\n // comment\n return 42;\n}";
581 let result = format_file_block(path, content);
582
583 assert!(result.contains("## File: src/lib.rs"));
584 assert!(result.contains("```rust"));
585 assert!(result.contains("pub fn foo()"));
586 assert!(result.contains("// comment"));
587 assert!(result.contains("return 42;"));
588 }
589
590 #[test]
592 fn test_assemble_context_single_file() {
593 let temp_dir = TempDir::new().unwrap();
594 let test_file = temp_dir.path().join("test.rs");
595 fs::write(&test_file, "fn main() {}").unwrap();
596
597 let result = assemble_context(vec!["test.rs".to_string()], temp_dir.path()).unwrap();
598
599 assert!(result.contains("## File: test.rs"));
600 assert!(result.contains("```rust"));
601 assert!(result.contains("fn main() {}"));
602 }
603
604 #[test]
605 fn test_assemble_context_multiple_files() {
606 let temp_dir = TempDir::new().unwrap();
607
608 let file1 = temp_dir.path().join("file1.rs");
609 fs::write(&file1, "// file 1").unwrap();
610
611 let file2 = temp_dir.path().join("file2.py");
612 fs::write(&file2, "# file 2").unwrap();
613
614 let result = assemble_context(
615 vec!["file1.rs".to_string(), "file2.py".to_string()],
616 temp_dir.path(),
617 )
618 .unwrap();
619
620 assert!(result.contains("## File: file1.rs"));
621 assert!(result.contains("```rust"));
622 assert!(result.contains("// file 1"));
623
624 assert!(result.contains("## File: file2.py"));
625 assert!(result.contains("```python"));
626 assert!(result.contains("# file 2"));
627 }
628
629 #[test]
630 fn test_assemble_context_skips_missing_files() {
631 let temp_dir = TempDir::new().unwrap();
632
633 let existing = temp_dir.path().join("exists.rs");
634 fs::write(&existing, "fn hello() {}").unwrap();
635
636 let result = assemble_context(
637 vec!["exists.rs".to_string(), "missing.rs".to_string()],
638 temp_dir.path(),
639 )
640 .unwrap();
641
642 assert!(result.contains("## File: exists.rs"));
644 assert!(result.contains("fn hello() {}"));
645
646 assert!(!result.contains("missing.rs"));
648 }
649
650 #[test]
651 fn test_assemble_context_empty_paths() {
652 let temp_dir = TempDir::new().unwrap();
653
654 let result = assemble_context(vec![], temp_dir.path()).unwrap();
655
656 assert_eq!(result.trim(), "");
657 }
658
659 #[test]
660 fn test_assemble_context_rejects_symlink_escape() {
661 let temp_dir = TempDir::new().unwrap();
662 let project = temp_dir.path().join("project");
663 fs::create_dir(&project).unwrap();
664
665 let secret = temp_dir.path().join("secret.json");
667 fs::write(&secret, r#"{"api_key": "leaked"}"#).unwrap();
668
669 #[cfg(unix)]
671 {
672 std::os::unix::fs::symlink(&secret, project.join("secret.json")).unwrap();
673 let result = assemble_context(vec!["secret.json".to_string()], &project).unwrap();
674 assert!(
675 !result.contains("leaked"),
676 "Symlink escape should be blocked"
677 );
678 }
679 }
680
681 #[test]
682 fn test_assemble_context_preserves_content() {
683 let temp_dir = TempDir::new().unwrap();
684
685 let test_file = temp_dir.path().join("test.json");
686 let content = r#"{
687 "key": "value",
688 "nested": {
689 "inner": 42
690 }
691}"#;
692 fs::write(&test_file, content).unwrap();
693
694 let result = assemble_context(vec!["test.json".to_string()], temp_dir.path()).unwrap();
695
696 assert!(result.contains(content));
697 }
698}