use std::path::Path;
use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer};
#[derive(Debug, Clone, Copy)]
pub struct Checkpoint<'a> {
inner: Option<CheckpointInner<'a>>,
}
impl<'a> Checkpoint<'a> {
pub(crate) fn new(
input: &'a [serde_json::Value],
results: &'a [serde_json::Value],
path: &'a Path,
) -> anyhow::Result<Self> {
if results.len() > input.len() {
Err(anyhow::Error::msg(format!(
"internal error - results len ({}) is greater than input len ({})",
results.len(),
input.len(),
)))
} else {
Ok(Self {
inner: Some(CheckpointInner {
input,
results,
path,
}),
})
}
}
pub fn save(&self) -> anyhow::Result<()> {
if let Some(inner) = &self.inner {
atomic_save(inner.path, &inner)
} else {
Ok(())
}
}
pub fn checkpoint<T: Serialize + ?Sized>(&self, partial: &T) -> anyhow::Result<()> {
if let Some(inner) = &self.inner {
if inner.results.len() == inner.input.len() {
Err(anyhow::Error::msg("internal error - checkpoint is full"))
} else {
let appended = Appended {
checkpoint: *inner,
partial: serde_json::to_value(partial)?,
};
atomic_save(inner.path, &appended)
}
} else {
Ok(())
}
}
}
pub(crate) fn atomic_save<T>(path: &Path, object: &T) -> anyhow::Result<()>
where
T: Serialize + ?Sized,
{
let temp = format!("{}.temp", path.display());
if Path::new(&temp).exists() {
return Err(anyhow::Error::msg(format!(
"Temporary file {} already exists. Aborting!",
temp
)));
}
let buffer = std::fs::File::create(&temp)?;
serde_json::to_writer_pretty(buffer, object)?;
std::fs::rename(&temp, path)?;
Ok(())
}
#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct RawResult {
pub(crate) input: serde_json::Value,
pub(crate) results: serde_json::Value,
}
impl RawResult {
pub(crate) fn load(path: &Path) -> anyhow::Result<Vec<Self>> {
crate::internal::load_from_disk(path)
}
}
#[derive(Debug, Clone, Copy)]
struct CheckpointInner<'a> {
input: &'a [serde_json::Value],
results: &'a [serde_json::Value],
path: &'a Path,
}
#[derive(Debug, Serialize)]
struct SingleResult<'a> {
input: &'a serde_json::Value,
results: &'a serde_json::Value,
}
impl Serialize for CheckpointInner<'_> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut seq = serializer.serialize_seq(Some(self.results.len()))?;
for (input, results) in std::iter::zip(self.input.iter(), self.results.iter()) {
seq.serialize_element(&SingleResult { input, results })?;
}
seq.end()
}
}
struct Appended<'a> {
checkpoint: CheckpointInner<'a>,
partial: serde_json::Value,
}
impl Serialize for Appended<'_> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut seq = serializer.serialize_seq(Some(self.checkpoint.results.len() + 1))?;
std::iter::zip(
self.checkpoint.input.iter(),
self.checkpoint
.results
.iter()
.chain(std::iter::once(&self.partial)),
)
.try_for_each(|(input, results)| seq.serialize_element(&SingleResult { input, results }))?;
seq.end()
}
}
#[cfg(test)]
mod tests {
use serde::Deserialize;
use serde_json::value::Value;
use super::*;
use crate::{test::TypeInput, utils::datatype::DataType};
fn load_from_file<T>(path: &std::path::Path) -> T
where
T: for<'a> Deserialize<'a>,
{
let file = std::fs::File::open(path).unwrap();
let reader = std::io::BufReader::new(file);
serde_json::from_reader(reader).unwrap()
}
fn check_results(results: &[Value], inputs: &[TypeInput], expected: &[Value]) {
assert_eq!(results.len(), inputs.len());
assert_eq!(results.len(), expected.len());
for i in 0..results.len() {
match &results[i] {
Value::Object(map) => {
assert_eq!(
map.len(),
2,
"Each serialized result should only have two top level entries"
);
let input = TypeInput::deserialize(&map["input"]).unwrap();
assert_eq!(input, inputs[i].clone());
assert_eq!(map["results"], expected[i]);
}
_ => panic!("incorrect formatting for output {}", i),
}
}
}
#[test]
fn test_atomic_save() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path();
let message: &str = "hello world";
let full = path.join("file.json");
assert!(!full.exists());
assert!(atomic_save(&full, message).is_ok());
assert!(full.exists());
let deserialized: String = load_from_file(&full);
assert_eq!(deserialized, message);
std::fs::File::create(path.join("file.json.temp")).unwrap();
let err = atomic_save(&full, message).unwrap_err();
let message = format!("{:?}", err);
assert!(message.contains("Temporary file"));
assert!(message.contains("already exists"));
}
#[test]
fn test_checkpoint() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path();
let savepath = path.join("output.json");
let inputs = [
TypeInput::new(DataType::Float32, 1, false),
TypeInput::new(DataType::Float16, 2, false),
TypeInput::new(DataType::Float64, 3, false),
];
let serialized: Vec<_> = inputs
.iter()
.map(|i| serde_json::to_value(i).unwrap())
.collect();
{
let checkpoint = Checkpoint::new(&serialized, &[], &savepath).unwrap();
assert!(!savepath.exists());
checkpoint.save().unwrap();
assert!(savepath.exists());
let reloaded: Vec<Value> = load_from_file(&savepath);
assert!(reloaded.is_empty());
checkpoint.checkpoint("some string").unwrap();
let reloaded: Vec<Value> = load_from_file(&savepath);
check_results(
&reloaded,
&inputs[0..1],
&[Value::String("some string".into())],
);
}
{
let values = vec![serde_json::to_value("some result").unwrap()];
let checkpoint = Checkpoint::new(&serialized, &values, &savepath).unwrap();
checkpoint.save().unwrap();
{
let reloaded: Vec<Value> = load_from_file(&savepath);
check_results(
&reloaded,
&inputs[0..1],
&[Value::String("some result".into())],
);
}
checkpoint.checkpoint("another result").unwrap();
{
let reloaded: Vec<Value> = load_from_file(&savepath);
check_results(
&reloaded,
&inputs[0..2],
&[
Value::String("some result".into()),
Value::String("another result".into()),
],
);
}
}
{
let values = vec![
serde_json::to_value("a").unwrap(),
serde_json::to_value("b").unwrap(),
serde_json::to_value("c").unwrap(),
];
let checkpoint = Checkpoint::new(&serialized, &values, &savepath).unwrap();
checkpoint.save().unwrap();
let reloaded: Vec<Value> = load_from_file(&savepath);
check_results(
&reloaded,
&inputs,
&[
Value::String("a".into()),
Value::String("b".into()),
Value::String("c".into()),
],
);
let err = checkpoint.checkpoint("too full").unwrap_err();
let message = err.to_string();
assert!(message.contains("internal error - checkpoint is full"));
}
{
let values = vec![
serde_json::to_value("a").unwrap(),
serde_json::to_value("b").unwrap(),
serde_json::to_value("c").unwrap(),
serde_json::to_value("d").unwrap(),
];
let err = Checkpoint::new(&serialized, &values, &savepath).unwrap_err();
let message = err.to_string();
assert!(message.contains("internal error - results len"));
}
}
}