diskann_benchmark_runner/app.rs
1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::{io::Write, path::PathBuf};
7
8use clap::{Parser, Subcommand};
9
10use crate::{
11 jobs::{self, Jobs},
12 output::Output,
13 registry,
14 result::Checkpoint,
15 utils::fmt::Banner,
16};
17
18/// Parsed command line options.
19#[derive(Debug, Subcommand)]
20pub enum Commands {
21 /// List the kinds of input formats available for ingestion.
22 Inputs {
23 /// Describe the layout of the named input kind.
24 describe: Option<String>,
25 },
26 /// List the available benchmarks.
27 Benchmarks {},
28 /// Provide a skeleton JSON file for running a set of benchmarks.
29 Skeleton,
30 /// Run a list of benchmarks.
31 Run {
32 /// The input file to run.
33 #[arg(long = "input-file")]
34 input_file: PathBuf,
35 /// The path where the output file should reside.
36 #[arg(long = "output-file")]
37 output_file: PathBuf,
38 /// Parse an input file and perform all validation checks, but don't actually run any
39 /// benchmarks.
40 #[arg(long, action)]
41 dry_run: bool,
42 },
43}
44
45/// The CLI used to drive a benchmark application.
46#[derive(Debug, Parser)]
47pub struct App {
48 #[command(subcommand)]
49 command: Commands,
50}
51
52impl App {
53 /// Construct [`Self`] by parsing commandline arguments from [`std::env::args]`.
54 ///
55 /// This simply redirects to [`clap::Parser::parse`] and is provided to allow parsing
56 /// without the [`clap::Parser`] trait in scope.
57 pub fn parse() -> Self {
58 <Self as clap::Parser>::parse()
59 }
60
61 /// Construct [`Self`] by parsing command line arguments from the iterator.
62 ///
63 /// This simply redirects to [`clap::Parser::try_parse_from`] and is provided to allow
64 /// parsing without the [`clap::Parser`] trait in scope.
65 pub fn try_parse_from<I, T>(itr: I) -> anyhow::Result<Self>
66 where
67 I: IntoIterator<Item = T>,
68 T: Into<std::ffi::OsString> + Clone,
69 {
70 Ok(<Self as clap::Parser>::try_parse_from(itr)?)
71 }
72
73 /// Construct [`Self`] directly from a [`Commands`] enum.
74 pub fn from_commands(command: Commands) -> Self {
75 Self { command }
76 }
77
78 /// Run the application using the registered `inputs` and `outputs`.
79 pub fn run(
80 &self,
81 inputs: ®istry::Inputs,
82 benchmarks: ®istry::Benchmarks,
83 mut output: &mut dyn Output,
84 ) -> anyhow::Result<()> {
85 match &self.command {
86 // If a named benchmark isn't given, then list the available benchmarks.
87 Commands::Inputs { describe } => {
88 if let Some(describe) = describe {
89 if let Some(input) = inputs.get(describe) {
90 let repr = jobs::Unprocessed::format_input(input)?;
91 writeln!(
92 output,
93 "The example JSON representation for \"{}\" is:",
94 describe
95 )?;
96 writeln!(output, "{}", serde_json::to_string_pretty(&repr)?)?;
97 return Ok(());
98 } else {
99 writeln!(output, "No input found for \"{}\"", describe)?;
100 }
101
102 return Ok(());
103 }
104
105 writeln!(output, "Available input kinds are listed below:")?;
106 let mut tags: Vec<_> = inputs.tags().collect();
107 tags.sort();
108 for i in tags.iter() {
109 writeln!(output, " {}", i)?;
110 }
111 }
112 // List the available benchmarks.
113 Commands::Benchmarks {} => {
114 writeln!(output, "Registered Benchmarks:")?;
115 for (name, method) in benchmarks.methods() {
116 writeln!(output, " {}: {}", name, method.signatures()[0])?;
117 }
118 }
119 Commands::Skeleton => {
120 writeln!(output, "Skeleton input file:")?;
121 writeln!(output, "{}", Jobs::example()?)?;
122 }
123 // Run the benchmarks
124 Commands::Run {
125 input_file,
126 output_file,
127 dry_run,
128 } => {
129 // Parse and validate the input.
130 let run = Jobs::load(input_file, inputs)?;
131 // Check if we have a match for each benchmark.
132 for job in run.jobs().iter() {
133 if !benchmarks.has_match(job) {
134 let repr = serde_json::to_string_pretty(&job.serialize()?)?;
135
136 const MAX_METHODS: usize = 3;
137 let mismatches = match benchmarks.debug(job, MAX_METHODS) {
138 // Debug should return `Err` if there is not a match.
139 // Returning `Ok(())` here indicates an internal error with the
140 // dispatcher.
141 Ok(()) => {
142 return Err(anyhow::Error::msg(format!(
143 "experienced internal error while debugging:\n{}",
144 repr
145 )))
146 }
147 Err(m) => m,
148 };
149
150 writeln!(
151 output,
152 "Could not find a match for the following input:\n\n{}\n",
153 repr
154 )?;
155 writeln!(output, "Closest matches:\n")?;
156 for (i, mismatch) in mismatches.into_iter().enumerate() {
157 writeln!(
158 output,
159 " {}. \"{}\": {}",
160 i + 1,
161 mismatch.method(),
162 mismatch.reason(),
163 )?;
164 }
165 writeln!(output)?;
166
167 return Err(anyhow::Error::msg(
168 "could not find find a benchmark for all inputs",
169 ));
170 }
171 }
172
173 if *dry_run {
174 writeln!(
175 output,
176 "Success - skipping running benchmarks because \"--dry-run\" was used."
177 )?;
178 return Ok(());
179 }
180
181 // The collection of output results for each run.
182 let mut results = Vec::<serde_json::Value>::new();
183
184 // Now - we've verified the integrity of all the jobs we want to run and that
185 // each job can match an associated benchmark.
186 //
187 // All that's left is to actually run the benchmarks.
188 let jobs = run.jobs();
189 let serialized = jobs
190 .iter()
191 .map(|job| {
192 serde_json::to_value(jobs::Unprocessed::new(
193 job.tag().into(),
194 job.serialize()?,
195 ))
196 })
197 .collect::<Result<Vec<_>, serde_json::Error>>()?;
198 for (i, job) in jobs.iter().enumerate() {
199 let prefix: &str = if i != 0 { "\n\n" } else { "" };
200 writeln!(
201 output,
202 "{}{}",
203 prefix,
204 Banner::new(&format!("Running Job {} of {}", i + 1, jobs.len()))
205 )?;
206
207 // Run the specified job.
208 let checkpoint = Checkpoint::new(&serialized, &results, output_file)?;
209 let r = benchmarks.call(job, checkpoint, output)?;
210
211 // Collect the results
212 results.push(r);
213
214 // Save everything.
215 Checkpoint::new(&serialized, &results, output_file)?.save()?;
216 }
217 }
218 };
219 Ok(())
220 }
221}
222
223///////////
224// Tests //
225///////////
226
227/// The integration test below look inside the `tests` directory for folders.
228///
229/// ## Input Files
230///
231/// Each folder should have at least a `stdin.txt` file specifying the command line to give
232/// to the `App` parser.
233///
234/// Within the `stdin.txt` command line, there are several special symbols:
235///
236/// * $INPUT - Resolves to `input.json` in the same directory as the `stdin.txt` file.
237/// * $OUTPUT - Resolves to `output.json` in a temporary directory.
238///
239/// As mentioned - an input JSON file can be included and must be named "input.json" to be
240/// discoverable.
241///
242/// ## Output Files
243///
244/// Tests should have at least a `stdout.txt` file with the expected outputs for running the
245/// command in `stdin.txt`. If an output JSON file is expected, it should be name `output.json`.
246///
247/// ## Test Discovery and Running
248///
249/// The unit test will visit each folder in `tests` and run the outlined scenario. The
250/// `stdout.txt` expected output is compared to the actual output and if they do not match,
251/// the test fails.
252///
253/// Additionally, if `output.json` is present, the unit test will verify that (1) the command
254/// did in fact produce an output JSON file and (2) the generated file matches the expected file.
255///
256/// ## Regenerating Expected Results
257///
258/// The benchmark output will naturally change over time. Running the unit tests with the
259/// environment variable
260/// ```text
261/// POCKETBENCH_TEST=overwrite
262/// ```
263/// will replace the `stdout.txt` (and `output.json` if one was generated) for each test
264/// scenario. Developers should then consult `git diff` to ensure that major regressions
265/// to the output did not occur.
266#[cfg(test)]
267mod tests {
268 use super::*;
269
270 use std::{
271 ffi::OsString,
272 path::{Path, PathBuf},
273 sync::LazyLock,
274 };
275
276 use crate::registry;
277
278 const ENV: &str = "POCKETBENCH_TEST";
279
280 // Expected I/O files.
281 const STDIN: &str = "stdin.txt";
282 const STDOUT: &str = "stdout.txt";
283 const INPUT_FILE: &str = "input.json";
284 const OUTPUT_FILE: &str = "output.json";
285
286 // Normalize a string for comparison.
287 //
288 // Steps taken:
289 //
290 // 1. All leading trailing whitespace is removed.
291 // 2. Windows line-endings `\n\r` are replaced with `\n`.
292 fn normalize(s: String) -> String {
293 let trimmed = s.trim().to_string();
294 trimmed.replace("\r\n", "\n")
295 }
296
297 // Read the entire contents of a file to a string.
298 fn read_to_string<P: AsRef<Path>>(path: P, ctx: &str) -> String {
299 match std::fs::read_to_string(path.as_ref()) {
300 Ok(s) => normalize(s),
301 Err(err) => panic!(
302 "failed to read {} {:?} with error: {}",
303 ctx,
304 path.as_ref(),
305 err
306 ),
307 }
308 }
309
310 // Check if `POCKETBENCH_TEST=overwrite` is configured. Return `true` if so - otherwise
311 // return `false`.
312 //
313 // If `POCKETBENCH_TEST` is set but its value is not `overwrite` - panic.
314 fn overwrite() -> bool {
315 match std::env::var(ENV) {
316 Ok(v) => {
317 if v == "overwrite" {
318 true
319 } else {
320 panic!(
321 "Unknown value for {}: \"{}\". Expected \"overwrite\"",
322 ENV, v
323 );
324 }
325 }
326 Err(std::env::VarError::NotPresent) => false,
327 Err(std::env::VarError::NotUnicode(_)) => {
328 panic!("Value for {} is not unicode", ENV);
329 }
330 }
331 }
332
333 // There does not appear to be a supported was of checking whether backtraces are
334 // enabled without first actually capturing a backtrace.
335 static BACKTRACE_ENABLED: LazyLock<bool> = LazyLock::new(|| {
336 use std::backtrace::{Backtrace, BacktraceStatus};
337 Backtrace::capture().status() == BacktraceStatus::Captured
338 });
339
340 // Strip the backtrace from `stdout` if running with backtraces enabled.
341 fn strip_backtrace(s: String) -> String {
342 println!("pre = {}", s);
343 if !*BACKTRACE_ENABLED {
344 return s;
345 }
346
347 // Split into lines until we see `Stack backtrace`, then drop the empty
348 //
349 // Prints with stack traces will looks something like
350 // ```
351 // while processing input 2 of 2
352 //
353 // Caused by:
354 // unknown variant `f32`, expected one of `float64`, `float32`, <snip>
355 //
356 // Stack backtrace:
357 // 0:
358 // ```
359 // This works by splitting the output into lines - looking for the keyword
360 // `Stack backtrace` and taking all lines up to that point.
361 let mut stacktrace_found = false;
362 let lines: Vec<_> = s
363 .lines()
364 .take_while(|l| {
365 stacktrace_found = *l == "Stack backtrace:";
366 !stacktrace_found
367 })
368 .collect();
369
370 if lines.is_empty() {
371 String::new()
372 } else if stacktrace_found {
373 // When `anyhow` inserts a backtrace - it separates the body of the error from
374 // the stack trace with a newline. This strips that newline.
375 //
376 // Indexing is okay because we've already handled the empty case.
377 lines[..lines.len() - 1].join("\n")
378 } else {
379 // No stacktrace found - do not strip a trailing empty line.
380 lines.join("\n")
381 }
382 }
383
384 // Test Runner
385 struct Test {
386 dir: PathBuf,
387 overwrite: bool,
388 }
389
390 impl Test {
391 fn new(dir: &Path) -> Self {
392 Self {
393 dir: dir.into(),
394 overwrite: overwrite(),
395 }
396 }
397
398 fn parse_stdin(&self, tempdir: &Path) -> App {
399 let path = self.dir.join(STDIN);
400
401 // Read the standard input file to a string.
402 let stdin = read_to_string(&path, "standard input");
403
404 let args: Vec<OsString> = stdin
405 .split_whitespace()
406 .map(|v| -> OsString { self.resolve(v, tempdir).into() })
407 .collect();
408
409 // Split and resolve special symbols
410 App::try_parse_from(std::iter::once(OsString::from("test-app")).chain(args)).unwrap()
411 }
412
413 fn resolve(&self, s: &str, tempdir: &Path) -> PathBuf {
414 if s == "$INPUT" {
415 self.dir.join(INPUT_FILE)
416 } else if s == "$OUTPUT" {
417 tempdir.join(OUTPUT_FILE)
418 } else {
419 s.into()
420 }
421 }
422
423 fn run(&self, tempdir: &Path) {
424 let app = self.parse_stdin(tempdir);
425
426 // Register inputs
427 let mut inputs = registry::Inputs::new();
428 crate::test::register_inputs(&mut inputs).unwrap();
429
430 // Register outputs
431 let mut benchmarks = registry::Benchmarks::new();
432 crate::test::register_benchmarks(&mut benchmarks);
433
434 // Run app - collecting output into a buffer.
435 //
436 // If the app returns an error - format the error to the output buffer as well
437 // using the debug formatting option.
438 let mut buffer = crate::output::Memory::new();
439 if let Err(err) = app.run(&inputs, &benchmarks, &mut buffer) {
440 let mut b: &mut dyn crate::Output = &mut buffer;
441 write!(b, "{:?}", err).unwrap();
442 }
443
444 // Check that `stdout` matches
445 let stdout: String =
446 normalize(strip_backtrace(buffer.into_inner().try_into().unwrap()));
447 let output = self.dir.join(STDOUT);
448 if self.overwrite {
449 std::fs::write(output, stdout).unwrap();
450 } else {
451 let expected = read_to_string(&output, "expected standard output");
452 if stdout != expected {
453 panic!("Got:\n--\n{}\n--\nExpected:\n--\n{}\n--", stdout, expected);
454 }
455 }
456
457 // Check that the output files match.
458 let output_path = tempdir.join(OUTPUT_FILE);
459 let was_output_generated = output_path.is_file();
460
461 let expected_output_path = self.dir.join(OUTPUT_FILE);
462 let is_output_expected = expected_output_path.is_file();
463
464 if self.overwrite {
465 // Copy the output file to the destination.
466 if was_output_generated {
467 println!(
468 "Moving generated output file {:?} to {:?}",
469 output_path, expected_output_path
470 );
471
472 if let Err(err) = std::fs::rename(&output_path, &expected_output_path) {
473 panic!(
474 "Moving generated output file {:?} to expected location {:?} failed: {}",
475 output_path, expected_output_path, err
476 );
477 }
478 } else if is_output_expected {
479 println!("Removing outdated output file {:?}", expected_output_path);
480 if let Err(err) = std::fs::remove_file(&expected_output_path) {
481 panic!(
482 "Failed removing outdated output file {:?}: {}",
483 expected_output_path, err
484 );
485 }
486 }
487 } else {
488 match (was_output_generated, is_output_expected) {
489 (true, true) => {
490 let output_contents = read_to_string(output_path, "generated output JSON");
491
492 let expected_contents =
493 read_to_string(expected_output_path, "expected output JSON");
494
495 if output_contents != expected_contents {
496 panic!(
497 "Got:\n\n{}\n\nExpected:\n\n{}\n",
498 output_contents, expected_contents
499 );
500 }
501 }
502 (true, false) => {
503 let output_contents = read_to_string(output_path, "generated output JSON");
504
505 panic!(
506 "An output JSON was generated when none was expected. Contents:\n\n{}",
507 output_contents
508 );
509 }
510 (false, true) => {
511 panic!("No output JSON was generated when one was expected");
512 }
513 (false, false) => { /* this is okay */ }
514 }
515 }
516 }
517 }
518
519 fn run_specific_test(test_dir: &Path) {
520 println!("running test in {:?}", test_dir);
521 let temp_dir = tempfile::tempdir().unwrap();
522 Test::new(test_dir).run(temp_dir.path());
523 }
524
525 fn run_all_tests_in(dir: &str) {
526 let dir: PathBuf = format!("{}/tests/{}", env!("CARGO_MANIFEST_DIR"), dir).into();
527 for entry in std::fs::read_dir(dir).unwrap() {
528 let entry = entry.unwrap();
529 if let Ok(file_type) = entry.file_type() {
530 if file_type.is_dir() {
531 run_specific_test(&entry.path());
532 }
533 } else {
534 panic!("couldn't get file type for {:?}", entry.path());
535 }
536 }
537 }
538
539 #[test]
540 fn top_level_tests() {
541 run_all_tests_in("");
542 }
543}