diskann_benchmark_runner/
result.rs1use std::path::Path;
9
10use serde::{ser::SerializeSeq, Serialize, Serializer};
11
12#[derive(Debug, Clone, Copy)]
17pub struct Checkpoint<'a> {
18 inner: Option<CheckpointInner<'a>>,
19}
20
21impl<'a> Checkpoint<'a> {
22 pub(crate) fn new(
33 input: &'a [serde_json::Value],
34 results: &'a [serde_json::Value],
35 path: &'a Path,
36 ) -> anyhow::Result<Self> {
37 if results.len() > input.len() {
38 Err(anyhow::Error::msg(format!(
39 "internal error - results len ({}) is greater than input len ({})",
40 results.len(),
41 input.len(),
42 )))
43 } else {
44 Ok(Self {
45 inner: Some(CheckpointInner {
46 input,
47 results,
48 path,
49 }),
50 })
51 }
52 }
53
54 pub(crate) fn empty() -> Self {
56 Self { inner: None }
57 }
58
59 pub fn save(&self) -> anyhow::Result<()> {
61 if let Some(inner) = &self.inner {
62 atomic_save(inner.path, &inner)
63 } else {
64 Ok(())
65 }
66 }
67
68 pub fn checkpoint<T: Serialize + ?Sized>(&self, partial: &T) -> anyhow::Result<()> {
76 if let Some(inner) = &self.inner {
77 if inner.results.len() == inner.input.len() {
78 Err(anyhow::Error::msg("internal error - checkpoint is full"))
79 } else {
80 let appended = Appended {
81 checkpoint: *inner,
82 partial: serde_json::to_value(partial)?,
83 };
84 atomic_save(inner.path, &appended)
85 }
86 } else {
87 Ok(())
88 }
89 }
90}
91
92pub(crate) fn atomic_save<T>(path: &Path, object: &T) -> anyhow::Result<()>
104where
105 T: Serialize + ?Sized,
106{
107 let temp = format!("{}.temp", path.display());
108 if Path::new(&temp).exists() {
109 return Err(anyhow::Error::msg(format!(
110 "Temporary file {} already exists. Aborting!",
111 temp
112 )));
113 }
114
115 let buffer = std::fs::File::create(&temp)?;
116 serde_json::to_writer_pretty(buffer, object)?;
117 std::fs::rename(&temp, path)?;
118 Ok(())
119}
120
121#[derive(Debug, Clone, Copy)]
126struct CheckpointInner<'a> {
127 input: &'a [serde_json::Value],
128 results: &'a [serde_json::Value],
129 path: &'a Path,
130}
131
132#[derive(Debug, Serialize)]
136struct SingleResult<'a> {
137 input: &'a serde_json::Value,
138 results: &'a serde_json::Value,
139}
140impl Serialize for CheckpointInner<'_> {
141 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
143 where
144 S: Serializer,
145 {
146 let mut seq = serializer.serialize_seq(Some(self.results.len()))?;
147 for (input, results) in std::iter::zip(self.input.iter(), self.results.iter()) {
148 seq.serialize_element(&SingleResult { input, results })?;
149 }
150 seq.end()
151 }
152}
153
154struct Appended<'a> {
159 checkpoint: CheckpointInner<'a>,
160 partial: serde_json::Value,
161}
162
163impl Serialize for Appended<'_> {
164 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
165 where
166 S: Serializer,
167 {
168 let mut seq = serializer.serialize_seq(Some(self.checkpoint.results.len() + 1))?;
169 std::iter::zip(
170 self.checkpoint.input.iter(),
171 self.checkpoint
172 .results
173 .iter()
174 .chain(std::iter::once(&self.partial)),
175 )
176 .try_for_each(|(input, results)| seq.serialize_element(&SingleResult { input, results }))?;
177 seq.end()
178 }
179}
180
181#[cfg(test)]
186mod tests {
187 use serde::Deserialize;
188 use serde_json::value::Value;
189
190 use super::*;
191 use crate::{test::TypeInput, utils::datatype::DataType};
192
193 fn load_from_file<T>(path: &std::path::Path) -> T
194 where
195 T: for<'a> Deserialize<'a>,
196 {
197 let file = std::fs::File::open(path).unwrap();
198 let reader = std::io::BufReader::new(file);
199 serde_json::from_reader(reader).unwrap()
200 }
201
202 fn check_results(results: &[Value], inputs: &[TypeInput], expected: &[Value]) {
210 assert_eq!(results.len(), inputs.len());
211 assert_eq!(results.len(), expected.len());
212
213 for i in 0..results.len() {
214 match &results[i] {
215 Value::Object(map) => {
216 assert_eq!(
217 map.len(),
218 2,
219 "Each serialized result should only have two top level entries"
220 );
221 let input = TypeInput::deserialize(&map["input"]).unwrap();
222 assert_eq!(input, inputs[i].clone());
223 assert_eq!(map["results"], expected[i]);
224 }
225 _ => panic!("incorrect formatting for output {}", i),
226 }
227 }
228 }
229
230 #[test]
231 fn test_atomic_save() {
232 let dir = tempfile::tempdir().unwrap();
233 let path = dir.path();
234
235 let message: &str = "hello world";
236 let full = path.join("file.json");
237 assert!(!full.exists());
238 assert!(atomic_save(&full, message).is_ok());
239 assert!(full.exists());
240
241 let deserialized: String = load_from_file(&full);
243 assert_eq!(deserialized, message);
244
245 std::fs::File::create(path.join("file.json.temp")).unwrap();
247
248 let err = atomic_save(&full, message).unwrap_err();
249 let message = format!("{:?}", err);
250 assert!(message.contains("Temporary file"));
251 assert!(message.contains("already exists"));
252 }
253
254 #[test]
255 fn test_empty() {
256 let checkpoint = Checkpoint::empty();
257
258 assert!(checkpoint.save().is_ok());
260 assert!(checkpoint.checkpoint("hello world").is_ok());
261 }
262
263 #[test]
264 fn test_checkpoint() {
265 let dir = tempfile::tempdir().unwrap();
266 let path = dir.path();
267
268 let savepath = path.join("output.json");
269
270 let inputs = [
271 TypeInput::new(DataType::Float32, 1, false),
272 TypeInput::new(DataType::Float16, 2, false),
273 TypeInput::new(DataType::Float64, 3, false),
274 ];
275
276 let serialized: Vec<_> = inputs
277 .iter()
278 .map(|i| serde_json::to_value(i).unwrap())
279 .collect();
280
281 {
283 let checkpoint = Checkpoint::new(&serialized, &[], &savepath).unwrap();
284 assert!(!savepath.exists());
285 checkpoint.save().unwrap();
286 assert!(savepath.exists());
287 let reloaded: Vec<Value> = load_from_file(&savepath);
288 assert!(reloaded.is_empty());
289
290 checkpoint.checkpoint("some string").unwrap();
292 let reloaded: Vec<Value> = load_from_file(&savepath);
293 check_results(
294 &reloaded,
295 &inputs[0..1],
296 &[Value::String("some string".into())],
297 );
298 }
299
300 {
302 let values = vec![serde_json::to_value("some result").unwrap()];
303 let checkpoint = Checkpoint::new(&serialized, &values, &savepath).unwrap();
304
305 checkpoint.save().unwrap();
306 {
307 let reloaded: Vec<Value> = load_from_file(&savepath);
308 check_results(
309 &reloaded,
310 &inputs[0..1],
311 &[Value::String("some result".into())],
312 );
313 }
314
315 checkpoint.checkpoint("another result").unwrap();
317 {
318 let reloaded: Vec<Value> = load_from_file(&savepath);
319 check_results(
320 &reloaded,
321 &inputs[0..2],
322 &[
323 Value::String("some result".into()),
324 Value::String("another result".into()),
325 ],
326 );
327 }
328 }
329
330 {
332 let values = vec![
333 serde_json::to_value("a").unwrap(),
334 serde_json::to_value("b").unwrap(),
335 serde_json::to_value("c").unwrap(),
336 ];
337 let checkpoint = Checkpoint::new(&serialized, &values, &savepath).unwrap();
338 checkpoint.save().unwrap();
339 let reloaded: Vec<Value> = load_from_file(&savepath);
340
341 check_results(
342 &reloaded,
343 &inputs,
344 &[
345 Value::String("a".into()),
346 Value::String("b".into()),
347 Value::String("c".into()),
348 ],
349 );
350
351 let err = checkpoint.checkpoint("too full").unwrap_err();
353 let message = err.to_string();
354 assert!(message.contains("internal error - checkpoint is full"));
355 }
356
357 {
359 let values = vec![
360 serde_json::to_value("a").unwrap(),
361 serde_json::to_value("b").unwrap(),
362 serde_json::to_value("c").unwrap(),
363 serde_json::to_value("d").unwrap(),
364 ];
365 let err = Checkpoint::new(&serialized, &values, &savepath).unwrap_err();
366 let message = err.to_string();
367 assert!(message.contains("internal error - results len"));
368 }
369 }
370}