diskann_benchmark_runner/
registry.rs1use std::collections::{hash_map::Entry, HashMap};
7
8use thiserror::Error;
9
10use crate::{
11 benchmark::{self, Benchmark, Regression},
12 dispatcher::{FailureScore, MatchScore},
13 input, Any, Checkpoint, Input, Output,
14};
15
16pub struct Inputs {
18 inputs: HashMap<&'static str, Box<dyn input::DynInput>>,
20}
21
22impl Inputs {
23 pub fn new() -> Self {
25 Self {
26 inputs: HashMap::new(),
27 }
28 }
29
30 pub fn get(&self, tag: &str) -> Option<input::Registered<'_>> {
32 self.inputs.get(tag).map(|v| input::Registered(&**v))
33 }
34
35 pub fn register<T>(&mut self) -> anyhow::Result<()>
40 where
41 T: Input + 'static,
42 {
43 let tag = T::tag();
44 match self.inputs.entry(tag) {
45 Entry::Vacant(entry) => {
46 entry.insert(Box::new(crate::input::Wrapper::<T>::new()));
47 Ok(())
48 }
49 Entry::Occupied(_) => {
50 #[derive(Debug, Error)]
51 #[error("An input with the tag \"{}\" already exists", self.0)]
52 struct AlreadyExists(&'static str);
53
54 Err(anyhow::anyhow!(AlreadyExists(tag)))
55 }
56 }
57 }
58
59 pub fn tags(&self) -> impl ExactSizeIterator<Item = &'static str> + use<'_> {
61 self.inputs.keys().copied()
62 }
63}
64
65impl Default for Inputs {
66 fn default() -> Self {
67 Self::new()
68 }
69}
70
71pub(crate) struct RegisteredBenchmark {
73 name: String,
74 benchmark: Box<dyn benchmark::internal::Benchmark>,
75}
76
77impl std::fmt::Debug for RegisteredBenchmark {
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 let benchmark = Capture(&*self.benchmark, None);
80 f.debug_struct("RegisteredBenchmark")
81 .field("name", &self.name)
82 .field("benchmark", &benchmark)
83 .finish()
84 }
85}
86
87impl RegisteredBenchmark {
88 pub(crate) fn name(&self) -> &str {
89 &self.name
90 }
91
92 pub(crate) fn benchmark(&self) -> &dyn benchmark::internal::Benchmark {
93 &*self.benchmark
94 }
95}
96
97pub struct Benchmarks {
99 benchmarks: Vec<RegisteredBenchmark>,
100}
101
102impl Benchmarks {
103 pub fn new() -> Self {
105 Self {
106 benchmarks: Vec::new(),
107 }
108 }
109
110 pub fn register<T>(&mut self, name: impl Into<String>)
112 where
113 T: Benchmark + 'static,
114 {
115 self.benchmarks.push(RegisteredBenchmark {
116 name: name.into(),
117 benchmark: Box::new(benchmark::internal::Wrapper::<T>::new()),
118 });
119 }
120
121 pub(crate) fn names(&self) -> impl ExactSizeIterator<Item = (&str, String)> {
123 self.benchmarks.iter().map(|entry| {
124 (
125 entry.name.as_str(),
126 Capture(&*entry.benchmark, None).to_string(),
127 )
128 })
129 }
130
131 pub fn has_match(&self, job: &Any) -> bool {
133 self.find_best_match(job).is_some()
134 }
135
136 pub fn call(
142 &self,
143 job: &Any,
144 checkpoint: Checkpoint<'_>,
145 output: &mut dyn Output,
146 ) -> anyhow::Result<serde_json::Value> {
147 match self.find_best_match(job) {
148 Some(entry) => entry.benchmark.run(job, checkpoint, output),
149 None => Err(anyhow::Error::msg(
150 "could not find a matching benchmark for the given input",
151 )),
152 }
153 }
154
155 pub fn debug(&self, job: &Any, max_methods: usize) -> Result<(), Vec<Mismatch>> {
160 if self.has_match(job) {
161 return Ok(());
162 }
163
164 let mut failures: Vec<(&RegisteredBenchmark, FailureScore)> = self
166 .benchmarks
167 .iter()
168 .filter_map(|entry| match entry.benchmark.try_match(job) {
169 Ok(_) => None,
170 Err(score) => Some((entry, score)),
171 })
172 .collect();
173
174 failures.sort_by_key(|(_, score)| *score);
175 failures.truncate(max_methods);
176
177 let mismatches = failures
178 .into_iter()
179 .map(|(entry, _)| {
180 let reason = Capture(&*entry.benchmark, Some(job)).to_string();
181
182 Mismatch {
183 method: entry.name.clone(),
184 reason,
185 }
186 })
187 .collect();
188
189 Err(mismatches)
190 }
191
192 fn find_best_match(&self, job: &Any) -> Option<&RegisteredBenchmark> {
194 self.benchmarks
195 .iter()
196 .filter_map(|entry| {
197 entry
198 .benchmark
199 .try_match(job)
200 .ok()
201 .map(|score| (entry, score))
202 })
203 .min_by_key(|(_, score)| *score)
204 .map(|(entry, _)| entry)
205 }
206
207 pub fn register_regression<T>(&mut self, name: impl Into<String>)
216 where
217 T: Regression + 'static,
218 {
219 let registered = benchmark::internal::Wrapper::<T, _>::new_with(
220 benchmark::internal::WithRegression::<T>::new(),
221 );
222 self.benchmarks.push(RegisteredBenchmark {
223 name: name.into(),
224 benchmark: Box::new(registered),
225 });
226 }
227
228 pub(crate) fn tolerances(&self) -> HashMap<&'static str, RegisteredTolerance<'_>> {
231 let mut tolerances = HashMap::<&'static str, RegisteredTolerance<'_>>::new();
232 for b in self.benchmarks.iter() {
233 if let Some(regression) = b.benchmark.as_regression() {
234 let t = regression.tolerance();
239 let packaged = RegressionBenchmark {
240 benchmark: b,
241 regression,
242 };
243
244 match tolerances.entry(t.tag()) {
245 Entry::Occupied(occupied) => occupied.into_mut().regressions.push(packaged),
246 Entry::Vacant(vacant) => {
247 vacant.insert(RegisteredTolerance {
248 tolerance: input::Registered(t),
249 regressions: vec![packaged],
250 });
251 }
252 }
253 }
254 }
255
256 tolerances
257 }
258}
259
260impl Default for Benchmarks {
261 fn default() -> Self {
262 Self::new()
263 }
264}
265
266pub struct Mismatch {
268 method: String,
269 reason: String,
270}
271
272impl Mismatch {
273 pub fn method(&self) -> &str {
275 &self.method
276 }
277
278 pub fn reason(&self) -> &str {
280 &self.reason
281 }
282}
283
284#[derive(Debug, Clone, Copy)]
289pub(crate) struct RegressionBenchmark<'a> {
290 benchmark: &'a RegisteredBenchmark,
291 regression: &'a dyn benchmark::internal::Regression,
292}
293
294impl RegressionBenchmark<'_> {
295 pub(crate) fn name(&self) -> &str {
296 self.benchmark.name()
297 }
298
299 pub(crate) fn input_tag(&self) -> &'static str {
300 self.regression.input_tag()
301 }
302
303 pub(crate) fn try_match(&self, input: &Any) -> Result<MatchScore, FailureScore> {
304 self.benchmark.benchmark().try_match(input)
305 }
306
307 pub(crate) fn check(
308 &self,
309 tolerance: &Any,
310 input: &Any,
311 before: &serde_json::Value,
312 after: &serde_json::Value,
313 ) -> anyhow::Result<benchmark::internal::CheckedPassFail> {
314 self.regression.check(tolerance, input, before, after)
315 }
316}
317
318#[derive(Debug)]
319pub(crate) struct RegisteredTolerance<'a> {
320 pub(crate) tolerance: input::Registered<'a>,
322
323 pub(crate) regressions: Vec<RegressionBenchmark<'a>>,
326}
327
328struct Capture<'a>(&'a dyn benchmark::internal::Benchmark, Option<&'a Any>);
330
331impl std::fmt::Display for Capture<'_> {
332 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
333 self.0.description(f, self.1)
334 }
335}
336
337impl std::fmt::Debug for Capture<'_> {
338 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
339 self.0.description(f, self.1)
340 }
341}