1use std::path::{Path, PathBuf};
30use thiserror::Error;
31
32pub const DEFAULT_MAX_FILE_SIZE: usize = 1024 * 1024;
34
35pub const DEFAULT_MAX_TOTAL_SIZE: usize = 10 * 1024 * 1024;
37
38#[derive(Debug, Clone)]
40pub enum ContextSource {
41 Content {
43 content: String,
45 },
46 File {
48 path: String,
50 required: bool,
52 },
53 Files {
55 paths: Vec<String>,
57 required: bool,
59 },
60 Glob {
62 pattern: String,
64 },
65}
66
67#[derive(Debug, Clone)]
69pub struct ResolvedContext {
70 pub source: String,
72 pub resolved_path: Option<PathBuf>,
74 pub content: String,
76}
77
78#[derive(Debug, Clone, Default)]
80pub struct ContextLoadResult {
81 pub files: Vec<ResolvedContext>,
83 pub skipped: Vec<String>,
85 pub total_bytes: usize,
87}
88
89#[derive(Debug, Clone)]
91pub struct ContextConfig {
92 pub max_file_size: usize,
94 pub max_total_size: usize,
96}
97
98impl Default for ContextConfig {
99 fn default() -> Self {
100 Self {
101 max_file_size: DEFAULT_MAX_FILE_SIZE,
102 max_total_size: DEFAULT_MAX_TOTAL_SIZE,
103 }
104 }
105}
106
107#[derive(Debug, Clone)]
109pub struct PathVariables {
110 pub cwd: PathBuf,
112 pub home: PathBuf,
114}
115
116impl PathVariables {
117 pub fn current() -> Self {
119 Self {
120 cwd: std::env::current_dir().unwrap_or_default(),
121 home: dirs::home_dir().unwrap_or_default(),
122 }
123 }
124}
125
126#[derive(Debug, Error)]
128pub enum ContextError {
129 #[error("required context file not found: {0}")]
131 FileNotFound(String),
132
133 #[error("context file is not valid UTF-8: {path}")]
135 InvalidUtf8 {
136 path: String,
138 },
139
140 #[error("context file exceeds size limit ({size} bytes > {limit} bytes): {path}")]
142 FileTooLarge {
143 path: String,
145 size: usize,
147 limit: usize,
149 },
150
151 #[error("total context size exceeds limit ({size} bytes > {limit} bytes)")]
153 TotalSizeTooLarge {
154 size: usize,
156 limit: usize,
158 },
159
160 #[error("failed to read context file {path}: {message}")]
162 IoError {
163 path: String,
165 message: String,
167 },
168
169 #[error("invalid glob pattern: {0}")]
171 InvalidPattern(String),
172}
173
174fn expand_path(path: &str, vars: &PathVariables) -> String {
182 let home_str = vars.home.to_str().unwrap_or("");
183 let cwd_str = vars.cwd.to_str().unwrap_or("");
184
185 if path == "~" {
187 return home_str.to_string();
188 }
189 if let Some(rest) = path.strip_prefix("~/") {
190 return format!("{}/{}", home_str, rest);
191 }
192
193 let mut result = path.to_string();
195 result = result.replace("$HOME", home_str);
196 result = result.replace("$CWD", cwd_str);
197
198 result
199}
200
201fn load_file(
203 path: &Path,
204 config: &ContextConfig,
205 total_bytes: &mut usize,
206) -> Result<String, ContextError> {
207 let metadata = std::fs::metadata(path).map_err(|e| ContextError::IoError {
208 path: path.display().to_string(),
209 message: e.to_string(),
210 })?;
211
212 let size = metadata.len() as usize;
213
214 if size > config.max_file_size {
216 return Err(ContextError::FileTooLarge {
217 path: path.display().to_string(),
218 size,
219 limit: config.max_file_size,
220 });
221 }
222
223 if *total_bytes + size > config.max_total_size {
225 return Err(ContextError::TotalSizeTooLarge {
226 size: *total_bytes + size,
227 limit: config.max_total_size,
228 });
229 }
230
231 let content = std::fs::read_to_string(path).map_err(|e| {
233 if e.kind() == std::io::ErrorKind::InvalidData {
234 ContextError::InvalidUtf8 {
235 path: path.display().to_string(),
236 }
237 } else {
238 ContextError::IoError {
239 path: path.display().to_string(),
240 message: e.to_string(),
241 }
242 }
243 })?;
244
245 *total_bytes += size;
246 Ok(content)
247}
248
249pub fn resolve_context(
254 sources: &[ContextSource],
255 vars: &PathVariables,
256 config: &ContextConfig,
257) -> Result<ContextLoadResult, ContextError> {
258 let mut files = Vec::new();
259 let mut skipped = Vec::new();
260 let mut total_bytes = 0usize;
261
262 for source in sources {
263 match source {
264 ContextSource::Content { content } => {
265 let size = content.len();
266 if total_bytes + size > config.max_total_size {
267 return Err(ContextError::TotalSizeTooLarge {
268 size: total_bytes + size,
269 limit: config.max_total_size,
270 });
271 }
272 total_bytes += size;
273 files.push(ResolvedContext {
274 source: "inline content".to_string(),
275 resolved_path: None,
276 content: content.clone(),
277 });
278 }
279
280 ContextSource::File { path, required } => {
281 let expanded = expand_path(path, vars);
282 let resolved = PathBuf::from(&expanded);
283
284 if !resolved.exists() {
285 if *required {
286 return Err(ContextError::FileNotFound(expanded));
287 }
288 skipped.push(expanded);
289 continue;
290 }
291
292 let content = load_file(&resolved, config, &mut total_bytes)?;
293 files.push(ResolvedContext {
294 source: path.clone(),
295 resolved_path: Some(resolved),
296 content,
297 });
298 }
299
300 ContextSource::Files { paths, required } => {
301 for path in paths {
302 let expanded = expand_path(path, vars);
303 let resolved = PathBuf::from(&expanded);
304
305 if !resolved.exists() {
306 if *required {
307 return Err(ContextError::FileNotFound(expanded));
308 }
309 skipped.push(expanded);
310 continue;
311 }
312
313 let content = load_file(&resolved, config, &mut total_bytes)?;
314 files.push(ResolvedContext {
315 source: path.clone(),
316 resolved_path: Some(resolved),
317 content,
318 });
319 }
320 }
321
322 ContextSource::Glob { pattern } => {
323 let expanded = expand_path(pattern, vars);
324 let matches = glob::glob(&expanded)
325 .map_err(|e| ContextError::InvalidPattern(e.to_string()))?;
326
327 let mut pattern_files: Vec<PathBuf> = matches
328 .filter_map(|r| r.ok())
329 .filter(|p| p.is_file())
330 .collect();
331
332 pattern_files.sort();
334
335 for resolved in pattern_files {
337 let content = load_file(&resolved, config, &mut total_bytes)?;
338 files.push(ResolvedContext {
339 source: pattern.clone(),
340 resolved_path: Some(resolved),
341 content,
342 });
343 }
344 }
345 }
346 }
347
348 Ok(ContextLoadResult {
349 files,
350 skipped,
351 total_bytes,
352 })
353}
354
355pub fn build_effective_prompt(
363 system_prompt: Option<&str>,
364 context: &ContextLoadResult,
365) -> Option<String> {
366 let mut parts = Vec::new();
367
368 if let Some(prompt) = system_prompt {
370 parts.push(prompt.to_string());
371 }
372
373 for ctx in &context.files {
375 let header = match &ctx.resolved_path {
376 Some(path) => format!("<!-- Context from: {} -->", path.display()),
377 None => "<!-- Inline context -->".to_string(),
378 };
379 parts.push(format!("\n---\n{}\n{}", header, ctx.content));
380 }
381
382 if parts.is_empty() {
383 None
384 } else {
385 Some(parts.join("\n"))
386 }
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392 use std::fs;
393 use tempfile::TempDir;
394
395 #[test]
396 fn test_expand_path_cwd() {
397 let vars = PathVariables {
398 cwd: PathBuf::from("/workspace"),
399 home: PathBuf::from("/home/user"),
400 };
401
402 assert_eq!(expand_path("$CWD/AGENTS.md", &vars), "/workspace/AGENTS.md");
403 }
404
405 #[test]
406 fn test_expand_path_home_var() {
407 let vars = PathVariables {
408 cwd: PathBuf::from("/workspace"),
409 home: PathBuf::from("/home/user"),
410 };
411
412 assert_eq!(
413 expand_path("$HOME/.config/agent.md", &vars),
414 "/home/user/.config/agent.md"
415 );
416 }
417
418 #[test]
419 fn test_expand_path_tilde() {
420 let vars = PathVariables {
421 cwd: PathBuf::from("/workspace"),
422 home: PathBuf::from("/home/user"),
423 };
424
425 assert_eq!(
426 expand_path("~/.config/agent.md", &vars),
427 "/home/user/.config/agent.md"
428 );
429 }
430
431 #[test]
432 fn test_expand_path_tilde_alone() {
433 let vars = PathVariables {
434 cwd: PathBuf::from("/workspace"),
435 home: PathBuf::from("/home/user"),
436 };
437
438 assert_eq!(expand_path("~", &vars), "/home/user");
439 }
440
441 #[test]
442 fn test_expand_path_relative() {
443 let vars = PathVariables {
444 cwd: PathBuf::from("/workspace"),
445 home: PathBuf::from("/home/user"),
446 };
447
448 assert_eq!(expand_path("AGENTS.md", &vars), "AGENTS.md");
451 }
452
453 #[test]
454 fn test_resolve_context_content() {
455 let sources = vec![ContextSource::Content {
456 content: "# Rules\nBe helpful.".to_string(),
457 }];
458
459 let vars = PathVariables::current();
460 let config = ContextConfig::default();
461
462 let result = resolve_context(&sources, &vars, &config).unwrap();
463
464 assert_eq!(result.files.len(), 1);
465 assert_eq!(result.files[0].content, "# Rules\nBe helpful.");
466 assert!(result.files[0].resolved_path.is_none());
467 assert_eq!(result.files[0].source, "inline content");
468 }
469
470 #[test]
471 fn test_resolve_context_single_file() {
472 let temp = TempDir::new().unwrap();
473 let file_path = temp.path().join("AGENTS.md");
474 fs::write(&file_path, "# Agent Instructions\nBe helpful.").unwrap();
475
476 let sources = vec![ContextSource::File {
477 path: file_path.to_str().unwrap().to_string(),
478 required: true,
479 }];
480
481 let vars = PathVariables::current();
482 let config = ContextConfig::default();
483
484 let result = resolve_context(&sources, &vars, &config).unwrap();
485
486 assert_eq!(result.files.len(), 1);
487 assert_eq!(result.files[0].content, "# Agent Instructions\nBe helpful.");
488 assert!(result.skipped.is_empty());
489 }
490
491 #[test]
492 fn test_resolve_context_optional_missing() {
493 let sources = vec![ContextSource::File {
494 path: "/nonexistent/file.md".to_string(),
495 required: false,
496 }];
497
498 let vars = PathVariables::current();
499 let config = ContextConfig::default();
500
501 let result = resolve_context(&sources, &vars, &config).unwrap();
502
503 assert!(result.files.is_empty());
504 assert_eq!(result.skipped.len(), 1);
505 assert_eq!(result.skipped[0], "/nonexistent/file.md");
506 }
507
508 #[test]
509 fn test_resolve_context_required_missing() {
510 let sources = vec![ContextSource::File {
511 path: "/nonexistent/file.md".to_string(),
512 required: true,
513 }];
514
515 let vars = PathVariables::current();
516 let config = ContextConfig::default();
517
518 let result = resolve_context(&sources, &vars, &config);
519
520 assert!(matches!(result, Err(ContextError::FileNotFound(_))));
521 }
522
523 #[test]
524 fn test_resolve_context_files_all_exist() {
525 let temp = TempDir::new().unwrap();
526 fs::write(temp.path().join("a.md"), "File A").unwrap();
527 fs::write(temp.path().join("b.md"), "File B").unwrap();
528
529 let sources = vec![ContextSource::Files {
530 paths: vec![
531 temp.path().join("a.md").to_str().unwrap().to_string(),
532 temp.path().join("b.md").to_str().unwrap().to_string(),
533 ],
534 required: true,
535 }];
536
537 let vars = PathVariables::current();
538 let config = ContextConfig::default();
539
540 let result = resolve_context(&sources, &vars, &config).unwrap();
541
542 assert_eq!(result.files.len(), 2);
543 assert_eq!(result.files[0].content, "File A");
544 assert_eq!(result.files[1].content, "File B");
545 }
546
547 #[test]
548 fn test_resolve_context_files_required_one_missing() {
549 let temp = TempDir::new().unwrap();
550 fs::write(temp.path().join("a.md"), "File A").unwrap();
551
552 let sources = vec![ContextSource::Files {
553 paths: vec![
554 temp.path().join("a.md").to_str().unwrap().to_string(),
555 temp.path().join("missing.md").to_str().unwrap().to_string(),
556 ],
557 required: true,
558 }];
559
560 let vars = PathVariables::current();
561 let config = ContextConfig::default();
562
563 let result = resolve_context(&sources, &vars, &config);
564
565 assert!(matches!(result, Err(ContextError::FileNotFound(_))));
566 }
567
568 #[test]
569 fn test_resolve_context_files_optional_one_missing() {
570 let temp = TempDir::new().unwrap();
571 fs::write(temp.path().join("a.md"), "File A").unwrap();
572
573 let sources = vec![ContextSource::Files {
574 paths: vec![
575 temp.path().join("a.md").to_str().unwrap().to_string(),
576 temp.path().join("missing.md").to_str().unwrap().to_string(),
577 ],
578 required: false,
579 }];
580
581 let vars = PathVariables::current();
582 let config = ContextConfig::default();
583
584 let result = resolve_context(&sources, &vars, &config).unwrap();
585
586 assert_eq!(result.files.len(), 1);
587 assert_eq!(result.files[0].content, "File A");
588 assert_eq!(result.skipped.len(), 1);
589 }
590
591 #[test]
592 fn test_resolve_context_glob() {
593 let temp = TempDir::new().unwrap();
594
595 fs::write(temp.path().join("a.md"), "File A").unwrap();
597 fs::write(temp.path().join("b.md"), "File B").unwrap();
598 fs::write(temp.path().join("c.txt"), "Not markdown").unwrap();
599
600 let pattern = format!("{}/*.md", temp.path().display());
601 let sources = vec![ContextSource::Glob { pattern }];
602
603 let vars = PathVariables::current();
604 let config = ContextConfig::default();
605
606 let result = resolve_context(&sources, &vars, &config).unwrap();
607
608 assert_eq!(result.files.len(), 2);
609 assert!(result.files[0]
611 .resolved_path
612 .as_ref()
613 .unwrap()
614 .ends_with("a.md"));
615 assert!(result.files[1]
616 .resolved_path
617 .as_ref()
618 .unwrap()
619 .ends_with("b.md"));
620 }
621
622 #[test]
623 fn test_resolve_context_glob_no_matches() {
624 let temp = TempDir::new().unwrap();
625
626 let pattern = format!("{}/*.md", temp.path().display());
627 let sources = vec![ContextSource::Glob { pattern }];
628
629 let vars = PathVariables::current();
630 let config = ContextConfig::default();
631
632 let result = resolve_context(&sources, &vars, &config).unwrap();
634 assert!(result.files.is_empty());
635 }
636
637 #[test]
638 fn test_resolve_context_file_too_large() {
639 let temp = TempDir::new().unwrap();
640 let file_path = temp.path().join("large.md");
641
642 let content = "x".repeat(1000);
644 fs::write(&file_path, &content).unwrap();
645
646 let sources = vec![ContextSource::File {
647 path: file_path.to_str().unwrap().to_string(),
648 required: true,
649 }];
650
651 let vars = PathVariables::current();
652 let config = ContextConfig {
653 max_file_size: 100, max_total_size: DEFAULT_MAX_TOTAL_SIZE,
655 };
656
657 let result = resolve_context(&sources, &vars, &config);
658
659 assert!(matches!(result, Err(ContextError::FileTooLarge { .. })));
660 }
661
662 #[test]
663 fn test_resolve_context_total_too_large() {
664 let temp = TempDir::new().unwrap();
665
666 fs::write(temp.path().join("a.md"), "x".repeat(60)).unwrap();
668 fs::write(temp.path().join("b.md"), "x".repeat(60)).unwrap();
669
670 let pattern = format!("{}/*.md", temp.path().display());
671 let sources = vec![ContextSource::Glob { pattern }];
672
673 let vars = PathVariables::current();
674 let config = ContextConfig {
675 max_file_size: 100,
676 max_total_size: 100, };
678
679 let result = resolve_context(&sources, &vars, &config);
680
681 assert!(matches!(
682 result,
683 Err(ContextError::TotalSizeTooLarge { .. })
684 ));
685 }
686
687 #[test]
688 fn test_resolve_context_declaration_order() {
689 let temp = TempDir::new().unwrap();
690
691 fs::write(temp.path().join("first.md"), "First").unwrap();
692 fs::write(temp.path().join("second.md"), "Second").unwrap();
693
694 let sources = vec![
695 ContextSource::File {
696 path: temp.path().join("second.md").to_str().unwrap().to_string(),
697 required: true,
698 },
699 ContextSource::File {
700 path: temp.path().join("first.md").to_str().unwrap().to_string(),
701 required: true,
702 },
703 ];
704
705 let vars = PathVariables::current();
706 let config = ContextConfig::default();
707
708 let result = resolve_context(&sources, &vars, &config).unwrap();
709
710 assert_eq!(result.files.len(), 2);
712 assert_eq!(result.files[0].content, "Second");
713 assert_eq!(result.files[1].content, "First");
714 }
715
716 #[test]
717 fn test_build_effective_prompt_system_only() {
718 let context = ContextLoadResult::default();
719 let result = build_effective_prompt(Some("You are helpful."), &context);
720
721 assert_eq!(result, Some("You are helpful.".to_string()));
722 }
723
724 #[test]
725 fn test_build_effective_prompt_context_only() {
726 let context = ContextLoadResult {
727 files: vec![ResolvedContext {
728 source: "test.md".to_string(),
729 resolved_path: Some(PathBuf::from("/path/to/test.md")),
730 content: "Context content".to_string(),
731 }],
732 skipped: vec![],
733 total_bytes: 15,
734 };
735
736 let result = build_effective_prompt(None, &context);
737
738 assert!(result.is_some());
739 let prompt = result.unwrap();
740 assert!(prompt.contains("Context content"));
741 assert!(prompt.contains("/path/to/test.md"));
742 }
743
744 #[test]
745 fn test_build_effective_prompt_inline_content() {
746 let context = ContextLoadResult {
747 files: vec![ResolvedContext {
748 source: "inline content".to_string(),
749 resolved_path: None,
750 content: "Inline rules".to_string(),
751 }],
752 skipped: vec![],
753 total_bytes: 12,
754 };
755
756 let result = build_effective_prompt(None, &context);
757
758 assert!(result.is_some());
759 let prompt = result.unwrap();
760 assert!(prompt.contains("Inline rules"));
761 assert!(prompt.contains("Inline context"));
762 }
763
764 #[test]
765 fn test_build_effective_prompt_combined() {
766 let context = ContextLoadResult {
767 files: vec![ResolvedContext {
768 source: "test.md".to_string(),
769 resolved_path: Some(PathBuf::from("/path/to/test.md")),
770 content: "Context content".to_string(),
771 }],
772 skipped: vec![],
773 total_bytes: 15,
774 };
775
776 let result = build_effective_prompt(Some("System prompt"), &context);
777
778 assert!(result.is_some());
779 let prompt = result.unwrap();
780 assert!(prompt.starts_with("System prompt"));
782 assert!(prompt.contains("Context content"));
784 }
785
786 #[test]
787 fn test_build_effective_prompt_empty() {
788 let context = ContextLoadResult::default();
789 let result = build_effective_prompt(None, &context);
790
791 assert!(result.is_none());
792 }
793
794 #[test]
795 fn test_context_config_default() {
796 let config = ContextConfig::default();
797
798 assert_eq!(config.max_file_size, DEFAULT_MAX_FILE_SIZE);
799 assert_eq!(config.max_total_size, DEFAULT_MAX_TOTAL_SIZE);
800 }
801}