Skip to main content

diskann_benchmark_runner/
benchmark.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use serde::{Deserialize, Serialize};
7
8use crate::{
9    dispatcher::{FailureScore, MatchScore},
10    Any, Checkpoint, Input, Output,
11};
12
13/// A registered benchmark.
14///
15/// Benchmarks consist of an [`Input`] and a corresponding serialized `Output`. Inputs will
16/// first be validated with the benchmark using [`try_match`](Self::try_match). Only
17/// successful matches will be passed to [`run`](Self::run).
18pub trait Benchmark: 'static {
19    /// The [`Input`] type this benchmark matches against.
20    type Input: Input + 'static;
21
22    /// The concrete type of the results generated by this benchmark.
23    type Output: Serialize;
24
25    /// Return whether or not this benchmark is compatible with `input`.
26    ///
27    /// On success, returns `Ok(MatchScore)`. [`MatchScore`]s of all benchmarks will be
28    /// collected and the benchmark with the lowest final score will be selected.
29    ///
30    /// In the case of ties, the winner is chosen using an unspecified tie-breaking procedure.
31    ///
32    /// On failure, returns `Err(FailureScore)`. In the [`crate::registry::Benchmarks`]
33    /// registry, [`FailureScore`]s will be used to rank the "nearest misses". Implementations
34    /// are encouraged to generate ranked [`FailureScore`]s to assist in user level debugging.
35    fn try_match(&self, input: &Self::Input) -> Result<MatchScore, FailureScore>;
36
37    /// Return descriptive information about the benchmark.
38    ///
39    /// If `input` is `None`, then high level information about the benchmark should be relayed.
40    /// If `input` is `Some`, and is an unsuccessful match, diagnostic information about what
41    /// was expected should be generated to help users.
42    fn description(
43        &self,
44        f: &mut std::fmt::Formatter<'_>,
45        input: Option<&Self::Input>,
46    ) -> std::fmt::Result;
47
48    /// Run the benchmark with `input`.
49    ///
50    /// All prints should be directed to `output`. The `checkpoint` is provided so
51    /// long-running benchmarks can periodically save output to prevent data loss due to
52    /// an early error.
53    ///
54    /// Implementors may assume that [`Self::try_match`] returned `Ok` on `input`.
55    fn run(
56        &self,
57        input: &Self::Input,
58        checkpoint: Checkpoint<'_>,
59        output: &mut dyn Output,
60    ) -> anyhow::Result<Self::Output>;
61}
62
63/// A refinement of [`Benchmark`], that supports before/after comparison of generated results.
64///
65/// Benchmarks are associated with a "tolerance" input, which may contain runtime values
66/// controlling the amount of slack a benchmark is allowed to have between runs before failing.
67///
68/// The semantics of pass or failure are left solely to the discretion of the [`Regression`]
69/// implementation.
70///
71/// See: [`register_regression`](crate::registry::Benchmarks::register_regression).
72pub trait Regression: Benchmark<Output: for<'a> Deserialize<'a>> {
73    /// The tolerance [`Input`] associated with this regression check.
74    type Tolerances: Input + 'static;
75
76    /// The report summary used to describe a successful regression check.
77    type Pass: Serialize + std::fmt::Display + 'static;
78
79    /// The report summary used to describe an unsuccessful regression check.
80    type Fail: Serialize + std::fmt::Display + 'static;
81
82    /// Run any regression checks necessary for two benchmark runs `before` and `after`.
83    /// Argument `tolerances` contain any tuned runtime tolerances to use when determining
84    /// whether or not a regression is detected.
85    ///
86    /// The `input` is the raw input that would have been provided to [`Benchmark::run`]
87    /// when generating the `before` and `after` outputs.
88    ///
89    /// Implementations of `check` should not attempt to print to `stdout` or any other
90    /// stream. Instead, all diagnostics should be encoded in the returned [`PassFail`] type
91    /// for reporting upstream.
92    fn check(
93        &self,
94        tolerances: &Self::Tolerances,
95        input: &Self::Input,
96        before: &Self::Output,
97        after: &Self::Output,
98    ) -> anyhow::Result<PassFail<Self::Pass, Self::Fail>>;
99}
100
101/// Describe whether or not a [`Regression`] passed or failed.
102#[derive(Debug, Clone, Copy)]
103pub enum PassFail<P, F> {
104    Pass(P),
105    Fail(F),
106}
107
108//////////////
109// Internal //
110//////////////
111
112pub(crate) mod internal {
113    use super::*;
114
115    use anyhow::Context;
116    use thiserror::Error;
117
118    /// Object-safe trait for type-erased benchmarks stored in the registry.
119    pub(crate) trait Benchmark {
120        fn try_match(&self, input: &Any) -> Result<MatchScore, FailureScore>;
121
122        fn description(
123            &self,
124            f: &mut std::fmt::Formatter<'_>,
125            input: Option<&Any>,
126        ) -> std::fmt::Result;
127
128        fn run(
129            &self,
130            input: &Any,
131            checkpoint: Checkpoint<'_>,
132            output: &mut dyn Output,
133        ) -> anyhow::Result<serde_json::Value>;
134
135        /// If supported, return an object capable of running regression checks on this benchmark.
136        fn as_regression(&self) -> Option<&dyn Regression>;
137    }
138
139    pub(crate) struct Checked {
140        pub(crate) json: serde_json::Value,
141        pub(crate) display: Box<dyn std::fmt::Display>,
142    }
143
144    impl Checked {
145        /// Serialize `value` to `serde_json::Value` and box it for future display.
146        fn new<T>(value: T) -> Result<Self, serde_json::Error>
147        where
148            T: Serialize + std::fmt::Display + 'static,
149        {
150            Ok(Self {
151                json: serde_json::to_value(&value)?,
152                display: Box::new(value),
153            })
154        }
155    }
156
157    pub(crate) type CheckedPassFail = PassFail<Checked, Checked>;
158
159    pub(crate) trait Regression {
160        fn tolerance(&self) -> &dyn crate::input::DynInput;
161        fn input_tag(&self) -> &'static str;
162        fn check(
163            &self,
164            tolerances: &Any,
165            input: &Any,
166            before: &serde_json::Value,
167            after: &serde_json::Value,
168        ) -> anyhow::Result<CheckedPassFail>;
169    }
170
171    impl std::fmt::Debug for dyn Regression + '_ {
172        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173            f.debug_struct("dyn Regression")
174                .field("tolerance", &self.tolerance().tag())
175                .field("input_tag", &self.input_tag())
176                .finish()
177        }
178    }
179
180    pub(crate) trait AsRegression<T> {
181        fn as_regression(benchmark: &T) -> Option<&dyn Regression>;
182    }
183
184    #[derive(Debug, Clone, Copy)]
185    pub(crate) struct NoRegression;
186
187    impl<T> AsRegression<T> for NoRegression {
188        fn as_regression(_benchmark: &T) -> Option<&dyn Regression> {
189            None
190        }
191    }
192
193    #[derive(Debug, Clone, Copy)]
194    pub(crate) struct WithRegression;
195
196    impl<T> AsRegression<T> for WithRegression
197    where
198        T: super::Regression,
199    {
200        fn as_regression(benchmark: &T) -> Option<&dyn Regression> {
201            Some(benchmark)
202        }
203    }
204
205    impl<T> Regression for T
206    where
207        T: super::Regression,
208    {
209        fn tolerance(&self) -> &dyn crate::input::DynInput {
210            &crate::input::Wrapper::<T::Tolerances>::INSTANCE
211        }
212
213        fn input_tag(&self) -> &'static str {
214            T::Input::tag()
215        }
216
217        fn check(
218            &self,
219            tolerance: &Any,
220            input: &Any,
221            before: &serde_json::Value,
222            after: &serde_json::Value,
223        ) -> anyhow::Result<CheckedPassFail> {
224            let tolerance = tolerance
225                .downcast_ref::<T::Tolerances>()
226                .ok_or_else(|| BadDownCast::new(T::Tolerances::tag(), tolerance.tag()))
227                .context("failed to obtain tolerance")?;
228
229            let input = input
230                .downcast_ref::<T::Input>()
231                .ok_or_else(|| BadDownCast::new(T::Input::tag(), input.tag()))
232                .context("failed to obtain input")?;
233
234            let before = T::Output::deserialize(before)
235                .map_err(|err| DeserializationError::new(Kind::Before, err))?;
236
237            let after = T::Output::deserialize(after)
238                .map_err(|err| DeserializationError::new(Kind::After, err))?;
239
240            let passfail = match self.check(tolerance, input, &before, &after)? {
241                PassFail::Pass(pass) => PassFail::Pass(Checked::new(pass)?),
242                PassFail::Fail(fail) => PassFail::Fail(Checked::new(fail)?),
243            };
244
245            Ok(passfail)
246        }
247    }
248
249    #[derive(Debug, Clone, Copy)]
250    pub(crate) struct Wrapper<T, R = NoRegression> {
251        benchmark: T,
252        _regression: R,
253    }
254
255    impl<T, R> Wrapper<T, R> {
256        pub(crate) const fn new(benchmark: T, regression: R) -> Self {
257            Self {
258                benchmark,
259                _regression: regression,
260            }
261        }
262    }
263
264    /// The score given to unsuccessful downcasts in [`Benchmark::try_match`].
265    const MATCH_FAIL: FailureScore = FailureScore(10_000);
266
267    impl<T, R> Benchmark for Wrapper<T, R>
268    where
269        T: super::Benchmark,
270        R: AsRegression<T>,
271    {
272        fn try_match(&self, input: &Any) -> Result<MatchScore, FailureScore> {
273            if let Some(cast) = input.downcast_ref::<T::Input>() {
274                self.benchmark.try_match(cast)
275            } else {
276                Err(MATCH_FAIL)
277            }
278        }
279
280        fn description(
281            &self,
282            f: &mut std::fmt::Formatter<'_>,
283            input: Option<&Any>,
284        ) -> std::fmt::Result {
285            match input {
286                Some(input) => match input.downcast_ref::<T::Input>() {
287                    Some(cast) => self.benchmark.description(f, Some(cast)),
288                    None => write!(
289                        f,
290                        "expected tag \"{}\" - instead got \"{}\"",
291                        T::Input::tag(),
292                        input.tag(),
293                    ),
294                },
295                None => {
296                    writeln!(f, "tag \"{}\"", <T::Input as Input>::tag())?;
297                    self.benchmark.description(f, None)
298                }
299            }
300        }
301
302        fn run(
303            &self,
304            input: &Any,
305            checkpoint: Checkpoint<'_>,
306            output: &mut dyn Output,
307        ) -> anyhow::Result<serde_json::Value> {
308            match input.downcast_ref::<T::Input>() {
309                Some(input) => {
310                    let result = self.benchmark.run(input, checkpoint, output)?;
311                    Ok(serde_json::to_value(result)?)
312                }
313                None => Err(BadDownCast::new(T::Input::tag(), input.tag()).into()),
314            }
315        }
316
317        // Extensions
318        fn as_regression(&self) -> Option<&dyn Regression> {
319            R::as_regression(&self.benchmark)
320        }
321    }
322
323    //--------//
324    // Errors //
325    //--------//
326
327    #[derive(Debug, Clone, Copy, Error)]
328    #[error(
329        "INTERNAL ERROR: bad downcast - expected \"{}\" but got \"{}\"",
330        self.expected,
331        self.got
332    )]
333    struct BadDownCast {
334        expected: &'static str,
335        got: &'static str,
336    }
337
338    impl BadDownCast {
339        fn new(expected: &'static str, got: &'static str) -> Self {
340            Self { expected, got }
341        }
342    }
343
344    #[derive(Debug, Error)]
345    #[error(
346        "the \"{}\" results do not match the output schema expected by this benchmark",
347        self.kind
348    )]
349    struct DeserializationError {
350        kind: Kind,
351        source: serde_json::Error,
352    }
353
354    impl DeserializationError {
355        fn new(kind: Kind, source: serde_json::Error) -> Self {
356            Self { kind, source }
357        }
358    }
359
360    #[derive(Debug, Clone, Copy)]
361    enum Kind {
362        Before,
363        After,
364    }
365
366    impl std::fmt::Display for Kind {
367        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
368            let as_str = match self {
369                Self::Before => "before",
370                Self::After => "after",
371            };
372
373            write!(f, "{}", as_str)
374        }
375    }
376}