Skip to main content

armour_core/
persist.rs

1#![allow(clippy::unwrap_used)]
2use core::fmt::Debug;
3use std::{fs::File, io, os::unix::fs::FileExt, path::Path, sync::Arc};
4
5use parking_lot::Mutex;
6use serde::{Serialize, de::DeserializeOwned};
7use serde_json::{from_reader, to_vec_pretty, to_writer_pretty};
8use thiserror::Error;
9
10#[derive(Debug, Error)]
11pub enum PersistError {
12    #[error(transparent)]
13    Io(#[from] io::Error),
14    #[error(transparent)]
15    Json(#[from] serde_json::Error),
16}
17
18type Result<T> = std::result::Result<T, PersistError>;
19
20#[derive(Debug)]
21struct Inner<T: DeserializeOwned + Serialize + Default> {
22    data: Mutex<T>,
23    file: File,
24}
25
26impl<T> Inner<T>
27where
28    T: DeserializeOwned + Serialize + Default,
29{
30    fn open(path: impl AsRef<Path>) -> Self {
31        let str_path = path.as_ref().to_str().unwrap_or_default().to_string();
32        let mut file = File::options()
33            .write(true)
34            .read(true)
35            .create(true)
36            .truncate(false)
37            .open(path.as_ref())
38            .expect(&str_path);
39
40        if file.metadata().expect(&str_path).len() == 0 {
41            let data = T::default();
42            to_writer_pretty(&mut file, &data).expect(&str_path);
43            let data = Mutex::new(data);
44            Self { data, file }
45        } else {
46            let data = from_reader(&mut file).expect(&str_path);
47            let data = Mutex::new(data);
48            Self { data, file }
49        }
50    }
51
52    fn replace(&self, data: T) -> Result<()> {
53        let mut inner = self.data.lock();
54        *inner = data;
55        self.flush(&inner)?;
56        Ok(())
57    }
58
59    fn update(&self, f: impl FnOnce(&mut T)) -> Result<()> {
60        let mut inner = self.data.lock();
61        f(&mut inner);
62        self.flush(&inner)?;
63        Ok(())
64    }
65
66    fn save(&self) -> Result<()> {
67        let data = self.data.lock();
68        self.flush(&data)?;
69        Ok(())
70    }
71
72    fn flush(&self, data: &T) -> Result<()> {
73        self.file.sync_all()?;
74        let bytes = to_vec_pretty(data)?;
75        self.file.set_len(0)?;
76        self.file.write_all_at(&bytes, 0)?;
77        self.file.sync_all()?;
78        Ok(())
79    }
80}
81
82impl<T> Drop for Inner<T>
83where
84    T: DeserializeOwned + Serialize + Default,
85{
86    fn drop(&mut self) {
87        if let Err(e) = self.save() {
88            tracing::error!("Failed to save data: {}", e);
89        }
90    }
91}
92
93#[derive(Debug, Clone)]
94pub struct Persist<T: DeserializeOwned + Serialize + Default> {
95    inner: Arc<Inner<T>>,
96    path: String,
97}
98
99impl<T> Persist<T>
100where
101    T: DeserializeOwned + Serialize + Default,
102{
103    pub fn open(path: impl AsRef<Path>) -> Self {
104        let str_path = path.as_ref().to_str().unwrap_or_default().to_string();
105        let inner = Inner::open(path);
106        Self {
107            inner: Arc::new(inner),
108            path: str_path,
109        }
110    }
111
112    /// save the data to the file
113    pub fn replace(&self, data: T) {
114        self.inner.replace(data).expect(&self.path);
115    }
116
117    pub fn update(&self, f: impl FnOnce(&mut T)) {
118        self.inner.update(f).expect(&self.path);
119    }
120
121    #[tracing::instrument(skip(self))]
122    pub fn save(&self) {
123        self.inner.save().expect(&self.path);
124    }
125
126    pub fn cloned(&self) -> T
127    where
128        T: Clone,
129    {
130        self.inner.data.lock().clone()
131    }
132}