use serde::{Deserialize, Serialize};
use crate::{
dispatcher::{FailureScore, MatchScore},
Any, Checkpoint, Input, Output,
};
pub trait Benchmark: 'static {
type Input: Input + 'static;
type Output: Serialize;
fn try_match(&self, input: &Self::Input) -> Result<MatchScore, FailureScore>;
fn description(
&self,
f: &mut std::fmt::Formatter<'_>,
input: Option<&Self::Input>,
) -> std::fmt::Result;
fn run(
&self,
input: &Self::Input,
checkpoint: Checkpoint<'_>,
output: &mut dyn Output,
) -> anyhow::Result<Self::Output>;
}
pub trait Regression: Benchmark<Output: for<'a> Deserialize<'a>> {
type Tolerances: Input + 'static;
type Pass: Serialize + std::fmt::Display + 'static;
type Fail: Serialize + std::fmt::Display + 'static;
fn check(
&self,
tolerances: &Self::Tolerances,
input: &Self::Input,
before: &Self::Output,
after: &Self::Output,
) -> anyhow::Result<PassFail<Self::Pass, Self::Fail>>;
}
#[derive(Debug, Clone, Copy)]
pub enum PassFail<P, F> {
Pass(P),
Fail(F),
}
pub(crate) mod internal {
use super::*;
use anyhow::Context;
use thiserror::Error;
pub(crate) trait Benchmark {
fn try_match(&self, input: &Any) -> Result<MatchScore, FailureScore>;
fn description(
&self,
f: &mut std::fmt::Formatter<'_>,
input: Option<&Any>,
) -> std::fmt::Result;
fn run(
&self,
input: &Any,
checkpoint: Checkpoint<'_>,
output: &mut dyn Output,
) -> anyhow::Result<serde_json::Value>;
fn as_regression(&self) -> Option<&dyn Regression>;
}
pub(crate) struct Checked {
pub(crate) json: serde_json::Value,
pub(crate) display: Box<dyn std::fmt::Display>,
}
impl Checked {
fn new<T>(value: T) -> Result<Self, serde_json::Error>
where
T: Serialize + std::fmt::Display + 'static,
{
Ok(Self {
json: serde_json::to_value(&value)?,
display: Box::new(value),
})
}
}
pub(crate) type CheckedPassFail = PassFail<Checked, Checked>;
pub(crate) trait Regression {
fn tolerance(&self) -> &dyn crate::input::DynInput;
fn input_tag(&self) -> &'static str;
fn check(
&self,
tolerances: &Any,
input: &Any,
before: &serde_json::Value,
after: &serde_json::Value,
) -> anyhow::Result<CheckedPassFail>;
}
impl std::fmt::Debug for dyn Regression + '_ {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("dyn Regression")
.field("tolerance", &self.tolerance().tag())
.field("input_tag", &self.input_tag())
.finish()
}
}
pub(crate) trait AsRegression<T> {
fn as_regression(benchmark: &T) -> Option<&dyn Regression>;
}
#[derive(Debug, Clone, Copy)]
pub(crate) struct NoRegression;
impl<T> AsRegression<T> for NoRegression {
fn as_regression(_benchmark: &T) -> Option<&dyn Regression> {
None
}
}
#[derive(Debug, Clone, Copy)]
pub(crate) struct WithRegression;
impl<T> AsRegression<T> for WithRegression
where
T: super::Regression,
{
fn as_regression(benchmark: &T) -> Option<&dyn Regression> {
Some(benchmark)
}
}
impl<T> Regression for T
where
T: super::Regression,
{
fn tolerance(&self) -> &dyn crate::input::DynInput {
&crate::input::Wrapper::<T::Tolerances>::INSTANCE
}
fn input_tag(&self) -> &'static str {
T::Input::tag()
}
fn check(
&self,
tolerance: &Any,
input: &Any,
before: &serde_json::Value,
after: &serde_json::Value,
) -> anyhow::Result<CheckedPassFail> {
let tolerance = tolerance
.downcast_ref::<T::Tolerances>()
.ok_or_else(|| BadDownCast::new(T::Tolerances::tag(), tolerance.tag()))
.context("failed to obtain tolerance")?;
let input = input
.downcast_ref::<T::Input>()
.ok_or_else(|| BadDownCast::new(T::Input::tag(), input.tag()))
.context("failed to obtain input")?;
let before = T::Output::deserialize(before)
.map_err(|err| DeserializationError::new(Kind::Before, err))?;
let after = T::Output::deserialize(after)
.map_err(|err| DeserializationError::new(Kind::After, err))?;
let passfail = match self.check(tolerance, input, &before, &after)? {
PassFail::Pass(pass) => PassFail::Pass(Checked::new(pass)?),
PassFail::Fail(fail) => PassFail::Fail(Checked::new(fail)?),
};
Ok(passfail)
}
}
#[derive(Debug, Clone, Copy)]
pub(crate) struct Wrapper<T, R = NoRegression> {
benchmark: T,
_regression: R,
}
impl<T, R> Wrapper<T, R> {
pub(crate) const fn new(benchmark: T, regression: R) -> Self {
Self {
benchmark,
_regression: regression,
}
}
}
const MATCH_FAIL: FailureScore = FailureScore(10_000);
impl<T, R> Benchmark for Wrapper<T, R>
where
T: super::Benchmark,
R: AsRegression<T>,
{
fn try_match(&self, input: &Any) -> Result<MatchScore, FailureScore> {
if let Some(cast) = input.downcast_ref::<T::Input>() {
self.benchmark.try_match(cast)
} else {
Err(MATCH_FAIL)
}
}
fn description(
&self,
f: &mut std::fmt::Formatter<'_>,
input: Option<&Any>,
) -> std::fmt::Result {
match input {
Some(input) => match input.downcast_ref::<T::Input>() {
Some(cast) => self.benchmark.description(f, Some(cast)),
None => write!(
f,
"expected tag \"{}\" - instead got \"{}\"",
T::Input::tag(),
input.tag(),
),
},
None => {
writeln!(f, "tag \"{}\"", <T::Input as Input>::tag())?;
self.benchmark.description(f, None)
}
}
}
fn run(
&self,
input: &Any,
checkpoint: Checkpoint<'_>,
output: &mut dyn Output,
) -> anyhow::Result<serde_json::Value> {
match input.downcast_ref::<T::Input>() {
Some(input) => {
let result = self.benchmark.run(input, checkpoint, output)?;
Ok(serde_json::to_value(result)?)
}
None => Err(BadDownCast::new(T::Input::tag(), input.tag()).into()),
}
}
fn as_regression(&self) -> Option<&dyn Regression> {
R::as_regression(&self.benchmark)
}
}
#[derive(Debug, Clone, Copy, Error)]
#[error(
"INTERNAL ERROR: bad downcast - expected \"{}\" but got \"{}\"",
self.expected,
self.got
)]
struct BadDownCast {
expected: &'static str,
got: &'static str,
}
impl BadDownCast {
fn new(expected: &'static str, got: &'static str) -> Self {
Self { expected, got }
}
}
#[derive(Debug, Error)]
#[error(
"the \"{}\" results do not match the output schema expected by this benchmark",
self.kind
)]
struct DeserializationError {
kind: Kind,
source: serde_json::Error,
}
impl DeserializationError {
fn new(kind: Kind, source: serde_json::Error) -> Self {
Self { kind, source }
}
}
#[derive(Debug, Clone, Copy)]
enum Kind {
Before,
After,
}
impl std::fmt::Display for Kind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let as_str = match self {
Self::Before => "before",
Self::After => "after",
};
write!(f, "{}", as_str)
}
}
}