use std::fs;
use std::path::Path;
use crate::constants::TMP_SUFFIX;
use rustix::fd::OwnedFd;
use rustix::fs::{openat, renameat, symlinkat, unlinkat, AtFlags, Mode, OFlags, CWD};
use rustix::io::Errno;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Instant;
fn errno_to_io(e: Errno) -> std::io::Error {
std::io::Error::from_raw_os_error(e.raw_os_error())
}
static NEXT_TMP_COUNTER: AtomicU64 = AtomicU64::new(0);
pub fn open_dir_nofollow(dir: &Path) -> std::io::Result<OwnedFd> {
use std::os::unix::ffi::OsStrExt;
let c = std::ffi::CString::new(dir.as_os_str().as_bytes())
.map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid path"))?;
openat(
CWD,
c.as_c_str(),
OFlags::RDONLY | OFlags::DIRECTORY | OFlags::CLOEXEC | OFlags::NOFOLLOW,
Mode::empty(),
)
.map_err(errno_to_io)
}
pub fn fsync_parent_dir(path: &Path) -> std::io::Result<()> {
if let Some(parent) = path.parent() {
let dir = fs::File::open(parent)?;
dir.sync_all()?;
}
Ok(())
}
fn fsync_dirfd(dirfd: &OwnedFd) -> std::io::Result<()> {
rustix::fs::fsync(dirfd).map_err(errno_to_io)
}
pub fn atomic_symlink_swap(
source: &Path,
target: &Path,
allow_degraded: bool,
force_exdev: Option<bool>,
) -> std::io::Result<(bool, u64)> {
use std::os::unix::ffi::OsStrExt;
let parent = target.parent().unwrap_or_else(|| Path::new("."));
let fname = target.file_name().ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"target must not end with a slash",
)
})?;
let pid = std::process::id();
let ctr = NEXT_TMP_COUNTER.fetch_add(1, Ordering::Relaxed);
let tmp_name = format!(".{}.{}.{}{}", fname.to_string_lossy(), pid, ctr, TMP_SUFFIX);
let dirfd = open_dir_nofollow(parent)?;
let tmp_c = std::ffi::CString::new(tmp_name.as_str()).map_err(|_| {
std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid tmp cstring")
})?;
let new_c = std::ffi::CString::new(fname.as_bytes()).map_err(|_| {
std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid target name")
})?;
let src_c = std::ffi::CString::new(source.as_os_str().as_bytes()).map_err(|_| {
std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid source path")
})?;
match unlinkat(&dirfd, tmp_c.as_c_str(), AtFlags::empty()) {
Ok(()) | Err(Errno::NOENT) => {}
Err(e) => return Err(errno_to_io(e)),
}
symlinkat(src_c.as_c_str(), &dirfd, tmp_c.as_c_str()).map_err(errno_to_io)?;
let rename_res = renameat(&dirfd, tmp_c.as_c_str(), &dirfd, new_c.as_c_str());
let allow_env_overrides = std::env::var_os("SWITCHYARD_TEST_ALLOW_ENV_OVERRIDES")
== Some(std::ffi::OsString::from("1"));
let inject_exdev = match force_exdev {
Some(b) => b,
None => {
allow_env_overrides
&& std::env::var_os("SWITCHYARD_FORCE_EXDEV") == Some(std::ffi::OsString::from("1"))
}
};
let rename_res = if inject_exdev {
match rename_res {
Ok(()) => Err(Errno::XDEV),
Err(e) => Err(e),
}
} else {
rename_res
};
match rename_res {
Ok(()) => {
let t_fsync = Instant::now();
let res = fsync_dirfd(&dirfd);
let fsync_ms = u64::try_from(t_fsync.elapsed().as_millis()).unwrap_or(u64::MAX);
if let Err(e) = res {
let _ = e;
}
Ok((false, fsync_ms))
}
Err(e) if e == Errno::XDEV && allow_degraded => {
match unlinkat(&dirfd, new_c.as_c_str(), AtFlags::empty()) {
Ok(()) | Err(Errno::NOENT) => {}
Err(e) => {
let _ = unlinkat(&dirfd, tmp_c.as_c_str(), AtFlags::empty());
return Err(errno_to_io(e));
}
}
if let Err(e) =
symlinkat(src_c.as_c_str(), &dirfd, new_c.as_c_str()).map_err(errno_to_io)
{
let _ = unlinkat(&dirfd, tmp_c.as_c_str(), AtFlags::empty());
return Err(e);
}
let _ = unlinkat(&dirfd, tmp_c.as_c_str(), AtFlags::empty());
let t_fsync = Instant::now();
let res = fsync_dirfd(&dirfd);
let fsync_ms = u64::try_from(t_fsync.elapsed().as_millis()).unwrap_or(u64::MAX);
if let Err(e) = res {
let _ = e;
}
Ok((true, fsync_ms))
}
Err(e) => {
let _ = unlinkat(&dirfd, tmp_c.as_c_str(), AtFlags::empty());
Err(errno_to_io(e))
}
}
}