Skip to main content

diskann_benchmark_runner/
registry.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::collections::{hash_map::Entry, HashMap};
7
8use thiserror::Error;
9
10use crate::{
11    benchmark::{self, Benchmark, Regression},
12    dispatcher::{FailureScore, MatchScore},
13    input, Any, Checkpoint, Input, Output,
14};
15
16/// A collection of [`crate::Input`].
17pub struct Inputs {
18    // Inputs keyed by their tag type.
19    inputs: HashMap<&'static str, Box<dyn input::DynInput>>,
20}
21
22impl Inputs {
23    /// Construct a new empty [`Inputs`] registry.
24    pub fn new() -> Self {
25        Self {
26            inputs: HashMap::new(),
27        }
28    }
29
30    /// Return the input with the registered `tag` if present. Otherwise, return `None`.
31    pub fn get(&self, tag: &str) -> Option<input::Registered<'_>> {
32        self.inputs.get(tag).map(|v| input::Registered(&**v))
33    }
34
35    /// Register the [`Input`] `T` in the registry.
36    ///
37    /// Returns an error if any other input with the same [`Input::tag()`] has been registered
38    /// while leaving the underlying registry unchanged.
39    pub fn register<T>(&mut self) -> anyhow::Result<()>
40    where
41        T: Input + 'static,
42    {
43        let tag = T::tag();
44        match self.inputs.entry(tag) {
45            Entry::Vacant(entry) => {
46                entry.insert(Box::new(crate::input::Wrapper::<T>::new()));
47                Ok(())
48            }
49            Entry::Occupied(_) => {
50                #[derive(Debug, Error)]
51                #[error("An input with the tag \"{}\" already exists", self.0)]
52                struct AlreadyExists(&'static str);
53
54                Err(anyhow::anyhow!(AlreadyExists(tag)))
55            }
56        }
57    }
58
59    /// Return an iterator over all registered input tags in an unspecified order.
60    pub fn tags(&self) -> impl ExactSizeIterator<Item = &'static str> + use<'_> {
61        self.inputs.keys().copied()
62    }
63}
64
65impl Default for Inputs {
66    fn default() -> Self {
67        Self::new()
68    }
69}
70
71/// A registered benchmark entry: a name paired with a type-erased benchmark.
72pub(crate) struct RegisteredBenchmark {
73    name: String,
74    benchmark: Box<dyn benchmark::internal::Benchmark>,
75}
76
77impl std::fmt::Debug for RegisteredBenchmark {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        let benchmark = Capture(&*self.benchmark, None);
80        f.debug_struct("RegisteredBenchmark")
81            .field("name", &self.name)
82            .field("benchmark", &benchmark)
83            .finish()
84    }
85}
86
87impl RegisteredBenchmark {
88    pub(crate) fn name(&self) -> &str {
89        &self.name
90    }
91
92    pub(crate) fn benchmark(&self) -> &dyn benchmark::internal::Benchmark {
93        &*self.benchmark
94    }
95}
96
97/// A collection of registered benchmarks.
98pub struct Benchmarks {
99    benchmarks: Vec<RegisteredBenchmark>,
100}
101
102impl Benchmarks {
103    /// Return a new empty registry.
104    pub fn new() -> Self {
105        Self {
106            benchmarks: Vec::new(),
107        }
108    }
109
110    /// Register a new benchmark with the given name.
111    pub fn register<T>(&mut self, name: impl Into<String>)
112    where
113        T: Benchmark + 'static,
114    {
115        self.benchmarks.push(RegisteredBenchmark {
116            name: name.into(),
117            benchmark: Box::new(benchmark::internal::Wrapper::<T>::new()),
118        });
119    }
120
121    /// Return an iterator over registered benchmark names and their descriptions.
122    pub(crate) fn names(&self) -> impl ExactSizeIterator<Item = (&str, String)> {
123        self.benchmarks.iter().map(|entry| {
124            (
125                entry.name.as_str(),
126                Capture(&*entry.benchmark, None).to_string(),
127            )
128        })
129    }
130
131    /// Return `true` if `job` matches with any registered benchmark. Otherwise, return `false`.
132    pub fn has_match(&self, job: &Any) -> bool {
133        self.find_best_match(job).is_some()
134    }
135
136    /// Attempt to run the best matching benchmark for `job`.
137    ///
138    /// Returns the results of the benchmark if successful.
139    ///
140    /// Errors if a suitable method could not be found or if the invoked benchmark failed.
141    pub fn call(
142        &self,
143        job: &Any,
144        checkpoint: Checkpoint<'_>,
145        output: &mut dyn Output,
146    ) -> anyhow::Result<serde_json::Value> {
147        match self.find_best_match(job) {
148            Some(entry) => entry.benchmark.run(job, checkpoint, output),
149            None => Err(anyhow::Error::msg(
150                "could not find a matching benchmark for the given input",
151            )),
152        }
153    }
154
155    /// Attempt to debug reasons for a missed dispatch, returning at most `max_methods`
156    /// reasons.
157    ///
158    /// Returns `Ok(())` if a match was found.
159    pub fn debug(&self, job: &Any, max_methods: usize) -> Result<(), Vec<Mismatch>> {
160        if self.has_match(job) {
161            return Ok(());
162        }
163
164        // Collect all failures with their scores, sorted by score (best near-misses first).
165        let mut failures: Vec<(&RegisteredBenchmark, FailureScore)> = self
166            .benchmarks
167            .iter()
168            .filter_map(|entry| match entry.benchmark.try_match(job) {
169                Ok(_) => None,
170                Err(score) => Some((entry, score)),
171            })
172            .collect();
173
174        failures.sort_by_key(|(_, score)| *score);
175        failures.truncate(max_methods);
176
177        let mismatches = failures
178            .into_iter()
179            .map(|(entry, _)| {
180                let reason = Capture(&*entry.benchmark, Some(job)).to_string();
181
182                Mismatch {
183                    method: entry.name.clone(),
184                    reason,
185                }
186            })
187            .collect();
188
189        Err(mismatches)
190    }
191
192    /// Find the best matching benchmark for `job` by score.
193    fn find_best_match(&self, job: &Any) -> Option<&RegisteredBenchmark> {
194        self.benchmarks
195            .iter()
196            .filter_map(|entry| {
197                entry
198                    .benchmark
199                    .try_match(job)
200                    .ok()
201                    .map(|score| (entry, score))
202            })
203            .min_by_key(|(_, score)| *score)
204            .map(|(entry, _)| entry)
205    }
206
207    //-------------------//
208    // Regression Checks //
209    //-------------------//
210
211    /// Register a regression-checkable benchmark with the associated name.
212    ///
213    /// Upon registration, the associated [`Regression::Tolerances`] input and the benchmark
214    /// itself will be reachable via [`Check`](crate::app::Check).
215    pub fn register_regression<T>(&mut self, name: impl Into<String>)
216    where
217        T: Regression + 'static,
218    {
219        let registered = benchmark::internal::Wrapper::<T, _>::new_with(
220            benchmark::internal::WithRegression::<T>::new(),
221        );
222        self.benchmarks.push(RegisteredBenchmark {
223            name: name.into(),
224            benchmark: Box::new(registered),
225        });
226    }
227
228    /// Return a collection of all tolerance related inputs, keyed by the input tag type
229    /// of the tolerance.
230    pub(crate) fn tolerances(&self) -> HashMap<&'static str, RegisteredTolerance<'_>> {
231        let mut tolerances = HashMap::<&'static str, RegisteredTolerance<'_>>::new();
232        for b in self.benchmarks.iter() {
233            if let Some(regression) = b.benchmark.as_regression() {
234                // If a tolerance input already exists - then simply add this benchmark
235                // to the list of benchmarks associated with the tolerance.
236                //
237                // Otherwise, create a new entry.
238                let t = regression.tolerance();
239                let packaged = RegressionBenchmark {
240                    benchmark: b,
241                    regression,
242                };
243
244                match tolerances.entry(t.tag()) {
245                    Entry::Occupied(occupied) => occupied.into_mut().regressions.push(packaged),
246                    Entry::Vacant(vacant) => {
247                        vacant.insert(RegisteredTolerance {
248                            tolerance: input::Registered(t),
249                            regressions: vec![packaged],
250                        });
251                    }
252                }
253            }
254        }
255
256        tolerances
257    }
258}
259
260impl Default for Benchmarks {
261    fn default() -> Self {
262        Self::new()
263    }
264}
265
266/// Document the reason for a method matching failure.
267pub struct Mismatch {
268    method: String,
269    reason: String,
270}
271
272impl Mismatch {
273    /// Return the name of the benchmark that we failed to match.
274    pub fn method(&self) -> &str {
275        &self.method
276    }
277
278    /// Return the reason why this method was not a match.
279    pub fn reason(&self) -> &str {
280        &self.reason
281    }
282}
283
284//----------//
285// Internal //
286//----------//
287
288#[derive(Debug, Clone, Copy)]
289pub(crate) struct RegressionBenchmark<'a> {
290    benchmark: &'a RegisteredBenchmark,
291    regression: &'a dyn benchmark::internal::Regression,
292}
293
294impl RegressionBenchmark<'_> {
295    pub(crate) fn name(&self) -> &str {
296        self.benchmark.name()
297    }
298
299    pub(crate) fn input_tag(&self) -> &'static str {
300        self.regression.input_tag()
301    }
302
303    pub(crate) fn try_match(&self, input: &Any) -> Result<MatchScore, FailureScore> {
304        self.benchmark.benchmark().try_match(input)
305    }
306
307    pub(crate) fn check(
308        &self,
309        tolerance: &Any,
310        input: &Any,
311        before: &serde_json::Value,
312        after: &serde_json::Value,
313    ) -> anyhow::Result<benchmark::internal::CheckedPassFail> {
314        self.regression.check(tolerance, input, before, after)
315    }
316}
317
318#[derive(Debug)]
319pub(crate) struct RegisteredTolerance<'a> {
320    /// The tolerance parser.
321    pub(crate) tolerance: input::Registered<'a>,
322
323    /// A single tolerance input can apply to multiple benchmarks. This field records all
324    /// such benchmarks that are available in the registry that use this tolerance.
325    pub(crate) regressions: Vec<RegressionBenchmark<'a>>,
326}
327
328/// Helper to capture a `Benchmark::description` call into a `String` via `Display`.
329struct Capture<'a>(&'a dyn benchmark::internal::Benchmark, Option<&'a Any>);
330
331impl std::fmt::Display for Capture<'_> {
332    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
333        self.0.description(f, self.1)
334    }
335}
336
337impl std::fmt::Debug for Capture<'_> {
338    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
339        self.0.description(f, self.1)
340    }
341}