#[cfg(feature = "include")]
use crate::input_source::{
IncludeRequest, IncludeResolveError, InputSource, ResolveProblem, ResolvedInclude,
};
#[cfg(feature = "include")]
use encoding_rs_io::DecodeReaderBytesBuilder;
#[cfg(feature = "include")]
use std::fs;
#[cfg(feature = "include")]
use std::io::{self, Read};
#[cfg(feature = "include")]
use std::path::{Component, Path, PathBuf};
#[cfg(feature = "include")]
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum SafeFileReadMode {
#[default]
Reader,
Text,
}
#[cfg(feature = "include")]
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum SymlinkPolicy {
FollowWithinRoot,
#[default]
Reject,
}
#[cfg(feature = "include")]
#[derive(Clone, Debug)]
pub struct SafeFileResolver {
allow_root: PathBuf,
root_base_dir: PathBuf,
root_source_id: Option<String>,
read_mode: SafeFileReadMode,
symlink_policy: SymlinkPolicy,
}
#[cfg(feature = "include")]
impl SafeFileResolver {
pub fn new<P>(allow_root: P) -> io::Result<Self>
where
P: AsRef<Path>,
{
let allow_root = canonicalize_existing_dir(allow_root.as_ref())?;
Ok(Self {
root_base_dir: allow_root.clone(),
allow_root,
root_source_id: None,
read_mode: SafeFileReadMode::Reader,
symlink_policy: SymlinkPolicy::Reject,
})
}
pub fn with_root_base_dir<P>(mut self, root_base_dir: P) -> io::Result<Self>
where
P: AsRef<Path>,
{
let root_base_dir = canonicalize_existing_dir(root_base_dir.as_ref())?;
ensure_inside_root_io(&self.allow_root, &root_base_dir, "root base directory")?;
self.root_base_dir = root_base_dir;
self.root_source_id = None;
Ok(self)
}
pub fn with_root_file<P>(mut self, root_file: P) -> io::Result<Self>
where
P: AsRef<Path>,
{
let root_file = canonicalize_existing_file(root_file.as_ref())?;
ensure_inside_root_io(&self.allow_root, &root_file, "root file")?;
let Some(parent) = root_file.parent() else {
return Err(invalid_input("root file does not have a parent directory"));
};
self.root_base_dir = parent.to_path_buf();
self.root_source_id = Some(path_to_string(&root_file));
Ok(self)
}
pub fn with_read_mode(mut self, read_mode: SafeFileReadMode) -> Self {
self.read_mode = read_mode;
self
}
pub fn with_symlink_policy(mut self, symlink_policy: SymlinkPolicy) -> Self {
self.symlink_policy = symlink_policy;
self
}
pub fn resolve(&self, req: IncludeRequest<'_>) -> Result<ResolvedInclude, IncludeResolveError> {
let (path_spec, fragment) = split_include_spec(req.spec)?;
let spec_path = Path::new(path_spec);
validate_relative_include_spec(spec_path, req.spec)?;
let base_dir = self.base_dir_for_request(&req)?;
if self.symlink_policy == SymlinkPolicy::Reject {
self.reject_symlinks_in_spec_path(&base_dir, spec_path, path_spec)?;
}
let joined = base_dir.join(spec_path);
let canonical_target = fs::canonicalize(&joined).map_err(|e| {
IncludeResolveError::FileInclude(Box::new(ResolveProblem::ResolveFailed {
spec: req.spec.to_string(),
base_dir: base_dir.display().to_string(),
err: e,
}))
})?;
self.ensure_inside_root(&canonical_target, req.spec)?;
let metadata = fs::metadata(&canonical_target)?;
if !metadata.is_file() {
return Err(IncludeResolveError::FileInclude(Box::new(
ResolveProblem::TargetNotRegularFile {
target: canonical_target.display().to_string(),
},
)));
}
if let Some(remaining) = req.size_remaining {
let size = usize::try_from(metadata.len()).unwrap_or(usize::MAX);
if size > remaining {
return Err(IncludeResolveError::SizeLimitExceeded(size, remaining));
}
}
let id = path_to_string(&canonical_target);
if req.from_id.is_none()
&& let Some(root_source_id) = &self.root_source_id
&& root_source_id == &id
{
return Err(IncludeResolveError::FileInclude(Box::new(
ResolveProblem::TargetIsRootFile {
spec: req.spec.to_string(),
},
)));
}
let base_name = display_name(&self.allow_root, &canonical_target);
let name = match fragment {
Some(fragment) => format!("{base_name}#{fragment}"),
None => base_name,
};
let source = match (self.read_mode, fragment) {
(_, Some(fragment)) => InputSource::AnchoredText {
text: read_decoded_file(&canonical_target)?,
anchor: fragment.to_string(),
},
(SafeFileReadMode::Text, None) => {
InputSource::from_string(read_decoded_file(&canonical_target)?)
}
(SafeFileReadMode::Reader, None) => {
InputSource::from_reader(fs::File::open(&canonical_target)?)
}
};
Ok(ResolvedInclude { id, name, source })
}
pub fn into_callback(
self,
) -> impl for<'req> FnMut(IncludeRequest<'req>) -> Result<ResolvedInclude, IncludeResolveError>
{
move |req| self.resolve(req)
}
fn base_dir_for_request(
&self,
req: &IncludeRequest<'_>,
) -> Result<PathBuf, IncludeResolveError> {
let Some(from_id) = req.from_id else {
return Ok(self.root_base_dir.clone());
};
let from_id_path = Path::new(from_id);
if !from_id_path.is_absolute() {
return Err(IncludeResolveError::FileInclude(Box::new(
ResolveProblem::ParentIdNotAbsoluteCanonical {
parent_id: from_id.to_string(),
},
)));
}
let from_path = fs::canonicalize(from_id_path).map_err(|e| {
IncludeResolveError::FileInclude(Box::new(ResolveProblem::ParentResolveFailed {
parent_id: from_id.to_string(),
from_name: req.from_name.to_string(),
err: e,
}))
})?;
self.ensure_inside_root(&from_path, req.spec)?;
let metadata = fs::metadata(&from_path)?;
if !metadata.is_file() {
return Err(IncludeResolveError::FileInclude(Box::new(
ResolveProblem::ParentNotRegularFile {
parent: from_path.display().to_string(),
},
)));
}
let Some(parent) = from_path.parent() else {
return Err(IncludeResolveError::FileInclude(Box::new(
ResolveProblem::ParentHasNoDirectory {
parent: from_path.display().to_string(),
},
)));
};
Ok(parent.to_path_buf())
}
fn ensure_inside_root(
&self,
canonical_path: &Path,
spec: &str,
) -> Result<(), IncludeResolveError> {
if canonical_path.starts_with(&self.allow_root) {
let relative = canonical_path.strip_prefix(&self.allow_root).unwrap();
if relative.components().any(|component| {
component.as_os_str().to_str().is_some_and(|segment| {
segment.starts_with('.') && segment != "." && segment != ".."
})
}) {
return Err(IncludeResolveError::FileInclude(Box::new(
ResolveProblem::HiddenFile {
spec: spec.to_string(),
},
)));
}
Ok(())
} else {
Err(IncludeResolveError::FileInclude(Box::new(
ResolveProblem::ResolvesOutsideRoot {
spec: spec.to_string(),
root: self.allow_root.display().to_string(),
},
)))
}
}
fn reject_symlinks_in_spec_path(
&self,
base_dir: &Path,
spec_path: &Path,
spec_display: &str,
) -> Result<(), IncludeResolveError> {
let mut current = base_dir.to_path_buf();
for component in spec_path.components() {
match component {
Component::CurDir => {}
Component::ParentDir => {
current.pop();
if !current.starts_with(&self.allow_root) {
return Err(IncludeResolveError::FileInclude(Box::new(
ResolveProblem::ResolvesOutsideRoot {
spec: spec_display.to_string(),
root: self.allow_root.display().to_string(),
},
)));
}
}
Component::Normal(part) => {
current.push(part);
match fs::symlink_metadata(¤t) {
Ok(meta) if meta.file_type().is_symlink() => {
return Err(IncludeResolveError::FileInclude(Box::new(
ResolveProblem::TraversesSymlink {
spec: spec_display.to_string(),
},
)));
}
Ok(_) => {}
Err(err) if err.kind() == io::ErrorKind::NotFound => {}
Err(err) => return Err(err.into()),
}
}
Component::RootDir | Component::Prefix(_) => {
return Err(IncludeResolveError::FileInclude(Box::new(
ResolveProblem::AbsolutePathNotAllowed {
spec: spec_display.to_string(),
},
)));
}
}
}
Ok(())
}
}
#[cfg(feature = "include")]
fn read_decoded_file(path: &Path) -> Result<String, IncludeResolveError> {
let file = fs::File::open(path)?;
let mut decoder = DecodeReaderBytesBuilder::new()
.encoding(None)
.bom_override(true)
.build(file);
let mut text = String::new();
decoder.read_to_string(&mut text)?;
Ok(text)
}
#[cfg(feature = "include")]
fn split_include_spec(raw_spec: &str) -> Result<(&str, Option<&str>), IncludeResolveError> {
let Some((path, fragment)) = raw_spec.split_once('#') else {
return Ok((raw_spec, None));
};
if path.is_empty() {
return Err(IncludeResolveError::FileInclude(Box::new(
ResolveProblem::EmptyPath,
)));
}
if fragment.is_empty() {
return Err(IncludeResolveError::FileInclude(Box::new(
ResolveProblem::EmptyFragment,
)));
}
if fragment.contains('#') {
return Err(IncludeResolveError::FileInclude(Box::new(
ResolveProblem::FragmentContainsHash {
spec: raw_spec.to_string(),
},
)));
}
Ok((path, Some(fragment)))
}
#[cfg(feature = "include")]
fn canonicalize_existing_dir(path: &Path) -> io::Result<PathBuf> {
let canonical = fs::canonicalize(path)?;
let metadata = fs::metadata(&canonical)?;
if metadata.is_dir() {
Ok(canonical)
} else {
Err(invalid_input(format!(
"expected a directory, got '{}'",
canonical.display()
)))
}
}
#[cfg(feature = "include")]
fn canonicalize_existing_file(path: &Path) -> io::Result<PathBuf> {
let canonical = fs::canonicalize(path)?;
let metadata = fs::metadata(&canonical)?;
if metadata.is_file() {
Ok(canonical)
} else {
Err(invalid_input(format!(
"expected a file, got '{}'",
canonical.display()
)))
}
}
#[cfg(feature = "include")]
fn validate_relative_include_spec(
spec_path: &Path,
raw_spec: &str,
) -> Result<(), IncludeResolveError> {
if raw_spec.is_empty() {
return Err(IncludeResolveError::FileInclude(Box::new(
ResolveProblem::EmptyPath,
)));
}
if spec_path.components().any(|component| {
component
.as_os_str()
.to_str()
.is_some_and(|segment| segment.starts_with('.') && segment != "." && segment != "..")
}) {
return Err(IncludeResolveError::FileInclude(Box::new(
ResolveProblem::HiddenFile {
spec: raw_spec.to_string(),
},
)));
}
if let Some(filename) = spec_path.file_name().and_then(|n| n.to_str())
&& !filename.ends_with(".yml")
&& !filename.ends_with(".yaml")
{
return Err(IncludeResolveError::FileInclude(Box::new(
ResolveProblem::InvalidExtension {
spec: raw_spec.to_string(),
},
)));
}
if spec_path.is_absolute() {
return Err(IncludeResolveError::FileInclude(Box::new(
ResolveProblem::AbsolutePathNotAllowed {
spec: raw_spec.to_string(),
},
)));
}
if spec_path
.components()
.any(|component| matches!(component, Component::RootDir | Component::Prefix(_)))
{
return Err(IncludeResolveError::FileInclude(Box::new(
ResolveProblem::AbsolutePathNotAllowed {
spec: raw_spec.to_string(),
},
)));
}
Ok(())
}
#[cfg(feature = "include")]
fn display_name(allow_root: &Path, canonical_target: &Path) -> String {
canonical_target
.strip_prefix(allow_root)
.ok()
.and_then(|relative| {
if relative.as_os_str().is_empty() {
None
} else {
Some(relative.display().to_string())
}
})
.unwrap_or_else(|| canonical_target.display().to_string())
}
#[cfg(feature = "include")]
fn ensure_inside_root_io(allow_root: &Path, canonical_path: &Path, what: &str) -> io::Result<()> {
if canonical_path.starts_with(allow_root) {
Ok(())
} else {
Err(io::Error::new(
io::ErrorKind::PermissionDenied,
format!(
"{} '{}' is outside the configured root '{}'",
what,
canonical_path.display(),
allow_root.display()
),
))
}
}
#[cfg(feature = "include")]
fn invalid_input(message: impl Into<String>) -> io::Error {
io::Error::new(io::ErrorKind::InvalidInput, message.into())
}
#[cfg(feature = "include")]
fn path_to_string(path: &Path) -> String {
path.as_os_str().to_string_lossy().into_owned()
}
#[cfg(all(test, feature = "include_fs", not(miri), not(target_os = "wasi")))]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
#[cfg_attr(miri, ignore)]
fn safe_file_resolver_streams_regular_files_by_default() {
let dir = tempdir().unwrap();
let root = dir.path();
std::fs::write(root.join("child.yaml"), "value: 1\n").unwrap();
let resolver = SafeFileResolver::new(root).unwrap();
let resolved = resolver
.resolve(IncludeRequest {
spec: "child.yaml",
from_name: "<input>",
from_id: None,
stack: vec!["<input>".to_string()],
size_remaining: None,
location: crate::Location::UNKNOWN,
})
.unwrap();
assert!(matches!(resolved.source, InputSource::Reader(_)));
}
#[test]
#[cfg_attr(miri, ignore)]
fn safe_file_resolver_keeps_fragment_includes_text_backed() {
let dir = tempdir().unwrap();
let root = dir.path();
std::fs::write(root.join("child.yaml"), "defaults: &defaults\n value: 1\n").unwrap();
let resolver = SafeFileResolver::new(root).unwrap();
let resolved = resolver
.resolve(IncludeRequest {
spec: "child.yaml#defaults",
from_name: "<input>",
from_id: None,
stack: vec!["<input>".to_string()],
size_remaining: None,
location: crate::Location::UNKNOWN,
})
.unwrap();
assert!(matches!(
resolved.source,
InputSource::AnchoredText { ref anchor, .. } if anchor == "defaults"
));
}
#[test]
#[cfg_attr(miri, ignore)]
fn safe_file_resolver_rejects_files_larger_than_remaining_quota() {
let dir = tempdir().unwrap();
let root = dir.path();
std::fs::write(root.join("child.yaml"), "value: 123\n").unwrap();
let resolver = SafeFileResolver::new(root).unwrap();
let error = resolver
.resolve(IncludeRequest {
spec: "child.yaml",
from_name: "<input>",
from_id: None,
stack: vec!["<input>".to_string()],
size_remaining: Some(4),
location: crate::Location::UNKNOWN,
})
.unwrap_err();
assert!(matches!(
error,
IncludeResolveError::SizeLimitExceeded(size, 4) if size > 4
));
}
}