Skip to main content

diskann_benchmark_runner/
registry.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::collections::{hash_map::Entry, HashMap};
7
8use thiserror::Error;
9
10use crate::{
11    benchmark::{self, Benchmark, Regression},
12    dispatcher::{FailureScore, MatchScore},
13    input, Any, Checkpoint, Input, Output,
14};
15
16/// A collection of [`crate::Input`].
17pub struct Inputs {
18    // Inputs keyed by their tag type.
19    inputs: HashMap<&'static str, Box<dyn input::DynInput>>,
20}
21
22impl Inputs {
23    /// Construct a new empty [`Inputs`] registry.
24    pub fn new() -> Self {
25        Self {
26            inputs: HashMap::new(),
27        }
28    }
29
30    /// Return the input with the registered `tag` if present. Otherwise, return `None`.
31    pub fn get(&self, tag: &str) -> Option<input::Registered<'_>> {
32        self.inputs.get(tag).map(|v| input::Registered(&**v))
33    }
34
35    /// Register the [`Input`] `T` in the registry.
36    ///
37    /// Returns an error if any other input with the same [`Input::tag()`] has been registered
38    /// while leaving the underlying registry unchanged.
39    pub fn register<T>(&mut self) -> anyhow::Result<()>
40    where
41        T: Input + 'static,
42    {
43        let tag = T::tag();
44        match self.inputs.entry(tag) {
45            Entry::Vacant(entry) => {
46                entry.insert(Box::new(crate::input::Wrapper::<T>::new()));
47                Ok(())
48            }
49            Entry::Occupied(_) => {
50                #[derive(Debug, Error)]
51                #[error("An input with the tag \"{}\" already exists", self.0)]
52                struct AlreadyExists(&'static str);
53
54                Err(anyhow::anyhow!(AlreadyExists(tag)))
55            }
56        }
57    }
58
59    /// Return an iterator over all registered input tags in an unspecified order.
60    pub fn tags(&self) -> impl ExactSizeIterator<Item = &'static str> + use<'_> {
61        self.inputs.keys().copied()
62    }
63}
64
65impl Default for Inputs {
66    fn default() -> Self {
67        Self::new()
68    }
69}
70
71/// A registered benchmark entry: a name paired with a type-erased benchmark.
72pub(crate) struct RegisteredBenchmark {
73    name: String,
74    benchmark: Box<dyn benchmark::internal::Benchmark>,
75}
76
77impl std::fmt::Debug for RegisteredBenchmark {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        let benchmark = Capture(&*self.benchmark, None);
80        f.debug_struct("RegisteredBenchmark")
81            .field("name", &self.name)
82            .field("benchmark", &benchmark)
83            .finish()
84    }
85}
86
87impl RegisteredBenchmark {
88    pub(crate) fn name(&self) -> &str {
89        &self.name
90    }
91
92    pub(crate) fn benchmark(&self) -> &dyn benchmark::internal::Benchmark {
93        &*self.benchmark
94    }
95}
96
97/// A collection of registered benchmarks.
98pub struct Benchmarks {
99    benchmarks: Vec<RegisteredBenchmark>,
100}
101
102impl Benchmarks {
103    /// Return a new empty registry.
104    pub fn new() -> Self {
105        Self {
106            benchmarks: Vec::new(),
107        }
108    }
109
110    /// Register a new benchmark with the given name.
111    pub fn register<T>(&mut self, name: impl Into<String>, benchmark: T)
112    where
113        T: Benchmark,
114    {
115        self.benchmarks.push(RegisteredBenchmark {
116            name: name.into(),
117            benchmark: Box::new(benchmark::internal::Wrapper::<T, _>::new(
118                benchmark,
119                benchmark::internal::NoRegression,
120            )),
121        });
122    }
123
124    /// Return an iterator over registered benchmark names and their descriptions.
125    pub(crate) fn names(&self) -> impl ExactSizeIterator<Item = (&str, String)> {
126        self.benchmarks.iter().map(|entry| {
127            (
128                entry.name.as_str(),
129                Capture(&*entry.benchmark, None).to_string(),
130            )
131        })
132    }
133
134    /// Return `true` if `job` matches with any registered benchmark. Otherwise, return `false`.
135    pub fn has_match(&self, job: &Any) -> bool {
136        self.find_best_match(job).is_some()
137    }
138
139    /// Attempt to run the best matching benchmark for `job`.
140    ///
141    /// Returns the results of the benchmark if successful.
142    ///
143    /// Errors if a suitable method could not be found or if the invoked benchmark failed.
144    pub fn call(
145        &self,
146        job: &Any,
147        checkpoint: Checkpoint<'_>,
148        output: &mut dyn Output,
149    ) -> anyhow::Result<serde_json::Value> {
150        match self.find_best_match(job) {
151            Some(entry) => entry.benchmark.run(job, checkpoint, output),
152            None => Err(anyhow::Error::msg(
153                "could not find a matching benchmark for the given input",
154            )),
155        }
156    }
157
158    /// Attempt to debug reasons for a missed dispatch, returning at most `max_methods`
159    /// reasons.
160    ///
161    /// Returns `Ok(())` if a match was found.
162    pub fn debug(&self, job: &Any, max_methods: usize) -> Result<(), Vec<Mismatch>> {
163        if self.has_match(job) {
164            return Ok(());
165        }
166
167        // Collect all failures with their scores, sorted by score (best near-misses first).
168        let mut failures: Vec<(&RegisteredBenchmark, FailureScore)> = self
169            .benchmarks
170            .iter()
171            .filter_map(|entry| match entry.benchmark.try_match(job) {
172                Ok(_) => None,
173                Err(score) => Some((entry, score)),
174            })
175            .collect();
176
177        failures.sort_by_key(|(_, score)| *score);
178        failures.truncate(max_methods);
179
180        let mismatches = failures
181            .into_iter()
182            .map(|(entry, _)| {
183                let reason = Capture(&*entry.benchmark, Some(job)).to_string();
184
185                Mismatch {
186                    method: entry.name.clone(),
187                    reason,
188                }
189            })
190            .collect();
191
192        Err(mismatches)
193    }
194
195    /// Find the best matching benchmark for `job` by score.
196    fn find_best_match(&self, job: &Any) -> Option<&RegisteredBenchmark> {
197        self.benchmarks
198            .iter()
199            .filter_map(|entry| {
200                entry
201                    .benchmark
202                    .try_match(job)
203                    .ok()
204                    .map(|score| (entry, score))
205            })
206            .min_by_key(|(_, score)| *score)
207            .map(|(entry, _)| entry)
208    }
209
210    //-------------------//
211    // Regression Checks //
212    //-------------------//
213
214    /// Register a regression-checkable benchmark with the associated name.
215    ///
216    /// Upon registration, the associated [`Regression::Tolerances`] input and the benchmark
217    /// itself will be reachable via [`Check`](crate::app::Check).
218    pub fn register_regression<T>(&mut self, name: impl Into<String>, benchmark: T)
219    where
220        T: Regression,
221    {
222        let registered = benchmark::internal::Wrapper::<T, _>::new(
223            benchmark,
224            benchmark::internal::WithRegression,
225        );
226        self.benchmarks.push(RegisteredBenchmark {
227            name: name.into(),
228            benchmark: Box::new(registered),
229        });
230    }
231
232    /// Return a collection of all tolerance related inputs, keyed by the input tag type
233    /// of the tolerance.
234    pub(crate) fn tolerances(&self) -> HashMap<&'static str, RegisteredTolerance<'_>> {
235        let mut tolerances = HashMap::<&'static str, RegisteredTolerance<'_>>::new();
236        for b in self.benchmarks.iter() {
237            if let Some(regression) = b.benchmark.as_regression() {
238                // If a tolerance input already exists - then simply add this benchmark
239                // to the list of benchmarks associated with the tolerance.
240                //
241                // Otherwise, create a new entry.
242                let t = regression.tolerance();
243                let packaged = RegressionBenchmark {
244                    benchmark: b,
245                    regression,
246                };
247
248                match tolerances.entry(t.tag()) {
249                    Entry::Occupied(occupied) => occupied.into_mut().regressions.push(packaged),
250                    Entry::Vacant(vacant) => {
251                        vacant.insert(RegisteredTolerance {
252                            tolerance: input::Registered(t),
253                            regressions: vec![packaged],
254                        });
255                    }
256                }
257            }
258        }
259
260        tolerances
261    }
262}
263
264impl Default for Benchmarks {
265    fn default() -> Self {
266        Self::new()
267    }
268}
269
270/// Document the reason for a method matching failure.
271pub struct Mismatch {
272    method: String,
273    reason: String,
274}
275
276impl Mismatch {
277    /// Return the name of the benchmark that we failed to match.
278    pub fn method(&self) -> &str {
279        &self.method
280    }
281
282    /// Return the reason why this method was not a match.
283    pub fn reason(&self) -> &str {
284        &self.reason
285    }
286}
287
288//----------//
289// Internal //
290//----------//
291
292#[derive(Debug, Clone, Copy)]
293pub(crate) struct RegressionBenchmark<'a> {
294    benchmark: &'a RegisteredBenchmark,
295    regression: &'a dyn benchmark::internal::Regression,
296}
297
298impl RegressionBenchmark<'_> {
299    pub(crate) fn name(&self) -> &str {
300        self.benchmark.name()
301    }
302
303    pub(crate) fn input_tag(&self) -> &'static str {
304        self.regression.input_tag()
305    }
306
307    pub(crate) fn try_match(&self, input: &Any) -> Result<MatchScore, FailureScore> {
308        self.benchmark.benchmark().try_match(input)
309    }
310
311    pub(crate) fn check(
312        &self,
313        tolerance: &Any,
314        input: &Any,
315        before: &serde_json::Value,
316        after: &serde_json::Value,
317    ) -> anyhow::Result<benchmark::internal::CheckedPassFail> {
318        self.regression.check(tolerance, input, before, after)
319    }
320}
321
322#[derive(Debug)]
323pub(crate) struct RegisteredTolerance<'a> {
324    /// The tolerance parser.
325    pub(crate) tolerance: input::Registered<'a>,
326
327    /// A single tolerance input can apply to multiple benchmarks. This field records all
328    /// such benchmarks that are available in the registry that use this tolerance.
329    pub(crate) regressions: Vec<RegressionBenchmark<'a>>,
330}
331
332/// Helper to capture a `Benchmark::description` call into a `String` via `Display`.
333struct Capture<'a>(&'a dyn benchmark::internal::Benchmark, Option<&'a Any>);
334
335impl std::fmt::Display for Capture<'_> {
336    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337        self.0.description(f, self.1)
338    }
339}
340
341impl std::fmt::Debug for Capture<'_> {
342    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
343        self.0.description(f, self.1)
344    }
345}