Skip to main content

diskann_benchmark_runner/
benchmark.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use serde::{Deserialize, Serialize};
7
8use crate::{
9    dispatcher::{FailureScore, MatchScore},
10    Any, Checkpoint, Input, Output,
11};
12
13/// A registered benchmark.
14///
15/// Benchmarks consist of an [`Input`] and a corresponding serialized `Output`. Inputs will
16/// first be validated with the benchmark using [`try_match`](Self::try_match). Only
17/// successful matches will be passed to [`run`](Self::run).
18pub trait Benchmark {
19    /// The [`Input`] type this benchmark matches against.
20    type Input: Input + 'static;
21
22    /// The concrete type of the results generated by this benchmark.
23    type Output: Serialize;
24
25    /// Return whether or not this benchmark is compatible with `input`.
26    ///
27    /// On success, returns `Ok(MatchScore)`. [`MatchScore`]s of all benchmarks will be
28    /// collected and the benchmark with the lowest final score will be selected.
29    ///
30    /// In the case of ties, the winner is chosen using an unspecified tie-breaking procedure.
31    ///
32    /// On failure, returns `Err(FailureScore)`. In the [`crate::registry::Benchmarks`]
33    /// registry, [`FailureScore`]s will be used to rank the "nearest misses". Implementations
34    /// are encouraged to generate ranked [`FailureScore`]s to assist in user level debugging.
35    fn try_match(input: &Self::Input) -> Result<MatchScore, FailureScore>;
36
37    /// Return descriptive information about the benchmark.
38    ///
39    /// If `input` is `None`, then high level information about the benchmark should be relayed.
40    /// If `input` is `Some`, and is an unsuccessful match, diagnostic information about what
41    /// was expected should be generated to help users.
42    fn description(
43        f: &mut std::fmt::Formatter<'_>,
44        input: Option<&Self::Input>,
45    ) -> std::fmt::Result;
46
47    /// Run the benchmark with `input`.
48    ///
49    /// All prints should be directed to `output`. The `checkpoint` is provided so
50    /// long-running benchmarks can periodically save output to prevent data loss due to
51    /// an early error.
52    ///
53    /// Implementors may assume that [`Self::try_match`] returned `Ok` on `input`.
54    fn run(
55        input: &Self::Input,
56        checkpoint: Checkpoint<'_>,
57        output: &mut dyn Output,
58    ) -> anyhow::Result<Self::Output>;
59}
60
61/// A refinement of [`Benchmark`], that supports before/after comparison of generated results.
62///
63/// Benchmarks are associated with a "tolerance" input, which may contain runtime values
64/// controlling the amount of slack a benchmark is allowed to have between runs before failing.
65///
66/// The semantics of pass or failure are left solely to the discretion of the [`Regression`]
67/// implementation.
68///
69/// See: [`register_regression`](crate::registry::Benchmarks::register_regression).
70pub trait Regression: Benchmark<Output: for<'a> Deserialize<'a>> {
71    /// The tolerance [`Input`] associated with this regression check.
72    type Tolerances: Input + 'static;
73
74    /// The report summary used to describe a successful regression check.
75    type Pass: Serialize + std::fmt::Display + 'static;
76
77    /// The report summary used to describe an unsuccessful regression check.
78    type Fail: Serialize + std::fmt::Display + 'static;
79
80    /// Run any regression checks necessary for two benchmark runs `before` and `after`.
81    /// Argument `tolerances` contain any tuned runtime tolerances to use when determining
82    /// whether or not a regression is detected.
83    ///
84    /// The `input` is the raw input that would have been provided to [`Benchmark::run`]
85    /// when generating the `before` and `after` outputs.
86    ///
87    /// Implementations of `check` should not attempt to print to `stdout` or any other
88    /// stream. Instead, all diagnostics should be encoded in the returned [`PassFail`] type
89    /// for reporting upstream.
90    fn check(
91        tolerances: &Self::Tolerances,
92        input: &Self::Input,
93        before: &Self::Output,
94        after: &Self::Output,
95    ) -> anyhow::Result<PassFail<Self::Pass, Self::Fail>>;
96}
97
98/// Describe whether or not a [`Regression`] passed or failed.
99#[derive(Debug, Clone, Copy)]
100pub enum PassFail<P, F> {
101    Pass(P),
102    Fail(F),
103}
104
105//////////////
106// Internal //
107//////////////
108
109pub(crate) mod internal {
110    use super::*;
111
112    use std::marker::PhantomData;
113
114    use anyhow::Context;
115    use thiserror::Error;
116
117    /// Object-safe trait for type-erased benchmarks stored in the registry.
118    pub(crate) trait Benchmark {
119        fn try_match(&self, input: &Any) -> Result<MatchScore, FailureScore>;
120
121        fn description(
122            &self,
123            f: &mut std::fmt::Formatter<'_>,
124            input: Option<&Any>,
125        ) -> std::fmt::Result;
126
127        fn run(
128            &self,
129            input: &Any,
130            checkpoint: Checkpoint<'_>,
131            output: &mut dyn Output,
132        ) -> anyhow::Result<serde_json::Value>;
133
134        /// If supported, return an object capable of running regression checks on this benchmark.
135        fn as_regression(&self) -> Option<&dyn Regression>;
136    }
137
138    pub(crate) struct Checked {
139        pub(crate) json: serde_json::Value,
140        pub(crate) display: Box<dyn std::fmt::Display>,
141    }
142
143    impl Checked {
144        /// Serialize `value` to `serde_json::Value` and box it for future display.
145        fn new<T>(value: T) -> Result<Self, serde_json::Error>
146        where
147            T: Serialize + std::fmt::Display + 'static,
148        {
149            Ok(Self {
150                json: serde_json::to_value(&value)?,
151                display: Box::new(value),
152            })
153        }
154    }
155
156    pub(crate) type CheckedPassFail = PassFail<Checked, Checked>;
157
158    pub(crate) trait Regression {
159        fn tolerance(&self) -> &dyn crate::input::DynInput;
160        fn input_tag(&self) -> &'static str;
161        fn check(
162            &self,
163            tolerances: &Any,
164            input: &Any,
165            before: &serde_json::Value,
166            after: &serde_json::Value,
167        ) -> anyhow::Result<CheckedPassFail>;
168    }
169
170    impl std::fmt::Debug for dyn Regression + '_ {
171        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172            f.debug_struct("dyn Regression")
173                .field("tolerance", &self.tolerance().tag())
174                .field("input_tag", &self.input_tag())
175                .finish()
176        }
177    }
178
179    pub(crate) trait AsRegression {
180        fn as_regression(&self) -> Option<&dyn Regression>;
181    }
182
183    #[derive(Debug, Clone)]
184    pub(crate) struct NoRegression;
185
186    impl AsRegression for NoRegression {
187        fn as_regression(&self) -> Option<&dyn Regression> {
188            None
189        }
190    }
191
192    #[derive(Debug, Clone, Copy)]
193    pub(crate) struct WithRegression<T>(PhantomData<T>);
194
195    impl<T> WithRegression<T> {
196        pub(crate) const fn new() -> Self {
197            Self(PhantomData)
198        }
199    }
200
201    impl<T> AsRegression for WithRegression<T>
202    where
203        T: super::Regression,
204    {
205        fn as_regression(&self) -> Option<&dyn Regression> {
206            Some(self)
207        }
208    }
209
210    impl<T> Regression for WithRegression<T>
211    where
212        T: super::Regression,
213    {
214        fn tolerance(&self) -> &dyn crate::input::DynInput {
215            &crate::input::Wrapper::<T::Tolerances>::INSTANCE
216        }
217
218        fn input_tag(&self) -> &'static str {
219            T::Input::tag()
220        }
221
222        fn check(
223            &self,
224            tolerance: &Any,
225            input: &Any,
226            before: &serde_json::Value,
227            after: &serde_json::Value,
228        ) -> anyhow::Result<CheckedPassFail> {
229            let tolerance = tolerance
230                .downcast_ref::<T::Tolerances>()
231                .ok_or_else(|| BadDownCast::new(T::Tolerances::tag(), tolerance.tag()))
232                .context("failed to obtain tolerance")?;
233
234            let input = input
235                .downcast_ref::<T::Input>()
236                .ok_or_else(|| BadDownCast::new(T::Input::tag(), input.tag()))
237                .context("failed to obtain input")?;
238
239            let before = T::Output::deserialize(before)
240                .map_err(|err| DeserializationError::new(Kind::Before, err))?;
241
242            let after = T::Output::deserialize(after)
243                .map_err(|err| DeserializationError::new(Kind::After, err))?;
244
245            let passfail = match T::check(tolerance, input, &before, &after)? {
246                PassFail::Pass(pass) => PassFail::Pass(Checked::new(pass)?),
247                PassFail::Fail(fail) => PassFail::Fail(Checked::new(fail)?),
248            };
249
250            Ok(passfail)
251        }
252    }
253
254    #[derive(Debug, Clone, Copy)]
255    pub(crate) struct Wrapper<T, R = NoRegression> {
256        regression: R,
257        _type: PhantomData<T>,
258    }
259
260    impl<T> Wrapper<T, NoRegression> {
261        pub(crate) const fn new() -> Self {
262            Self::new_with(NoRegression)
263        }
264    }
265
266    impl<T, R> Wrapper<T, R> {
267        pub(crate) const fn new_with(regression: R) -> Self {
268            Self {
269                regression,
270                _type: PhantomData,
271            }
272        }
273    }
274
275    /// The score given to unsuccessful downcasts in [`Benchmark::try_match`].
276    const MATCH_FAIL: FailureScore = FailureScore(10_000);
277
278    impl<T, R> Benchmark for Wrapper<T, R>
279    where
280        T: super::Benchmark,
281        R: AsRegression,
282    {
283        fn try_match(&self, input: &Any) -> Result<MatchScore, FailureScore> {
284            if let Some(cast) = input.downcast_ref::<T::Input>() {
285                T::try_match(cast)
286            } else {
287                Err(MATCH_FAIL)
288            }
289        }
290
291        fn description(
292            &self,
293            f: &mut std::fmt::Formatter<'_>,
294            input: Option<&Any>,
295        ) -> std::fmt::Result {
296            match input {
297                Some(input) => match input.downcast_ref::<T::Input>() {
298                    Some(cast) => T::description(f, Some(cast)),
299                    None => write!(
300                        f,
301                        "expected tag \"{}\" - instead got \"{}\"",
302                        T::Input::tag(),
303                        input.tag(),
304                    ),
305                },
306                None => {
307                    writeln!(f, "tag \"{}\"", <T::Input as Input>::tag())?;
308                    T::description(f, None)
309                }
310            }
311        }
312
313        fn run(
314            &self,
315            input: &Any,
316            checkpoint: Checkpoint<'_>,
317            output: &mut dyn Output,
318        ) -> anyhow::Result<serde_json::Value> {
319            match input.downcast_ref::<T::Input>() {
320                Some(input) => {
321                    let result = T::run(input, checkpoint, output)?;
322                    Ok(serde_json::to_value(result)?)
323                }
324                None => Err(BadDownCast::new(T::Input::tag(), input.tag()).into()),
325            }
326        }
327
328        // Extensions
329        fn as_regression(&self) -> Option<&dyn Regression> {
330            self.regression.as_regression()
331        }
332    }
333
334    //--------//
335    // Errors //
336    //--------//
337
338    #[derive(Debug, Clone, Copy, Error)]
339    #[error(
340        "INTERNAL ERROR: bad downcast - expected \"{}\" but got \"{}\"",
341        self.expected,
342        self.got
343    )]
344    struct BadDownCast {
345        expected: &'static str,
346        got: &'static str,
347    }
348
349    impl BadDownCast {
350        fn new(expected: &'static str, got: &'static str) -> Self {
351            Self { expected, got }
352        }
353    }
354
355    #[derive(Debug, Error)]
356    #[error(
357        "the \"{}\" results do not match the output schema expected by this benchmark",
358        self.kind
359    )]
360    struct DeserializationError {
361        kind: Kind,
362        source: serde_json::Error,
363    }
364
365    impl DeserializationError {
366        fn new(kind: Kind, source: serde_json::Error) -> Self {
367            Self { kind, source }
368        }
369    }
370
371    #[derive(Debug, Clone, Copy)]
372    enum Kind {
373        Before,
374        After,
375    }
376
377    impl std::fmt::Display for Kind {
378        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
379            let as_str = match self {
380                Self::Before => "before",
381                Self::After => "after",
382            };
383
384            write!(f, "{}", as_str)
385        }
386    }
387}