#![allow(clippy::unwrap_used)]
use core::fmt::Debug;
use std::{
fs::{self, File},
io::{self, Write},
path::{Path, PathBuf},
sync::Arc,
};
use parking_lot::Mutex;
use serde::{Serialize, de::DeserializeOwned};
use serde_json::{from_reader, to_vec_pretty};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum PersistError {
#[error(transparent)]
Io(#[from] io::Error),
#[error(transparent)]
Json(#[from] serde_json::Error),
}
type Result<T> = std::result::Result<T, PersistError>;
fn tmp_path(path: &Path) -> PathBuf {
path.with_file_name(format!(
"{}.tmp",
path.file_name()
.and_then(|n| n.to_str())
.unwrap_or("persist")
))
}
#[cfg(unix)]
fn fsync_parent(path: &Path) -> io::Result<()> {
let Some(parent) = path.parent() else {
return Ok(());
};
if parent.as_os_str().is_empty() {
return Ok(());
}
File::open(parent)?.sync_all()
}
#[cfg(not(unix))]
fn fsync_parent(_path: &Path) -> io::Result<()> {
Ok(())
}
#[derive(Debug)]
struct Inner<T: DeserializeOwned + Serialize + Default + Clone> {
data: Mutex<T>,
path: PathBuf,
sync: bool,
}
impl<T> Inner<T>
where
T: DeserializeOwned + Serialize + Default + Clone,
{
fn open(path: impl AsRef<Path>, sync: bool) -> Result<Self> {
let path = path.as_ref().to_path_buf();
let tmp = tmp_path(&path);
if tmp.exists() {
fs::remove_file(&tmp)?;
}
let data = if !path.exists() {
T::default()
} else {
let file = File::open(&path)?;
if file.metadata()?.len() == 0 {
T::default()
} else {
from_reader(file)?
}
};
let inner = Self {
data: Mutex::new(data),
path,
sync,
};
if !inner.path.exists() || inner.path.metadata()?.len() == 0 {
let data = inner.data.lock();
inner.flush(&data)?;
}
Ok(inner)
}
fn replace(&self, data: T) -> Result<()> {
let mut inner = self.data.lock();
let old = inner.clone();
*inner = data;
if let Err(e) = self.flush(&inner) {
*inner = old;
return Err(e);
}
Ok(())
}
fn update(&self, f: impl FnOnce(&mut T)) -> Result<()> {
let mut inner = self.data.lock();
let old = inner.clone();
f(&mut inner);
if let Err(e) = self.flush(&inner) {
*inner = old;
return Err(e);
}
Ok(())
}
fn save(&self) -> Result<()> {
let data = self.data.lock();
self.flush(&data)?;
Ok(())
}
fn flush(&self, data: &T) -> Result<()> {
let tmp = tmp_path(&self.path);
let bytes = to_vec_pretty(data)?;
{
let mut file = File::create(&tmp)?;
file.write_all(&bytes)?;
if self.sync {
file.sync_all()?;
}
}
fs::rename(&tmp, &self.path)?;
if self.sync {
fsync_parent(&self.path)?;
}
Ok(())
}
}
impl<T> Drop for Inner<T>
where
T: DeserializeOwned + Serialize + Default + Clone,
{
fn drop(&mut self) {
if let Err(e) = self.save() {
tracing::error!("Failed to save data: {}", e);
}
}
}
#[derive(Debug, Clone)]
pub struct Persist<T: DeserializeOwned + Serialize + Default + Clone> {
inner: Arc<Inner<T>>,
}
impl<T> Persist<T>
where
T: DeserializeOwned + Serialize + Default + Clone,
{
pub fn open(path: impl AsRef<Path>) -> Result<Self> {
let inner = Inner::open(path, true)?;
Ok(Self {
inner: Arc::new(inner),
})
}
pub fn open_no_sync(path: impl AsRef<Path>) -> Result<Self> {
let inner = Inner::open(path, false)?;
Ok(Self {
inner: Arc::new(inner),
})
}
pub fn replace(&self, data: T) -> Result<()> {
self.inner.replace(data)
}
pub fn update(&self, f: impl FnOnce(&mut T)) -> Result<()> {
self.inner.update(f)
}
#[tracing::instrument(skip(self))]
pub fn save(&self) -> Result<()> {
self.inner.save()
}
pub fn cloned(&self) -> T
where
T: Clone,
{
self.inner.data.lock().clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicUsize, Ordering};
use tempfile::tempdir;
#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)]
struct TestData {
n: u32,
}
#[test]
fn flush_produces_valid_json() {
let dir = tempdir().unwrap();
let path = dir.path().join("data.json");
let p = Persist::open(&path).unwrap();
p.replace(TestData { n: 42 }).unwrap();
let file = File::open(&path).unwrap();
let read: TestData = from_reader(file).unwrap();
assert_eq!(read, TestData { n: 42 });
}
#[test]
fn replace_update_save_return_ok() {
let dir = tempdir().unwrap();
let path = dir.path().join("data.json");
let p = Persist::open(&path).unwrap();
p.replace(TestData { n: 1 }).unwrap();
p.update(|d| d.n = 2).unwrap();
p.save().unwrap();
assert_eq!(p.cloned(), TestData { n: 2 });
}
#[test]
fn stale_tmp_cleaned_on_open() {
let dir = tempdir().unwrap();
let path = dir.path().join("data.json");
let tmp = tmp_path(&path);
fs::write(&path, r#"{"n":1}"#).unwrap();
fs::write(&tmp, r#"{"n":99}"#).unwrap();
assert!(tmp.exists());
let p: Persist<TestData> = Persist::open(&path).unwrap();
assert!(!tmp.exists());
assert_eq!(p.cloned(), TestData { n: 1 });
}
#[test]
fn drop_flushes_best_effort() {
static DROPS: AtomicUsize = AtomicUsize::new(0);
let dir = tempdir().unwrap();
let path = dir.path().join("data.json");
{
let p = Persist::open(&path).unwrap();
p.replace(TestData { n: 7 }).unwrap();
DROPS.fetch_add(1, Ordering::SeqCst);
}
assert_eq!(DROPS.load(Ordering::SeqCst), 1);
let file = File::open(&path).unwrap();
let read: TestData = from_reader(file).unwrap();
assert_eq!(read, TestData { n: 7 });
}
}