diskann_benchmark_runner/
app.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::{io::Write, path::PathBuf};
7
8use clap::{Parser, Subcommand};
9
10use crate::{
11    jobs::{self, Jobs},
12    output::Output,
13    registry,
14    result::Checkpoint,
15    utils::fmt::Banner,
16};
17
18/// Parsed command line options.
19#[derive(Debug, Subcommand)]
20pub enum Commands {
21    /// List the kinds of input formats available for ingestion.
22    Inputs {
23        /// Describe the layout of the named input kind.
24        describe: Option<String>,
25    },
26    /// List the available benchmarks.
27    Benchmarks {},
28    /// Provide a skeleton JSON file for running a set of benchmarks.
29    Skeleton,
30    /// Run a list of benchmarks.
31    Run {
32        /// The input file to run.
33        #[arg(long = "input-file")]
34        input_file: PathBuf,
35        /// The path where the output file should reside.
36        #[arg(long = "output-file")]
37        output_file: PathBuf,
38        /// Parse an input file and perform all validation checks, but don't actually run any
39        /// benchmarks.
40        #[arg(long, action)]
41        dry_run: bool,
42    },
43}
44
45/// The CLI used to drive a benchmark application.
46#[derive(Debug, Parser)]
47pub struct App {
48    #[command(subcommand)]
49    command: Commands,
50}
51
52impl App {
53    /// Construct [`Self`] by parsing commandline arguments from [`std::env::args]`.
54    ///
55    /// This simply redirects to [`clap::Parser::parse`] and is provided to allow parsing
56    /// without the [`clap::Parser`] trait in scope.
57    pub fn parse() -> Self {
58        <Self as clap::Parser>::parse()
59    }
60
61    /// Construct [`Self`] by parsing command line arguments from the iterator.
62    ///
63    /// This simply redirects to [`clap::Parser::try_parse_from`] and is provided to allow
64    /// parsing without the [`clap::Parser`] trait in scope.
65    pub fn try_parse_from<I, T>(itr: I) -> anyhow::Result<Self>
66    where
67        I: IntoIterator<Item = T>,
68        T: Into<std::ffi::OsString> + Clone,
69    {
70        Ok(<Self as clap::Parser>::try_parse_from(itr)?)
71    }
72
73    /// Construct [`Self`] directly from a [`Commands`] enum.
74    pub fn from_commands(command: Commands) -> Self {
75        Self { command }
76    }
77
78    /// Run the application using the registered `inputs` and `outputs`.
79    pub fn run(
80        &self,
81        inputs: &registry::Inputs,
82        benchmarks: &registry::Benchmarks,
83        mut output: &mut dyn Output,
84    ) -> anyhow::Result<()> {
85        match &self.command {
86            // If a named benchmark isn't given, then list the available benchmarks.
87            Commands::Inputs { describe } => {
88                if let Some(describe) = describe {
89                    if let Some(input) = inputs.get(describe) {
90                        let repr = jobs::Unprocessed::format_input(input)?;
91                        writeln!(
92                            output,
93                            "The example JSON representation for \"{}\" is:",
94                            describe
95                        )?;
96                        writeln!(output, "{}", serde_json::to_string_pretty(&repr)?)?;
97                        return Ok(());
98                    } else {
99                        writeln!(output, "No input found for \"{}\"", describe)?;
100                    }
101
102                    return Ok(());
103                }
104
105                writeln!(output, "Available input kinds are listed below:")?;
106                let mut tags: Vec<_> = inputs.tags().collect();
107                tags.sort();
108                for i in tags.iter() {
109                    writeln!(output, "    {}", i)?;
110                }
111            }
112            // List the available benchmarks.
113            Commands::Benchmarks {} => {
114                writeln!(output, "Registered Benchmarks:")?;
115                for (name, method) in benchmarks.methods() {
116                    writeln!(output, "    {}: {}", name, method.signatures()[0])?;
117                }
118            }
119            Commands::Skeleton => {
120                writeln!(output, "Skeleton input file:")?;
121                writeln!(output, "{}", Jobs::example()?)?;
122            }
123            // Run the benchmarks
124            Commands::Run {
125                input_file,
126                output_file,
127                dry_run,
128            } => {
129                // Parse and validate the input.
130                let run = Jobs::load(input_file, inputs)?;
131                // Check if we have a match for each benchmark.
132                for job in run.jobs().iter() {
133                    if !benchmarks.has_match(job) {
134                        let repr = serde_json::to_string_pretty(&job.serialize()?)?;
135
136                        const MAX_METHODS: usize = 3;
137                        let mismatches = match benchmarks.debug(job, MAX_METHODS) {
138                            // Debug should return `Err` if there is not a match.
139                            // Returning `Ok(())` here indicates an internal error with the
140                            // dispatcher.
141                            Ok(()) => {
142                                return Err(anyhow::Error::msg(format!(
143                                    "experienced internal error while debugging:\n{}",
144                                    repr
145                                )))
146                            }
147                            Err(m) => m,
148                        };
149
150                        writeln!(
151                            output,
152                            "Could not find a match for the following input:\n\n{}\n",
153                            repr
154                        )?;
155                        writeln!(output, "Closest matches:\n")?;
156                        for (i, mismatch) in mismatches.into_iter().enumerate() {
157                            writeln!(
158                                output,
159                                "    {}. \"{}\": {}",
160                                i + 1,
161                                mismatch.method(),
162                                mismatch.reason(),
163                            )?;
164                        }
165                        writeln!(output)?;
166
167                        return Err(anyhow::Error::msg(
168                            "could not find find a benchmark for all inputs",
169                        ));
170                    }
171                }
172
173                if *dry_run {
174                    writeln!(
175                        output,
176                        "Success - skipping running benchmarks because \"--dry-run\" was used."
177                    )?;
178                    return Ok(());
179                }
180
181                // The collection of output results for each run.
182                let mut results = Vec::<serde_json::Value>::new();
183
184                // Now - we've verified the integrity of all the jobs we want to run and that
185                // each job can match an associated benchmark.
186                //
187                // All that's left is to actually run the benchmarks.
188                let jobs = run.jobs();
189                let serialized = jobs
190                    .iter()
191                    .map(|job| {
192                        serde_json::to_value(jobs::Unprocessed::new(
193                            job.tag().into(),
194                            job.serialize()?,
195                        ))
196                    })
197                    .collect::<Result<Vec<_>, serde_json::Error>>()?;
198                for (i, job) in jobs.iter().enumerate() {
199                    let prefix: &str = if i != 0 { "\n\n" } else { "" };
200                    writeln!(
201                        output,
202                        "{}{}",
203                        prefix,
204                        Banner::new(&format!("Running Job {} of {}", i + 1, jobs.len()))
205                    )?;
206
207                    // Run the specified job.
208                    let checkpoint = Checkpoint::new(&serialized, &results, output_file)?;
209                    let r = benchmarks.call(job, checkpoint, output)?;
210
211                    // Collect the results
212                    results.push(r);
213
214                    // Save everything.
215                    Checkpoint::new(&serialized, &results, output_file)?.save()?;
216                }
217            }
218        };
219        Ok(())
220    }
221}
222
223///////////
224// Tests //
225///////////
226
227/// The integration test below look inside the `tests` directory for folders.
228///
229/// ## Input Files
230///
231/// Each folder should have at least a `stdin.txt` file specifying the command line to give
232/// to the `App` parser.
233///
234/// Within the `stdin.txt` command line, there are several special symbols:
235///
236/// * $INPUT - Resolves to `input.json` in the same directory as the `stdin.txt` file.
237/// * $OUTPUT - Resolves to `output.json` in a temporary directory.
238///
239/// As mentioned - an input JSON file can be included and must be named "input.json" to be
240/// discoverable.
241///
242/// ## Output Files
243///
244/// Tests should have at least a `stdout.txt` file with the expected outputs for running the
245/// command in `stdin.txt`. If an output JSON file is expected, it should be name `output.json`.
246///
247/// ## Test Discovery and Running
248///
249/// The unit test will visit each folder in `tests` and run the outlined scenario. The
250/// `stdout.txt` expected output is compared to the actual output and if they do not match,
251/// the test fails.
252///
253/// Additionally, if `output.json` is present, the unit test will verify that (1) the command
254/// did in fact produce an output JSON file and (2) the generated file matches the expected file.
255///
256/// ## Regenerating Expected Results
257///
258/// The benchmark output will naturally change over time. Running the unit tests with the
259/// environment variable
260/// ```text
261/// POCKETBENCH_TEST=overwrite
262/// ```
263/// will replace the `stdout.txt` (and `output.json` if one was generated) for each test
264/// scenario. Developers should then consult `git diff` to ensure that major regressions
265/// to the output did not occur.
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    use std::{
271        ffi::OsString,
272        path::{Path, PathBuf},
273        sync::LazyLock,
274    };
275
276    use crate::registry;
277
278    const ENV: &str = "POCKETBENCH_TEST";
279
280    // Expected I/O files.
281    const STDIN: &str = "stdin.txt";
282    const STDOUT: &str = "stdout.txt";
283    const INPUT_FILE: &str = "input.json";
284    const OUTPUT_FILE: &str = "output.json";
285
286    // Normalize a string for comparison.
287    //
288    // Steps taken:
289    //
290    // 1. All leading trailing whitespace is removed.
291    // 2. Windows line-endings `\n\r` are replaced with `\n`.
292    fn normalize(s: String) -> String {
293        let trimmed = s.trim().to_string();
294        trimmed.replace("\r\n", "\n")
295    }
296
297    // Read the entire contents of a file to a string.
298    fn read_to_string<P: AsRef<Path>>(path: P, ctx: &str) -> String {
299        match std::fs::read_to_string(path.as_ref()) {
300            Ok(s) => normalize(s),
301            Err(err) => panic!(
302                "failed to read {} {:?} with error: {}",
303                ctx,
304                path.as_ref(),
305                err
306            ),
307        }
308    }
309
310    // Check if `POCKETBENCH_TEST=overwrite` is configured. Return `true` if so - otherwise
311    // return `false`.
312    //
313    // If `POCKETBENCH_TEST` is set but its value is not `overwrite` - panic.
314    fn overwrite() -> bool {
315        match std::env::var(ENV) {
316            Ok(v) => {
317                if v == "overwrite" {
318                    true
319                } else {
320                    panic!(
321                        "Unknown value for {}: \"{}\". Expected \"overwrite\"",
322                        ENV, v
323                    );
324                }
325            }
326            Err(std::env::VarError::NotPresent) => false,
327            Err(std::env::VarError::NotUnicode(_)) => {
328                panic!("Value for {} is not unicode", ENV);
329            }
330        }
331    }
332
333    // There does not appear to be a supported was of checking whether backtraces are
334    // enabled without first actually capturing a backtrace.
335    static BACKTRACE_ENABLED: LazyLock<bool> = LazyLock::new(|| {
336        use std::backtrace::{Backtrace, BacktraceStatus};
337        Backtrace::capture().status() == BacktraceStatus::Captured
338    });
339
340    // Strip the backtrace from `stdout` if running with backtraces enabled.
341    fn strip_backtrace(s: String) -> String {
342        println!("pre = {}", s);
343        if !*BACKTRACE_ENABLED {
344            return s;
345        }
346
347        // Split into lines until we see `Stack backtrace`, then drop the empty
348        //
349        // Prints with stack traces will looks something like
350        // ```
351        // while processing input 2 of 2
352        //
353        // Caused by:
354        //     unknown variant `f32`, expected one of `float64`, `float32`, <snip>
355        //
356        // Stack backtrace:
357        //    0:
358        // ```
359        // This works by splitting the output into lines - looking for the keyword
360        // `Stack backtrace` and taking all lines up to that point.
361        let mut stacktrace_found = false;
362        let lines: Vec<_> = s
363            .lines()
364            .take_while(|l| {
365                stacktrace_found = *l == "Stack backtrace:";
366                !stacktrace_found
367            })
368            .collect();
369
370        if lines.is_empty() {
371            String::new()
372        } else if stacktrace_found {
373            // When `anyhow` inserts a backtrace - it separates the body of the error from
374            // the stack trace with a newline. This strips that newline.
375            //
376            // Indexing is okay because we've already handled the empty case.
377            lines[..lines.len() - 1].join("\n")
378        } else {
379            // No stacktrace found - do not strip a trailing empty line.
380            lines.join("\n")
381        }
382    }
383
384    // Test Runner
385    struct Test {
386        dir: PathBuf,
387        overwrite: bool,
388    }
389
390    impl Test {
391        fn new(dir: &Path) -> Self {
392            Self {
393                dir: dir.into(),
394                overwrite: overwrite(),
395            }
396        }
397
398        fn parse_stdin(&self, tempdir: &Path) -> App {
399            let path = self.dir.join(STDIN);
400
401            // Read the standard input file to a string.
402            let stdin = read_to_string(&path, "standard input");
403
404            let args: Vec<OsString> = stdin
405                .split_whitespace()
406                .map(|v| -> OsString { self.resolve(v, tempdir).into() })
407                .collect();
408
409            // Split and resolve special symbols
410            App::try_parse_from(std::iter::once(OsString::from("test-app")).chain(args)).unwrap()
411        }
412
413        fn resolve(&self, s: &str, tempdir: &Path) -> PathBuf {
414            if s == "$INPUT" {
415                self.dir.join(INPUT_FILE)
416            } else if s == "$OUTPUT" {
417                tempdir.join(OUTPUT_FILE)
418            } else {
419                s.into()
420            }
421        }
422
423        fn run(&self, tempdir: &Path) {
424            let app = self.parse_stdin(tempdir);
425
426            // Register inputs
427            let mut inputs = registry::Inputs::new();
428            crate::test::register_inputs(&mut inputs).unwrap();
429
430            // Register outputs
431            let mut benchmarks = registry::Benchmarks::new();
432            crate::test::register_benchmarks(&mut benchmarks);
433
434            // Run app - collecting output into a buffer.
435            //
436            // If the app returns an error - format the error to the output buffer as well
437            // using the debug formatting option.
438            let mut buffer = crate::output::Memory::new();
439            if let Err(err) = app.run(&inputs, &benchmarks, &mut buffer) {
440                let mut b: &mut dyn crate::Output = &mut buffer;
441                write!(b, "{:?}", err).unwrap();
442            }
443
444            // Check that `stdout` matches
445            let stdout: String =
446                normalize(strip_backtrace(buffer.into_inner().try_into().unwrap()));
447            let output = self.dir.join(STDOUT);
448            if self.overwrite {
449                std::fs::write(output, stdout).unwrap();
450            } else {
451                let expected = read_to_string(&output, "expected standard output");
452                if stdout != expected {
453                    panic!("Got:\n--\n{}\n--\nExpected:\n--\n{}\n--", stdout, expected);
454                }
455            }
456
457            // Check that the output files match.
458            let output_path = tempdir.join(OUTPUT_FILE);
459            let was_output_generated = output_path.is_file();
460
461            let expected_output_path = self.dir.join(OUTPUT_FILE);
462            let is_output_expected = expected_output_path.is_file();
463
464            if self.overwrite {
465                // Copy the output file to the destination.
466                if was_output_generated {
467                    println!(
468                        "Moving generated output file {:?} to {:?}",
469                        output_path, expected_output_path
470                    );
471
472                    if let Err(err) = std::fs::rename(&output_path, &expected_output_path) {
473                        panic!(
474                            "Moving generated output file {:?} to expected location {:?} failed: {}",
475                            output_path, expected_output_path, err
476                        );
477                    }
478                } else if is_output_expected {
479                    println!("Removing outdated output file {:?}", expected_output_path);
480                    if let Err(err) = std::fs::remove_file(&expected_output_path) {
481                        panic!(
482                            "Failed removing outdated output file {:?}: {}",
483                            expected_output_path, err
484                        );
485                    }
486                }
487            } else {
488                match (was_output_generated, is_output_expected) {
489                    (true, true) => {
490                        let output_contents = read_to_string(output_path, "generated output JSON");
491
492                        let expected_contents =
493                            read_to_string(expected_output_path, "expected output JSON");
494
495                        if output_contents != expected_contents {
496                            panic!(
497                                "Got:\n\n{}\n\nExpected:\n\n{}\n",
498                                output_contents, expected_contents
499                            );
500                        }
501                    }
502                    (true, false) => {
503                        let output_contents = read_to_string(output_path, "generated output JSON");
504
505                        panic!(
506                            "An output JSON was generated when none was expected. Contents:\n\n{}",
507                            output_contents
508                        );
509                    }
510                    (false, true) => {
511                        panic!("No output JSON was generated when one was expected");
512                    }
513                    (false, false) => { /* this is okay */ }
514                }
515            }
516        }
517    }
518
519    fn run_specific_test(test_dir: &Path) {
520        println!("running test in {:?}", test_dir);
521        let temp_dir = tempfile::tempdir().unwrap();
522        Test::new(test_dir).run(temp_dir.path());
523    }
524
525    fn run_all_tests_in(dir: &str) {
526        let dir: PathBuf = format!("{}/tests/{}", env!("CARGO_MANIFEST_DIR"), dir).into();
527        for entry in std::fs::read_dir(dir).unwrap() {
528            let entry = entry.unwrap();
529            if let Ok(file_type) = entry.file_type() {
530                if file_type.is_dir() {
531                    run_specific_test(&entry.path());
532                }
533            } else {
534                panic!("couldn't get file type for {:?}", entry.path());
535            }
536        }
537    }
538
539    #[test]
540    fn top_level_tests() {
541        run_all_tests_in("");
542    }
543}