use std::{
fs::{File, OpenOptions},
io::{self, Write},
path::{Path, PathBuf},
};
use crate::error::SnapshotError;
pub const TEMP_SUFFIX: &str = ".tmp";
pub fn check_same_filesystem(dest: &Path, temp: &Path) -> Result<(), SnapshotError> {
#[cfg(unix)]
{
use std::os::unix::fs::MetadataExt;
let dest_dir = dest.parent().unwrap_or_else(|| Path::new("."));
let temp_dir = temp.parent().unwrap_or_else(|| Path::new("."));
let dest_dev = std::fs::metadata(dest_dir)
.map_err(SnapshotError::Io)?
.dev();
let temp_dev = std::fs::metadata(temp_dir)
.map_err(SnapshotError::Io)?
.dev();
if dest_dev != temp_dev {
return Err(SnapshotError::AtomicCommitCrossFs {
dest: dest.to_path_buf(),
temp_dir: temp_dir.to_path_buf(),
});
}
Ok(())
}
#[cfg(not(unix))]
{
let _ = (dest, temp);
Ok(())
}
}
#[must_use]
pub fn derive_temp_path(dest: &Path) -> PathBuf {
let mut s = dest.as_os_str().to_owned();
s.push(TEMP_SUFFIX);
PathBuf::from(s)
}
#[derive(Debug)]
pub struct UnlinkOnDrop {
path: Option<PathBuf>,
}
impl UnlinkOnDrop {
#[must_use]
pub fn new(path: PathBuf) -> Self {
Self { path: Some(path) }
}
pub fn disarm(&mut self) {
self.path = None;
}
#[must_use]
pub fn path(&self) -> Option<&Path> {
self.path.as_deref()
}
}
impl Drop for UnlinkOnDrop {
fn drop(&mut self) {
if let Some(path) = self.path.take() {
let _ = std::fs::remove_file(&path);
}
}
}
#[derive(Debug)]
pub struct AtomicWriter {
file: File,
temp_path: PathBuf,
dest_path: PathBuf,
guard: UnlinkOnDrop,
}
impl AtomicWriter {
pub fn open(dest: &Path) -> Result<Self, SnapshotError> {
let temp = derive_temp_path(dest);
check_same_filesystem(dest, &temp)?;
let file = OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open(&temp)?;
let guard = UnlinkOnDrop::new(temp.clone());
Ok(Self {
file,
temp_path: temp,
dest_path: dest.to_path_buf(),
guard,
})
}
#[must_use]
pub fn temp_path(&self) -> &Path {
&self.temp_path
}
#[must_use]
pub fn dest_path(&self) -> &Path {
&self.dest_path
}
pub fn file_mut(&mut self) -> &mut File {
&mut self.file
}
pub fn commit(mut self) -> Result<(), SnapshotError> {
self.file.flush()?;
self.file.sync_all()?;
drop(self.file);
match std::fs::rename(&self.temp_path, &self.dest_path) {
Ok(()) => {
self.guard.disarm();
Ok(())
}
Err(e) => Err(SnapshotError::AtomicCommitFailed(e)),
}
}
}
impl Write for AtomicWriter {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.file.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.file.flush()
}
}
#[cfg(test)]
mod tests {
use tempfile::TempDir;
use super::*;
fn dest_in(dir: &Path, name: &str) -> PathBuf {
dir.join(name)
}
#[test]
fn test_should_derive_temp_path_with_tmp_suffix() {
let p = derive_temp_path(Path::new("/tmp/x.snap"));
assert_eq!(p, PathBuf::from("/tmp/x.snap.tmp"));
}
#[test]
fn test_should_commit_temp_file_onto_destination() {
let dir = TempDir::new().unwrap();
let dest = dest_in(dir.path(), "x.snap");
let mut w = AtomicWriter::open(&dest).unwrap();
w.write_all(b"hello world").unwrap();
w.commit().unwrap();
let s = std::fs::read_to_string(&dest).unwrap();
assert_eq!(s, "hello world");
assert!(!derive_temp_path(&dest).exists());
}
#[test]
fn test_should_unlink_temp_when_writer_dropped_without_commit() {
let dir = TempDir::new().unwrap();
let dest = dest_in(dir.path(), "x.snap");
{
let mut w = AtomicWriter::open(&dest).unwrap();
w.write_all(b"partial").unwrap();
}
assert!(!dest.exists(), "no commit happened");
assert!(!derive_temp_path(&dest).exists(), "temp file leaked");
}
#[test]
fn test_should_leave_existing_destination_alone_when_writer_dropped() {
let dir = TempDir::new().unwrap();
let dest = dest_in(dir.path(), "x.snap");
std::fs::write(&dest, b"prior good").unwrap();
{
let mut w = AtomicWriter::open(&dest).unwrap();
w.write_all(b"new bad").unwrap();
}
let s = std::fs::read_to_string(&dest).unwrap();
assert_eq!(s, "prior good", "previous good pair clobbered");
}
#[test]
fn test_should_handle_back_to_back_atomic_writes() {
let dir = TempDir::new().unwrap();
let dest = dest_in(dir.path(), "x.snap");
for i in 0..3 {
let mut w = AtomicWriter::open(&dest).unwrap();
let payload = format!("snapshot {i}");
w.write_all(payload.as_bytes()).unwrap();
w.commit().unwrap();
assert_eq!(std::fs::read_to_string(&dest).unwrap(), payload);
}
}
#[test]
fn test_should_reject_cross_filesystem_path() {
let dir = TempDir::new().unwrap();
let nonexistent = dir.path().join("does-not-exist").join("inner.snap");
let res = AtomicWriter::open(&nonexistent);
assert!(res.is_err());
}
#[test]
fn test_should_disarm_guard_to_avoid_post_commit_unlink() {
let dir = TempDir::new().unwrap();
let dest = dest_in(dir.path(), "x.snap");
let mut g = UnlinkOnDrop::new(dest.clone());
std::fs::write(&dest, b"persistent").unwrap();
g.disarm();
drop(g);
assert!(dest.exists());
}
}