use std::io::Write;
use serde::{Deserialize, Serialize};
use crate::{
benchmark::{PassFail, Regression},
dispatcher::{Description, DispatchRule, FailureScore, MatchScore},
utils::datatype::{DataType, Type},
Any, Benchmark, CheckDeserialization, Checker, Checkpoint, Input, Output,
};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub(crate) struct TypeInput {
pub(super) data_type: DataType,
pub(super) dim: usize,
pub(super) error_when_checked: bool,
#[serde(skip)]
pub(crate) checked: bool,
}
impl TypeInput {
pub(crate) fn new(data_type: DataType, dim: usize, error_when_checked: bool) -> Self {
Self {
data_type,
dim,
error_when_checked,
checked: false,
}
}
fn run(&self) -> &'static str {
self.data_type.as_str()
}
}
impl Input for TypeInput {
fn tag() -> &'static str {
"test-input-types"
}
fn try_deserialize(
serialized: &serde_json::Value,
checker: &mut Checker,
) -> anyhow::Result<Any> {
checker.any(TypeInput::deserialize(serialized)?)
}
fn example() -> anyhow::Result<serde_json::Value> {
Ok(serde_json::to_value(TypeInput::new(
DataType::Float32,
128,
false,
))?)
}
}
impl CheckDeserialization for TypeInput {
fn check_deserialization(&mut self, _checker: &mut Checker) -> anyhow::Result<()> {
if self.error_when_checked {
Err(anyhow::anyhow!("test input erroring when checked"))
} else {
self.checked = true;
Ok(())
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub(super) struct Tolerance {
pub(super) error_when_checked: bool,
#[serde(skip)]
pub(crate) checked: bool,
}
impl Input for Tolerance {
fn tag() -> &'static str {
"test-input-types-tolerance"
}
fn try_deserialize(
serialized: &serde_json::Value,
checker: &mut Checker,
) -> anyhow::Result<Any> {
checker.any(Self::deserialize(serialized)?)
}
fn example() -> anyhow::Result<serde_json::Value> {
let this = Self {
error_when_checked: false,
checked: false,
};
Ok(serde_json::to_value(this)?)
}
}
impl CheckDeserialization for Tolerance {
fn check_deserialization(&mut self, _checker: &mut Checker) -> anyhow::Result<()> {
if self.error_when_checked {
Err(anyhow::anyhow!("test input erroring when checked"))
} else {
self.checked = true;
Ok(())
}
}
}
#[derive(Debug)]
pub(super) struct TypeBench<T>(std::marker::PhantomData<T>);
impl<T> Benchmark for TypeBench<T>
where
T: 'static,
Type<T>: DispatchRule<DataType, Error: std::error::Error + Send + Sync + 'static>,
{
type Input = TypeInput;
type Output = String;
fn try_match(input: &TypeInput) -> Result<MatchScore, FailureScore> {
Type::<T>::try_match(&input.data_type).map(|m| MatchScore(m.0 + 10))
}
fn description(f: &mut std::fmt::Formatter<'_>, input: Option<&TypeInput>) -> std::fmt::Result {
Type::<T>::description(f, input.map(|i| &i.data_type))
}
fn run(
input: &TypeInput,
checkpoint: Checkpoint<'_>,
mut output: &mut dyn Output,
) -> anyhow::Result<Self::Output> {
let result = input.run().to_string();
write!(output, "hello: {}", result)?;
checkpoint.checkpoint(&result)?;
Ok(result)
}
}
impl<T> Regression for TypeBench<T>
where
T: 'static,
Type<T>: DispatchRule<DataType, Error: std::error::Error + Send + Sync + 'static>,
{
type Tolerances = Tolerance;
type Pass = DataType;
type Fail = DataType;
fn check(
_tolerance: &Tolerance,
input: &TypeInput,
before: &String,
after: &String,
) -> anyhow::Result<PassFail<Self::Pass, Self::Fail>> {
let expected = input.run();
assert_eq!(*before, expected);
assert_eq!(*after, expected);
Ok(PassFail::Pass(input.data_type))
}
}
#[derive(Debug)]
pub(super) struct ExactTypeBench<T, const N: usize>(std::marker::PhantomData<T>);
impl<T, const N: usize> Benchmark for ExactTypeBench<T, N>
where
T: 'static,
Type<T>: DispatchRule<DataType, Error: std::error::Error + Send + Sync + 'static>,
{
type Input = TypeInput;
type Output = String;
fn try_match(input: &TypeInput) -> Result<MatchScore, FailureScore> {
if input.dim == N {
Type::<T>::try_match(&input.data_type)
} else {
Err(FailureScore(1000))
}
}
fn description(f: &mut std::fmt::Formatter<'_>, input: Option<&TypeInput>) -> std::fmt::Result {
match input {
None => {
write!(f, "{}, dim={}", Description::<DataType, Type<T>>::new(), N)
}
Some(input) => {
let type_result = Type::<T>::try_match_verbose(&input.data_type);
let dim_ok = input.dim == N;
match (type_result, dim_ok) {
(Ok(_), true) => write!(f, "successful match"),
(Err(err), true) => write!(f, "{}", err),
(Ok(_), false) => {
write!(f, "expected dim={}, but found dim={}", N, input.dim)
}
(Err(err), false) => {
write!(
f,
"{}; expected dim={}, but found dim={}",
err, N, input.dim
)
}
}
}
}
}
fn run(
input: &TypeInput,
checkpoint: Checkpoint<'_>,
mut output: &mut dyn Output,
) -> anyhow::Result<Self::Output> {
let s = format!("hello<{}>: {}", N, input.data_type.as_str());
write!(output, "{}", s)?;
checkpoint.checkpoint(&s)?;
Ok(s)
}
}
impl<T, const N: usize> Regression for ExactTypeBench<T, N>
where
T: 'static,
Type<T>: DispatchRule<DataType, Error: std::error::Error + Send + Sync + 'static>,
{
type Tolerances = Tolerance;
type Pass = String;
type Fail = String;
fn check(
_tolerance: &Tolerance,
input: &TypeInput,
before: &String,
after: &String,
) -> anyhow::Result<PassFail<Self::Pass, Self::Fail>> {
let expected = format!("hello<{}>: {}", N, input.data_type.as_str());
assert_eq!(*before, expected);
assert_eq!(*after, expected);
Ok(PassFail::Pass(format!(
"exact match dim={} type={}",
N, input.data_type
)))
}
}