diskann_benchmark_runner/
result.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6//! A utility for providing incremental saving of results.
7
8use std::path::Path;
9
10use serde::{ser::SerializeSeq, Serialize, Serializer};
11
12/// A helper to generate incremental snapshots of data while a benchmark is progressing.
13///
14/// Benchmark implementations may use this to save results as they become available rather
15/// than waiting until the end.
16#[derive(Debug, Clone, Copy)]
17pub struct Checkpoint<'a> {
18    inner: Option<CheckpointInner<'a>>,
19}
20
21impl<'a> Checkpoint<'a> {
22    /// Create a new check-point that serializes the zip-combination of `input` and `results`
23    /// to `path`.
24    ///
25    /// This is meant to be used in context where we wish to incrementally save new results
26    /// along with all the results generated so far. As such, this requires
27    /// ```text
28    /// inputs.len() <= results.len()
29    /// ```
30    /// Subsequent calls to `checkpoint` will be assumed to belong to the input at
31    /// `results.len() + 1` and will be saved at that position.
32    pub(crate) fn new(
33        input: &'a [serde_json::Value],
34        results: &'a [serde_json::Value],
35        path: &'a Path,
36    ) -> anyhow::Result<Self> {
37        if results.len() > input.len() {
38            Err(anyhow::Error::msg(format!(
39                "internal error - results len ({}) is greater than input len ({})",
40                results.len(),
41                input.len(),
42            )))
43        } else {
44            Ok(Self {
45                inner: Some(CheckpointInner {
46                    input,
47                    results,
48                    path,
49                }),
50            })
51        }
52    }
53
54    /// Create an empty checkpointer that turns calls to `checkpoint` into a no-op.
55    pub(crate) fn empty() -> Self {
56        Self { inner: None }
57    }
58
59    /// Atomically save the zip of the inputs and results to the configured path.
60    pub fn save(&self) -> anyhow::Result<()> {
61        if let Some(inner) = &self.inner {
62            atomic_save(inner.path, &inner)
63        } else {
64            Ok(())
65        }
66    }
67
68    /// Treat `partial` as a new partial result for the current contents of the checkpoint.
69    ///
70    /// All previously generated results will be saved and `partial` will be grouped at
71    /// the input at `self.inner.results.len() + 1`.
72    ///
73    /// This function should only be called if `self` is not full (as in, there is at least
74    /// one input that does not have a corresponding result.
75    pub fn checkpoint<T: Serialize + ?Sized>(&self, partial: &T) -> anyhow::Result<()> {
76        if let Some(inner) = &self.inner {
77            if inner.results.len() == inner.input.len() {
78                Err(anyhow::Error::msg("internal error - checkpoint is full"))
79            } else {
80                let appended = Appended {
81                    checkpoint: *inner,
82                    partial: serde_json::to_value(partial)?,
83                };
84                atomic_save(inner.path, &appended)
85            }
86        } else {
87            Ok(())
88        }
89    }
90}
91
92/// Atomically save the serializable `object` to a JSON file at `path`.
93///
94/// This function works by first serializing to `format!("{}.temp", path)` and then using
95/// `std::fs::rename`, making the operation safe from interrupts.
96///
97/// This can fail for a number of reasons:
98///
99/// 1. `path` is not an valid file path.
100/// 2. The temporary file `format!("{}.temp", path)` already exists.
101/// 3. Serialization fails.
102/// 4. Renaming fails.
103pub(crate) fn atomic_save<T>(path: &Path, object: &T) -> anyhow::Result<()>
104where
105    T: Serialize + ?Sized,
106{
107    let temp = format!("{}.temp", path.display());
108    if Path::new(&temp).exists() {
109        return Err(anyhow::Error::msg(format!(
110            "Temporary file {} already exists. Aborting!",
111            temp
112        )));
113    }
114
115    let buffer = std::fs::File::create(&temp)?;
116    serde_json::to_writer_pretty(buffer, object)?;
117    std::fs::rename(&temp, path)?;
118    Ok(())
119}
120
121////////////////////////////
122// Implementation Details //
123////////////////////////////
124
125#[derive(Debug, Clone, Copy)]
126struct CheckpointInner<'a> {
127    input: &'a [serde_json::Value],
128    results: &'a [serde_json::Value],
129    path: &'a Path,
130}
131
132// This applies the "zip" like behavior between pairs of `input` and `results` in
133// `CheckpointInner`, so the data structure can act as a vector of pairs rather than as a
134// pair of vectors.
135#[derive(Debug, Serialize)]
136struct SingleResult<'a> {
137    input: &'a serde_json::Value,
138    results: &'a serde_json::Value,
139}
140impl Serialize for CheckpointInner<'_> {
141    /// Serialize up to `self.results.len()` pairs of inputs and results.
142    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
143    where
144        S: Serializer,
145    {
146        let mut seq = serializer.serialize_seq(Some(self.results.len()))?;
147        for (input, results) in std::iter::zip(self.input.iter(), self.results.iter()) {
148            seq.serialize_element(&SingleResult { input, results })?;
149        }
150        seq.end()
151    }
152}
153
154/// A lazily appended partial data result.
155///
156/// NOTE: The associated `Checkpoint` must "have room" for an additional value, That is,
157/// `checkpoint.results.len() < checkpoint.input.len()`.
158struct Appended<'a> {
159    checkpoint: CheckpointInner<'a>,
160    partial: serde_json::Value,
161}
162
163impl Serialize for Appended<'_> {
164    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
165    where
166        S: Serializer,
167    {
168        let mut seq = serializer.serialize_seq(Some(self.checkpoint.results.len() + 1))?;
169        std::iter::zip(
170            self.checkpoint.input.iter(),
171            self.checkpoint
172                .results
173                .iter()
174                .chain(std::iter::once(&self.partial)),
175        )
176        .try_for_each(|(input, results)| seq.serialize_element(&SingleResult { input, results }))?;
177        seq.end()
178    }
179}
180
181///////////
182// Tests //
183///////////
184
185#[cfg(test)]
186mod tests {
187    use serde::Deserialize;
188    use serde_json::value::Value;
189
190    use super::*;
191    use crate::{test::TypeInput, utils::datatype::DataType};
192
193    fn load_from_file<T>(path: &std::path::Path) -> T
194    where
195        T: for<'a> Deserialize<'a>,
196    {
197        let file = std::fs::File::open(path).unwrap();
198        let reader = std::io::BufReader::new(file);
199        serde_json::from_reader(reader).unwrap()
200    }
201
202    // Check that each result has the form:
203    // ```
204    // {
205    //     input: <input-object>,
206    //     results: <result-object>,
207    // }
208    // ```
209    fn check_results(results: &[Value], inputs: &[TypeInput], expected: &[Value]) {
210        assert_eq!(results.len(), inputs.len());
211        assert_eq!(results.len(), expected.len());
212
213        for i in 0..results.len() {
214            match &results[i] {
215                Value::Object(map) => {
216                    assert_eq!(
217                        map.len(),
218                        2,
219                        "Each serialized result should only have two top level entries"
220                    );
221                    let input = TypeInput::deserialize(&map["input"]).unwrap();
222                    assert_eq!(input, inputs[i].clone());
223                    assert_eq!(map["results"], expected[i]);
224                }
225                _ => panic!("incorrect formatting for output {}", i),
226            }
227        }
228    }
229
230    #[test]
231    fn test_atomic_save() {
232        let dir = tempfile::tempdir().unwrap();
233        let path = dir.path();
234
235        let message: &str = "hello world";
236        let full = path.join("file.json");
237        assert!(!full.exists());
238        assert!(atomic_save(&full, message).is_ok());
239        assert!(full.exists());
240
241        // Deserialize
242        let deserialized: String = load_from_file(&full);
243        assert_eq!(deserialized, message);
244
245        // Atomic save should fail if the temp file already exists.
246        std::fs::File::create(path.join("file.json.temp")).unwrap();
247
248        let err = atomic_save(&full, message).unwrap_err();
249        let message = format!("{:?}", err);
250        assert!(message.contains("Temporary file"));
251        assert!(message.contains("already exists"));
252    }
253
254    #[test]
255    fn test_empty() {
256        let checkpoint = Checkpoint::empty();
257
258        // Make sure we can still call "save" and "checkpoint".
259        assert!(checkpoint.save().is_ok());
260        assert!(checkpoint.checkpoint("hello world").is_ok());
261    }
262
263    #[test]
264    fn test_checkpoint() {
265        let dir = tempfile::tempdir().unwrap();
266        let path = dir.path();
267
268        let savepath = path.join("output.json");
269
270        let inputs = [
271            TypeInput::new(DataType::Float32, 1, false),
272            TypeInput::new(DataType::Float16, 2, false),
273            TypeInput::new(DataType::Float64, 3, false),
274        ];
275
276        let serialized: Vec<_> = inputs
277            .iter()
278            .map(|i| serde_json::to_value(i).unwrap())
279            .collect();
280
281        // No saved values.
282        {
283            let checkpoint = Checkpoint::new(&serialized, &[], &savepath).unwrap();
284            assert!(!savepath.exists());
285            checkpoint.save().unwrap();
286            assert!(savepath.exists());
287            let reloaded: Vec<Value> = load_from_file(&savepath);
288            assert!(reloaded.is_empty());
289
290            // Append a new result.
291            checkpoint.checkpoint("some string").unwrap();
292            let reloaded: Vec<Value> = load_from_file(&savepath);
293            check_results(
294                &reloaded,
295                &inputs[0..1],
296                &[Value::String("some string".into())],
297            );
298        }
299
300        // One saved value.
301        {
302            let values = vec![serde_json::to_value("some result").unwrap()];
303            let checkpoint = Checkpoint::new(&serialized, &values, &savepath).unwrap();
304
305            checkpoint.save().unwrap();
306            {
307                let reloaded: Vec<Value> = load_from_file(&savepath);
308                check_results(
309                    &reloaded,
310                    &inputs[0..1],
311                    &[Value::String("some result".into())],
312                );
313            }
314
315            // Checkpointing will now yield 2 elements.
316            checkpoint.checkpoint("another result").unwrap();
317            {
318                let reloaded: Vec<Value> = load_from_file(&savepath);
319                check_results(
320                    &reloaded,
321                    &inputs[0..2],
322                    &[
323                        Value::String("some result".into()),
324                        Value::String("another result".into()),
325                    ],
326                );
327            }
328        }
329
330        // Full checkpoint.
331        {
332            let values = vec![
333                serde_json::to_value("a").unwrap(),
334                serde_json::to_value("b").unwrap(),
335                serde_json::to_value("c").unwrap(),
336            ];
337            let checkpoint = Checkpoint::new(&serialized, &values, &savepath).unwrap();
338            checkpoint.save().unwrap();
339            let reloaded: Vec<Value> = load_from_file(&savepath);
340
341            check_results(
342                &reloaded,
343                &inputs,
344                &[
345                    Value::String("a".into()),
346                    Value::String("b".into()),
347                    Value::String("c".into()),
348                ],
349            );
350
351            // If we try to checkpoint, we should get an error.
352            let err = checkpoint.checkpoint("too full").unwrap_err();
353            let message = err.to_string();
354            assert!(message.contains("internal error - checkpoint is full"));
355        }
356
357        // Malformed Input
358        {
359            let values = vec![
360                serde_json::to_value("a").unwrap(),
361                serde_json::to_value("b").unwrap(),
362                serde_json::to_value("c").unwrap(),
363                serde_json::to_value("d").unwrap(),
364            ];
365            let err = Checkpoint::new(&serialized, &values, &savepath).unwrap_err();
366            let message = err.to_string();
367            assert!(message.contains("internal error - results len"));
368        }
369    }
370}