#[cfg(feature = "include_fs")]
use crate::error::Error;
use crate::error::Result;
use crate::prelude::*;
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct IncludeRequest<'a> {
pub spec: &'a str,
pub from_id: usize,
pub depth: usize,
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct InputSource {
pub name: String,
pub bytes: String,
}
impl InputSource {
#[must_use]
pub fn new(name: impl Into<String>, bytes: impl Into<String>) -> Self {
Self {
name: name.into(),
bytes: bytes.into(),
}
}
}
#[derive(Clone)]
pub struct IncludeResolver(Arc<dyn Fn(IncludeRequest<'_>) -> Result<InputSource> + Send + Sync>);
impl IncludeResolver {
#[must_use]
pub fn new<F>(f: F) -> Self
where
F: Fn(IncludeRequest<'_>) -> Result<InputSource> + Send + Sync + 'static,
{
Self(Arc::new(f))
}
pub fn resolve(&self, req: IncludeRequest<'_>) -> Result<InputSource> {
(self.0)(req)
}
}
impl core::fmt::Debug for IncludeResolver {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("IncludeResolver")
.field("ptr", &Arc::as_ptr(&self.0))
.finish()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
#[non_exhaustive]
pub enum SymlinkPolicy {
#[default]
FollowWithinRoot,
Reject,
}
#[cfg(feature = "include_fs")]
#[cfg_attr(docsrs, doc(cfg(feature = "include_fs")))]
#[derive(Debug, Clone)]
pub struct SafeFileResolver {
root: std::path::PathBuf,
symlink_policy: SymlinkPolicy,
}
#[cfg(feature = "include_fs")]
impl SafeFileResolver {
#[must_use]
pub fn new(root: impl Into<std::path::PathBuf>) -> Self {
Self {
root: root.into(),
symlink_policy: SymlinkPolicy::default(),
}
}
#[must_use]
pub fn symlink_policy(mut self, policy: SymlinkPolicy) -> Self {
self.symlink_policy = policy;
self
}
#[must_use]
pub fn into_resolver(self) -> IncludeResolver {
let this = self.clone();
IncludeResolver::new(move |req: IncludeRequest<'_>| this.resolve(req))
}
fn resolve(&self, req: IncludeRequest<'_>) -> Result<InputSource> {
use std::fs;
let (path_part, _frag) = split_fragment(req.spec);
let candidate = self.root.join(path_part);
let canon_root = fs::canonicalize(&self.root).map_err(|e| {
Error::Custom(format!("include resolver: cannot canonicalise root: {e}"))
})?;
let canon = fs::canonicalize(&candidate).map_err(|e| {
Error::Custom(format!(
"include resolver: cannot canonicalise `{}`: {e}",
candidate.display()
))
})?;
if !canon.starts_with(&canon_root) {
return Err(Error::Custom(format!(
"include resolver: `{}` escapes sandbox root `{}`",
canon.display(),
canon_root.display()
)));
}
if self.symlink_policy == SymlinkPolicy::Reject {
let meta = fs::symlink_metadata(&candidate).map_err(|e| {
Error::Custom(format!(
"include resolver: cannot stat `{}`: {e}",
candidate.display()
))
})?;
if meta.file_type().is_symlink() {
return Err(Error::Custom(format!(
"include resolver: symlink rejected by policy: `{}`",
candidate.display()
)));
}
}
let bytes = fs::read_to_string(&canon).map_err(|e| {
Error::Custom(format!(
"include resolver: cannot read `{}`: {e}",
canon.display()
))
})?;
Ok(InputSource::new(canon.display().to_string(), bytes))
}
}
#[must_use]
pub fn split_fragment(spec: &str) -> (&str, Option<&str>) {
match spec.split_once('#') {
Some((p, f)) => (p, Some(f)),
None => (spec, None),
}
}