use crate::error::IncludeError;
use crate::lex::{tokenize, ShellMode};
use crate::parse::{parse, Stmt};
use std::collections::HashSet;
use std::path::{Path, PathBuf};
#[derive(Debug, Clone)]
pub struct IncludeContext {
pub chain: Vec<PathBuf>,
pub seen: HashSet<PathBuf>,
}
impl IncludeContext {
pub fn new() -> Self {
IncludeContext {
chain: Vec::new(),
seen: HashSet::new(),
}
}
pub fn include_file(
&mut self,
path: &str,
base_dir: &Path,
) -> Result<Vec<Stmt>, IncludeError> {
let resolved = if path.starts_with('/') {
PathBuf::from(path)
} else {
base_dir.join(path)
};
let canonical = resolved.canonicalize().unwrap_or(resolved);
if self.chain.iter().any(|p| p == &canonical) {
let chain_str = self
.chain
.iter()
.map(|p| p.display().to_string())
.chain(std::iter::once(canonical.display().to_string()))
.collect::<Vec<_>>()
.join(" -> ");
return Err(IncludeError::CircularInclude { chain: chain_str });
}
if self.seen.contains(&canonical) {
return Ok(Vec::new());
}
let content = std::fs::read_to_string(&canonical).map_err(|e| {
if e.kind() == std::io::ErrorKind::NotFound {
IncludeError::FileNotFound {
path: path.to_string(),
}
} else {
IncludeError::Io(e)
}
})?;
self.chain.push(canonical.clone());
let result = (|| {
let tokens = tokenize(&content, ShellMode::Sh).map_err(|e| {
IncludeError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("{}: {e}", canonical.display()),
))
})?;
parse(&tokens).map_err(|e| {
IncludeError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("{}: {e}", canonical.display()),
))
})
})();
self.chain.pop();
self.seen.insert(canonical);
result
}
pub fn include_command(
&mut self,
command: &str,
base_dir: &Path,
) -> Result<Vec<Stmt>, IncludeError> {
let output = std::process::Command::new("sh")
.arg("-c")
.arg(command)
.current_dir(base_dir)
.output()
.map_err(IncludeError::Io)?;
if !output.status.success() {
return Err(IncludeError::CommandFailed {
command: command.to_string(),
});
}
let stdout = String::from_utf8_lossy(&output.stdout);
let tokens = tokenize(&stdout, ShellMode::Sh).map_err(|e| {
IncludeError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
e.to_string(),
))
})?;
parse(&tokens).map_err(|e| {
IncludeError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
e.to_string(),
))
})
}
}
impl Default for IncludeContext {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn write_temp_mkfile(name: &str, content: &str) -> PathBuf {
let dir = std::env::temp_dir().join("mk_test_include");
std::fs::create_dir_all(&dir).unwrap();
let path = dir.join(name);
std::fs::write(&path, content).unwrap();
path
}
#[test]
fn include_simple_file() {
let included = write_temp_mkfile("common.mk", "CC = gcc\n");
let mut ctx = IncludeContext::new();
let stmts = ctx
.include_file(included.to_str().unwrap(), &std::env::temp_dir())
.unwrap();
assert_eq!(stmts.len(), 1);
}
#[test]
fn include_file_not_found() {
let mut ctx = IncludeContext::new();
let result = ctx.include_file("nonexistent.mk", &PathBuf::from("."));
assert!(matches!(result, Err(IncludeError::FileNotFound { .. })));
}
#[test]
fn circular_include_detected() {
let path = write_temp_mkfile("circular.mk", "CC = gcc\n");
let canonical = path.canonicalize().unwrap();
let dir = path.parent().unwrap().to_path_buf();
let mut ctx = IncludeContext::new();
ctx.chain.push(canonical);
let result = ctx.include_file(path.to_str().unwrap(), &dir);
assert!(matches!(result, Err(IncludeError::CircularInclude { .. })));
}
#[test]
fn chain_cleared_after_successful_include() {
let path = write_temp_mkfile("chain_test.mk", "CC = gcc\n");
let dir = std::env::temp_dir().join("mk_test_include");
let mut ctx = IncludeContext::new();
ctx.include_file(path.to_str().unwrap(), &dir).unwrap();
assert!(ctx.chain.is_empty());
}
#[test]
fn chain_cleaned_on_lex_error() {
let bad = write_temp_mkfile("bad_lex.mk", "TARGET: prereq\n\tcmd 'oops\n");
let dir = std::env::temp_dir().join("mk_test_include");
let mut ctx = IncludeContext::new();
let result = ctx.include_file(bad.to_str().unwrap(), &dir);
assert!(result.is_err());
assert!(ctx.chain.is_empty());
}
#[test]
fn diamond_include_skipped_on_second_encounter() {
let d = write_temp_mkfile("diamond_d.mk", "VAR = from_d\n");
let dir = std::env::temp_dir().join("mk_test_include");
let mut ctx = IncludeContext::new();
let stmts1 = ctx.include_file(d.to_str().unwrap(), &dir).unwrap();
assert_eq!(stmts1.len(), 1, "first include of D should return statement");
assert!(ctx.seen.len() == 1, "D should be in seen set");
let stmts2 = ctx.include_file(d.to_str().unwrap(), &dir).unwrap();
assert!(stmts2.is_empty(), "second include of D should be empty");
}
#[test]
fn diamond_include_chain_cleared() {
let d = write_temp_mkfile("diamond_chain_d.mk", "VAR = d_val\n");
let dir = std::env::temp_dir().join("mk_test_include");
let mut ctx = IncludeContext::new();
ctx.include_file(d.to_str().unwrap(), &dir).unwrap();
assert!(ctx.chain.is_empty(), "chain should be empty after include");
assert!(ctx.seen.len() == 1, "seen set should contain D");
}
#[test]
fn absolute_path() {
let path = write_temp_mkfile("absolute_test.mk", "TARGET = foo\n");
let mut ctx = IncludeContext::new();
let stmts = ctx
.include_file(path.to_str().unwrap(), &PathBuf::from("/unused"))
.unwrap();
assert_eq!(stmts.len(), 1);
}
#[test]
fn include_empty_file() {
let path = write_temp_mkfile("empty.mk", "");
let mut ctx = IncludeContext::new();
let stmts = ctx
.include_file(path.to_str().unwrap(), &std::env::temp_dir())
.unwrap();
assert!(stmts.is_empty());
}
#[test]
fn include_with_rule_and_recipe() {
let path = write_temp_mkfile("recipe_test.mk", "target: prereq\n\techo hello\n");
let mut ctx = IncludeContext::new();
let stmts = ctx
.include_file(path.to_str().unwrap(), &std::env::temp_dir())
.unwrap();
assert_eq!(stmts.len(), 1);
match &stmts[0] {
Stmt::Rule(r) => {
assert_eq!(r.targets, vec!["target"]);
assert_eq!(r.prereqs, vec!["prereq"]);
assert_eq!(r.recipe, Some("echo hello".into()));
}
_ => panic!("expected Rule"),
}
}
#[test]
fn include_with_multiple_statements() {
let path = write_temp_mkfile(
"multi.mk",
"CC = gcc\nCFLAGS = -Wall\n\nprog: main.o\n\t$(CC) -o $target $prereq\n",
);
let mut ctx = IncludeContext::new();
let stmts = ctx
.include_file(path.to_str().unwrap(), &std::env::temp_dir())
.unwrap();
assert_eq!(stmts.len(), 3);
}
#[test]
fn relative_path_resolution() {
let parent_dir = std::env::temp_dir().join("mk_test_parent");
let child_dir = parent_dir.join("sub");
std::fs::create_dir_all(&child_dir).unwrap();
let sub_path = child_dir.join("child.mk");
std::fs::write(&sub_path, "VAR = child_value\n").unwrap();
let mut ctx = IncludeContext::new();
let stmts = ctx.include_file("sub/child.mk", &parent_dir).unwrap();
assert_eq!(stmts.len(), 1);
}
#[test]
fn include_context_default() {
let ctx = IncludeContext::default();
assert!(ctx.chain.is_empty());
}
#[test]
fn circular_include_chain_message() {
let dir = std::env::temp_dir().join("mk_test_chain_msg");
std::fs::create_dir_all(&dir).unwrap();
let a_path = dir.join("a.mk");
let b_path = dir.join("b.mk");
std::fs::write(&a_path, "CC = gcc\n").unwrap();
std::fs::write(&b_path, "CXX = g++\n").unwrap();
let mut ctx = IncludeContext::new();
let canonical_a = a_path.canonicalize().unwrap();
let canonical_b = b_path.canonicalize().unwrap();
ctx.chain.push(canonical_a.clone());
ctx.chain.push(canonical_b.clone());
let result = ctx.include_file(a_path.to_str().unwrap(), &dir);
match result {
Err(IncludeError::CircularInclude { chain }) => {
assert!(chain.contains("a.mk"));
assert!(chain.contains("b.mk"));
assert!(chain.contains(" -> "));
}
other => panic!("expected CircularInclude, got {other:?}"),
}
}
#[test]
fn include_command_simple() {
let mut ctx = IncludeContext::new();
let stmts = ctx
.include_command("echo 'TARGET = value'", &std::env::current_dir().unwrap())
.unwrap();
assert_eq!(stmts.len(), 1);
match &stmts[0] {
Stmt::Assign(a) => {
assert_eq!(a.name, "TARGET");
assert_eq!(a.value, "value");
}
_ => panic!("expected Assign"),
}
}
#[test]
fn include_command_failed() {
let mut ctx = IncludeContext::new();
let result =
ctx.include_command("exit 1", &std::env::current_dir().unwrap());
assert!(matches!(result, Err(IncludeError::CommandFailed { .. })));
}
#[test]
fn include_command_rule_with_recipe() {
let mut ctx = IncludeContext::new();
let stmts = ctx
.include_command(
"printf 'target: prereq\n\techo hello\n'",
&std::env::current_dir().unwrap(),
)
.unwrap();
assert_eq!(stmts.len(), 1);
match &stmts[0] {
Stmt::Rule(r) => {
assert_eq!(r.targets, vec!["target"]);
assert_eq!(r.prereqs, vec!["prereq"]);
assert_eq!(r.recipe, Some("echo hello".into()));
}
_ => panic!("expected Rule"),
}
}
}