diskann_benchmark_runner/
benchmark.rs1use serde::{Deserialize, Serialize};
7
8use crate::{
9 dispatcher::{FailureScore, MatchScore},
10 Any, Checkpoint, Input, Output,
11};
12
13pub trait Benchmark: 'static {
19 type Input: Input + 'static;
21
22 type Output: Serialize;
24
25 fn try_match(&self, input: &Self::Input) -> Result<MatchScore, FailureScore>;
36
37 fn description(
43 &self,
44 f: &mut std::fmt::Formatter<'_>,
45 input: Option<&Self::Input>,
46 ) -> std::fmt::Result;
47
48 fn run(
56 &self,
57 input: &Self::Input,
58 checkpoint: Checkpoint<'_>,
59 output: &mut dyn Output,
60 ) -> anyhow::Result<Self::Output>;
61}
62
63pub trait Regression: Benchmark<Output: for<'a> Deserialize<'a>> {
73 type Tolerances: Input + 'static;
75
76 type Pass: Serialize + std::fmt::Display + 'static;
78
79 type Fail: Serialize + std::fmt::Display + 'static;
81
82 fn check(
93 &self,
94 tolerances: &Self::Tolerances,
95 input: &Self::Input,
96 before: &Self::Output,
97 after: &Self::Output,
98 ) -> anyhow::Result<PassFail<Self::Pass, Self::Fail>>;
99}
100
101#[derive(Debug, Clone, Copy)]
103pub enum PassFail<P, F> {
104 Pass(P),
105 Fail(F),
106}
107
108pub(crate) mod internal {
113 use super::*;
114
115 use anyhow::Context;
116 use thiserror::Error;
117
118 pub(crate) trait Benchmark {
120 fn try_match(&self, input: &Any) -> Result<MatchScore, FailureScore>;
121
122 fn description(
123 &self,
124 f: &mut std::fmt::Formatter<'_>,
125 input: Option<&Any>,
126 ) -> std::fmt::Result;
127
128 fn run(
129 &self,
130 input: &Any,
131 checkpoint: Checkpoint<'_>,
132 output: &mut dyn Output,
133 ) -> anyhow::Result<serde_json::Value>;
134
135 fn as_regression(&self) -> Option<&dyn Regression>;
137 }
138
139 pub(crate) struct Checked {
140 pub(crate) json: serde_json::Value,
141 pub(crate) display: Box<dyn std::fmt::Display>,
142 }
143
144 impl Checked {
145 fn new<T>(value: T) -> Result<Self, serde_json::Error>
147 where
148 T: Serialize + std::fmt::Display + 'static,
149 {
150 Ok(Self {
151 json: serde_json::to_value(&value)?,
152 display: Box::new(value),
153 })
154 }
155 }
156
157 pub(crate) type CheckedPassFail = PassFail<Checked, Checked>;
158
159 pub(crate) trait Regression {
160 fn tolerance(&self) -> &dyn crate::input::DynInput;
161 fn input_tag(&self) -> &'static str;
162 fn check(
163 &self,
164 tolerances: &Any,
165 input: &Any,
166 before: &serde_json::Value,
167 after: &serde_json::Value,
168 ) -> anyhow::Result<CheckedPassFail>;
169 }
170
171 impl std::fmt::Debug for dyn Regression + '_ {
172 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173 f.debug_struct("dyn Regression")
174 .field("tolerance", &self.tolerance().tag())
175 .field("input_tag", &self.input_tag())
176 .finish()
177 }
178 }
179
180 pub(crate) trait AsRegression<T> {
181 fn as_regression(benchmark: &T) -> Option<&dyn Regression>;
182 }
183
184 #[derive(Debug, Clone, Copy)]
185 pub(crate) struct NoRegression;
186
187 impl<T> AsRegression<T> for NoRegression {
188 fn as_regression(_benchmark: &T) -> Option<&dyn Regression> {
189 None
190 }
191 }
192
193 #[derive(Debug, Clone, Copy)]
194 pub(crate) struct WithRegression;
195
196 impl<T> AsRegression<T> for WithRegression
197 where
198 T: super::Regression,
199 {
200 fn as_regression(benchmark: &T) -> Option<&dyn Regression> {
201 Some(benchmark)
202 }
203 }
204
205 impl<T> Regression for T
206 where
207 T: super::Regression,
208 {
209 fn tolerance(&self) -> &dyn crate::input::DynInput {
210 &crate::input::Wrapper::<T::Tolerances>::INSTANCE
211 }
212
213 fn input_tag(&self) -> &'static str {
214 T::Input::tag()
215 }
216
217 fn check(
218 &self,
219 tolerance: &Any,
220 input: &Any,
221 before: &serde_json::Value,
222 after: &serde_json::Value,
223 ) -> anyhow::Result<CheckedPassFail> {
224 let tolerance = tolerance
225 .downcast_ref::<T::Tolerances>()
226 .ok_or_else(|| BadDownCast::new(T::Tolerances::tag(), tolerance.tag()))
227 .context("failed to obtain tolerance")?;
228
229 let input = input
230 .downcast_ref::<T::Input>()
231 .ok_or_else(|| BadDownCast::new(T::Input::tag(), input.tag()))
232 .context("failed to obtain input")?;
233
234 let before = T::Output::deserialize(before)
235 .map_err(|err| DeserializationError::new(Kind::Before, err))?;
236
237 let after = T::Output::deserialize(after)
238 .map_err(|err| DeserializationError::new(Kind::After, err))?;
239
240 let passfail = match self.check(tolerance, input, &before, &after)? {
241 PassFail::Pass(pass) => PassFail::Pass(Checked::new(pass)?),
242 PassFail::Fail(fail) => PassFail::Fail(Checked::new(fail)?),
243 };
244
245 Ok(passfail)
246 }
247 }
248
249 #[derive(Debug, Clone, Copy)]
250 pub(crate) struct Wrapper<T, R = NoRegression> {
251 benchmark: T,
252 _regression: R,
253 }
254
255 impl<T, R> Wrapper<T, R> {
256 pub(crate) const fn new(benchmark: T, regression: R) -> Self {
257 Self {
258 benchmark,
259 _regression: regression,
260 }
261 }
262 }
263
264 const MATCH_FAIL: FailureScore = FailureScore(10_000);
266
267 impl<T, R> Benchmark for Wrapper<T, R>
268 where
269 T: super::Benchmark,
270 R: AsRegression<T>,
271 {
272 fn try_match(&self, input: &Any) -> Result<MatchScore, FailureScore> {
273 if let Some(cast) = input.downcast_ref::<T::Input>() {
274 self.benchmark.try_match(cast)
275 } else {
276 Err(MATCH_FAIL)
277 }
278 }
279
280 fn description(
281 &self,
282 f: &mut std::fmt::Formatter<'_>,
283 input: Option<&Any>,
284 ) -> std::fmt::Result {
285 match input {
286 Some(input) => match input.downcast_ref::<T::Input>() {
287 Some(cast) => self.benchmark.description(f, Some(cast)),
288 None => write!(
289 f,
290 "expected tag \"{}\" - instead got \"{}\"",
291 T::Input::tag(),
292 input.tag(),
293 ),
294 },
295 None => {
296 writeln!(f, "tag \"{}\"", <T::Input as Input>::tag())?;
297 self.benchmark.description(f, None)
298 }
299 }
300 }
301
302 fn run(
303 &self,
304 input: &Any,
305 checkpoint: Checkpoint<'_>,
306 output: &mut dyn Output,
307 ) -> anyhow::Result<serde_json::Value> {
308 match input.downcast_ref::<T::Input>() {
309 Some(input) => {
310 let result = self.benchmark.run(input, checkpoint, output)?;
311 Ok(serde_json::to_value(result)?)
312 }
313 None => Err(BadDownCast::new(T::Input::tag(), input.tag()).into()),
314 }
315 }
316
317 fn as_regression(&self) -> Option<&dyn Regression> {
319 R::as_regression(&self.benchmark)
320 }
321 }
322
323 #[derive(Debug, Clone, Copy, Error)]
328 #[error(
329 "INTERNAL ERROR: bad downcast - expected \"{}\" but got \"{}\"",
330 self.expected,
331 self.got
332 )]
333 struct BadDownCast {
334 expected: &'static str,
335 got: &'static str,
336 }
337
338 impl BadDownCast {
339 fn new(expected: &'static str, got: &'static str) -> Self {
340 Self { expected, got }
341 }
342 }
343
344 #[derive(Debug, Error)]
345 #[error(
346 "the \"{}\" results do not match the output schema expected by this benchmark",
347 self.kind
348 )]
349 struct DeserializationError {
350 kind: Kind,
351 source: serde_json::Error,
352 }
353
354 impl DeserializationError {
355 fn new(kind: Kind, source: serde_json::Error) -> Self {
356 Self { kind, source }
357 }
358 }
359
360 #[derive(Debug, Clone, Copy)]
361 enum Kind {
362 Before,
363 After,
364 }
365
366 impl std::fmt::Display for Kind {
367 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
368 let as_str = match self {
369 Self::Before => "before",
370 Self::After => "after",
371 };
372
373 write!(f, "{}", as_str)
374 }
375 }
376}