Skip to main content

diskann_benchmark_runner/
result.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6//! A utility for providing incremental saving of results.
7
8use std::path::Path;
9
10use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer};
11
12/// A helper to generate incremental snapshots of data while a benchmark is progressing.
13///
14/// Benchmark implementations may use this to save results as they become available rather
15/// than waiting until the end.
16#[derive(Debug, Clone, Copy)]
17pub struct Checkpoint<'a> {
18    inner: Option<CheckpointInner<'a>>,
19}
20
21impl<'a> Checkpoint<'a> {
22    /// Create a new check-point that serializes the zip-combination of `input` and `results`
23    /// to `path`.
24    ///
25    /// This is meant to be used in context where we wish to incrementally save new results
26    /// along with all the results generated so far. As such, this requires
27    /// ```text
28    /// inputs.len() <= results.len()
29    /// ```
30    /// Subsequent calls to `checkpoint` will be assumed to belong to the input at
31    /// `results.len() + 1` and will be saved at that position.
32    pub(crate) fn new(
33        input: &'a [serde_json::Value],
34        results: &'a [serde_json::Value],
35        path: &'a Path,
36    ) -> anyhow::Result<Self> {
37        if results.len() > input.len() {
38            Err(anyhow::Error::msg(format!(
39                "internal error - results len ({}) is greater than input len ({})",
40                results.len(),
41                input.len(),
42            )))
43        } else {
44            Ok(Self {
45                inner: Some(CheckpointInner {
46                    input,
47                    results,
48                    path,
49                }),
50            })
51        }
52    }
53
54    /// Atomically save the zip of the inputs and results to the configured path.
55    pub fn save(&self) -> anyhow::Result<()> {
56        if let Some(inner) = &self.inner {
57            atomic_save(inner.path, &inner)
58        } else {
59            Ok(())
60        }
61    }
62
63    /// Treat `partial` as a new partial result for the current contents of the checkpoint.
64    ///
65    /// All previously generated results will be saved and `partial` will be grouped at
66    /// the input at `self.inner.results.len() + 1`.
67    ///
68    /// This function should only be called if `self` is not full (as in, there is at least
69    /// one input that does not have a corresponding result.
70    pub fn checkpoint<T: Serialize + ?Sized>(&self, partial: &T) -> anyhow::Result<()> {
71        if let Some(inner) = &self.inner {
72            if inner.results.len() == inner.input.len() {
73                Err(anyhow::Error::msg("internal error - checkpoint is full"))
74            } else {
75                let appended = Appended {
76                    checkpoint: *inner,
77                    partial: serde_json::to_value(partial)?,
78                };
79                atomic_save(inner.path, &appended)
80            }
81        } else {
82            Ok(())
83        }
84    }
85}
86
87/// Atomically save the serializable `object` to a JSON file at `path`.
88///
89/// This function works by first serializing to `format!("{}.temp", path)` and then using
90/// `std::fs::rename`, making the operation safe from interrupts.
91///
92/// This can fail for a number of reasons:
93///
94/// 1. `path` is not an valid file path.
95/// 2. The temporary file `format!("{}.temp", path)` already exists.
96/// 3. Serialization fails.
97/// 4. Renaming fails.
98pub(crate) fn atomic_save<T>(path: &Path, object: &T) -> anyhow::Result<()>
99where
100    T: Serialize + ?Sized,
101{
102    let temp = format!("{}.temp", path.display());
103    if Path::new(&temp).exists() {
104        return Err(anyhow::Error::msg(format!(
105            "Temporary file {} already exists. Aborting!",
106            temp
107        )));
108    }
109
110    let buffer = std::fs::File::create(&temp)?;
111    serde_json::to_writer_pretty(buffer, object)?;
112    std::fs::rename(&temp, path)?;
113    Ok(())
114}
115
116/// A utility for loading results previously saved via [`Checkpoint::checkpoint`].
117#[derive(Debug, Serialize, Deserialize)]
118pub(crate) struct RawResult {
119    pub(crate) input: serde_json::Value,
120    pub(crate) results: serde_json::Value,
121}
122
123impl RawResult {
124    pub(crate) fn load(path: &Path) -> anyhow::Result<Vec<Self>> {
125        crate::internal::load_from_disk(path)
126    }
127}
128
129////////////////////////////
130// Implementation Details //
131////////////////////////////
132
133#[derive(Debug, Clone, Copy)]
134struct CheckpointInner<'a> {
135    input: &'a [serde_json::Value],
136    results: &'a [serde_json::Value],
137    path: &'a Path,
138}
139
140// This applies the "zip" like behavior between pairs of `input` and `results` in
141// `CheckpointInner`, so the data structure can act as a vector of pairs rather than as a
142// pair of vectors.
143#[derive(Debug, Serialize)]
144struct SingleResult<'a> {
145    input: &'a serde_json::Value,
146    results: &'a serde_json::Value,
147}
148impl Serialize for CheckpointInner<'_> {
149    /// Serialize up to `self.results.len()` pairs of inputs and results.
150    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
151    where
152        S: Serializer,
153    {
154        let mut seq = serializer.serialize_seq(Some(self.results.len()))?;
155        for (input, results) in std::iter::zip(self.input.iter(), self.results.iter()) {
156            seq.serialize_element(&SingleResult { input, results })?;
157        }
158        seq.end()
159    }
160}
161
162/// A lazily appended partial data result.
163///
164/// NOTE: The associated `Checkpoint` must "have room" for an additional value, That is,
165/// `checkpoint.results.len() < checkpoint.input.len()`.
166struct Appended<'a> {
167    checkpoint: CheckpointInner<'a>,
168    partial: serde_json::Value,
169}
170
171impl Serialize for Appended<'_> {
172    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
173    where
174        S: Serializer,
175    {
176        let mut seq = serializer.serialize_seq(Some(self.checkpoint.results.len() + 1))?;
177        std::iter::zip(
178            self.checkpoint.input.iter(),
179            self.checkpoint
180                .results
181                .iter()
182                .chain(std::iter::once(&self.partial)),
183        )
184        .try_for_each(|(input, results)| seq.serialize_element(&SingleResult { input, results }))?;
185        seq.end()
186    }
187}
188
189///////////
190// Tests //
191///////////
192
193#[cfg(test)]
194mod tests {
195    use serde::Deserialize;
196    use serde_json::value::Value;
197
198    use super::*;
199    use crate::{test::TypeInput, utils::datatype::DataType};
200
201    fn load_from_file<T>(path: &std::path::Path) -> T
202    where
203        T: for<'a> Deserialize<'a>,
204    {
205        let file = std::fs::File::open(path).unwrap();
206        let reader = std::io::BufReader::new(file);
207        serde_json::from_reader(reader).unwrap()
208    }
209
210    // Check that each result has the form:
211    // ```
212    // {
213    //     input: <input-object>,
214    //     results: <result-object>,
215    // }
216    // ```
217    fn check_results(results: &[Value], inputs: &[TypeInput], expected: &[Value]) {
218        assert_eq!(results.len(), inputs.len());
219        assert_eq!(results.len(), expected.len());
220
221        for i in 0..results.len() {
222            match &results[i] {
223                Value::Object(map) => {
224                    assert_eq!(
225                        map.len(),
226                        2,
227                        "Each serialized result should only have two top level entries"
228                    );
229                    let input = TypeInput::deserialize(&map["input"]).unwrap();
230                    assert_eq!(input, inputs[i].clone());
231                    assert_eq!(map["results"], expected[i]);
232                }
233                _ => panic!("incorrect formatting for output {}", i),
234            }
235        }
236    }
237
238    #[test]
239    fn test_atomic_save() {
240        let dir = tempfile::tempdir().unwrap();
241        let path = dir.path();
242
243        let message: &str = "hello world";
244        let full = path.join("file.json");
245        assert!(!full.exists());
246        assert!(atomic_save(&full, message).is_ok());
247        assert!(full.exists());
248
249        // Deserialize
250        let deserialized: String = load_from_file(&full);
251        assert_eq!(deserialized, message);
252
253        // Atomic save should fail if the temp file already exists.
254        std::fs::File::create(path.join("file.json.temp")).unwrap();
255
256        let err = atomic_save(&full, message).unwrap_err();
257        let message = format!("{:?}", err);
258        assert!(message.contains("Temporary file"));
259        assert!(message.contains("already exists"));
260    }
261
262    #[test]
263    fn test_checkpoint() {
264        let dir = tempfile::tempdir().unwrap();
265        let path = dir.path();
266
267        let savepath = path.join("output.json");
268
269        let inputs = [
270            TypeInput::new(DataType::Float32, 1, false),
271            TypeInput::new(DataType::Float16, 2, false),
272            TypeInput::new(DataType::Float64, 3, false),
273        ];
274
275        let serialized: Vec<_> = inputs
276            .iter()
277            .map(|i| serde_json::to_value(i).unwrap())
278            .collect();
279
280        // No saved values.
281        {
282            let checkpoint = Checkpoint::new(&serialized, &[], &savepath).unwrap();
283            assert!(!savepath.exists());
284            checkpoint.save().unwrap();
285            assert!(savepath.exists());
286            let reloaded: Vec<Value> = load_from_file(&savepath);
287            assert!(reloaded.is_empty());
288
289            // Append a new result.
290            checkpoint.checkpoint("some string").unwrap();
291            let reloaded: Vec<Value> = load_from_file(&savepath);
292            check_results(
293                &reloaded,
294                &inputs[0..1],
295                &[Value::String("some string".into())],
296            );
297        }
298
299        // One saved value.
300        {
301            let values = vec![serde_json::to_value("some result").unwrap()];
302            let checkpoint = Checkpoint::new(&serialized, &values, &savepath).unwrap();
303
304            checkpoint.save().unwrap();
305            {
306                let reloaded: Vec<Value> = load_from_file(&savepath);
307                check_results(
308                    &reloaded,
309                    &inputs[0..1],
310                    &[Value::String("some result".into())],
311                );
312            }
313
314            // Checkpointing will now yield 2 elements.
315            checkpoint.checkpoint("another result").unwrap();
316            {
317                let reloaded: Vec<Value> = load_from_file(&savepath);
318                check_results(
319                    &reloaded,
320                    &inputs[0..2],
321                    &[
322                        Value::String("some result".into()),
323                        Value::String("another result".into()),
324                    ],
325                );
326            }
327        }
328
329        // Full checkpoint.
330        {
331            let values = vec![
332                serde_json::to_value("a").unwrap(),
333                serde_json::to_value("b").unwrap(),
334                serde_json::to_value("c").unwrap(),
335            ];
336            let checkpoint = Checkpoint::new(&serialized, &values, &savepath).unwrap();
337            checkpoint.save().unwrap();
338            let reloaded: Vec<Value> = load_from_file(&savepath);
339
340            check_results(
341                &reloaded,
342                &inputs,
343                &[
344                    Value::String("a".into()),
345                    Value::String("b".into()),
346                    Value::String("c".into()),
347                ],
348            );
349
350            // If we try to checkpoint, we should get an error.
351            let err = checkpoint.checkpoint("too full").unwrap_err();
352            let message = err.to_string();
353            assert!(message.contains("internal error - checkpoint is full"));
354        }
355
356        // Malformed Input
357        {
358            let values = vec![
359                serde_json::to_value("a").unwrap(),
360                serde_json::to_value("b").unwrap(),
361                serde_json::to_value("c").unwrap(),
362                serde_json::to_value("d").unwrap(),
363            ];
364            let err = Checkpoint::new(&serialized, &values, &savepath).unwrap_err();
365            let message = err.to_string();
366            assert!(message.contains("internal error - results len"));
367        }
368    }
369}