use std::path::{Path, PathBuf};
use thiserror::Error;
pub const DEFAULT_MAX_FILE_SIZE: usize = 1024 * 1024;
pub const DEFAULT_MAX_TOTAL_SIZE: usize = 10 * 1024 * 1024;
#[derive(Debug, Clone)]
pub enum ContextSource {
Content {
content: String,
},
File {
path: String,
required: bool,
},
Files {
paths: Vec<String>,
required: bool,
},
Glob {
pattern: String,
},
}
#[derive(Debug, Clone)]
pub struct ResolvedContext {
pub source: String,
pub resolved_path: Option<PathBuf>,
pub content: String,
}
#[derive(Debug, Clone, Default)]
pub struct ContextLoadResult {
pub files: Vec<ResolvedContext>,
pub skipped: Vec<String>,
pub total_bytes: usize,
}
#[derive(Debug, Clone)]
pub struct ContextConfig {
pub max_file_size: usize,
pub max_total_size: usize,
}
impl Default for ContextConfig {
fn default() -> Self {
Self {
max_file_size: DEFAULT_MAX_FILE_SIZE,
max_total_size: DEFAULT_MAX_TOTAL_SIZE,
}
}
}
#[derive(Debug, Clone)]
pub struct PathVariables {
pub cwd: PathBuf,
pub home: PathBuf,
}
impl PathVariables {
pub fn current() -> Self {
Self {
cwd: std::env::current_dir().unwrap_or_default(),
home: dirs::home_dir().unwrap_or_default(),
}
}
}
#[derive(Debug, Error)]
pub enum ContextError {
#[error("required context file not found: {0}")]
FileNotFound(String),
#[error("context file is not valid UTF-8: {path}")]
InvalidUtf8 {
path: String,
},
#[error("context file exceeds size limit ({size} bytes > {limit} bytes): {path}")]
FileTooLarge {
path: String,
size: usize,
limit: usize,
},
#[error("total context size exceeds limit ({size} bytes > {limit} bytes)")]
TotalSizeTooLarge {
size: usize,
limit: usize,
},
#[error("failed to read context file {path}: {message}")]
IoError {
path: String,
message: String,
},
#[error("invalid glob pattern: {0}")]
InvalidPattern(String),
}
fn expand_path(path: &str, vars: &PathVariables) -> String {
let home_str = vars.home.to_str().unwrap_or("");
let cwd_str = vars.cwd.to_str().unwrap_or("");
if path == "~" {
return home_str.to_string();
}
if let Some(rest) = path.strip_prefix("~/") {
return format!("{}/{}", home_str, rest);
}
let mut result = path.to_string();
result = result.replace("$HOME", home_str);
result = result.replace("$CWD", cwd_str);
result
}
fn load_file(
path: &Path,
config: &ContextConfig,
total_bytes: &mut usize,
) -> Result<String, ContextError> {
let metadata = std::fs::metadata(path).map_err(|e| ContextError::IoError {
path: path.display().to_string(),
message: e.to_string(),
})?;
let size = metadata.len() as usize;
if size > config.max_file_size {
return Err(ContextError::FileTooLarge {
path: path.display().to_string(),
size,
limit: config.max_file_size,
});
}
if *total_bytes + size > config.max_total_size {
return Err(ContextError::TotalSizeTooLarge {
size: *total_bytes + size,
limit: config.max_total_size,
});
}
let content = std::fs::read_to_string(path).map_err(|e| {
if e.kind() == std::io::ErrorKind::InvalidData {
ContextError::InvalidUtf8 {
path: path.display().to_string(),
}
} else {
ContextError::IoError {
path: path.display().to_string(),
message: e.to_string(),
}
}
})?;
*total_bytes += size;
Ok(content)
}
pub fn resolve_context(
sources: &[ContextSource],
vars: &PathVariables,
config: &ContextConfig,
) -> Result<ContextLoadResult, ContextError> {
let mut files = Vec::new();
let mut skipped = Vec::new();
let mut total_bytes = 0usize;
for source in sources {
match source {
ContextSource::Content { content } => {
let size = content.len();
if total_bytes + size > config.max_total_size {
return Err(ContextError::TotalSizeTooLarge {
size: total_bytes + size,
limit: config.max_total_size,
});
}
total_bytes += size;
files.push(ResolvedContext {
source: "inline content".to_string(),
resolved_path: None,
content: content.clone(),
});
}
ContextSource::File { path, required } => {
let expanded = expand_path(path, vars);
let resolved = PathBuf::from(&expanded);
if !resolved.exists() {
if *required {
return Err(ContextError::FileNotFound(expanded));
}
skipped.push(expanded);
continue;
}
let content = load_file(&resolved, config, &mut total_bytes)?;
files.push(ResolvedContext {
source: path.clone(),
resolved_path: Some(resolved),
content,
});
}
ContextSource::Files { paths, required } => {
for path in paths {
let expanded = expand_path(path, vars);
let resolved = PathBuf::from(&expanded);
if !resolved.exists() {
if *required {
return Err(ContextError::FileNotFound(expanded));
}
skipped.push(expanded);
continue;
}
let content = load_file(&resolved, config, &mut total_bytes)?;
files.push(ResolvedContext {
source: path.clone(),
resolved_path: Some(resolved),
content,
});
}
}
ContextSource::Glob { pattern } => {
let expanded = expand_path(pattern, vars);
let matches = glob::glob(&expanded)
.map_err(|e| ContextError::InvalidPattern(e.to_string()))?;
let mut pattern_files: Vec<PathBuf> = matches
.filter_map(|r| r.ok())
.filter(|p| p.is_file())
.collect();
pattern_files.sort();
for resolved in pattern_files {
let content = load_file(&resolved, config, &mut total_bytes)?;
files.push(ResolvedContext {
source: pattern.clone(),
resolved_path: Some(resolved),
content,
});
}
}
}
}
Ok(ContextLoadResult {
files,
skipped,
total_bytes,
})
}
pub fn build_effective_prompt(
system_prompt: Option<&str>,
context: &ContextLoadResult,
) -> Option<String> {
let mut parts = Vec::new();
if let Some(prompt) = system_prompt {
parts.push(prompt.to_string());
}
for ctx in &context.files {
let header = match &ctx.resolved_path {
Some(path) => format!("<!-- Context from: {} -->", path.display()),
None => "<!-- Inline context -->".to_string(),
};
parts.push(format!("\n---\n{}\n{}", header, ctx.content));
}
if parts.is_empty() {
None
} else {
Some(parts.join("\n"))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::TempDir;
#[test]
fn test_expand_path_cwd() {
let vars = PathVariables {
cwd: PathBuf::from("/workspace"),
home: PathBuf::from("/home/user"),
};
assert_eq!(expand_path("$CWD/AGENTS.md", &vars), "/workspace/AGENTS.md");
}
#[test]
fn test_expand_path_home_var() {
let vars = PathVariables {
cwd: PathBuf::from("/workspace"),
home: PathBuf::from("/home/user"),
};
assert_eq!(
expand_path("$HOME/.config/agent.md", &vars),
"/home/user/.config/agent.md"
);
}
#[test]
fn test_expand_path_tilde() {
let vars = PathVariables {
cwd: PathBuf::from("/workspace"),
home: PathBuf::from("/home/user"),
};
assert_eq!(
expand_path("~/.config/agent.md", &vars),
"/home/user/.config/agent.md"
);
}
#[test]
fn test_expand_path_tilde_alone() {
let vars = PathVariables {
cwd: PathBuf::from("/workspace"),
home: PathBuf::from("/home/user"),
};
assert_eq!(expand_path("~", &vars), "/home/user");
}
#[test]
fn test_expand_path_relative() {
let vars = PathVariables {
cwd: PathBuf::from("/workspace"),
home: PathBuf::from("/home/user"),
};
assert_eq!(expand_path("AGENTS.md", &vars), "AGENTS.md");
}
#[test]
fn test_resolve_context_content() {
let sources = vec![ContextSource::Content {
content: "# Rules\nBe helpful.".to_string(),
}];
let vars = PathVariables::current();
let config = ContextConfig::default();
let result = resolve_context(&sources, &vars, &config).unwrap();
assert_eq!(result.files.len(), 1);
assert_eq!(result.files[0].content, "# Rules\nBe helpful.");
assert!(result.files[0].resolved_path.is_none());
assert_eq!(result.files[0].source, "inline content");
}
#[test]
fn test_resolve_context_single_file() {
let temp = TempDir::new().unwrap();
let file_path = temp.path().join("AGENTS.md");
fs::write(&file_path, "# Agent Instructions\nBe helpful.").unwrap();
let sources = vec![ContextSource::File {
path: file_path.to_str().unwrap().to_string(),
required: true,
}];
let vars = PathVariables::current();
let config = ContextConfig::default();
let result = resolve_context(&sources, &vars, &config).unwrap();
assert_eq!(result.files.len(), 1);
assert_eq!(result.files[0].content, "# Agent Instructions\nBe helpful.");
assert!(result.skipped.is_empty());
}
#[test]
fn test_resolve_context_optional_missing() {
let sources = vec![ContextSource::File {
path: "/nonexistent/file.md".to_string(),
required: false,
}];
let vars = PathVariables::current();
let config = ContextConfig::default();
let result = resolve_context(&sources, &vars, &config).unwrap();
assert!(result.files.is_empty());
assert_eq!(result.skipped.len(), 1);
assert_eq!(result.skipped[0], "/nonexistent/file.md");
}
#[test]
fn test_resolve_context_required_missing() {
let sources = vec![ContextSource::File {
path: "/nonexistent/file.md".to_string(),
required: true,
}];
let vars = PathVariables::current();
let config = ContextConfig::default();
let result = resolve_context(&sources, &vars, &config);
assert!(matches!(result, Err(ContextError::FileNotFound(_))));
}
#[test]
fn test_resolve_context_files_all_exist() {
let temp = TempDir::new().unwrap();
fs::write(temp.path().join("a.md"), "File A").unwrap();
fs::write(temp.path().join("b.md"), "File B").unwrap();
let sources = vec![ContextSource::Files {
paths: vec![
temp.path().join("a.md").to_str().unwrap().to_string(),
temp.path().join("b.md").to_str().unwrap().to_string(),
],
required: true,
}];
let vars = PathVariables::current();
let config = ContextConfig::default();
let result = resolve_context(&sources, &vars, &config).unwrap();
assert_eq!(result.files.len(), 2);
assert_eq!(result.files[0].content, "File A");
assert_eq!(result.files[1].content, "File B");
}
#[test]
fn test_resolve_context_files_required_one_missing() {
let temp = TempDir::new().unwrap();
fs::write(temp.path().join("a.md"), "File A").unwrap();
let sources = vec![ContextSource::Files {
paths: vec![
temp.path().join("a.md").to_str().unwrap().to_string(),
temp.path().join("missing.md").to_str().unwrap().to_string(),
],
required: true,
}];
let vars = PathVariables::current();
let config = ContextConfig::default();
let result = resolve_context(&sources, &vars, &config);
assert!(matches!(result, Err(ContextError::FileNotFound(_))));
}
#[test]
fn test_resolve_context_files_optional_one_missing() {
let temp = TempDir::new().unwrap();
fs::write(temp.path().join("a.md"), "File A").unwrap();
let sources = vec![ContextSource::Files {
paths: vec![
temp.path().join("a.md").to_str().unwrap().to_string(),
temp.path().join("missing.md").to_str().unwrap().to_string(),
],
required: false,
}];
let vars = PathVariables::current();
let config = ContextConfig::default();
let result = resolve_context(&sources, &vars, &config).unwrap();
assert_eq!(result.files.len(), 1);
assert_eq!(result.files[0].content, "File A");
assert_eq!(result.skipped.len(), 1);
}
#[test]
fn test_resolve_context_glob() {
let temp = TempDir::new().unwrap();
fs::write(temp.path().join("a.md"), "File A").unwrap();
fs::write(temp.path().join("b.md"), "File B").unwrap();
fs::write(temp.path().join("c.txt"), "Not markdown").unwrap();
let pattern = format!("{}/*.md", temp.path().display());
let sources = vec![ContextSource::Glob { pattern }];
let vars = PathVariables::current();
let config = ContextConfig::default();
let result = resolve_context(&sources, &vars, &config).unwrap();
assert_eq!(result.files.len(), 2);
assert!(result.files[0]
.resolved_path
.as_ref()
.unwrap()
.ends_with("a.md"));
assert!(result.files[1]
.resolved_path
.as_ref()
.unwrap()
.ends_with("b.md"));
}
#[test]
fn test_resolve_context_glob_no_matches() {
let temp = TempDir::new().unwrap();
let pattern = format!("{}/*.md", temp.path().display());
let sources = vec![ContextSource::Glob { pattern }];
let vars = PathVariables::current();
let config = ContextConfig::default();
let result = resolve_context(&sources, &vars, &config).unwrap();
assert!(result.files.is_empty());
}
#[test]
fn test_resolve_context_file_too_large() {
let temp = TempDir::new().unwrap();
let file_path = temp.path().join("large.md");
let content = "x".repeat(1000);
fs::write(&file_path, &content).unwrap();
let sources = vec![ContextSource::File {
path: file_path.to_str().unwrap().to_string(),
required: true,
}];
let vars = PathVariables::current();
let config = ContextConfig {
max_file_size: 100, max_total_size: DEFAULT_MAX_TOTAL_SIZE,
};
let result = resolve_context(&sources, &vars, &config);
assert!(matches!(result, Err(ContextError::FileTooLarge { .. })));
}
#[test]
fn test_resolve_context_total_too_large() {
let temp = TempDir::new().unwrap();
fs::write(temp.path().join("a.md"), "x".repeat(60)).unwrap();
fs::write(temp.path().join("b.md"), "x".repeat(60)).unwrap();
let pattern = format!("{}/*.md", temp.path().display());
let sources = vec![ContextSource::Glob { pattern }];
let vars = PathVariables::current();
let config = ContextConfig {
max_file_size: 100,
max_total_size: 100, };
let result = resolve_context(&sources, &vars, &config);
assert!(matches!(
result,
Err(ContextError::TotalSizeTooLarge { .. })
));
}
#[test]
fn test_resolve_context_declaration_order() {
let temp = TempDir::new().unwrap();
fs::write(temp.path().join("first.md"), "First").unwrap();
fs::write(temp.path().join("second.md"), "Second").unwrap();
let sources = vec![
ContextSource::File {
path: temp.path().join("second.md").to_str().unwrap().to_string(),
required: true,
},
ContextSource::File {
path: temp.path().join("first.md").to_str().unwrap().to_string(),
required: true,
},
];
let vars = PathVariables::current();
let config = ContextConfig::default();
let result = resolve_context(&sources, &vars, &config).unwrap();
assert_eq!(result.files.len(), 2);
assert_eq!(result.files[0].content, "Second");
assert_eq!(result.files[1].content, "First");
}
#[test]
fn test_build_effective_prompt_system_only() {
let context = ContextLoadResult::default();
let result = build_effective_prompt(Some("You are helpful."), &context);
assert_eq!(result, Some("You are helpful.".to_string()));
}
#[test]
fn test_build_effective_prompt_context_only() {
let context = ContextLoadResult {
files: vec![ResolvedContext {
source: "test.md".to_string(),
resolved_path: Some(PathBuf::from("/path/to/test.md")),
content: "Context content".to_string(),
}],
skipped: vec![],
total_bytes: 15,
};
let result = build_effective_prompt(None, &context);
assert!(result.is_some());
let prompt = result.unwrap();
assert!(prompt.contains("Context content"));
assert!(prompt.contains("/path/to/test.md"));
}
#[test]
fn test_build_effective_prompt_inline_content() {
let context = ContextLoadResult {
files: vec![ResolvedContext {
source: "inline content".to_string(),
resolved_path: None,
content: "Inline rules".to_string(),
}],
skipped: vec![],
total_bytes: 12,
};
let result = build_effective_prompt(None, &context);
assert!(result.is_some());
let prompt = result.unwrap();
assert!(prompt.contains("Inline rules"));
assert!(prompt.contains("Inline context"));
}
#[test]
fn test_build_effective_prompt_combined() {
let context = ContextLoadResult {
files: vec![ResolvedContext {
source: "test.md".to_string(),
resolved_path: Some(PathBuf::from("/path/to/test.md")),
content: "Context content".to_string(),
}],
skipped: vec![],
total_bytes: 15,
};
let result = build_effective_prompt(Some("System prompt"), &context);
assert!(result.is_some());
let prompt = result.unwrap();
assert!(prompt.starts_with("System prompt"));
assert!(prompt.contains("Context content"));
}
#[test]
fn test_build_effective_prompt_empty() {
let context = ContextLoadResult::default();
let result = build_effective_prompt(None, &context);
assert!(result.is_none());
}
#[test]
fn test_context_config_default() {
let config = ContextConfig::default();
assert_eq!(config.max_file_size, DEFAULT_MAX_FILE_SIZE);
assert_eq!(config.max_total_size, DEFAULT_MAX_TOTAL_SIZE);
}
}