use std::{collections::HashMap, hash::Hash, marker::PhantomData};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone)]
pub struct StateSaver<T> {
file_path: String,
phantom: PhantomData<T>,
}
impl<T> StateSaver<T> {
pub fn new(file_path: &str) -> Self {
Self {
file_path: file_path.to_owned(),
phantom: PhantomData,
}
}
pub fn state(&self) -> Option<T>
where
T: for<'a> Deserialize<'a>,
{
let file = std::fs::File::open(self.file_path.clone()).ok()?;
serde_json::from_reader(file).expect("serialization should be be bidirectional")
}
pub fn save(&self, state: &T)
where
T: Serialize,
{
use std::io::Write;
let mut file = std::fs::File::options()
.write(true)
.truncate(true)
.create(true)
.open(self.file_path.clone())
.expect("file open must not fail");
let str = serde_json::to_string(state).expect("serialization must not fail");
write!(file, "{str}").expect("writing must not fail");
}
}
#[derive(Debug, Clone)]
pub struct ProgressSaver<K, V>(StateSaver<HashMap<K, V>>);
impl<K, V> ProgressSaver<K, V> {
pub fn new(file_path: &str) -> Self {
Self(StateSaver::new(file_path))
}
pub fn state(&self) -> HashMap<K, V>
where
K: for<'a> Deserialize<'a> + Eq + Hash,
V: for<'b> Deserialize<'b>,
{
self.0.state().unwrap_or_default()
}
pub fn save(&self, id: K, result: V)
where
K: for<'a> Deserialize<'a> + Serialize + Eq + Hash,
V: for<'b> Deserialize<'b> + Serialize,
{
let mut state = self.state();
state.insert(id, result);
self.0.save(&state);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn state_saver_works() {
let state_saver = StateSaver::new("state_test.txt");
state_saver.save(&1);
assert_eq!(state_saver.state(), Some(1));
state_saver.save(&2);
assert_eq!(state_saver.state(), Some(2));
}
#[test]
fn progress_saver_works() {
let progress_saver = ProgressSaver::new("progress_test.txt");
progress_saver.save(0, 1);
assert_eq!(HashMap::from([(0, 1)]), progress_saver.state());
progress_saver.save(2, 3);
assert_eq!(HashMap::from([(0, 1), (2, 3)]), progress_saver.state());
progress_saver.save(2, 5);
assert_eq!(HashMap::from([(0, 1), (2, 5)]), progress_saver.state());
}
}