use std::{
ffi::{OsStr, OsString},
fs::Metadata,
os::{fd::OwnedFd, unix::ffi::OsStrExt},
path::PathBuf,
pin::Pin,
};
use rand::{Rng, distr::Alphanumeric, rng};
use rustix::{
fd::AsFd,
fs::{AtFlags, Mode, OFlags},
io::Errno,
};
use tokio::{fs::File, io::AsyncWrite};
use crate::{ErrorKind, Result};
pub struct AtomicFile {
tempfile: File,
dir: OwnedFd,
temp_name: OsString,
final_name: OsString,
}
impl AtomicFile {
pub fn new(path: impl Into<PathBuf>) -> Result<AtomicFile> {
let path = path.into();
let dirpath = path
.parent()
.ok_or(ErrorKind::InvalidInput.error("path requires a parent"))?;
let final_name = path
.file_name()
.ok_or(ErrorKind::InvalidInput.error("path requires a filename"))?
.to_os_string();
let dir = if dirpath.as_os_str().is_empty() {
rustix::fs::open(".", OFlags::DIRECTORY | OFlags::CLOEXEC, Mode::empty())
} else {
rustix::fs::open(dirpath, OFlags::DIRECTORY | OFlags::CLOEXEC, Mode::empty())
}
.map_err(|e| ErrorKind::Io.error(e))?;
let temp_name = {
let mut rng = rng();
let mut buf = *b"123456.tmp";
for c in buf.iter_mut().take(6) {
*c = rng.sample(Alphanumeric);
}
OsStr::from_bytes(&buf).to_os_string()
};
let tempfile = rustix::fs::openat(
dir.as_fd(),
&temp_name,
OFlags::WRONLY | OFlags::CREATE | OFlags::EXCL | OFlags::CLOEXEC,
Mode::from(0o600),
)
.map(|fd| File::from(std::fs::File::from(fd)))
.map_err(|e| ErrorKind::Io.error(e))?;
Ok(AtomicFile {
tempfile,
dir,
temp_name,
final_name,
})
}
pub async fn commit(self) -> Result<Metadata> {
let meta = self
.tempfile
.metadata()
.await
.map_err(|e| ErrorKind::Io.error(e))?;
rustix::fs::renameat(&self.dir, self.temp_name, &self.dir, self.final_name)
.map_err(|e| ErrorKind::Io.error(e))?;
Ok(meta)
}
pub async fn commit_new(self) -> Result<Metadata> {
let meta = self
.tempfile
.metadata()
.await
.map_err(|e| ErrorKind::Io.error(e))?;
rustix::fs::linkat(
&self.dir,
&self.temp_name,
&self.dir,
&self.final_name,
AtFlags::empty(),
)
.map_err(|e| {
if e == Errno::EXIST {
ErrorKind::Exists.error(e)
} else {
ErrorKind::Io.error(e)
}
})?;
rustix::fs::unlinkat(self.dir, self.temp_name, AtFlags::empty())
.map_err(|e| ErrorKind::Io.error(e))?;
Ok(meta)
}
}
impl AsyncWrite for AtomicFile {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
let inner = Pin::new(&mut Pin::get_mut(self).tempfile);
AsyncWrite::poll_write(inner, cx, buf)
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let inner = Pin::new(&mut Pin::get_mut(self).tempfile);
AsyncWrite::poll_flush(inner, cx)
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let inner = Pin::new(&mut Pin::get_mut(self).tempfile);
AsyncWrite::poll_shutdown(inner, cx)
}
}
#[cfg(test)]
mod tests {
use std::os::unix::fs::MetadataExt;
use tempfile::tempdir;
use tokio::io::AsyncWriteExt;
use super::AtomicFile;
#[tokio::test]
async fn metadata_preserved_after_commit() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.txt");
let mut file = AtomicFile::new(&path).unwrap();
file.write_all(b"hello world").await.unwrap();
let meta_before = file.commit().await.unwrap();
let meta_after = tokio::fs::metadata(&path).await.unwrap();
assert_eq!(meta_before.ino(), meta_after.ino());
assert_eq!(meta_before.mtime(), meta_after.mtime());
}
#[tokio::test]
async fn metadata_preserved_after_commit_new() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.txt");
let mut file = AtomicFile::new(&path).unwrap();
file.write_all(b"hello world").await.unwrap();
let meta_before = file.commit_new().await.unwrap();
let meta_after = tokio::fs::metadata(&path).await.unwrap();
assert_eq!(meta_before.ino(), meta_after.ino());
assert_eq!(meta_before.mtime(), meta_after.mtime());
}
}