diskann_benchmark_runner/
registry.rs1use std::collections::{hash_map::Entry, HashMap};
7
8use thiserror::Error;
9
10use crate::{
11 benchmark::{self, Benchmark, Regression},
12 dispatcher::{FailureScore, MatchScore},
13 input, Any, Checkpoint, Input, Output,
14};
15
16pub struct Inputs {
18 inputs: HashMap<&'static str, Box<dyn input::DynInput>>,
20}
21
22impl Inputs {
23 pub fn new() -> Self {
25 Self {
26 inputs: HashMap::new(),
27 }
28 }
29
30 pub fn get(&self, tag: &str) -> Option<input::Registered<'_>> {
32 self.inputs.get(tag).map(|v| input::Registered(&**v))
33 }
34
35 pub fn register<T>(&mut self) -> anyhow::Result<()>
40 where
41 T: Input + 'static,
42 {
43 let tag = T::tag();
44 match self.inputs.entry(tag) {
45 Entry::Vacant(entry) => {
46 entry.insert(Box::new(crate::input::Wrapper::<T>::new()));
47 Ok(())
48 }
49 Entry::Occupied(_) => {
50 #[derive(Debug, Error)]
51 #[error("An input with the tag \"{}\" already exists", self.0)]
52 struct AlreadyExists(&'static str);
53
54 Err(anyhow::anyhow!(AlreadyExists(tag)))
55 }
56 }
57 }
58
59 pub fn tags(&self) -> impl ExactSizeIterator<Item = &'static str> + use<'_> {
61 self.inputs.keys().copied()
62 }
63}
64
65impl Default for Inputs {
66 fn default() -> Self {
67 Self::new()
68 }
69}
70
71pub(crate) struct RegisteredBenchmark {
73 name: String,
74 benchmark: Box<dyn benchmark::internal::Benchmark>,
75}
76
77impl std::fmt::Debug for RegisteredBenchmark {
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 let benchmark = Capture(&*self.benchmark, None);
80 f.debug_struct("RegisteredBenchmark")
81 .field("name", &self.name)
82 .field("benchmark", &benchmark)
83 .finish()
84 }
85}
86
87impl RegisteredBenchmark {
88 pub(crate) fn name(&self) -> &str {
89 &self.name
90 }
91
92 pub(crate) fn benchmark(&self) -> &dyn benchmark::internal::Benchmark {
93 &*self.benchmark
94 }
95}
96
97pub struct Benchmarks {
99 benchmarks: Vec<RegisteredBenchmark>,
100}
101
102impl Benchmarks {
103 pub fn new() -> Self {
105 Self {
106 benchmarks: Vec::new(),
107 }
108 }
109
110 pub fn register<T>(&mut self, name: impl Into<String>, benchmark: T)
112 where
113 T: Benchmark,
114 {
115 self.benchmarks.push(RegisteredBenchmark {
116 name: name.into(),
117 benchmark: Box::new(benchmark::internal::Wrapper::<T, _>::new(
118 benchmark,
119 benchmark::internal::NoRegression,
120 )),
121 });
122 }
123
124 pub(crate) fn names(&self) -> impl ExactSizeIterator<Item = (&str, String)> {
126 self.benchmarks.iter().map(|entry| {
127 (
128 entry.name.as_str(),
129 Capture(&*entry.benchmark, None).to_string(),
130 )
131 })
132 }
133
134 pub fn has_match(&self, job: &Any) -> bool {
136 self.find_best_match(job).is_some()
137 }
138
139 pub fn call(
145 &self,
146 job: &Any,
147 checkpoint: Checkpoint<'_>,
148 output: &mut dyn Output,
149 ) -> anyhow::Result<serde_json::Value> {
150 match self.find_best_match(job) {
151 Some(entry) => entry.benchmark.run(job, checkpoint, output),
152 None => Err(anyhow::Error::msg(
153 "could not find a matching benchmark for the given input",
154 )),
155 }
156 }
157
158 pub fn debug(&self, job: &Any, max_methods: usize) -> Result<(), Vec<Mismatch>> {
163 if self.has_match(job) {
164 return Ok(());
165 }
166
167 let mut failures: Vec<(&RegisteredBenchmark, FailureScore)> = self
169 .benchmarks
170 .iter()
171 .filter_map(|entry| match entry.benchmark.try_match(job) {
172 Ok(_) => None,
173 Err(score) => Some((entry, score)),
174 })
175 .collect();
176
177 failures.sort_by_key(|(_, score)| *score);
178 failures.truncate(max_methods);
179
180 let mismatches = failures
181 .into_iter()
182 .map(|(entry, _)| {
183 let reason = Capture(&*entry.benchmark, Some(job)).to_string();
184
185 Mismatch {
186 method: entry.name.clone(),
187 reason,
188 }
189 })
190 .collect();
191
192 Err(mismatches)
193 }
194
195 fn find_best_match(&self, job: &Any) -> Option<&RegisteredBenchmark> {
197 self.benchmarks
198 .iter()
199 .filter_map(|entry| {
200 entry
201 .benchmark
202 .try_match(job)
203 .ok()
204 .map(|score| (entry, score))
205 })
206 .min_by_key(|(_, score)| *score)
207 .map(|(entry, _)| entry)
208 }
209
210 pub fn register_regression<T>(&mut self, name: impl Into<String>, benchmark: T)
219 where
220 T: Regression,
221 {
222 let registered = benchmark::internal::Wrapper::<T, _>::new(
223 benchmark,
224 benchmark::internal::WithRegression,
225 );
226 self.benchmarks.push(RegisteredBenchmark {
227 name: name.into(),
228 benchmark: Box::new(registered),
229 });
230 }
231
232 pub(crate) fn tolerances(&self) -> HashMap<&'static str, RegisteredTolerance<'_>> {
235 let mut tolerances = HashMap::<&'static str, RegisteredTolerance<'_>>::new();
236 for b in self.benchmarks.iter() {
237 if let Some(regression) = b.benchmark.as_regression() {
238 let t = regression.tolerance();
243 let packaged = RegressionBenchmark {
244 benchmark: b,
245 regression,
246 };
247
248 match tolerances.entry(t.tag()) {
249 Entry::Occupied(occupied) => occupied.into_mut().regressions.push(packaged),
250 Entry::Vacant(vacant) => {
251 vacant.insert(RegisteredTolerance {
252 tolerance: input::Registered(t),
253 regressions: vec![packaged],
254 });
255 }
256 }
257 }
258 }
259
260 tolerances
261 }
262}
263
264impl Default for Benchmarks {
265 fn default() -> Self {
266 Self::new()
267 }
268}
269
270pub struct Mismatch {
272 method: String,
273 reason: String,
274}
275
276impl Mismatch {
277 pub fn method(&self) -> &str {
279 &self.method
280 }
281
282 pub fn reason(&self) -> &str {
284 &self.reason
285 }
286}
287
288#[derive(Debug, Clone, Copy)]
293pub(crate) struct RegressionBenchmark<'a> {
294 benchmark: &'a RegisteredBenchmark,
295 regression: &'a dyn benchmark::internal::Regression,
296}
297
298impl RegressionBenchmark<'_> {
299 pub(crate) fn name(&self) -> &str {
300 self.benchmark.name()
301 }
302
303 pub(crate) fn input_tag(&self) -> &'static str {
304 self.regression.input_tag()
305 }
306
307 pub(crate) fn try_match(&self, input: &Any) -> Result<MatchScore, FailureScore> {
308 self.benchmark.benchmark().try_match(input)
309 }
310
311 pub(crate) fn check(
312 &self,
313 tolerance: &Any,
314 input: &Any,
315 before: &serde_json::Value,
316 after: &serde_json::Value,
317 ) -> anyhow::Result<benchmark::internal::CheckedPassFail> {
318 self.regression.check(tolerance, input, before, after)
319 }
320}
321
322#[derive(Debug)]
323pub(crate) struct RegisteredTolerance<'a> {
324 pub(crate) tolerance: input::Registered<'a>,
326
327 pub(crate) regressions: Vec<RegressionBenchmark<'a>>,
330}
331
332struct Capture<'a>(&'a dyn benchmark::internal::Benchmark, Option<&'a Any>);
334
335impl std::fmt::Display for Capture<'_> {
336 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337 self.0.description(f, self.1)
338 }
339}
340
341impl std::fmt::Debug for Capture<'_> {
342 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
343 self.0.description(f, self.1)
344 }
345}