diskann_benchmark_runner/
registry.rs1use std::collections::{hash_map::Entry, HashMap};
7
8use thiserror::Error;
9
10use crate::{
11 benchmark::{self, Benchmark, FailureScore, MatchScore, Regression},
12 input, Checkpoint, Input, Output,
13};
14
15pub(crate) struct RegisteredBenchmark {
17 name: String,
18 benchmark: Box<dyn benchmark::internal::Benchmark>,
19}
20
21impl std::fmt::Debug for RegisteredBenchmark {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 let benchmark = Capture(&*self.benchmark, None);
24 f.debug_struct("RegisteredBenchmark")
25 .field("name", &self.name)
26 .field("benchmark", &benchmark)
27 .finish()
28 }
29}
30
31impl RegisteredBenchmark {
32 pub(crate) fn name(&self) -> &str {
33 &self.name
34 }
35
36 pub(crate) fn benchmark(&self) -> &dyn benchmark::internal::Benchmark {
37 &*self.benchmark
38 }
39}
40
41pub struct Registry {
43 inputs: HashMap<&'static str, Box<dyn input::internal::DynInput>>,
45 benchmarks: Vec<RegisteredBenchmark>,
46}
47
48impl Registry {
49 pub fn new() -> Self {
51 Self {
52 inputs: HashMap::new(),
53 benchmarks: Vec::new(),
54 }
55 }
56
57 pub fn input(&self, tag: &str) -> Option<input::Registered<'_>> {
64 self._input(tag).map(input::Registered)
65 }
66
67 pub fn input_tags(&self) -> impl ExactSizeIterator<Item = &'static str> + use<'_> {
69 self.inputs.keys().copied()
70 }
71
72 pub fn register<T>(
78 &mut self,
79 name: impl Into<String>,
80 benchmark: T,
81 ) -> Result<(), RegistryError>
82 where
83 T: Benchmark,
84 {
85 self.register_input::<T::Input>()?;
86
87 self.benchmarks.push(RegisteredBenchmark {
88 name: name.into(),
89 benchmark: Box::new(benchmark::internal::Wrapper::<T, _>::new(
90 benchmark,
91 benchmark::internal::NoRegression,
92 )),
93 });
94 Ok(())
95 }
96
97 pub(crate) fn names(&self) -> impl ExactSizeIterator<Item = (&str, String)> {
99 self.benchmarks.iter().map(|entry| {
100 (
101 entry.name.as_str(),
102 Capture(&*entry.benchmark, None).to_string(),
103 )
104 })
105 }
106
107 pub(crate) fn has_match(&self, job: &input::internal::Any) -> bool {
109 self.find_best_match(job).is_some()
110 }
111
112 pub(crate) fn call(
118 &self,
119 job: &input::internal::Any,
120 checkpoint: Checkpoint<'_>,
121 output: &mut dyn Output,
122 ) -> anyhow::Result<serde_json::Value> {
123 match self.find_best_match(job) {
124 Some(entry) => entry.benchmark.run(job, checkpoint, output),
125 None => Err(anyhow::Error::msg(
126 "could not find a matching benchmark for the given input",
127 )),
128 }
129 }
130
131 pub(crate) fn debug(
136 &self,
137 job: &input::internal::Any,
138 max_methods: usize,
139 ) -> Result<(), Vec<Mismatch>> {
140 if self.has_match(job) {
141 return Ok(());
142 }
143
144 let mut failures: Vec<(&RegisteredBenchmark, FailureScore)> = self
146 .benchmarks
147 .iter()
148 .filter_map(|entry| match entry.benchmark.try_match(job) {
149 Ok(_) => None,
150 Err(score) => Some((entry, score)),
151 })
152 .collect();
153
154 failures.sort_by_key(|(_, score)| *score);
155 failures.truncate(max_methods);
156
157 let mismatches = failures
158 .into_iter()
159 .map(|(entry, _)| {
160 let reason = Capture(&*entry.benchmark, Some(job)).to_string();
161
162 Mismatch {
163 method: entry.name.clone(),
164 reason,
165 }
166 })
167 .collect();
168
169 Err(mismatches)
170 }
171
172 fn find_best_match(&self, job: &input::internal::Any) -> Option<&RegisteredBenchmark> {
174 self.benchmarks
175 .iter()
176 .filter_map(|entry| {
177 entry
178 .benchmark
179 .try_match(job)
180 .ok()
181 .map(|score| (entry, score))
182 })
183 .min_by_key(|(_, score)| *score)
184 .map(|(entry, _)| entry)
185 }
186
187 fn _input(&self, tag: &str) -> Option<&dyn input::internal::DynInput> {
188 self.inputs.get(tag).map(|v| &**v)
189 }
190
191 fn register_input<T>(&mut self) -> Result<(), RegistryError>
192 where
193 T: Input + 'static,
194 {
195 let tag = T::tag();
196 let wrapper = crate::input::internal::Wrapper::<T>::new();
197 match self.inputs.entry(tag) {
198 Entry::Vacant(v) => {
199 v.insert(Box::new(wrapper));
200 Ok(())
201 }
202 Entry::Occupied(o) => {
203 use input::internal::DynInput;
204
205 if o.get().as_any().is::<crate::input::internal::Wrapper<T>>() {
206 Ok(())
207 } else {
208 Err(RegistryError {
209 tag,
210 existing: o.get().type_name(),
211 new: wrapper.type_name(),
212 })
213 }
214 }
215 }
216 }
217
218 pub fn register_regression<T>(
231 &mut self,
232 name: impl Into<String>,
233 benchmark: T,
234 ) -> Result<(), RegistryError>
235 where
236 T: Regression,
237 {
238 self.register_input::<T::Input>()?;
239
240 let registered = benchmark::internal::Wrapper::<T, _>::new(
241 benchmark,
242 benchmark::internal::WithRegression,
243 );
244 self.benchmarks.push(RegisteredBenchmark {
245 name: name.into(),
246 benchmark: Box::new(registered),
247 });
248
249 Ok(())
250 }
251
252 pub(crate) fn tolerances(&self) -> HashMap<&'static str, RegisteredTolerance<'_>> {
255 let mut tolerances = HashMap::<&'static str, RegisteredTolerance<'_>>::new();
256 for b in self.benchmarks.iter() {
257 if let Some(regression) = b.benchmark.as_regression() {
258 let t = regression.tolerance();
263 let packaged = RegressionBenchmark {
264 benchmark: b,
265 regression,
266 };
267
268 match tolerances.entry(t.tag()) {
269 Entry::Occupied(occupied) => occupied.into_mut().regressions.push(packaged),
270 Entry::Vacant(vacant) => {
271 vacant.insert(RegisteredTolerance {
272 tolerance: input::Registered(t),
273 regressions: vec![packaged],
274 });
275 }
276 }
277 }
278 }
279
280 tolerances
281 }
282}
283
284impl Default for Registry {
285 fn default() -> Self {
286 Self::new()
287 }
288}
289
290#[derive(Debug, Error)]
292#[error(
293 "A different input with tag \"{}\" was already registered. Existing type: \"{}\". New type: \"{}\"",
294 self.tag,
295 self.existing,
296 self.new,
297)]
298pub struct RegistryError {
299 tag: &'static str,
300 existing: &'static str,
301 new: &'static str,
302}
303
304pub struct Mismatch {
306 method: String,
307 reason: String,
308}
309
310impl Mismatch {
311 pub fn method(&self) -> &str {
313 &self.method
314 }
315
316 pub fn reason(&self) -> &str {
318 &self.reason
319 }
320}
321
322#[derive(Debug, Clone, Copy)]
327pub(crate) struct RegressionBenchmark<'a> {
328 benchmark: &'a RegisteredBenchmark,
329 regression: &'a dyn benchmark::internal::Regression,
330}
331
332impl RegressionBenchmark<'_> {
333 pub(crate) fn name(&self) -> &str {
334 self.benchmark.name()
335 }
336
337 pub(crate) fn input_tag(&self) -> &'static str {
338 self.regression.input_tag()
339 }
340
341 pub(crate) fn try_match(
342 &self,
343 input: &input::internal::Any,
344 ) -> Result<MatchScore, FailureScore> {
345 self.benchmark.benchmark().try_match(input)
346 }
347
348 pub(crate) fn check(
349 &self,
350 tolerance: &input::internal::Any,
351 input: &input::internal::Any,
352 before: &serde_json::Value,
353 after: &serde_json::Value,
354 ) -> anyhow::Result<benchmark::internal::CheckedPassFail> {
355 self.regression.check(tolerance, input, before, after)
356 }
357}
358
359#[derive(Debug, Clone)]
360pub(crate) struct RegisteredTolerance<'a> {
361 pub(crate) tolerance: input::Registered<'a>,
363
364 pub(crate) regressions: Vec<RegressionBenchmark<'a>>,
367}
368
369struct Capture<'a>(
371 &'a dyn benchmark::internal::Benchmark,
372 Option<&'a input::internal::Any>,
373);
374
375impl std::fmt::Display for Capture<'_> {
376 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
377 self.0.description(f, self.1)
378 }
379}
380
381impl std::fmt::Debug for Capture<'_> {
382 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
383 self.0.description(f, self.1)
384 }
385}
386
387#[cfg(test)]
392mod tests {
393 use super::*;
394
395 use crate::{input, Checker};
396
397 macro_rules! input {
398 ($T:ident, $tag:literal) => {
399 #[derive(Debug)]
400 struct $T;
401
402 impl Input for $T {
403 type Raw = ();
404 fn tag() -> &'static str {
405 $tag
406 }
407 fn from_raw(_raw: Self::Raw, _checker: &mut Checker) -> anyhow::Result<$T> {
408 unimplemented!("this struct is for test only");
409 }
410 fn serialize(&self) -> anyhow::Result<serde_json::Value> {
411 unimplemented!("this struct is for test only");
412 }
413 fn example() -> Self::Raw {
414 unimplemented!("this struct is for test only");
415 }
416 }
417 };
418 }
419
420 input!(A, "type-a");
422 input!(B, "type-b");
423 input!(A2, "type-a");
424
425 #[test]
426 fn test_tag_conflicts() {
427 let mut registry = Registry::new();
428 registry.register_input::<A>().unwrap();
429 registry.register_input::<B>().unwrap();
430
431 let mut tags: Vec<_> = registry.input_tags().collect();
432 tags.sort();
433 assert_eq!(tags.as_slice(), ["type-a", "type-b"]);
434
435 {
436 let a = registry._input(A::tag()).unwrap();
437 assert!(a.as_any().is::<input::internal::Wrapper<A>>());
438
439 let name = a.type_name();
440 assert!(name.contains("A"), "{}", name);
441 }
442
443 {
444 let b = registry._input(B::tag()).unwrap();
445 assert!(b.as_any().is::<input::internal::Wrapper<B>>());
446
447 let name = b.type_name();
448 assert!(name.contains("B"), "{}", name);
449 }
450
451 let err = registry.register_input::<A2>().unwrap_err();
452 let msg = err.to_string();
453 assert!(
454 msg.contains("A different input with tag \"type-a\" was already registered"),
455 "FAILED: {}",
456 msg
457 );
458 }
459}