use std::path::{Component, Path, PathBuf};
use crate::diagnostics::{Diagnostic, DiagnosticCode};
use crate::error::{Error, Result};
use crate::fs::atomic_write::write_atomic;
use crate::LinkMode;
pub fn safe_relative_path(name: &str) -> Result<PathBuf> {
let reject = |msg: &str| {
Err(Error::from(Diagnostic::error(
DiagnosticCode::OutputPathTraversal,
format!("unsafe zone/link name {name:?}: {msg}"),
Path::new("<output>"),
0,
)))
};
if name.is_empty() {
return reject("name is empty");
}
if name.contains('\0') {
return reject("name contains a NUL byte");
}
if name.starts_with('/') {
return reject("name is absolute");
}
let mut path = PathBuf::new();
for comp in name.split('/') {
if comp.is_empty() {
return reject("name has an empty path component");
}
if comp == "." || comp == ".." {
return reject("name contains a '.' or '..' component");
}
if comp.starts_with('-') {
return reject("a path component starts with '-'");
}
path.push(comp);
}
Ok(path)
}
pub fn is_contained(root: &Path, candidate: &Path) -> bool {
fn normalise(p: &Path) -> Vec<Component<'_>> {
let mut out: Vec<Component<'_>> = Vec::new();
for c in p.components() {
match c {
Component::CurDir => {}
Component::ParentDir => {
if matches!(out.last(), Some(Component::Normal(_))) {
out.pop();
} else {
out.push(c);
}
}
other => out.push(other),
}
}
out
}
let root_n = normalise(root);
let cand_n = normalise(candidate);
cand_n.len() >= root_n.len() && cand_n[..root_n.len()] == root_n[..]
}
pub fn write_zone_file(
root: &Path,
name: &str,
bytes: &[u8],
overwrite: bool,
durable: bool,
) -> Result<PathBuf> {
let rel = safe_relative_path(name)?;
let target = root.join(&rel);
if !is_contained(root, &target) {
return Err(Error::from(Diagnostic::error(
DiagnosticCode::OutputPathTraversal,
format!("resolved path for {name:?} escapes the output root"),
Path::new("<output>"),
0,
)));
}
if let Some(parent) = target.parent() {
std::fs::create_dir_all(parent).map_err(|e| Error::io(parent, e))?;
}
write_atomic(&target, bytes, overwrite, durable)?;
Ok(target)
}
pub fn write_link(
root: &Path,
link_name: &str,
target_name: &str,
mode: LinkMode,
overwrite: bool,
durable: bool,
) -> Result<PathBuf> {
let link_rel = safe_relative_path(link_name)?;
let target_rel = safe_relative_path(target_name)?;
let link_path = root.join(&link_rel);
let target_path = root.join(&target_rel);
if !is_contained(root, &link_path) {
return Err(Error::from(Diagnostic::error(
DiagnosticCode::OutputPathTraversal,
format!("link path for {link_name:?} escapes the output root"),
Path::new("<output>"),
0,
)));
}
let parent = link_path.parent().map(|p| p.to_path_buf());
if let Some(parent) = &parent {
std::fs::create_dir_all(parent).map_err(|e| Error::io(parent, e))?;
}
match mode {
LinkMode::Copy => {
let bytes = std::fs::read(&target_path).map_err(|e| Error::io(&target_path, e))?;
write_atomic(&link_path, &bytes, overwrite, durable)?;
}
LinkMode::Symlink => {
let rel_target = relative_link_target(&link_rel, &target_rel);
if overwrite {
let dir = parent.as_deref().unwrap_or(root);
let tmp = symlink_temp_path(dir, &link_rel);
symlink(&rel_target, &tmp)?;
if let Err(e) = std::fs::rename(&tmp, &link_path) {
let _ = std::fs::remove_file(&tmp);
return Err(Error::io(&link_path, e));
}
} else {
symlink_exclusive(&rel_target, &link_path)?;
}
if durable {
if let Some(parent) = &parent {
crate::fs::atomic_write::fsync_dir(parent)?;
}
}
}
}
Ok(link_path)
}
fn symlink_temp_path(dir: &Path, link_rel: &Path) -> PathBuf {
use std::sync::atomic::{AtomicU64, Ordering};
static SEQ: AtomicU64 = AtomicU64::new(0);
let seq = SEQ.fetch_add(1, Ordering::Relaxed);
let base = link_rel
.file_name()
.and_then(|s| s.to_str())
.unwrap_or("link");
dir.join(format!(".{base}.symlink.tmp.{}.{seq}", std::process::id()))
}
fn relative_link_target(link_rel: &Path, target_rel: &Path) -> PathBuf {
let up = link_rel.components().count().saturating_sub(1);
let mut out = PathBuf::new();
for _ in 0..up {
out.push("..");
}
out.push(target_rel);
out
}
fn symlink(target: &Path, link: &Path) -> Result<()> {
#[cfg(unix)]
{
std::os::unix::fs::symlink(target, link).map_err(|e| Error::io(link, e))
}
#[cfg(not(unix))]
{
let _ = target;
Err(Error::config(format!(
"symlink mode is only supported on Unix; cannot create {}",
link.display()
)))
}
}
fn symlink_exclusive(target: &Path, link: &Path) -> Result<()> {
#[cfg(unix)]
{
match std::os::unix::fs::symlink(target, link) {
Ok(()) => Ok(()),
Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => Err(Error::config(format!(
"{} already exists (use --force to overwrite)",
link.display()
))),
Err(e) => Err(Error::io(link, e)),
}
}
#[cfg(not(unix))]
{
let _ = target;
Err(Error::config(format!(
"symlink mode is only supported on Unix; cannot create {}",
link.display()
)))
}
}
pub fn set_file_mode(path: &Path, mode: u32) -> Result<()> {
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
std::fs::set_permissions(path, std::fs::Permissions::from_mode(mode))
.map_err(|e| Error::io(path, e))
}
#[cfg(not(unix))]
{
let _ = (path, mode);
Err(Error::config(
"--mode (file permission bits) is only supported on Unix platforms",
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn durable_write_succeeds_and_round_trips() {
let dir = tempfile::tempdir().unwrap();
let p = write_zone_file(dir.path(), "Zone/D", b"hello", false, true).unwrap();
assert_eq!(std::fs::read(&p).unwrap(), b"hello");
}
#[cfg(unix)]
#[test]
fn symlink_exclusive_create_rejects_preexisting_without_following() {
let dir = tempfile::tempdir().unwrap();
let root = dir.path();
write_zone_file(root, "B", b"zone-b", false, true).unwrap();
write_link(root, "A", "B", LinkMode::Symlink, false, true).unwrap();
let err = write_link(root, "A", "B", LinkMode::Symlink, false, true).unwrap_err();
assert!(
err.to_string().contains("already exists"),
"expected an exclusive-create rejection, got: {err}"
);
}
#[cfg(unix)]
#[test]
fn symlink_force_overwrite_replaces_atomically() {
let dir = tempfile::tempdir().unwrap();
let root = dir.path();
write_zone_file(root, "B", b"zone-b", false, true).unwrap();
write_zone_file(root, "C", b"zone-c", false, true).unwrap();
write_link(root, "A", "B", LinkMode::Symlink, false, true).unwrap();
write_link(root, "A", "C", LinkMode::Symlink, true, true).unwrap();
let dest = std::fs::read_link(root.join("A")).unwrap();
assert_eq!(
dest,
PathBuf::from("C"),
"force-overwrite must retarget the link"
);
assert_eq!(std::fs::read(root.join("A")).unwrap(), b"zone-c");
}
#[test]
fn accepts_normal_names() {
assert_eq!(
safe_relative_path("Europe/London").unwrap(),
PathBuf::from("Europe/London")
);
assert_eq!(safe_relative_path("UTC").unwrap(), PathBuf::from("UTC"));
}
#[test]
fn rejects_traversal_and_tricks() {
for bad in [
"../etc/passwd",
"/abs/path",
"a/../../b",
"a//b",
"-flag/x",
"..",
".",
"",
] {
assert!(safe_relative_path(bad).is_err(), "should reject {bad:?}");
}
}
#[test]
fn containment_check() {
let root = Path::new("/out/zone");
assert!(is_contained(root, Path::new("/out/zone/Europe/London")));
assert!(!is_contained(root, Path::new("/out/zone/../escape")));
assert!(!is_contained(root, Path::new("/elsewhere")));
}
}