Skip to main content

diskann_benchmark_runner/
registry.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::collections::{hash_map::Entry, HashMap};
7
8use thiserror::Error;
9
10use crate::{
11    benchmark::{self, Benchmark, FailureScore, MatchScore, Regression},
12    input, Checkpoint, Input, Output,
13};
14
15/// A registered benchmark entry: a name paired with a type-erased benchmark.
16pub(crate) struct RegisteredBenchmark {
17    name: String,
18    benchmark: Box<dyn benchmark::internal::Benchmark>,
19}
20
21impl std::fmt::Debug for RegisteredBenchmark {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        let benchmark = Capture(&*self.benchmark, None);
24        f.debug_struct("RegisteredBenchmark")
25            .field("name", &self.name)
26            .field("benchmark", &benchmark)
27            .finish()
28    }
29}
30
31impl RegisteredBenchmark {
32    pub(crate) fn name(&self) -> &str {
33        &self.name
34    }
35
36    pub(crate) fn benchmark(&self) -> &dyn benchmark::internal::Benchmark {
37        &*self.benchmark
38    }
39}
40
41/// A collection of registered inputs and benchmarks.
42pub struct Registry {
43    // Inputs keyed by their tag type.
44    inputs: HashMap<&'static str, Box<dyn input::internal::DynInput>>,
45    benchmarks: Vec<RegisteredBenchmark>,
46}
47
48impl Registry {
49    /// Return a new empty registry.
50    pub fn new() -> Self {
51        Self {
52            inputs: HashMap::new(),
53            benchmarks: Vec::new(),
54        }
55    }
56
57    /// Return the input with the registered `tag` if present. Otherwise, return `None`.
58    ///
59    /// Inputs are automatically registered as a side-effect of:
60    ///
61    /// * [`register`](Self::register)
62    /// * [`register_regression`](Self::register_regression)
63    pub fn input(&self, tag: &str) -> Option<input::Registered<'_>> {
64        self._input(tag).map(input::Registered)
65    }
66
67    /// Return an iterator over all registered input tags in an unspecified order.
68    pub fn input_tags(&self) -> impl ExactSizeIterator<Item = &'static str> + use<'_> {
69        self.inputs.keys().copied()
70    }
71
72    /// Register a new `benchmark` with the given `name`.
73    ///
74    /// As a side-effect, the benchmark's [`Input`](Benchmark::Input) type is also registered.
75    /// Duplicate registrations of the same tag and type are allowed; mismatched types for the
76    /// same tag return an error.
77    pub fn register<T>(
78        &mut self,
79        name: impl Into<String>,
80        benchmark: T,
81    ) -> Result<(), RegistryError>
82    where
83        T: Benchmark,
84    {
85        self.register_input::<T::Input>()?;
86
87        self.benchmarks.push(RegisteredBenchmark {
88            name: name.into(),
89            benchmark: Box::new(benchmark::internal::Wrapper::<T, _>::new(
90                benchmark,
91                benchmark::internal::NoRegression,
92            )),
93        });
94        Ok(())
95    }
96
97    /// Return an iterator over registered benchmark names and their descriptions.
98    pub(crate) fn names(&self) -> impl ExactSizeIterator<Item = (&str, String)> {
99        self.benchmarks.iter().map(|entry| {
100            (
101                entry.name.as_str(),
102                Capture(&*entry.benchmark, None).to_string(),
103            )
104        })
105    }
106
107    /// Return `true` if `job` matches with any registered benchmark. Otherwise, return `false`.
108    pub(crate) fn has_match(&self, job: &input::internal::Any) -> bool {
109        self.find_best_match(job).is_some()
110    }
111
112    /// Attempt to run the best matching benchmark for `job`.
113    ///
114    /// Returns the results of the benchmark if successful.
115    ///
116    /// Errors if a suitable method could not be found or if the invoked benchmark failed.
117    pub(crate) fn call(
118        &self,
119        job: &input::internal::Any,
120        checkpoint: Checkpoint<'_>,
121        output: &mut dyn Output,
122    ) -> anyhow::Result<serde_json::Value> {
123        match self.find_best_match(job) {
124            Some(entry) => entry.benchmark.run(job, checkpoint, output),
125            None => Err(anyhow::Error::msg(
126                "could not find a matching benchmark for the given input",
127            )),
128        }
129    }
130
131    /// Attempt to debug reasons for a missed dispatch, returning at most `max_methods`
132    /// reasons.
133    ///
134    /// Returns `Ok(())` if a match was found.
135    pub(crate) fn debug(
136        &self,
137        job: &input::internal::Any,
138        max_methods: usize,
139    ) -> Result<(), Vec<Mismatch>> {
140        if self.has_match(job) {
141            return Ok(());
142        }
143
144        // Collect all failures with their scores, sorted by score (best near-misses first).
145        let mut failures: Vec<(&RegisteredBenchmark, FailureScore)> = self
146            .benchmarks
147            .iter()
148            .filter_map(|entry| match entry.benchmark.try_match(job) {
149                Ok(_) => None,
150                Err(score) => Some((entry, score)),
151            })
152            .collect();
153
154        failures.sort_by_key(|(_, score)| *score);
155        failures.truncate(max_methods);
156
157        let mismatches = failures
158            .into_iter()
159            .map(|(entry, _)| {
160                let reason = Capture(&*entry.benchmark, Some(job)).to_string();
161
162                Mismatch {
163                    method: entry.name.clone(),
164                    reason,
165                }
166            })
167            .collect();
168
169        Err(mismatches)
170    }
171
172    /// Find the best matching benchmark for `job` by score.
173    fn find_best_match(&self, job: &input::internal::Any) -> Option<&RegisteredBenchmark> {
174        self.benchmarks
175            .iter()
176            .filter_map(|entry| {
177                entry
178                    .benchmark
179                    .try_match(job)
180                    .ok()
181                    .map(|score| (entry, score))
182            })
183            .min_by_key(|(_, score)| *score)
184            .map(|(entry, _)| entry)
185    }
186
187    fn _input(&self, tag: &str) -> Option<&dyn input::internal::DynInput> {
188        self.inputs.get(tag).map(|v| &**v)
189    }
190
191    fn register_input<T>(&mut self) -> Result<(), RegistryError>
192    where
193        T: Input + 'static,
194    {
195        let tag = T::tag();
196        let wrapper = crate::input::internal::Wrapper::<T>::new();
197        match self.inputs.entry(tag) {
198            Entry::Vacant(v) => {
199                v.insert(Box::new(wrapper));
200                Ok(())
201            }
202            Entry::Occupied(o) => {
203                use input::internal::DynInput;
204
205                if o.get().as_any().is::<crate::input::internal::Wrapper<T>>() {
206                    Ok(())
207                } else {
208                    Err(RegistryError {
209                        tag,
210                        existing: o.get().type_name(),
211                        new: wrapper.type_name(),
212                    })
213                }
214            }
215        }
216    }
217
218    //-------------------//
219    // Regression Checks //
220    //-------------------//
221
222    /// Register a regression-checkable `benchmark` with the given `name`.
223    ///
224    /// As a side-effect, the benchmark's [`Input`](Benchmark::Input) type is also registered.
225    /// Duplicate registrations of the same tag and type are allowed; mismatched types for the
226    /// same tag return an error.
227    ///
228    /// Upon registration, the associated [`Regression::Tolerances`] input and the benchmark
229    /// itself will be reachable via [`Check`](crate::app::Check).
230    pub fn register_regression<T>(
231        &mut self,
232        name: impl Into<String>,
233        benchmark: T,
234    ) -> Result<(), RegistryError>
235    where
236        T: Regression,
237    {
238        self.register_input::<T::Input>()?;
239
240        let registered = benchmark::internal::Wrapper::<T, _>::new(
241            benchmark,
242            benchmark::internal::WithRegression,
243        );
244        self.benchmarks.push(RegisteredBenchmark {
245            name: name.into(),
246            benchmark: Box::new(registered),
247        });
248
249        Ok(())
250    }
251
252    /// Return a collection of all tolerance related inputs, keyed by the input tag type
253    /// of the tolerance.
254    pub(crate) fn tolerances(&self) -> HashMap<&'static str, RegisteredTolerance<'_>> {
255        let mut tolerances = HashMap::<&'static str, RegisteredTolerance<'_>>::new();
256        for b in self.benchmarks.iter() {
257            if let Some(regression) = b.benchmark.as_regression() {
258                // If a tolerance input already exists - then simply add this benchmark
259                // to the list of benchmarks associated with the tolerance.
260                //
261                // Otherwise, create a new entry.
262                let t = regression.tolerance();
263                let packaged = RegressionBenchmark {
264                    benchmark: b,
265                    regression,
266                };
267
268                match tolerances.entry(t.tag()) {
269                    Entry::Occupied(occupied) => occupied.into_mut().regressions.push(packaged),
270                    Entry::Vacant(vacant) => {
271                        vacant.insert(RegisteredTolerance {
272                            tolerance: input::Registered(t),
273                            regressions: vec![packaged],
274                        });
275                    }
276                }
277            }
278        }
279
280        tolerances
281    }
282}
283
284impl Default for Registry {
285    fn default() -> Self {
286        Self::new()
287    }
288}
289
290/// Error for [`Registry::register`] or [`Registry::register_regression`].
291#[derive(Debug, Error)]
292#[error(
293    "A different input with tag \"{}\" was already registered. Existing type: \"{}\". New type: \"{}\"",
294    self.tag,
295    self.existing,
296    self.new,
297)]
298pub struct RegistryError {
299    tag: &'static str,
300    existing: &'static str,
301    new: &'static str,
302}
303
304/// Document the reason for a method matching failure.
305pub struct Mismatch {
306    method: String,
307    reason: String,
308}
309
310impl Mismatch {
311    /// Return the name of the benchmark that we failed to match.
312    pub fn method(&self) -> &str {
313        &self.method
314    }
315
316    /// Return the reason why this method was not a match.
317    pub fn reason(&self) -> &str {
318        &self.reason
319    }
320}
321
322//----------//
323// Internal //
324//----------//
325
326#[derive(Debug, Clone, Copy)]
327pub(crate) struct RegressionBenchmark<'a> {
328    benchmark: &'a RegisteredBenchmark,
329    regression: &'a dyn benchmark::internal::Regression,
330}
331
332impl RegressionBenchmark<'_> {
333    pub(crate) fn name(&self) -> &str {
334        self.benchmark.name()
335    }
336
337    pub(crate) fn input_tag(&self) -> &'static str {
338        self.regression.input_tag()
339    }
340
341    pub(crate) fn try_match(
342        &self,
343        input: &input::internal::Any,
344    ) -> Result<MatchScore, FailureScore> {
345        self.benchmark.benchmark().try_match(input)
346    }
347
348    pub(crate) fn check(
349        &self,
350        tolerance: &input::internal::Any,
351        input: &input::internal::Any,
352        before: &serde_json::Value,
353        after: &serde_json::Value,
354    ) -> anyhow::Result<benchmark::internal::CheckedPassFail> {
355        self.regression.check(tolerance, input, before, after)
356    }
357}
358
359#[derive(Debug, Clone)]
360pub(crate) struct RegisteredTolerance<'a> {
361    /// The tolerance parser.
362    pub(crate) tolerance: input::Registered<'a>,
363
364    /// A single tolerance input can apply to multiple benchmarks. This field records all
365    /// such benchmarks that are available in the registry that use this tolerance.
366    pub(crate) regressions: Vec<RegressionBenchmark<'a>>,
367}
368
369/// Helper to capture a `Benchmark::description` call into a `String` via `Display`.
370struct Capture<'a>(
371    &'a dyn benchmark::internal::Benchmark,
372    Option<&'a input::internal::Any>,
373);
374
375impl std::fmt::Display for Capture<'_> {
376    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
377        self.0.description(f, self.1)
378    }
379}
380
381impl std::fmt::Debug for Capture<'_> {
382    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
383        self.0.description(f, self.1)
384    }
385}
386
387///////////
388// Tests //
389///////////
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394
395    use crate::{input, Checker};
396
397    macro_rules! input {
398        ($T:ident, $tag:literal) => {
399            #[derive(Debug)]
400            struct $T;
401
402            impl Input for $T {
403                type Raw = ();
404                fn tag() -> &'static str {
405                    $tag
406                }
407                fn from_raw(_raw: Self::Raw, _checker: &mut Checker) -> anyhow::Result<$T> {
408                    unimplemented!("this struct is for test only");
409                }
410                fn serialize(&self) -> anyhow::Result<serde_json::Value> {
411                    unimplemented!("this struct is for test only");
412                }
413                fn example() -> Self::Raw {
414                    unimplemented!("this struct is for test only");
415                }
416            }
417        };
418    }
419
420    // For the types below, `A` and `B` have distinct tags, but `A2`'s tag conflicts with `A`.
421    input!(A, "type-a");
422    input!(B, "type-b");
423    input!(A2, "type-a");
424
425    #[test]
426    fn test_tag_conflicts() {
427        let mut registry = Registry::new();
428        registry.register_input::<A>().unwrap();
429        registry.register_input::<B>().unwrap();
430
431        let mut tags: Vec<_> = registry.input_tags().collect();
432        tags.sort();
433        assert_eq!(tags.as_slice(), ["type-a", "type-b"]);
434
435        {
436            let a = registry._input(A::tag()).unwrap();
437            assert!(a.as_any().is::<input::internal::Wrapper<A>>());
438
439            let name = a.type_name();
440            assert!(name.contains("A"), "{}", name);
441        }
442
443        {
444            let b = registry._input(B::tag()).unwrap();
445            assert!(b.as_any().is::<input::internal::Wrapper<B>>());
446
447            let name = b.type_name();
448            assert!(name.contains("B"), "{}", name);
449        }
450
451        let err = registry.register_input::<A2>().unwrap_err();
452        let msg = err.to_string();
453        assert!(
454            msg.contains("A different input with tag \"type-a\" was already registered"),
455            "FAILED: {}",
456            msg
457        );
458    }
459}