diskann_benchmark_runner/
benchmark.rs1use serde::{Deserialize, Serialize};
7
8use crate::{
9 dispatcher::{FailureScore, MatchScore},
10 Any, Checkpoint, Input, Output,
11};
12
13pub trait Benchmark {
19 type Input: Input + 'static;
21
22 type Output: Serialize;
24
25 fn try_match(input: &Self::Input) -> Result<MatchScore, FailureScore>;
36
37 fn description(
43 f: &mut std::fmt::Formatter<'_>,
44 input: Option<&Self::Input>,
45 ) -> std::fmt::Result;
46
47 fn run(
55 input: &Self::Input,
56 checkpoint: Checkpoint<'_>,
57 output: &mut dyn Output,
58 ) -> anyhow::Result<Self::Output>;
59}
60
61pub trait Regression: Benchmark<Output: for<'a> Deserialize<'a>> {
71 type Tolerances: Input + 'static;
73
74 type Pass: Serialize + std::fmt::Display + 'static;
76
77 type Fail: Serialize + std::fmt::Display + 'static;
79
80 fn check(
91 tolerances: &Self::Tolerances,
92 input: &Self::Input,
93 before: &Self::Output,
94 after: &Self::Output,
95 ) -> anyhow::Result<PassFail<Self::Pass, Self::Fail>>;
96}
97
98#[derive(Debug, Clone, Copy)]
100pub enum PassFail<P, F> {
101 Pass(P),
102 Fail(F),
103}
104
105pub(crate) mod internal {
110 use super::*;
111
112 use std::marker::PhantomData;
113
114 use anyhow::Context;
115 use thiserror::Error;
116
117 pub(crate) trait Benchmark {
119 fn try_match(&self, input: &Any) -> Result<MatchScore, FailureScore>;
120
121 fn description(
122 &self,
123 f: &mut std::fmt::Formatter<'_>,
124 input: Option<&Any>,
125 ) -> std::fmt::Result;
126
127 fn run(
128 &self,
129 input: &Any,
130 checkpoint: Checkpoint<'_>,
131 output: &mut dyn Output,
132 ) -> anyhow::Result<serde_json::Value>;
133
134 fn as_regression(&self) -> Option<&dyn Regression>;
136 }
137
138 pub(crate) struct Checked {
139 pub(crate) json: serde_json::Value,
140 pub(crate) display: Box<dyn std::fmt::Display>,
141 }
142
143 impl Checked {
144 fn new<T>(value: T) -> Result<Self, serde_json::Error>
146 where
147 T: Serialize + std::fmt::Display + 'static,
148 {
149 Ok(Self {
150 json: serde_json::to_value(&value)?,
151 display: Box::new(value),
152 })
153 }
154 }
155
156 pub(crate) type CheckedPassFail = PassFail<Checked, Checked>;
157
158 pub(crate) trait Regression {
159 fn tolerance(&self) -> &dyn crate::input::DynInput;
160 fn input_tag(&self) -> &'static str;
161 fn check(
162 &self,
163 tolerances: &Any,
164 input: &Any,
165 before: &serde_json::Value,
166 after: &serde_json::Value,
167 ) -> anyhow::Result<CheckedPassFail>;
168 }
169
170 impl std::fmt::Debug for dyn Regression + '_ {
171 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172 f.debug_struct("dyn Regression")
173 .field("tolerance", &self.tolerance().tag())
174 .field("input_tag", &self.input_tag())
175 .finish()
176 }
177 }
178
179 pub(crate) trait AsRegression {
180 fn as_regression(&self) -> Option<&dyn Regression>;
181 }
182
183 #[derive(Debug, Clone)]
184 pub(crate) struct NoRegression;
185
186 impl AsRegression for NoRegression {
187 fn as_regression(&self) -> Option<&dyn Regression> {
188 None
189 }
190 }
191
192 #[derive(Debug, Clone, Copy)]
193 pub(crate) struct WithRegression<T>(PhantomData<T>);
194
195 impl<T> WithRegression<T> {
196 pub(crate) const fn new() -> Self {
197 Self(PhantomData)
198 }
199 }
200
201 impl<T> AsRegression for WithRegression<T>
202 where
203 T: super::Regression,
204 {
205 fn as_regression(&self) -> Option<&dyn Regression> {
206 Some(self)
207 }
208 }
209
210 impl<T> Regression for WithRegression<T>
211 where
212 T: super::Regression,
213 {
214 fn tolerance(&self) -> &dyn crate::input::DynInput {
215 &crate::input::Wrapper::<T::Tolerances>::INSTANCE
216 }
217
218 fn input_tag(&self) -> &'static str {
219 T::Input::tag()
220 }
221
222 fn check(
223 &self,
224 tolerance: &Any,
225 input: &Any,
226 before: &serde_json::Value,
227 after: &serde_json::Value,
228 ) -> anyhow::Result<CheckedPassFail> {
229 let tolerance = tolerance
230 .downcast_ref::<T::Tolerances>()
231 .ok_or_else(|| BadDownCast::new(T::Tolerances::tag(), tolerance.tag()))
232 .context("failed to obtain tolerance")?;
233
234 let input = input
235 .downcast_ref::<T::Input>()
236 .ok_or_else(|| BadDownCast::new(T::Input::tag(), input.tag()))
237 .context("failed to obtain input")?;
238
239 let before = T::Output::deserialize(before)
240 .map_err(|err| DeserializationError::new(Kind::Before, err))?;
241
242 let after = T::Output::deserialize(after)
243 .map_err(|err| DeserializationError::new(Kind::After, err))?;
244
245 let passfail = match T::check(tolerance, input, &before, &after)? {
246 PassFail::Pass(pass) => PassFail::Pass(Checked::new(pass)?),
247 PassFail::Fail(fail) => PassFail::Fail(Checked::new(fail)?),
248 };
249
250 Ok(passfail)
251 }
252 }
253
254 #[derive(Debug, Clone, Copy)]
255 pub(crate) struct Wrapper<T, R = NoRegression> {
256 regression: R,
257 _type: PhantomData<T>,
258 }
259
260 impl<T> Wrapper<T, NoRegression> {
261 pub(crate) const fn new() -> Self {
262 Self::new_with(NoRegression)
263 }
264 }
265
266 impl<T, R> Wrapper<T, R> {
267 pub(crate) const fn new_with(regression: R) -> Self {
268 Self {
269 regression,
270 _type: PhantomData,
271 }
272 }
273 }
274
275 const MATCH_FAIL: FailureScore = FailureScore(10_000);
277
278 impl<T, R> Benchmark for Wrapper<T, R>
279 where
280 T: super::Benchmark,
281 R: AsRegression,
282 {
283 fn try_match(&self, input: &Any) -> Result<MatchScore, FailureScore> {
284 if let Some(cast) = input.downcast_ref::<T::Input>() {
285 T::try_match(cast)
286 } else {
287 Err(MATCH_FAIL)
288 }
289 }
290
291 fn description(
292 &self,
293 f: &mut std::fmt::Formatter<'_>,
294 input: Option<&Any>,
295 ) -> std::fmt::Result {
296 match input {
297 Some(input) => match input.downcast_ref::<T::Input>() {
298 Some(cast) => T::description(f, Some(cast)),
299 None => write!(
300 f,
301 "expected tag \"{}\" - instead got \"{}\"",
302 T::Input::tag(),
303 input.tag(),
304 ),
305 },
306 None => {
307 writeln!(f, "tag \"{}\"", <T::Input as Input>::tag())?;
308 T::description(f, None)
309 }
310 }
311 }
312
313 fn run(
314 &self,
315 input: &Any,
316 checkpoint: Checkpoint<'_>,
317 output: &mut dyn Output,
318 ) -> anyhow::Result<serde_json::Value> {
319 match input.downcast_ref::<T::Input>() {
320 Some(input) => {
321 let result = T::run(input, checkpoint, output)?;
322 Ok(serde_json::to_value(result)?)
323 }
324 None => Err(BadDownCast::new(T::Input::tag(), input.tag()).into()),
325 }
326 }
327
328 fn as_regression(&self) -> Option<&dyn Regression> {
330 self.regression.as_regression()
331 }
332 }
333
334 #[derive(Debug, Clone, Copy, Error)]
339 #[error(
340 "INTERNAL ERROR: bad downcast - expected \"{}\" but got \"{}\"",
341 self.expected,
342 self.got
343 )]
344 struct BadDownCast {
345 expected: &'static str,
346 got: &'static str,
347 }
348
349 impl BadDownCast {
350 fn new(expected: &'static str, got: &'static str) -> Self {
351 Self { expected, got }
352 }
353 }
354
355 #[derive(Debug, Error)]
356 #[error(
357 "the \"{}\" results do not match the output schema expected by this benchmark",
358 self.kind
359 )]
360 struct DeserializationError {
361 kind: Kind,
362 source: serde_json::Error,
363 }
364
365 impl DeserializationError {
366 fn new(kind: Kind, source: serde_json::Error) -> Self {
367 Self { kind, source }
368 }
369 }
370
371 #[derive(Debug, Clone, Copy)]
372 enum Kind {
373 Before,
374 After,
375 }
376
377 impl std::fmt::Display for Kind {
378 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
379 let as_str = match self {
380 Self::Before => "before",
381 Self::After => "after",
382 };
383
384 write!(f, "{}", as_str)
385 }
386 }
387}