diskann_benchmark_runner/
result.rs1use std::path::Path;
9
10use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer};
11
12#[derive(Debug, Clone, Copy)]
17pub struct Checkpoint<'a> {
18 inner: Option<CheckpointInner<'a>>,
19}
20
21impl<'a> Checkpoint<'a> {
22 pub(crate) fn new(
33 input: &'a [serde_json::Value],
34 results: &'a [serde_json::Value],
35 path: &'a Path,
36 ) -> anyhow::Result<Self> {
37 if results.len() > input.len() {
38 Err(anyhow::Error::msg(format!(
39 "internal error - results len ({}) is greater than input len ({})",
40 results.len(),
41 input.len(),
42 )))
43 } else {
44 Ok(Self {
45 inner: Some(CheckpointInner {
46 input,
47 results,
48 path,
49 }),
50 })
51 }
52 }
53
54 pub fn save(&self) -> anyhow::Result<()> {
56 if let Some(inner) = &self.inner {
57 atomic_save(inner.path, &inner)
58 } else {
59 Ok(())
60 }
61 }
62
63 pub fn checkpoint<T: Serialize + ?Sized>(&self, partial: &T) -> anyhow::Result<()> {
71 if let Some(inner) = &self.inner {
72 if inner.results.len() == inner.input.len() {
73 Err(anyhow::Error::msg("internal error - checkpoint is full"))
74 } else {
75 let appended = Appended {
76 checkpoint: *inner,
77 partial: serde_json::to_value(partial)?,
78 };
79 atomic_save(inner.path, &appended)
80 }
81 } else {
82 Ok(())
83 }
84 }
85}
86
87pub(crate) fn atomic_save<T>(path: &Path, object: &T) -> anyhow::Result<()>
99where
100 T: Serialize + ?Sized,
101{
102 let temp = format!("{}.temp", path.display());
103 if Path::new(&temp).exists() {
104 return Err(anyhow::Error::msg(format!(
105 "Temporary file {} already exists. Aborting!",
106 temp
107 )));
108 }
109
110 let buffer = std::fs::File::create(&temp)?;
111 serde_json::to_writer_pretty(buffer, object)?;
112 std::fs::rename(&temp, path)?;
113 Ok(())
114}
115
116#[derive(Debug, Serialize, Deserialize)]
118pub(crate) struct RawResult {
119 pub(crate) input: serde_json::Value,
120 pub(crate) results: serde_json::Value,
121}
122
123impl RawResult {
124 pub(crate) fn load(path: &Path) -> anyhow::Result<Vec<Self>> {
125 crate::internal::load_from_disk(path)
126 }
127}
128
129#[derive(Debug, Clone, Copy)]
134struct CheckpointInner<'a> {
135 input: &'a [serde_json::Value],
136 results: &'a [serde_json::Value],
137 path: &'a Path,
138}
139
140#[derive(Debug, Serialize)]
144struct SingleResult<'a> {
145 input: &'a serde_json::Value,
146 results: &'a serde_json::Value,
147}
148impl Serialize for CheckpointInner<'_> {
149 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
151 where
152 S: Serializer,
153 {
154 let mut seq = serializer.serialize_seq(Some(self.results.len()))?;
155 for (input, results) in std::iter::zip(self.input.iter(), self.results.iter()) {
156 seq.serialize_element(&SingleResult { input, results })?;
157 }
158 seq.end()
159 }
160}
161
162struct Appended<'a> {
167 checkpoint: CheckpointInner<'a>,
168 partial: serde_json::Value,
169}
170
171impl Serialize for Appended<'_> {
172 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
173 where
174 S: Serializer,
175 {
176 let mut seq = serializer.serialize_seq(Some(self.checkpoint.results.len() + 1))?;
177 std::iter::zip(
178 self.checkpoint.input.iter(),
179 self.checkpoint
180 .results
181 .iter()
182 .chain(std::iter::once(&self.partial)),
183 )
184 .try_for_each(|(input, results)| seq.serialize_element(&SingleResult { input, results }))?;
185 seq.end()
186 }
187}
188
189#[cfg(test)]
194mod tests {
195 use serde::Deserialize;
196 use serde_json::value::Value;
197
198 use super::*;
199 use crate::{test::TypeInput, utils::datatype::DataType};
200
201 fn load_from_file<T>(path: &std::path::Path) -> T
202 where
203 T: for<'a> Deserialize<'a>,
204 {
205 let file = std::fs::File::open(path).unwrap();
206 let reader = std::io::BufReader::new(file);
207 serde_json::from_reader(reader).unwrap()
208 }
209
210 fn check_results(results: &[Value], inputs: &[TypeInput], expected: &[Value]) {
218 assert_eq!(results.len(), inputs.len());
219 assert_eq!(results.len(), expected.len());
220
221 for i in 0..results.len() {
222 match &results[i] {
223 Value::Object(map) => {
224 assert_eq!(
225 map.len(),
226 2,
227 "Each serialized result should only have two top level entries"
228 );
229 let input = TypeInput::deserialize(&map["input"]).unwrap();
230 assert_eq!(input, inputs[i].clone());
231 assert_eq!(map["results"], expected[i]);
232 }
233 _ => panic!("incorrect formatting for output {}", i),
234 }
235 }
236 }
237
238 #[test]
239 fn test_atomic_save() {
240 let dir = tempfile::tempdir().unwrap();
241 let path = dir.path();
242
243 let message: &str = "hello world";
244 let full = path.join("file.json");
245 assert!(!full.exists());
246 assert!(atomic_save(&full, message).is_ok());
247 assert!(full.exists());
248
249 let deserialized: String = load_from_file(&full);
251 assert_eq!(deserialized, message);
252
253 std::fs::File::create(path.join("file.json.temp")).unwrap();
255
256 let err = atomic_save(&full, message).unwrap_err();
257 let message = format!("{:?}", err);
258 assert!(message.contains("Temporary file"));
259 assert!(message.contains("already exists"));
260 }
261
262 #[test]
263 fn test_checkpoint() {
264 let dir = tempfile::tempdir().unwrap();
265 let path = dir.path();
266
267 let savepath = path.join("output.json");
268
269 let inputs = [
270 TypeInput::new(DataType::Float32, 1, false),
271 TypeInput::new(DataType::Float16, 2, false),
272 TypeInput::new(DataType::Float64, 3, false),
273 ];
274
275 let serialized: Vec<_> = inputs
276 .iter()
277 .map(|i| serde_json::to_value(i).unwrap())
278 .collect();
279
280 {
282 let checkpoint = Checkpoint::new(&serialized, &[], &savepath).unwrap();
283 assert!(!savepath.exists());
284 checkpoint.save().unwrap();
285 assert!(savepath.exists());
286 let reloaded: Vec<Value> = load_from_file(&savepath);
287 assert!(reloaded.is_empty());
288
289 checkpoint.checkpoint("some string").unwrap();
291 let reloaded: Vec<Value> = load_from_file(&savepath);
292 check_results(
293 &reloaded,
294 &inputs[0..1],
295 &[Value::String("some string".into())],
296 );
297 }
298
299 {
301 let values = vec![serde_json::to_value("some result").unwrap()];
302 let checkpoint = Checkpoint::new(&serialized, &values, &savepath).unwrap();
303
304 checkpoint.save().unwrap();
305 {
306 let reloaded: Vec<Value> = load_from_file(&savepath);
307 check_results(
308 &reloaded,
309 &inputs[0..1],
310 &[Value::String("some result".into())],
311 );
312 }
313
314 checkpoint.checkpoint("another result").unwrap();
316 {
317 let reloaded: Vec<Value> = load_from_file(&savepath);
318 check_results(
319 &reloaded,
320 &inputs[0..2],
321 &[
322 Value::String("some result".into()),
323 Value::String("another result".into()),
324 ],
325 );
326 }
327 }
328
329 {
331 let values = vec![
332 serde_json::to_value("a").unwrap(),
333 serde_json::to_value("b").unwrap(),
334 serde_json::to_value("c").unwrap(),
335 ];
336 let checkpoint = Checkpoint::new(&serialized, &values, &savepath).unwrap();
337 checkpoint.save().unwrap();
338 let reloaded: Vec<Value> = load_from_file(&savepath);
339
340 check_results(
341 &reloaded,
342 &inputs,
343 &[
344 Value::String("a".into()),
345 Value::String("b".into()),
346 Value::String("c".into()),
347 ],
348 );
349
350 let err = checkpoint.checkpoint("too full").unwrap_err();
352 let message = err.to_string();
353 assert!(message.contains("internal error - checkpoint is full"));
354 }
355
356 {
358 let values = vec![
359 serde_json::to_value("a").unwrap(),
360 serde_json::to_value("b").unwrap(),
361 serde_json::to_value("c").unwrap(),
362 serde_json::to_value("d").unwrap(),
363 ];
364 let err = Checkpoint::new(&serialized, &values, &savepath).unwrap_err();
365 let message = err.to_string();
366 assert!(message.contains("internal error - results len"));
367 }
368 }
369}