armour-core 0.2.0

Core types for armour ecosystem
Documentation
#![allow(clippy::unwrap_used)]
use core::fmt::Debug;
use std::{
    fs::{self, File},
    io::{self, Write},
    path::{Path, PathBuf},
    sync::Arc,
};

use parking_lot::Mutex;
use serde::{Serialize, de::DeserializeOwned};
use serde_json::{from_reader, to_vec_pretty};
use thiserror::Error;

#[derive(Debug, Error)]
pub enum PersistError {
    #[error(transparent)]
    Io(#[from] io::Error),
    #[error(transparent)]
    Json(#[from] serde_json::Error),
}

type Result<T> = std::result::Result<T, PersistError>;

fn tmp_path(path: &Path) -> PathBuf {
    path.with_file_name(format!(
        "{}.tmp",
        path.file_name()
            .and_then(|n| n.to_str())
            .unwrap_or("persist")
    ))
}

#[cfg(unix)]
fn fsync_parent(path: &Path) -> io::Result<()> {
    let Some(parent) = path.parent() else {
        return Ok(());
    };
    if parent.as_os_str().is_empty() {
        return Ok(());
    }
    File::open(parent)?.sync_all()
}

#[cfg(not(unix))]
fn fsync_parent(_path: &Path) -> io::Result<()> {
    Ok(())
}

#[derive(Debug)]
struct Inner<T: DeserializeOwned + Serialize + Default + Clone> {
    data: Mutex<T>,
    path: PathBuf,
    sync: bool,
}

impl<T> Inner<T>
where
    T: DeserializeOwned + Serialize + Default + Clone,
{
    fn open(path: impl AsRef<Path>, sync: bool) -> Result<Self> {
        let path = path.as_ref().to_path_buf();
        let tmp = tmp_path(&path);

        if tmp.exists() {
            fs::remove_file(&tmp)?;
        }

        let data = if !path.exists() {
            T::default()
        } else {
            let file = File::open(&path)?;
            if file.metadata()?.len() == 0 {
                T::default()
            } else {
                from_reader(file)?
            }
        };

        let inner = Self {
            data: Mutex::new(data),
            path,
            sync,
        };

        if !inner.path.exists() || inner.path.metadata()?.len() == 0 {
            let data = inner.data.lock();
            inner.flush(&data)?;
        }

        Ok(inner)
    }

    fn replace(&self, data: T) -> Result<()> {
        let mut inner = self.data.lock();
        let old = inner.clone();
        *inner = data;
        if let Err(e) = self.flush(&inner) {
            *inner = old;
            return Err(e);
        }
        Ok(())
    }

    fn update(&self, f: impl FnOnce(&mut T)) -> Result<()> {
        let mut inner = self.data.lock();
        let old = inner.clone();
        f(&mut inner);
        if let Err(e) = self.flush(&inner) {
            *inner = old;
            return Err(e);
        }
        Ok(())
    }

    fn save(&self) -> Result<()> {
        let data = self.data.lock();
        self.flush(&data)?;
        Ok(())
    }

    fn flush(&self, data: &T) -> Result<()> {
        let tmp = tmp_path(&self.path);
        let bytes = to_vec_pretty(data)?;

        {
            let mut file = File::create(&tmp)?;
            file.write_all(&bytes)?;
            if self.sync {
                file.sync_all()?;
            }
        }

        fs::rename(&tmp, &self.path)?;

        if self.sync {
            fsync_parent(&self.path)?;
        }

        Ok(())
    }
}

impl<T> Drop for Inner<T>
where
    T: DeserializeOwned + Serialize + Default + Clone,
{
    fn drop(&mut self) {
        if let Err(e) = self.save() {
            tracing::error!("Failed to save data: {}", e);
        }
    }
}

#[derive(Debug, Clone)]
pub struct Persist<T: DeserializeOwned + Serialize + Default + Clone> {
    inner: Arc<Inner<T>>,
}

impl<T> Persist<T>
where
    T: DeserializeOwned + Serialize + Default + Clone,
{
    pub fn open(path: impl AsRef<Path>) -> Result<Self> {
        let inner = Inner::open(path, true)?;
        Ok(Self {
            inner: Arc::new(inner),
        })
    }

    pub fn open_no_sync(path: impl AsRef<Path>) -> Result<Self> {
        let inner = Inner::open(path, false)?;
        Ok(Self {
            inner: Arc::new(inner),
        })
    }

    /// Save the data to the file.
    pub fn replace(&self, data: T) -> Result<()> {
        self.inner.replace(data)
    }

    pub fn update(&self, f: impl FnOnce(&mut T)) -> Result<()> {
        self.inner.update(f)
    }

    #[tracing::instrument(skip(self))]
    pub fn save(&self) -> Result<()> {
        self.inner.save()
    }

    pub fn cloned(&self) -> T
    where
        T: Clone,
    {
        self.inner.data.lock().clone()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use serde::{Deserialize, Serialize};
    use std::sync::atomic::{AtomicUsize, Ordering};
    use tempfile::tempdir;

    #[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)]
    struct TestData {
        n: u32,
    }

    #[test]
    fn flush_produces_valid_json() {
        let dir = tempdir().unwrap();
        let path = dir.path().join("data.json");
        let p = Persist::open(&path).unwrap();
        p.replace(TestData { n: 42 }).unwrap();

        let file = File::open(&path).unwrap();
        let read: TestData = from_reader(file).unwrap();
        assert_eq!(read, TestData { n: 42 });
    }

    #[test]
    fn replace_update_save_return_ok() {
        let dir = tempdir().unwrap();
        let path = dir.path().join("data.json");
        let p = Persist::open(&path).unwrap();

        p.replace(TestData { n: 1 }).unwrap();
        p.update(|d| d.n = 2).unwrap();
        p.save().unwrap();

        assert_eq!(p.cloned(), TestData { n: 2 });
    }

    #[test]
    fn stale_tmp_cleaned_on_open() {
        let dir = tempdir().unwrap();
        let path = dir.path().join("data.json");
        let tmp = tmp_path(&path);

        fs::write(&path, r#"{"n":1}"#).unwrap();
        fs::write(&tmp, r#"{"n":99}"#).unwrap();
        assert!(tmp.exists());

        let p: Persist<TestData> = Persist::open(&path).unwrap();
        assert!(!tmp.exists());
        assert_eq!(p.cloned(), TestData { n: 1 });
    }

    #[test]
    fn drop_flushes_best_effort() {
        static DROPS: AtomicUsize = AtomicUsize::new(0);

        let dir = tempdir().unwrap();
        let path = dir.path().join("data.json");

        {
            let p = Persist::open(&path).unwrap();
            p.replace(TestData { n: 7 }).unwrap();
            DROPS.fetch_add(1, Ordering::SeqCst);
        }

        assert_eq!(DROPS.load(Ordering::SeqCst), 1);
        let file = File::open(&path).unwrap();
        let read: TestData = from_reader(file).unwrap();
        assert_eq!(read, TestData { n: 7 });
    }
}