extern crate proc_macro;
use std::path::{Path, PathBuf};
use proc_macro::TokenStream;
use quote::quote;
use syn::{Error, LitStr, parse_macro_input};
#[proc_macro]
pub fn include_absolute_path(input: TokenStream) -> TokenStream {
let lit_str = parse_macro_input!(input as LitStr);
let path = lit_str.value();
let span = lit_str.span();
let caller_file_str = proc_macro::Span::call_site()
.local_file()
.unwrap_or_else(|| {
panic!(
"Failed to get the source file location. \
This should not happen on stable Rust."
)
});
let caller_file = Path::new(&caller_file_str);
let expanded_path = match shellexpand::env(&path) {
Ok(expanded) => expanded,
Err(e) => panic!(
"Failed to expand environment variable in path '{path}': {e}. \
Make sure the environment variable exists and is valid."
),
};
let path_buf = PathBuf::from(expanded_path.as_ref());
if contains_suspicious_patterns(&path_buf) {
return Error::new(
span,
format!(
"Path '{path}' contains suspicious traversal patterns. \
For security reasons, paths with excessive '..' segments are not allowed."
),
)
.to_compile_error()
.into();
}
let raw_path = if path_buf.is_absolute() {
path_buf
} else {
let parent = caller_file.parent().unwrap_or_else(|| {
panic!(
"Failed to get parent directory of the source file '{}'. \
The file appears to be in the root directory.",
caller_file.display()
)
});
parent.join(&path_buf)
};
let absolute_path = match raw_path.canonicalize() {
Ok(path) => path,
Err(e) => {
let cwd = std::env::current_dir()
.map(|p| p.display().to_string())
.unwrap_or_else(|_| "<unknown>".to_string());
panic!(
"Failed to resolve path '{}': {e}. \
Make sure the file or directory exists and is accessible. \
Current working directory: {cwd}",
raw_path.display()
)
}
};
let absolute_path_str = absolute_path.to_str().unwrap_or_else(|| {
panic!(
"Path '{}' contains invalid UTF-8 characters. \
This is common on systems with non-UTF-8 file paths. \
Consider using ASCII-only paths.",
absolute_path.display()
)
});
let expanded = quote! {
#absolute_path_str
};
TokenStream::from(expanded)
}
fn contains_suspicious_patterns(path: &Path) -> bool {
let mut up_count = 0;
let mut total_components = 0;
for component in path.components() {
total_components += 1;
if matches!(component, std::path::Component::ParentDir) {
up_count += 1;
}
}
up_count > 3 || (total_components > 0 && up_count > total_components / 2)
}