lib_contra/
persistent.rs

1//! Allow saving and loading to/from disk
2use std::{
3    fs::File,
4    io::{self, BufReader, Cursor, Read, Write},
5    path::Path,
6    str::from_utf8,
7};
8
9use crate::{
10    deserialize::{
11        json::{FromJson, JsonDeserializer},
12        Deserialize,
13    },
14    error::{AnyError, IoResult},
15    serialize::{
16        json::{IntoJson, JsonSerializer, PrettyJsonFormatter},
17        Serialize,
18    },
19};
20
21/// Allow saving and loading to/from disk
22///
23/// Automatically implemented for types that implement both [Serialize] and [Deserialize]
24pub trait Persistent: Serialize + Deserialize {
25    fn save(&self, path: &str) -> Result<(), AnyError>;
26    fn load(path: &str) -> Result<Self, AnyError>;
27}
28
29fn serialize_with_default<S: Serialize>(value: &S) -> Result<Vec<u8>, AnyError> {
30    let mut buffer = Vec::with_capacity(128);
31    let mut ser = DefaultSerializer::new(PrettyJsonFormatter::new("\t".to_string()), &mut buffer);
32    value.serialize(&mut ser, &crate::position::Position::Closing)?;
33    Ok(buffer)
34}
35
36fn deserialize_with_default<D: Deserialize>(value: &[u8]) -> Result<D, AnyError> {
37    let cursor = Cursor::new(value);
38    let mut des = DefaultDeserializer::new(cursor);
39    D::deserialize(&mut des)
40}
41
42fn serialize_factory<S: Serialize>(value: &S, path: &Path) -> Result<Vec<u8>, AnyError> {
43    if let Some(ending) = path.extension() {
44        if ending == "json" {
45            return IntoJson::to_json(value).map(|json| json.into_bytes());
46        }
47    }
48    serialize_with_default(value)
49}
50
51fn deserializer_factory<D: Deserialize>(value: &[u8], path: &Path) -> Result<D, AnyError> {
52    if let Some(ending) = path.extension() {
53        if ending == "json" {
54            return FromJson::from_json(
55                &from_utf8(value).expect("failed to convert content to utf8"),
56            );
57        }
58    }
59    deserialize_with_default(value)
60}
61
62type DefaultSerializer<'w> = JsonSerializer<'w, Vec<u8>, PrettyJsonFormatter>;
63type DefaultDeserializer<'w> = JsonDeserializer<Cursor<&'w [u8]>>;
64
65impl<T: Sized + Serialize + Deserialize> Persistent for T {
66    fn save(&self, path: &str) -> Result<(), AnyError> {
67        let path = Path::new(path);
68        let buffer = serialize_factory(self, path)?;
69        write_bytes_file(buffer.as_slice(), path).map_err(|e| e.into())
70    }
71
72    fn load(path: &str) -> Result<Self, AnyError> {
73        let path = Path::new(path);
74        let content = read_bytes_file(path)?;
75        deserializer_factory(content.as_slice(), path)
76    }
77}
78
79fn write_bytes_file(bytes: &[u8], path: &Path) -> IoResult {
80    let mut f = File::create(path)?;
81    f.write_all(bytes)?;
82    Ok(())
83}
84
85fn read_bytes_file(path: &Path) -> Result<Vec<u8>, io::Error> {
86    let file = File::open(path)?;
87    let mut reader = BufReader::new(file);
88    let mut buffer = Vec::new();
89
90    reader.read_to_end(&mut buffer)?;
91
92    Ok(buffer)
93}
94
95#[cfg(test)]
96mod test {
97    use std::{
98        fs::{self},
99        path::Path,
100    };
101
102    use super::Persistent;
103
104    struct FileLifetime {
105        pub(crate) path: String,
106    }
107
108    impl Drop for FileLifetime {
109        fn drop(&mut self) {
110            let path = Path::new(&self.path);
111            if path.exists() {
112                if !path.is_dir() {
113                    fs::remove_file(&path)
114                        .expect(format!("failed to delete file: {}", self.path).as_str());
115                } else {
116                    fs::remove_dir_all(&path)
117                        .expect(format!("failed to delete directory: {}", self.path).as_str());
118                }
119            }
120        }
121    }
122
123    #[test]
124    fn save_and_then_load_works() {
125        let file_lifetime = FileLifetime {
126            path: "save_i32.json".to_string(),
127        };
128        let data = 32i32;
129
130        let saved = data.save(&file_lifetime.path);
131
132        assert!(saved.is_ok());
133        assert!(Path::new(&file_lifetime.path).exists());
134
135        let loaded = i32::load(&file_lifetime.path);
136
137        assert!(loaded.is_ok());
138        assert_eq!(data, loaded.unwrap());
139    }
140}