use std::{collections::HashMap, io::Write, path::Path, rc::Rc};
use anyhow::Context;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::{
benchmark::{internal::CheckedPassFail, PassFail},
internal::load_from_disk,
jobs, registry, result, Any, Checker,
};
struct Check<'a> {
regression: registry::RegressionBenchmark<'a>,
tolerance: Rc<Any>,
input: Any,
}
pub(crate) struct Checks<'a> {
checks: Vec<Check<'a>>,
}
impl<'a> Checks<'a> {
pub(crate) fn new(
tolerances: &Path,
input_file: &Path,
inputs: ®istry::Inputs,
entries: &'a HashMap<&'static str, registry::RegisteredTolerance<'a>>,
) -> anyhow::Result<Self> {
let partial = jobs::Partial::load(input_file)?;
let inputs = jobs::Jobs::parse(&partial, inputs)?;
let parsed = Raw::load(tolerances)?.parse(entries)?;
Self::match_all(parsed, partial, inputs)
}
pub(crate) fn jobs(self, before: &Path, after: &Path) -> anyhow::Result<Jobs<'a>> {
let (before_path, after_path) = (before, after);
let before = result::RawResult::load(before_path)?;
let after = result::RawResult::load(after_path)?;
let expected = self.checks.len();
anyhow::ensure!(
before.len() == expected,
"\"before\" file \"{}\" has {} entries but expected {}",
before_path.display(),
before.len(),
expected,
);
anyhow::ensure!(
after.len() == expected,
"\"after\" file \"{}\" has {} entries but expected {}",
after_path.display(),
after.len(),
expected,
);
let jobs = std::iter::zip(self.checks, std::iter::zip(before, after))
.map(|(check, (before, after))| {
let Check {
regression,
tolerance,
input,
} = check;
Job {
regression,
tolerance,
input,
before,
after,
}
})
.collect();
Ok(Jobs { jobs })
}
fn match_all(
parsed: Parsed<'a>,
partial: jobs::Partial,
inputs: jobs::Jobs,
) -> anyhow::Result<Self> {
debug_assert_eq!(
partial.jobs().len(),
inputs.jobs().len(),
"expected \"inputs\" to be the parsed representation of \"partial\""
);
let mut parsed_to_input: Vec<Vec<usize>> = vec![Vec::default(); parsed.inner.len()];
let mut input_to_parsed: Vec<Vec<usize>> = vec![Vec::default(); inputs.jobs().len()];
parsed.inner.iter().enumerate().for_each(|(i, t)| {
partial.jobs().iter().enumerate().for_each(|(j, raw)| {
if raw.tag == t.input.tag && is_subset(&raw.content, &t.input.content) {
parsed_to_input[i].push(j);
input_to_parsed[j].push(i);
}
})
});
let input_to_parsed = check_matches(parsed_to_input, input_to_parsed)?;
debug_assert_eq!(input_to_parsed.len(), inputs.jobs().len());
let checks = std::iter::zip(inputs.into_inner(), input_to_parsed.into_iter())
.map(|(input, index)| {
let inner = &parsed.inner[index];
assert_eq!(inner.input.tag, input.tag());
let regression = inner
.entry
.regressions
.iter()
.filter_map(|r| r.try_match(&input).ok().map(|score| (*r, score)))
.min_by_key(|(_, score)| *score)
.map(|(r, _)| r)
.ok_or_else(|| {
anyhow::anyhow!(
"Could not match input tag \"{}\" and tolerance tag \"{}\" to \
a valid benchmark. This likely means file or code changes \
between when the input file was last used. If the normal \
benchmark flow succeeds, please report this issue.",
inner.input.tag,
inner.tolerance.tag(),
)
})?;
Ok(Check {
regression,
tolerance: inner.tolerance.clone(),
input,
})
})
.collect::<anyhow::Result<Vec<_>>>()?;
Ok(Self { checks })
}
}
#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct RawInner {
input: jobs::Unprocessed,
tolerance: jobs::Unprocessed,
}
impl RawInner {
pub(crate) fn new(input: jobs::Unprocessed, tolerance: jobs::Unprocessed) -> Self {
Self { input, tolerance }
}
}
#[derive(Debug, Default, Serialize, Deserialize)]
pub(crate) struct Raw {
checks: Vec<RawInner>,
}
impl Raw {
pub(crate) fn load(path: &Path) -> anyhow::Result<Self> {
load_from_disk(path)
}
fn parse<'a>(
self,
entries: &'a HashMap<&'static str, registry::RegisteredTolerance<'a>>,
) -> anyhow::Result<Parsed<'a>> {
let num_checks = self.checks.len();
let mut checker = Checker::new(vec![], None);
let inner = self
.checks
.into_iter()
.enumerate()
.map(|(i, unprocessed)| {
let context = || {
format!(
"while processing tolerance input {} of {}",
i.wrapping_add(1),
num_checks,
)
};
let entry = entries
.get(&*unprocessed.tolerance.tag)
.ok_or_else(|| {
anyhow::anyhow!(
"Unrecognized tolerance tag: \"{}\"",
unprocessed.tolerance.tag
)
})
.with_context(context)?;
if !entry
.regressions
.iter()
.any(|r| r.input_tag() == unprocessed.input.tag)
{
let valid: Vec<_> = entry
.regressions
.iter()
.map(|pair| pair.input_tag())
.collect();
return Err(anyhow::anyhow!(
"input tag \"{}\" is not compatible with tolerance tag \"{}\". \
Valid input tags are: {:?}",
unprocessed.input.tag,
unprocessed.tolerance.tag,
valid,
))
.with_context(context);
}
checker.set_tag(entry.tolerance.tag());
let tolerance = entry
.tolerance
.try_deserialize(&unprocessed.tolerance.content, &mut checker)
.with_context(context)?;
Ok(ParsedInner {
entry,
tolerance: Rc::new(tolerance),
input: unprocessed.input,
})
})
.collect::<anyhow::Result<_>>()?;
Ok(Parsed { inner })
}
pub(crate) fn example() -> String {
#[expect(
clippy::expect_used,
reason = "we control the concrete struct and its serialization implementation"
)]
serde_json::to_string_pretty(&Self::default())
.expect("built-in serialization should succeed")
}
}
#[derive(Debug)]
struct ParsedInner<'a> {
entry: &'a registry::RegisteredTolerance<'a>,
tolerance: Rc<Any>,
input: jobs::Unprocessed,
}
#[derive(Debug)]
struct Parsed<'a> {
inner: Vec<ParsedInner<'a>>,
}
#[must_use]
pub(crate) fn is_subset(mut haystack: &Value, mut needle: &Value) -> bool {
macro_rules! false_if {
($expr:expr) => {
if $expr {
return false;
}
};
}
let mut stack = Vec::new();
loop {
match (haystack, needle) {
(Value::Null, Value::Null) => {
}
(Value::Bool(h), Value::Bool(n)) => false_if!(h != n),
(Value::Number(h), Value::Number(n)) => false_if!(h != n),
(Value::String(h), Value::String(n)) => false_if!(h != n),
(Value::Array(h), Value::Array(n)) => {
false_if!(h.len() < n.len());
std::iter::zip(h.iter(), n.iter()).for_each(|(h, n)| stack.push((h, n)));
}
(Value::Object(h), Value::Object(n)) => {
for (k, v) in n.iter() {
match h.get(k) {
Some(h) => stack.push((h, v)),
None => return false,
}
}
}
_ => return false,
}
if let Some((h, n)) = stack.pop() {
(haystack, needle) = (h, n);
} else {
break;
}
}
true
}
#[derive(Debug, PartialEq)]
enum MatchProblem {
OrphanedTolerance(usize),
UncoveredInput(usize),
AmbiguousInput(usize, Vec<usize>),
}
#[derive(Debug)]
struct AmbiguousMatch(Vec<MatchProblem>);
impl std::fmt::Display for AmbiguousMatch {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "tolerance matching failed:")?;
for problem in &self.0 {
match problem {
MatchProblem::OrphanedTolerance(i) => {
write!(f, "\n tolerance {} matched no inputs", i + 1)?;
}
MatchProblem::UncoveredInput(i) => {
write!(f, "\n input {} matched no tolerances", i + 1)?;
}
MatchProblem::AmbiguousInput(i, tolerances) => {
write!(f, "\n input {} matched tolerances ", i + 1)?;
for (j, &t) in tolerances.iter().enumerate() {
if j > 0 {
write!(f, ", ")?;
}
write!(f, "{}", t + 1)?;
}
}
}
}
Ok(())
}
}
impl std::error::Error for AmbiguousMatch {}
fn check_matches(
parsed_to_input: Vec<Vec<usize>>,
input_to_parsed: Vec<Vec<usize>>,
) -> Result<Vec<usize>, AmbiguousMatch> {
let mut problems = Vec::new();
for (i, matches) in parsed_to_input.iter().enumerate() {
if matches.is_empty() {
problems.push(MatchProblem::OrphanedTolerance(i));
}
}
let mut result = Vec::with_capacity(input_to_parsed.len());
for (i, matches) in input_to_parsed.into_iter().enumerate() {
match matches.len() {
0 => problems.push(MatchProblem::UncoveredInput(i)),
1 => result.push(matches[0]),
_ => problems.push(MatchProblem::AmbiguousInput(i, matches)),
}
}
if problems.is_empty() {
Ok(result)
} else {
Err(AmbiguousMatch(problems))
}
}
#[derive(Debug)]
pub(crate) struct Job<'a> {
regression: registry::RegressionBenchmark<'a>,
tolerance: Rc<Any>,
input: Any,
before: result::RawResult,
after: result::RawResult,
}
impl Job<'_> {
fn run(&self) -> anyhow::Result<CheckedPassFail> {
self.regression.check(
&self.tolerance,
&self.input,
&self.before.results,
&self.after.results,
)
}
}
#[derive(Debug)]
pub(crate) struct Jobs<'a> {
jobs: Vec<Job<'a>>,
}
impl Jobs<'_> {
pub(crate) fn run(
&self,
mut output: &mut dyn crate::output::Output,
output_file: Option<&Path>,
) -> anyhow::Result<()> {
let results: Vec<_> = self.jobs.iter().map(|job| job.run()).collect();
let check_outputs: Vec<CheckOutput<'_>> = std::iter::zip(self.jobs.iter(), results.iter())
.map(|(job, result)| -> anyhow::Result<_> {
let tolerance = job.tolerance.serialize()?;
let o = match result {
Ok(PassFail::Pass(checked)) => CheckOutput::pass(tolerance, &checked.json),
Ok(PassFail::Fail(checked)) => CheckOutput::fail(tolerance, &checked.json),
Err(err) => CheckOutput::error(tolerance, err),
};
Ok(o)
})
.collect::<anyhow::Result<_>>()?;
if let Some(path) = output_file {
let json = serde_json::to_string_pretty(&check_outputs)?;
std::fs::write(path, json)
.with_context(|| format!("failed to write output to \"{}\"", path.display()))?;
}
let mut has_errors = false;
for (i, result) in results.iter().enumerate() {
if let Err(err) = result {
let job = &self.jobs[i];
writeln!(
output,
"Check {} of {} ({:?}) encountered an error:\n{:?}\n",
i + 1,
self.jobs.len(),
job.regression.name(),
err,
)?;
has_errors = true;
}
}
if has_errors {
return Err(anyhow::anyhow!("one or more checks failed with errors"));
}
let mut has_failures = false;
for (i, result) in results.iter().enumerate() {
#[expect(
clippy::expect_used,
reason = "we would have ready returned if errors were present"
)]
let outcome = result
.as_ref()
.expect("no errors should be present any more");
if let PassFail::Fail(checked) = outcome {
let job = &self.jobs[i];
writeln!(
output,
"Check {} of {} ({:?}) FAILED:",
i + 1,
self.jobs.len(),
job.regression.name(),
)?;
writeln!(output, "{}", checked.display)?;
writeln!(output)?;
has_failures = true;
}
}
if has_failures {
return Err(anyhow::anyhow!("one or more regression checks failed"));
}
for (i, result) in results.iter().enumerate() {
#[expect(
clippy::expect_used,
reason = "we would have returned if errors were present"
)]
let outcome = result
.as_ref()
.expect("no errors should be present any more");
let PassFail::Pass(checked) = outcome else {
unreachable!("all failures handled above");
};
let job = &self.jobs[i];
writeln!(
output,
"Check {} of {} ({:?}) PASSED:",
i + 1,
self.jobs.len(),
job.regression.name(),
)?;
writeln!(output, "{}", checked.display)?;
writeln!(output)?;
}
Ok(())
}
}
#[derive(Serialize)]
struct CheckOutput<'a> {
status: &'static str,
tolerance: Value,
#[serde(skip_serializing_if = "Option::is_none")]
result: Option<&'a Value>,
#[serde(skip_serializing_if = "Option::is_none")]
error: Option<String>,
}
impl<'a> CheckOutput<'a> {
fn pass(tolerance: Value, result: &'a Value) -> Self {
Self {
status: "pass",
tolerance,
result: Some(result),
error: None,
}
}
fn fail(tolerance: Value, result: &'a Value) -> Self {
Self {
status: "fail",
tolerance,
result: Some(result),
error: None,
}
}
fn error(tolerance: Value, err: &anyhow::Error) -> Self {
let error = err
.chain()
.map(|e| e.to_string())
.collect::<Vec<_>>()
.join(": ");
Self {
status: "error",
tolerance,
result: None,
error: Some(error),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn empty_values() -> Vec<Value> {
vec![
Value::Null,
Value::Bool(false),
Value::Number(serde_json::Number::from_f64(0.0).unwrap()),
Value::String(String::new()),
Value::Array(Vec::new()),
Value::Object(serde_json::Map::new()),
]
}
#[test]
fn test_is_subset() {
for v in empty_values() {
if matches!(v, Value::Null) {
assert!(is_subset(&Value::Null, &v));
} else {
assert!(!is_subset(&Value::Null, &v));
}
}
assert!(is_subset(&json!(true), &json!(true)));
assert!(!is_subset(&json!(true), &json!(false)));
assert!(!is_subset(&json!(true), &json!(0)));
assert!(is_subset(&json!(7), &json!(7)));
assert!(!is_subset(&json!(7), &json!(8)));
assert!(!is_subset(&json!(7), &json!("7")));
assert!(is_subset(&json!("abc"), &json!("abc")));
assert!(!is_subset(&json!("abc"), &json!("def")));
assert!(is_subset(&json!([1, 2, 3]), &json!([])));
assert!(is_subset(&json!([1, 2, 3]), &json!([1])));
assert!(is_subset(&json!([1, 2, 3]), &json!([1, 2])));
assert!(is_subset(&json!([1, 2, 3]), &json!([1, 2, 3])));
assert!(!is_subset(&json!([1, 2]), &json!([1, 2, 3])));
assert!(!is_subset(&json!([1, 2, 3]), &json!([1, 3])));
assert!(is_subset(&json!({"a": 1, "b": 2}), &json!({"a": 1})));
assert!(is_subset(&json!({"a": 1, "b": 2}), &json!({})));
assert!(is_subset(
&json!({"a": {"b": 1, "c": 2}, "d": 3}),
&json!({"a": {"b": 1}}),
));
assert!(!is_subset(&json!({"a": 1}), &json!({"a": 1, "b": 2}),));
assert!(!is_subset(&json!({"a": {"b": 1}}), &json!({"a": {"b": 2}}),));
assert!(is_subset(
&json!({"ops": [{"kind": "l2", "dim": 128}, {"kind": "cosine", "dim": 256}]}),
&json!({"ops": [{"kind": "l2"}]}),
));
assert!(is_subset(
&json!({"ops": [{"kind": "l2", "dim": 128}, {"kind": "cosine", "dim": 256}]}),
&json!({"ops": [{"kind": "l2", "dim": 128}, {"kind": "cosine"}]}),
));
assert!(!is_subset(
&json!({"ops": [{"kind": "l2", "dim": 128}, {"kind": "cosine", "dim": 256}]}),
&json!({"ops": [{"kind": "cosine"}]}),
));
}
#[test]
fn test_check_matches_success() {
let result = check_matches(vec![vec![0], vec![1]], vec![vec![0], vec![1]]).unwrap();
assert_eq!(result, vec![0, 1]);
}
#[test]
fn test_check_matches_reports_problems_in_stable_order() {
let err = check_matches(
vec![vec![0], vec![], vec![2, 3]],
vec![vec![0], vec![], vec![2, 3]],
)
.unwrap_err();
assert_eq!(
&err.0,
&[
MatchProblem::OrphanedTolerance(1),
MatchProblem::UncoveredInput(1),
MatchProblem::AmbiguousInput(2, vec![2, 3]),
]
)
}
}