Skip to main content

diskann_benchmark_runner/
benchmark.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use serde::{Deserialize, Serialize};
7
8use crate::{Checkpoint, Input, Output};
9
10/// A registered benchmark.
11///
12/// Benchmarks consist of an [`Input`] and a corresponding serialized `Output`. Inputs will
13/// first be validated with the benchmark using [`try_match`](Self::try_match). Only
14/// successful matches will be passed to [`run`](Self::run).
15pub trait Benchmark: 'static {
16    /// The [`Input`] type this benchmark matches against.
17    type Input: Input + 'static;
18
19    /// The concrete type of the results generated by this benchmark.
20    type Output: Serialize;
21
22    /// Return whether or not this benchmark is compatible with `input`.
23    ///
24    /// On success, returns `Ok(MatchScore)`. [`MatchScore`]s of all benchmarks will be
25    /// collected and the benchmark with the lowest final score will be selected.
26    ///
27    /// In the case of ties, the winner is chosen using an unspecified tie-breaking procedure.
28    ///
29    /// On failure, returns `Err(FailureScore)`. In the [`crate::Registry`]
30    /// registry, [`FailureScore`]s will be used to rank the "nearest misses". Implementations
31    /// are encouraged to generate ranked [`FailureScore`]s to assist in user level debugging.
32    fn try_match(&self, input: &Self::Input) -> Result<MatchScore, FailureScore>;
33
34    /// Return descriptive information about the benchmark.
35    ///
36    /// If `input` is `None`, then high level information about the benchmark should be relayed.
37    /// If `input` is `Some`, and is an unsuccessful match, diagnostic information about what
38    /// was expected should be generated to help users.
39    fn description(
40        &self,
41        f: &mut std::fmt::Formatter<'_>,
42        input: Option<&Self::Input>,
43    ) -> std::fmt::Result;
44
45    /// Run the benchmark with `input`.
46    ///
47    /// All prints should be directed to `output`. The `checkpoint` is provided so
48    /// long-running benchmarks can periodically save output to prevent data loss due to
49    /// an early error.
50    ///
51    /// Implementors may assume that [`Self::try_match`] returned `Ok` on `input`.
52    fn run(
53        &self,
54        input: &Self::Input,
55        checkpoint: Checkpoint<'_>,
56        output: &mut dyn Output,
57    ) -> anyhow::Result<Self::Output>;
58}
59
60/// Successful matches from [`Benchmark::try_match`] will return `MatchScores`.
61///
62/// A lower numerical value indicates a better match for purposes of overload resolution.
63#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
64pub struct MatchScore(pub u32);
65
66impl std::fmt::Display for MatchScore {
67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68        write!(f, "success ({})", self.0)
69    }
70}
71
72/// Successful matches from [`Benchmark::try_match`] will return `FailureScores`.
73///
74/// A lower numerical value indicates a better match, which can help when compiling a
75/// list of considered and rejected candidates.
76#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
77pub struct FailureScore(pub u32);
78
79impl std::fmt::Display for FailureScore {
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        write!(f, "fail ({})", self.0)
82    }
83}
84
85/// A refinement of [`Benchmark`], that supports before/after comparison of generated results.
86///
87/// Benchmarks are associated with a "tolerance" input, which may contain runtime values
88/// controlling the amount of slack a benchmark is allowed to have between runs before failing.
89///
90/// The semantics of pass or failure are left solely to the discretion of the [`Regression`]
91/// implementation.
92///
93/// See: [`register_regression`](crate::Registry::register_regression).
94pub trait Regression: Benchmark<Output: for<'a> Deserialize<'a>> {
95    /// The tolerance [`Input`] associated with this regression check.
96    type Tolerances: Input + 'static;
97
98    /// The report summary used to describe a successful regression check.
99    type Pass: Serialize + std::fmt::Display + 'static;
100
101    /// The report summary used to describe an unsuccessful regression check.
102    type Fail: Serialize + std::fmt::Display + 'static;
103
104    /// Run any regression checks necessary for two benchmark runs `before` and `after`.
105    /// Argument `tolerances` contain any tuned runtime tolerances to use when determining
106    /// whether or not a regression is detected.
107    ///
108    /// The `input` is the raw input that would have been provided to [`Benchmark::run`]
109    /// when generating the `before` and `after` outputs.
110    ///
111    /// Implementations of `check` should not attempt to print to `stdout` or any other
112    /// stream. Instead, all diagnostics should be encoded in the returned [`PassFail`] type
113    /// for reporting upstream.
114    fn check(
115        &self,
116        tolerances: &Self::Tolerances,
117        input: &Self::Input,
118        before: &Self::Output,
119        after: &Self::Output,
120    ) -> anyhow::Result<PassFail<Self::Pass, Self::Fail>>;
121}
122
123/// Describe whether or not a [`Regression`] passed or failed.
124#[derive(Debug, Clone, Copy)]
125pub enum PassFail<P, F> {
126    Pass(P),
127    Fail(F),
128}
129
130//////////////
131// Internal //
132//////////////
133
134pub(crate) mod internal {
135    use super::*;
136
137    use crate::input::internal::Any;
138
139    use anyhow::Context;
140    use thiserror::Error;
141
142    /// Object-safe trait for type-erased benchmarks stored in the registry.
143    pub(crate) trait Benchmark {
144        fn try_match(&self, input: &Any) -> Result<MatchScore, FailureScore>;
145
146        fn description(
147            &self,
148            f: &mut std::fmt::Formatter<'_>,
149            input: Option<&Any>,
150        ) -> std::fmt::Result;
151
152        fn run(
153            &self,
154            input: &Any,
155            checkpoint: Checkpoint<'_>,
156            output: &mut dyn Output,
157        ) -> anyhow::Result<serde_json::Value>;
158
159        /// If supported, return an object capable of running regression checks on this benchmark.
160        fn as_regression(&self) -> Option<&dyn Regression>;
161    }
162
163    pub(crate) struct Checked {
164        pub(crate) json: serde_json::Value,
165        pub(crate) display: Box<dyn std::fmt::Display>,
166    }
167
168    impl Checked {
169        /// Serialize `value` to `serde_json::Value` and box it for future display.
170        fn new<T>(value: T) -> Result<Self, serde_json::Error>
171        where
172            T: Serialize + std::fmt::Display + 'static,
173        {
174            Ok(Self {
175                json: serde_json::to_value(&value)?,
176                display: Box::new(value),
177            })
178        }
179    }
180
181    pub(crate) type CheckedPassFail = PassFail<Checked, Checked>;
182
183    pub(crate) trait Regression {
184        fn tolerance(&self) -> &dyn crate::input::internal::DynInput;
185        fn input_tag(&self) -> &'static str;
186        fn check(
187            &self,
188            tolerances: &Any,
189            input: &Any,
190            before: &serde_json::Value,
191            after: &serde_json::Value,
192        ) -> anyhow::Result<CheckedPassFail>;
193    }
194
195    impl std::fmt::Debug for dyn Regression + '_ {
196        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197            f.debug_struct("dyn Regression")
198                .field("tolerance", &self.tolerance().tag())
199                .field("input_tag", &self.input_tag())
200                .finish()
201        }
202    }
203
204    pub(crate) trait AsRegression<T> {
205        fn as_regression(benchmark: &T) -> Option<&dyn Regression>;
206    }
207
208    #[derive(Debug, Clone, Copy)]
209    pub(crate) struct NoRegression;
210
211    impl<T> AsRegression<T> for NoRegression {
212        fn as_regression(_benchmark: &T) -> Option<&dyn Regression> {
213            None
214        }
215    }
216
217    #[derive(Debug, Clone, Copy)]
218    pub(crate) struct WithRegression;
219
220    impl<T> AsRegression<T> for WithRegression
221    where
222        T: super::Regression,
223    {
224        fn as_regression(benchmark: &T) -> Option<&dyn Regression> {
225            Some(benchmark)
226        }
227    }
228
229    impl<T> Regression for T
230    where
231        T: super::Regression,
232    {
233        fn tolerance(&self) -> &dyn crate::input::internal::DynInput {
234            &crate::input::internal::Wrapper::<T::Tolerances>::INSTANCE
235        }
236
237        fn input_tag(&self) -> &'static str {
238            T::Input::tag()
239        }
240
241        fn check(
242            &self,
243            tolerance: &Any,
244            input: &Any,
245            before: &serde_json::Value,
246            after: &serde_json::Value,
247        ) -> anyhow::Result<CheckedPassFail> {
248            let tolerance = tolerance
249                .downcast_ref::<T::Tolerances>()
250                .ok_or_else(|| BadDownCast::new(T::Tolerances::tag(), tolerance.tag()))
251                .context("failed to obtain tolerance")?;
252
253            let input = input
254                .downcast_ref::<T::Input>()
255                .ok_or_else(|| BadDownCast::new(T::Input::tag(), input.tag()))
256                .context("failed to obtain input")?;
257
258            let before = T::Output::deserialize(before)
259                .map_err(|err| DeserializationError::new(Kind::Before, err))?;
260
261            let after = T::Output::deserialize(after)
262                .map_err(|err| DeserializationError::new(Kind::After, err))?;
263
264            let passfail = match self.check(tolerance, input, &before, &after)? {
265                PassFail::Pass(pass) => PassFail::Pass(Checked::new(pass)?),
266                PassFail::Fail(fail) => PassFail::Fail(Checked::new(fail)?),
267            };
268
269            Ok(passfail)
270        }
271    }
272
273    #[derive(Debug, Clone, Copy)]
274    pub(crate) struct Wrapper<T, R = NoRegression> {
275        benchmark: T,
276        _regression: R,
277    }
278
279    impl<T, R> Wrapper<T, R> {
280        pub(crate) const fn new(benchmark: T, regression: R) -> Self {
281            Self {
282                benchmark,
283                _regression: regression,
284            }
285        }
286    }
287
288    /// The score given to unsuccessful downcasts in [`Benchmark::try_match`].
289    const MATCH_FAIL: FailureScore = FailureScore(10_000);
290
291    impl<T, R> Benchmark for Wrapper<T, R>
292    where
293        T: super::Benchmark,
294        R: AsRegression<T>,
295    {
296        fn try_match(&self, input: &Any) -> Result<MatchScore, FailureScore> {
297            if let Some(cast) = input.downcast_ref::<T::Input>() {
298                self.benchmark.try_match(cast)
299            } else {
300                Err(MATCH_FAIL)
301            }
302        }
303
304        fn description(
305            &self,
306            f: &mut std::fmt::Formatter<'_>,
307            input: Option<&Any>,
308        ) -> std::fmt::Result {
309            match input {
310                Some(input) => match input.downcast_ref::<T::Input>() {
311                    Some(cast) => self.benchmark.description(f, Some(cast)),
312                    None => write!(
313                        f,
314                        "expected tag \"{}\" - instead got \"{}\"",
315                        T::Input::tag(),
316                        input.tag(),
317                    ),
318                },
319                None => {
320                    writeln!(f, "tag \"{}\"", <T::Input as Input>::tag())?;
321                    self.benchmark.description(f, None)
322                }
323            }
324        }
325
326        fn run(
327            &self,
328            input: &Any,
329            checkpoint: Checkpoint<'_>,
330            output: &mut dyn Output,
331        ) -> anyhow::Result<serde_json::Value> {
332            match input.downcast_ref::<T::Input>() {
333                Some(input) => {
334                    let result = self.benchmark.run(input, checkpoint, output)?;
335                    Ok(serde_json::to_value(result)?)
336                }
337                None => Err(BadDownCast::new(T::Input::tag(), input.tag()).into()),
338            }
339        }
340
341        // Extensions
342        fn as_regression(&self) -> Option<&dyn Regression> {
343            R::as_regression(&self.benchmark)
344        }
345    }
346
347    //--------//
348    // Errors //
349    //--------//
350
351    #[derive(Debug, Clone, Copy, Error)]
352    #[error(
353        "INTERNAL ERROR: bad downcast - expected \"{}\" but got \"{}\"",
354        self.expected,
355        self.got
356    )]
357    struct BadDownCast {
358        expected: &'static str,
359        got: &'static str,
360    }
361
362    impl BadDownCast {
363        fn new(expected: &'static str, got: &'static str) -> Self {
364            Self { expected, got }
365        }
366    }
367
368    #[derive(Debug, Error)]
369    #[error(
370        "the \"{}\" results do not match the output schema expected by this benchmark",
371        self.kind
372    )]
373    struct DeserializationError {
374        kind: Kind,
375        source: serde_json::Error,
376    }
377
378    impl DeserializationError {
379        fn new(kind: Kind, source: serde_json::Error) -> Self {
380            Self { kind, source }
381        }
382    }
383
384    #[derive(Debug, Clone, Copy)]
385    enum Kind {
386        Before,
387        After,
388    }
389
390    impl std::fmt::Display for Kind {
391        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
392            let as_str = match self {
393                Self::Before => "before",
394                Self::After => "after",
395            };
396
397            write!(f, "{}", as_str)
398        }
399    }
400}