Skip to main content

diskann_benchmark_runner/
app.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::{io::Write, path::PathBuf};
7
8use clap::{Parser, Subcommand};
9
10use crate::{
11    jobs::{self, Jobs},
12    output::Output,
13    registry,
14    result::Checkpoint,
15    utils::fmt::Banner,
16};
17
18/// Parsed command line options.
19#[derive(Debug, Subcommand)]
20pub enum Commands {
21    /// List the kinds of input formats available for ingestion.
22    Inputs {
23        /// Describe the layout of the named input kind.
24        describe: Option<String>,
25    },
26    /// List the available benchmarks.
27    Benchmarks {},
28    /// Provide a skeleton JSON file for running a set of benchmarks.
29    Skeleton,
30    /// Run a list of benchmarks.
31    Run {
32        /// The input file to run.
33        #[arg(long = "input-file")]
34        input_file: PathBuf,
35        /// The path where the output file should reside.
36        #[arg(long = "output-file")]
37        output_file: PathBuf,
38        /// Parse an input file and perform all validation checks, but don't actually run any
39        /// benchmarks.
40        #[arg(long, action)]
41        dry_run: bool,
42    },
43}
44
45/// The CLI used to drive a benchmark application.
46#[derive(Debug, Parser)]
47pub struct App {
48    #[command(subcommand)]
49    command: Commands,
50}
51
52impl App {
53    /// Construct [`Self`] by parsing commandline arguments from [`std::env::args]`.
54    ///
55    /// This simply redirects to [`clap::Parser::parse`] and is provided to allow parsing
56    /// without the [`clap::Parser`] trait in scope.
57    pub fn parse() -> Self {
58        <Self as clap::Parser>::parse()
59    }
60
61    /// Construct [`Self`] by parsing command line arguments from the iterator.
62    ///
63    /// This simply redirects to [`clap::Parser::try_parse_from`] and is provided to allow
64    /// parsing without the [`clap::Parser`] trait in scope.
65    pub fn try_parse_from<I, T>(itr: I) -> anyhow::Result<Self>
66    where
67        I: IntoIterator<Item = T>,
68        T: Into<std::ffi::OsString> + Clone,
69    {
70        Ok(<Self as clap::Parser>::try_parse_from(itr)?)
71    }
72
73    /// Construct [`Self`] directly from a [`Commands`] enum.
74    pub fn from_commands(command: Commands) -> Self {
75        Self { command }
76    }
77
78    /// Run the application using the registered `inputs` and `outputs`.
79    pub fn run(
80        &self,
81        inputs: &registry::Inputs,
82        benchmarks: &registry::Benchmarks,
83        mut output: &mut dyn Output,
84    ) -> anyhow::Result<()> {
85        match &self.command {
86            // If a named benchmark isn't given, then list the available benchmarks.
87            Commands::Inputs { describe } => {
88                if let Some(describe) = describe {
89                    if let Some(input) = inputs.get(describe) {
90                        let repr = jobs::Unprocessed::format_input(input)?;
91                        writeln!(
92                            output,
93                            "The example JSON representation for \"{}\" is:",
94                            describe
95                        )?;
96                        writeln!(output, "{}", serde_json::to_string_pretty(&repr)?)?;
97                        return Ok(());
98                    } else {
99                        writeln!(output, "No input found for \"{}\"", describe)?;
100                    }
101
102                    return Ok(());
103                }
104
105                writeln!(output, "Available input kinds are listed below:")?;
106                let mut tags: Vec<_> = inputs.tags().collect();
107                tags.sort();
108                for i in tags.iter() {
109                    writeln!(output, "    {}", i)?;
110                }
111            }
112            // List the available benchmarks.
113            Commands::Benchmarks {} => {
114                writeln!(output, "Registered Benchmarks:")?;
115                for (name, method) in benchmarks.methods() {
116                    writeln!(output, "    {}: {}", name, method.signatures()[0])?;
117                }
118            }
119            Commands::Skeleton => {
120                writeln!(output, "Skeleton input file:")?;
121                writeln!(output, "{}", Jobs::example()?)?;
122            }
123            // Run the benchmarks
124            Commands::Run {
125                input_file,
126                output_file,
127                dry_run,
128            } => {
129                // Parse and validate the input.
130                let run = Jobs::load(input_file, inputs)?;
131                // Check if we have a match for each benchmark.
132                for job in run.jobs().iter() {
133                    if !benchmarks.has_match(job) {
134                        let repr = serde_json::to_string_pretty(&job.serialize()?)?;
135
136                        const MAX_METHODS: usize = 3;
137                        let mismatches = match benchmarks.debug(job, MAX_METHODS) {
138                            // Debug should return `Err` if there is not a match.
139                            // Returning `Ok(())` here indicates an internal error with the
140                            // dispatcher.
141                            Ok(()) => {
142                                return Err(anyhow::Error::msg(format!(
143                                    "experienced internal error while debugging:\n{}",
144                                    repr
145                                )))
146                            }
147                            Err(m) => m,
148                        };
149
150                        writeln!(
151                            output,
152                            "Could not find a match for the following input:\n\n{}\n",
153                            repr
154                        )?;
155                        writeln!(output, "Closest matches:\n")?;
156                        for (i, mismatch) in mismatches.into_iter().enumerate() {
157                            writeln!(
158                                output,
159                                "    {}. \"{}\": {}",
160                                i + 1,
161                                mismatch.method(),
162                                mismatch.reason(),
163                            )?;
164                        }
165                        writeln!(output)?;
166
167                        return Err(anyhow::Error::msg(
168                            "could not find find a benchmark for all inputs",
169                        ));
170                    }
171                }
172
173                if *dry_run {
174                    writeln!(
175                        output,
176                        "Success - skipping running benchmarks because \"--dry-run\" was used."
177                    )?;
178                    return Ok(());
179                }
180
181                // The collection of output results for each run.
182                let mut results = Vec::<serde_json::Value>::new();
183
184                // Now - we've verified the integrity of all the jobs we want to run and that
185                // each job can match an associated benchmark.
186                //
187                // All that's left is to actually run the benchmarks.
188                let jobs = run.jobs();
189                let serialized = jobs
190                    .iter()
191                    .map(|job| {
192                        serde_json::to_value(jobs::Unprocessed::new(
193                            job.tag().into(),
194                            job.serialize()?,
195                        ))
196                    })
197                    .collect::<Result<Vec<_>, serde_json::Error>>()?;
198                for (i, job) in jobs.iter().enumerate() {
199                    let prefix: &str = if i != 0 { "\n\n" } else { "" };
200                    writeln!(
201                        output,
202                        "{}{}",
203                        prefix,
204                        Banner::new(&format!("Running Job {} of {}", i + 1, jobs.len()))
205                    )?;
206
207                    // Run the specified job.
208                    let checkpoint = Checkpoint::new(&serialized, &results, output_file)?;
209                    let r = benchmarks.call(job, checkpoint, output)?;
210
211                    // Collect the results
212                    results.push(r);
213
214                    // Save everything.
215                    Checkpoint::new(&serialized, &results, output_file)?.save()?;
216                }
217            }
218        };
219        Ok(())
220    }
221}
222
223///////////
224// Tests //
225///////////
226
227/// The integration test below look inside the `tests` directory for folders.
228///
229/// ## Input Files
230///
231/// Each folder should have at least a `stdin.txt` file specifying the command line to give
232/// to the `App` parser.
233///
234/// Within the `stdin.txt` command line, there are several special symbols:
235///
236/// * $INPUT - Resolves to `input.json` in the same directory as the `stdin.txt` file.
237/// * $OUTPUT - Resolves to `output.json` in a temporary directory.
238///
239/// As mentioned - an input JSON file can be included and must be named "input.json" to be
240/// discoverable.
241///
242/// ## Output Files
243///
244/// Tests should have at least a `stdout.txt` file with the expected outputs for running the
245/// command in `stdin.txt`. If an output JSON file is expected, it should be name `output.json`.
246///
247/// ## Test Discovery and Running
248///
249/// The unit test will visit each folder in `tests` and run the outlined scenario. The
250/// `stdout.txt` expected output is compared to the actual output and if they do not match,
251/// the test fails.
252///
253/// Additionally, if `output.json` is present, the unit test will verify that (1) the command
254/// did in fact produce an output JSON file and (2) the generated file matches the expected file.
255///
256/// ## Regenerating Expected Results
257///
258/// The benchmark output will naturally change over time. Running the unit tests with the
259/// environment variable
260/// ```text
261/// POCKETBENCH_TEST=overwrite
262/// ```
263/// will replace the `stdout.txt` (and `output.json` if one was generated) for each test
264/// scenario. Developers should then consult `git diff` to ensure that major regressions
265/// to the output did not occur.
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    use std::{
271        ffi::OsString,
272        path::{Path, PathBuf},
273    };
274
275    use crate::{registry, ux};
276
277    const ENV: &str = "POCKETBENCH_TEST";
278
279    // Expected I/O files.
280    const STDIN: &str = "stdin.txt";
281    const STDOUT: &str = "stdout.txt";
282    const INPUT_FILE: &str = "input.json";
283    const OUTPUT_FILE: &str = "output.json";
284
285    // Read the entire contents of a file to a string.
286    fn read_to_string<P: AsRef<Path>>(path: P, ctx: &str) -> String {
287        match std::fs::read_to_string(path.as_ref()) {
288            Ok(s) => ux::normalize(s),
289            Err(err) => panic!(
290                "failed to read {} {:?} with error: {}",
291                ctx,
292                path.as_ref(),
293                err
294            ),
295        }
296    }
297
298    // Check if `POCKETBENCH_TEST=overwrite` is configured. Return `true` if so - otherwise
299    // return `false`.
300    //
301    // If `POCKETBENCH_TEST` is set but its value is not `overwrite` - panic.
302    fn overwrite() -> bool {
303        match std::env::var(ENV) {
304            Ok(v) => {
305                if v == "overwrite" {
306                    true
307                } else {
308                    panic!(
309                        "Unknown value for {}: \"{}\". Expected \"overwrite\"",
310                        ENV, v
311                    );
312                }
313            }
314            Err(std::env::VarError::NotPresent) => false,
315            Err(std::env::VarError::NotUnicode(_)) => {
316                panic!("Value for {} is not unicode", ENV);
317            }
318        }
319    }
320
321    // Test Runner
322    struct Test {
323        dir: PathBuf,
324        overwrite: bool,
325    }
326
327    impl Test {
328        fn new(dir: &Path) -> Self {
329            Self {
330                dir: dir.into(),
331                overwrite: overwrite(),
332            }
333        }
334
335        fn parse_stdin(&self, tempdir: &Path) -> App {
336            let path = self.dir.join(STDIN);
337
338            // Read the standard input file to a string.
339            let stdin = read_to_string(&path, "standard input");
340
341            let args: Vec<OsString> = stdin
342                .split_whitespace()
343                .map(|v| -> OsString { self.resolve(v, tempdir).into() })
344                .collect();
345
346            // Split and resolve special symbols
347            App::try_parse_from(std::iter::once(OsString::from("test-app")).chain(args)).unwrap()
348        }
349
350        fn resolve(&self, s: &str, tempdir: &Path) -> PathBuf {
351            if s == "$INPUT" {
352                self.dir.join(INPUT_FILE)
353            } else if s == "$OUTPUT" {
354                tempdir.join(OUTPUT_FILE)
355            } else {
356                s.into()
357            }
358        }
359
360        fn run(&self, tempdir: &Path) {
361            let app = self.parse_stdin(tempdir);
362
363            // Register inputs
364            let mut inputs = registry::Inputs::new();
365            crate::test::register_inputs(&mut inputs).unwrap();
366
367            // Register outputs
368            let mut benchmarks = registry::Benchmarks::new();
369            crate::test::register_benchmarks(&mut benchmarks);
370
371            // Run app - collecting output into a buffer.
372            //
373            // If the app returns an error - format the error to the output buffer as well
374            // using the debug formatting option.
375            let mut buffer = crate::output::Memory::new();
376            if let Err(err) = app.run(&inputs, &benchmarks, &mut buffer) {
377                let mut b: &mut dyn crate::Output = &mut buffer;
378                write!(b, "{:?}", err).unwrap();
379            }
380
381            // Check that `stdout` matches
382            let stdout: String =
383                ux::normalize(ux::strip_backtrace(buffer.into_inner().try_into().unwrap()));
384            let output = self.dir.join(STDOUT);
385            if self.overwrite {
386                std::fs::write(output, stdout).unwrap();
387            } else {
388                let expected = read_to_string(&output, "expected standard output");
389                if stdout != expected {
390                    panic!("Got:\n--\n{}\n--\nExpected:\n--\n{}\n--", stdout, expected);
391                }
392            }
393
394            // Check that the output files match.
395            let output_path = tempdir.join(OUTPUT_FILE);
396            let was_output_generated = output_path.is_file();
397
398            let expected_output_path = self.dir.join(OUTPUT_FILE);
399            let is_output_expected = expected_output_path.is_file();
400
401            if self.overwrite {
402                // Copy the output file to the destination.
403                if was_output_generated {
404                    println!(
405                        "Moving generated output file {:?} to {:?}",
406                        output_path, expected_output_path
407                    );
408
409                    if let Err(err) = std::fs::rename(&output_path, &expected_output_path) {
410                        panic!(
411                            "Moving generated output file {:?} to expected location {:?} failed: {}",
412                            output_path, expected_output_path, err
413                        );
414                    }
415                } else if is_output_expected {
416                    println!("Removing outdated output file {:?}", expected_output_path);
417                    if let Err(err) = std::fs::remove_file(&expected_output_path) {
418                        panic!(
419                            "Failed removing outdated output file {:?}: {}",
420                            expected_output_path, err
421                        );
422                    }
423                }
424            } else {
425                match (was_output_generated, is_output_expected) {
426                    (true, true) => {
427                        let output_contents = read_to_string(output_path, "generated output JSON");
428
429                        let expected_contents =
430                            read_to_string(expected_output_path, "expected output JSON");
431
432                        if output_contents != expected_contents {
433                            panic!(
434                                "Got:\n\n{}\n\nExpected:\n\n{}\n",
435                                output_contents, expected_contents
436                            );
437                        }
438                    }
439                    (true, false) => {
440                        let output_contents = read_to_string(output_path, "generated output JSON");
441
442                        panic!(
443                            "An output JSON was generated when none was expected. Contents:\n\n{}",
444                            output_contents
445                        );
446                    }
447                    (false, true) => {
448                        panic!("No output JSON was generated when one was expected");
449                    }
450                    (false, false) => { /* this is okay */ }
451                }
452            }
453        }
454    }
455
456    fn run_specific_test(test_dir: &Path) {
457        println!("running test in {:?}", test_dir);
458        let temp_dir = tempfile::tempdir().unwrap();
459        Test::new(test_dir).run(temp_dir.path());
460    }
461
462    fn run_all_tests_in(dir: &str) {
463        let dir: PathBuf = format!("{}/tests/{}", env!("CARGO_MANIFEST_DIR"), dir).into();
464        for entry in std::fs::read_dir(dir).unwrap() {
465            let entry = entry.unwrap();
466            if let Ok(file_type) = entry.file_type() {
467                if file_type.is_dir() {
468                    run_specific_test(&entry.path());
469                }
470            } else {
471                panic!("couldn't get file type for {:?}", entry.path());
472            }
473        }
474    }
475
476    #[test]
477    fn top_level_tests() {
478        run_all_tests_in("");
479    }
480}