diskann_benchmark_runner/
benchmark.rs1use serde::{Deserialize, Serialize};
7
8use crate::{Checkpoint, Input, Output};
9
10pub trait Benchmark: 'static {
16 type Input: Input + 'static;
18
19 type Output: Serialize;
21
22 fn try_match(&self, input: &Self::Input) -> Result<MatchScore, FailureScore>;
33
34 fn description(
40 &self,
41 f: &mut std::fmt::Formatter<'_>,
42 input: Option<&Self::Input>,
43 ) -> std::fmt::Result;
44
45 fn run(
53 &self,
54 input: &Self::Input,
55 checkpoint: Checkpoint<'_>,
56 output: &mut dyn Output,
57 ) -> anyhow::Result<Self::Output>;
58}
59
60#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
64pub struct MatchScore(pub u32);
65
66impl std::fmt::Display for MatchScore {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 write!(f, "success ({})", self.0)
69 }
70}
71
72#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
77pub struct FailureScore(pub u32);
78
79impl std::fmt::Display for FailureScore {
80 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 write!(f, "fail ({})", self.0)
82 }
83}
84
85pub trait Regression: Benchmark<Output: for<'a> Deserialize<'a>> {
95 type Tolerances: Input + 'static;
97
98 type Pass: Serialize + std::fmt::Display + 'static;
100
101 type Fail: Serialize + std::fmt::Display + 'static;
103
104 fn check(
115 &self,
116 tolerances: &Self::Tolerances,
117 input: &Self::Input,
118 before: &Self::Output,
119 after: &Self::Output,
120 ) -> anyhow::Result<PassFail<Self::Pass, Self::Fail>>;
121}
122
123#[derive(Debug, Clone, Copy)]
125pub enum PassFail<P, F> {
126 Pass(P),
127 Fail(F),
128}
129
130pub(crate) mod internal {
135 use super::*;
136
137 use crate::input::internal::Any;
138
139 use anyhow::Context;
140 use thiserror::Error;
141
142 pub(crate) trait Benchmark {
144 fn try_match(&self, input: &Any) -> Result<MatchScore, FailureScore>;
145
146 fn description(
147 &self,
148 f: &mut std::fmt::Formatter<'_>,
149 input: Option<&Any>,
150 ) -> std::fmt::Result;
151
152 fn run(
153 &self,
154 input: &Any,
155 checkpoint: Checkpoint<'_>,
156 output: &mut dyn Output,
157 ) -> anyhow::Result<serde_json::Value>;
158
159 fn as_regression(&self) -> Option<&dyn Regression>;
161 }
162
163 pub(crate) struct Checked {
164 pub(crate) json: serde_json::Value,
165 pub(crate) display: Box<dyn std::fmt::Display>,
166 }
167
168 impl Checked {
169 fn new<T>(value: T) -> Result<Self, serde_json::Error>
171 where
172 T: Serialize + std::fmt::Display + 'static,
173 {
174 Ok(Self {
175 json: serde_json::to_value(&value)?,
176 display: Box::new(value),
177 })
178 }
179 }
180
181 pub(crate) type CheckedPassFail = PassFail<Checked, Checked>;
182
183 pub(crate) trait Regression {
184 fn tolerance(&self) -> &dyn crate::input::internal::DynInput;
185 fn input_tag(&self) -> &'static str;
186 fn check(
187 &self,
188 tolerances: &Any,
189 input: &Any,
190 before: &serde_json::Value,
191 after: &serde_json::Value,
192 ) -> anyhow::Result<CheckedPassFail>;
193 }
194
195 impl std::fmt::Debug for dyn Regression + '_ {
196 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197 f.debug_struct("dyn Regression")
198 .field("tolerance", &self.tolerance().tag())
199 .field("input_tag", &self.input_tag())
200 .finish()
201 }
202 }
203
204 pub(crate) trait AsRegression<T> {
205 fn as_regression(benchmark: &T) -> Option<&dyn Regression>;
206 }
207
208 #[derive(Debug, Clone, Copy)]
209 pub(crate) struct NoRegression;
210
211 impl<T> AsRegression<T> for NoRegression {
212 fn as_regression(_benchmark: &T) -> Option<&dyn Regression> {
213 None
214 }
215 }
216
217 #[derive(Debug, Clone, Copy)]
218 pub(crate) struct WithRegression;
219
220 impl<T> AsRegression<T> for WithRegression
221 where
222 T: super::Regression,
223 {
224 fn as_regression(benchmark: &T) -> Option<&dyn Regression> {
225 Some(benchmark)
226 }
227 }
228
229 impl<T> Regression for T
230 where
231 T: super::Regression,
232 {
233 fn tolerance(&self) -> &dyn crate::input::internal::DynInput {
234 &crate::input::internal::Wrapper::<T::Tolerances>::INSTANCE
235 }
236
237 fn input_tag(&self) -> &'static str {
238 T::Input::tag()
239 }
240
241 fn check(
242 &self,
243 tolerance: &Any,
244 input: &Any,
245 before: &serde_json::Value,
246 after: &serde_json::Value,
247 ) -> anyhow::Result<CheckedPassFail> {
248 let tolerance = tolerance
249 .downcast_ref::<T::Tolerances>()
250 .ok_or_else(|| BadDownCast::new(T::Tolerances::tag(), tolerance.tag()))
251 .context("failed to obtain tolerance")?;
252
253 let input = input
254 .downcast_ref::<T::Input>()
255 .ok_or_else(|| BadDownCast::new(T::Input::tag(), input.tag()))
256 .context("failed to obtain input")?;
257
258 let before = T::Output::deserialize(before)
259 .map_err(|err| DeserializationError::new(Kind::Before, err))?;
260
261 let after = T::Output::deserialize(after)
262 .map_err(|err| DeserializationError::new(Kind::After, err))?;
263
264 let passfail = match self.check(tolerance, input, &before, &after)? {
265 PassFail::Pass(pass) => PassFail::Pass(Checked::new(pass)?),
266 PassFail::Fail(fail) => PassFail::Fail(Checked::new(fail)?),
267 };
268
269 Ok(passfail)
270 }
271 }
272
273 #[derive(Debug, Clone, Copy)]
274 pub(crate) struct Wrapper<T, R = NoRegression> {
275 benchmark: T,
276 _regression: R,
277 }
278
279 impl<T, R> Wrapper<T, R> {
280 pub(crate) const fn new(benchmark: T, regression: R) -> Self {
281 Self {
282 benchmark,
283 _regression: regression,
284 }
285 }
286 }
287
288 const MATCH_FAIL: FailureScore = FailureScore(10_000);
290
291 impl<T, R> Benchmark for Wrapper<T, R>
292 where
293 T: super::Benchmark,
294 R: AsRegression<T>,
295 {
296 fn try_match(&self, input: &Any) -> Result<MatchScore, FailureScore> {
297 if let Some(cast) = input.downcast_ref::<T::Input>() {
298 self.benchmark.try_match(cast)
299 } else {
300 Err(MATCH_FAIL)
301 }
302 }
303
304 fn description(
305 &self,
306 f: &mut std::fmt::Formatter<'_>,
307 input: Option<&Any>,
308 ) -> std::fmt::Result {
309 match input {
310 Some(input) => match input.downcast_ref::<T::Input>() {
311 Some(cast) => self.benchmark.description(f, Some(cast)),
312 None => write!(
313 f,
314 "expected tag \"{}\" - instead got \"{}\"",
315 T::Input::tag(),
316 input.tag(),
317 ),
318 },
319 None => {
320 writeln!(f, "tag \"{}\"", <T::Input as Input>::tag())?;
321 self.benchmark.description(f, None)
322 }
323 }
324 }
325
326 fn run(
327 &self,
328 input: &Any,
329 checkpoint: Checkpoint<'_>,
330 output: &mut dyn Output,
331 ) -> anyhow::Result<serde_json::Value> {
332 match input.downcast_ref::<T::Input>() {
333 Some(input) => {
334 let result = self.benchmark.run(input, checkpoint, output)?;
335 Ok(serde_json::to_value(result)?)
336 }
337 None => Err(BadDownCast::new(T::Input::tag(), input.tag()).into()),
338 }
339 }
340
341 fn as_regression(&self) -> Option<&dyn Regression> {
343 R::as_regression(&self.benchmark)
344 }
345 }
346
347 #[derive(Debug, Clone, Copy, Error)]
352 #[error(
353 "INTERNAL ERROR: bad downcast - expected \"{}\" but got \"{}\"",
354 self.expected,
355 self.got
356 )]
357 struct BadDownCast {
358 expected: &'static str,
359 got: &'static str,
360 }
361
362 impl BadDownCast {
363 fn new(expected: &'static str, got: &'static str) -> Self {
364 Self { expected, got }
365 }
366 }
367
368 #[derive(Debug, Error)]
369 #[error(
370 "the \"{}\" results do not match the output schema expected by this benchmark",
371 self.kind
372 )]
373 struct DeserializationError {
374 kind: Kind,
375 source: serde_json::Error,
376 }
377
378 impl DeserializationError {
379 fn new(kind: Kind, source: serde_json::Error) -> Self {
380 Self { kind, source }
381 }
382 }
383
384 #[derive(Debug, Clone, Copy)]
385 enum Kind {
386 Before,
387 After,
388 }
389
390 impl std::fmt::Display for Kind {
391 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
392 let as_str = match self {
393 Self::Before => "before",
394 Self::After => "after",
395 };
396
397 write!(f, "{}", as_str)
398 }
399 }
400}