use anyhow::{Context, Result};
use std::collections::HashSet;
use std::path::{Path, PathBuf};
mod resolver;
mod validation;
pub use resolver::{parse_include_line, resolve_include_pattern};
#[allow(unused_imports)]
pub use validation::{validate_glob_pattern, validate_include_path};
const MAX_INCLUDE_DEPTH: usize = 10;
const MAX_INCLUDED_FILES: usize = 100;
#[derive(Debug, Clone)]
pub struct IncludeContext {
depth: usize,
visited: HashSet<String>,
file_count: usize,
pub base_dir: PathBuf,
canonical_cache: std::collections::HashMap<PathBuf, PathBuf>,
}
impl IncludeContext {
pub fn new(config_path: &Path) -> Self {
let base_dir = config_path
.parent()
.unwrap_or_else(|| Path::new("/"))
.to_path_buf();
Self {
depth: 0,
visited: HashSet::with_capacity(16), file_count: 0,
base_dir,
canonical_cache: std::collections::HashMap::with_capacity(16),
}
}
fn can_include(&self) -> Result<()> {
if self.depth >= MAX_INCLUDE_DEPTH {
anyhow::bail!(
"Maximum include depth ({MAX_INCLUDE_DEPTH}) exceeded. This may indicate an include cycle or misconfiguration."
);
}
if self.file_count >= MAX_INCLUDED_FILES {
anyhow::bail!(
"Maximum number of included files ({MAX_INCLUDED_FILES}) exceeded. This limit exists to prevent DoS attacks."
);
}
Ok(())
}
fn enter_include(&mut self, path: &Path) -> Result<()> {
self.can_include()?;
let canonical = if let Some(cached) = self.canonical_cache.get(path) {
cached.clone()
} else if path.exists() {
let canonical = path
.canonicalize()
.with_context(|| format!("Failed to canonicalize path: {}", path.display()))?;
self.canonical_cache
.insert(path.to_path_buf(), canonical.clone());
canonical
} else {
if path.is_absolute() {
path.to_path_buf()
} else {
self.base_dir.join(path)
}
};
let canonical_str = canonical.to_string_lossy().into_owned();
if self.visited.contains(&canonical_str) {
anyhow::bail!(
"Include cycle detected: {} has already been processed",
path.display()
);
}
self.visited.insert(canonical_str);
self.depth += 1;
self.file_count += 1;
if let Some(parent) = canonical.parent() {
self.base_dir = parent.to_path_buf();
}
if self.canonical_cache.len() > 100 {
self.canonical_cache.clear();
}
Ok(())
}
fn exit_include(&mut self) {
if self.depth > 0 {
self.depth -= 1;
}
}
}
#[derive(Debug, Clone)]
pub struct IncludedFile {
pub path: PathBuf,
pub content: String,
#[allow(dead_code)]
pub line_offset: usize,
}
pub async fn resolve_includes(config_path: &Path, content: &str) -> Result<Vec<IncludedFile>> {
let mut context = IncludeContext::new(config_path);
let canonical = if config_path.exists() {
config_path.canonicalize().with_context(|| {
format!(
"Failed to canonicalize main config path: {}",
config_path.display()
)
})?
} else {
config_path.to_path_buf()
};
context
.visited
.insert(canonical.to_string_lossy().into_owned());
process_file_with_includes(config_path, content, &mut context).await
}
async fn process_file_with_includes(
file_path: &Path,
content: &str,
context: &mut IncludeContext,
) -> Result<Vec<IncludedFile>> {
let mut result = Vec::new();
let mut current_content = String::new();
for (line_number, line) in content.lines().enumerate() {
let line_number = line_number + 1; let trimmed = line.trim();
if let Some(patterns) = parse_include_line(trimmed) {
if !current_content.is_empty() {
let line_offset: usize = result
.iter()
.map(|f: &IncludedFile| f.content.lines().count())
.sum();
result.push(IncludedFile {
path: file_path.to_path_buf(),
content: current_content.clone(),
line_offset,
});
current_content.clear();
}
for pattern in patterns {
let resolved_files = resolve_include_pattern(pattern, context)
.await
.with_context(|| {
format!(
"Failed to resolve Include pattern '{}' at line {} in {}",
pattern,
line_number,
file_path.display()
)
})?;
for include_path in resolved_files {
context.enter_include(&include_path).with_context(|| {
format!("Failed to include file: {}", include_path.display())
})?;
let include_content = tokio::time::timeout(
std::time::Duration::from_secs(5),
tokio::fs::read_to_string(&include_path),
)
.await
.map_err(|_| {
anyhow::anyhow!("Timeout reading include file: {}", include_path.display())
})?
.with_context(|| {
format!("Failed to read include file: {}", include_path.display())
})?;
let mut included_files = Box::pin(process_file_with_includes(
&include_path,
&include_content,
context,
))
.await?;
result.append(&mut included_files);
context.exit_include();
}
}
} else {
current_content.push_str(line);
current_content.push('\n');
}
}
if !current_content.is_empty() {
let line_offset: usize = result
.iter()
.map(|f: &IncludedFile| f.content.lines().count())
.sum();
result.push(IncludedFile {
path: file_path.to_path_buf(),
content: current_content,
line_offset,
});
}
if result.is_empty() {
result.push(IncludedFile {
path: file_path.to_path_buf(),
content: content.to_string(),
line_offset: 0,
});
}
Ok(result)
}
pub fn combine_included_files(files: &[IncludedFile]) -> String {
let mut combined = String::new();
for file in files {
if !combined.is_empty() {
combined.push('\n');
}
combined.push_str(&format!("# Source: {}\n", file.path.display()));
combined.push_str(&file.content);
}
combined
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::TempDir;
#[tokio::test]
async fn test_resolve_includes_simple() {
let temp_dir = TempDir::new().unwrap();
let main_config = temp_dir.path().join("config");
let main_content = "Host example.com\n User mainuser\n";
fs::write(&main_config, main_content).unwrap();
let result = resolve_includes(&main_config, main_content).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].path, main_config);
assert_eq!(result[0].content, main_content);
}
#[tokio::test]
async fn test_resolve_includes_with_include() {
let temp_dir = TempDir::new().unwrap();
let include_dir = temp_dir.path().join("config.d");
fs::create_dir(&include_dir).unwrap();
let included_file = include_dir.join("extra.conf");
let included_content = "Host included.com\n User includeduser\n";
fs::write(&included_file, included_content).unwrap();
let main_config = temp_dir.path().join("config");
let main_content = format!(
"Include {}\n\nHost example.com\n User mainuser\n",
included_file.display()
);
fs::write(&main_config, &main_content).unwrap();
let result = resolve_includes(&main_config, &main_content).await.unwrap();
assert_eq!(result.len(), 2, "Should have 2 file chunks");
assert_eq!(
result[0].path, included_file,
"First should be included file"
);
assert_eq!(result[0].content, included_content);
assert_eq!(
result[1].path, main_config,
"Second should be rest of main config"
);
assert!(
result[1].content.contains("Host example.com"),
"Should contain main config content"
);
}
#[tokio::test]
async fn test_include_cycle_detection() {
let temp_dir = TempDir::new().unwrap();
let file_a = temp_dir.path().join("a.conf");
let content_a = format!("Include {}\n", temp_dir.path().join("b.conf").display());
fs::write(&file_a, &content_a).unwrap();
let file_b = temp_dir.path().join("b.conf");
let content_b = format!("Include {}\n", file_a.display());
fs::write(&file_b, content_b).unwrap();
let result = resolve_includes(&file_a, &content_a).await;
assert!(result.is_err());
let err_display = result.as_ref().unwrap_err().to_string();
let err_chain = format!("{:?}", result.unwrap_err());
println!("Error display: {err_display}"); println!("Error chain: {err_chain}"); assert!(
err_chain.contains("cycle")
|| err_chain.contains("already been processed")
|| err_chain.contains("Include cycle"),
"Expected cycle detection in error chain but got: {err_chain}"
);
}
#[tokio::test]
async fn test_max_depth_limit() {
let temp_dir = TempDir::new().unwrap();
let mut prev_file = temp_dir.path().join("config");
let mut prev_content = String::new();
for i in 0..=MAX_INCLUDE_DEPTH + 1 {
let file = temp_dir.path().join(format!("level{i}.conf"));
let content = if i == 0 {
"Host start\n".to_string()
} else {
format!("Include {}\n", prev_file.display())
};
fs::write(&file, &content).unwrap();
prev_file = file;
prev_content = content;
}
let result = resolve_includes(&prev_file, &prev_content).await;
assert!(result.is_err());
let error = result.unwrap_err();
let err_chain = format!("{error:?}");
assert!(err_chain.contains("depth") || err_chain.contains("Maximum include depth"));
}
#[tokio::test]
async fn test_glob_pattern_expansion() {
let temp_dir = TempDir::new().unwrap();
let config_dir = temp_dir.path().join("config.d");
fs::create_dir(&config_dir).unwrap();
fs::write(config_dir.join("01-first.conf"), "Host first\n").unwrap();
fs::write(config_dir.join("02-second.conf"), "Host second\n").unwrap();
fs::write(config_dir.join("03-third.conf"), "Host third\n").unwrap();
let main_config = temp_dir.path().join("config");
let main_content = format!("Include {}/*.conf\n", config_dir.display());
fs::write(&main_config, &main_content).unwrap();
let result = resolve_includes(&main_config, &main_content).await.unwrap();
assert_eq!(result.len(), 3);
assert!(
result[0]
.path
.file_name()
.unwrap()
.to_str()
.unwrap()
.contains("01-first")
);
assert!(
result[1]
.path
.file_name()
.unwrap()
.to_str()
.unwrap()
.contains("02-second")
);
assert!(
result[2]
.path
.file_name()
.unwrap()
.to_str()
.unwrap()
.contains("03-third")
);
}
#[tokio::test]
async fn test_multiple_patterns_in_include() {
let temp_dir = TempDir::new().unwrap();
let dir1 = temp_dir.path().join("dir1");
let dir2 = temp_dir.path().join("dir2");
fs::create_dir(&dir1).unwrap();
fs::create_dir(&dir2).unwrap();
fs::write(dir1.join("config1.conf"), "Host host1\n").unwrap();
fs::write(dir2.join("config2.conf"), "Host host2\n").unwrap();
let main_config = temp_dir.path().join("config");
let main_content = format!(
"Include {} {}\n",
dir1.join("*.conf").display(),
dir2.join("*.conf").display()
);
fs::write(&main_config, &main_content).unwrap();
let result = resolve_includes(&main_config, &main_content).await.unwrap();
assert_eq!(result.len(), 2);
assert!(
result[0].content.contains("Host host1") || result[1].content.contains("Host host1")
);
assert!(
result[0].content.contains("Host host2") || result[1].content.contains("Host host2")
);
}
#[tokio::test]
async fn test_include_nonexistent_file_skipped() {
let temp_dir = TempDir::new().unwrap();
let main_config = temp_dir.path().join("config");
let nonexistent_path = temp_dir.path().join("nonexistent.conf");
let main_content = format!(
"Include {}\nHost example.com\n User testuser\n",
nonexistent_path.display()
);
fs::write(&main_config, &main_content).unwrap();
let result = resolve_includes(&main_config, &main_content).await.unwrap();
assert_eq!(result.len(), 1);
assert!(result[0].content.contains("Host example.com"));
}
#[tokio::test]
async fn test_include_order_preservation() {
let temp_dir = TempDir::new().unwrap();
let include_dir = temp_dir.path().join("includes");
fs::create_dir(&include_dir).unwrap();
fs::write(
include_dir.join("first.conf"),
"Host first\n Port 1111\n",
)
.unwrap();
fs::write(
include_dir.join("second.conf"),
"Host second\n Port 2222\n",
)
.unwrap();
fs::write(
include_dir.join("third.conf"),
"Host third\n Port 3333\n",
)
.unwrap();
let main_config = temp_dir.path().join("config");
let main_content = format!(
"Host start\n Port 9999\n\nInclude {}\n\nHost middle\n Port 5555\n\nInclude {}\n\nHost end\n Port 1\n",
include_dir.join("first.conf").display(),
include_dir.join("second.conf").display()
);
fs::write(&main_config, &main_content).unwrap();
let result = resolve_includes(&main_config, &main_content).await.unwrap();
let combined = combine_included_files(&result);
let start_pos = combined.find("Host start").unwrap();
let first_pos = combined.find("Host first").unwrap();
let middle_pos = combined.find("Host middle").unwrap();
let second_pos = combined.find("Host second").unwrap();
let end_pos = combined.find("Host end").unwrap();
assert!(start_pos < first_pos, "start should come before first");
assert!(first_pos < middle_pos, "first should come before middle");
assert!(middle_pos < second_pos, "middle should come before second");
assert!(second_pos < end_pos, "second should come before end");
}
#[tokio::test]
async fn test_empty_glob_pattern() {
let temp_dir = TempDir::new().unwrap();
let main_config = temp_dir.path().join("config");
let main_content = format!(
"Include {}\nHost example.com\n",
temp_dir.path().join("nonexistent/*.conf").display()
);
fs::write(&main_config, &main_content).unwrap();
let result = resolve_includes(&main_config, &main_content).await.unwrap();
assert_eq!(result.len(), 1);
assert!(result[0].content.contains("Host example.com"));
}
}