use crate::spec::types::ParityFailure;
use std::collections::BTreeMap;
pub trait Reporter: Send + Sync {
fn on_op_start(&mut self, op_id: &str, input_count: usize);
fn on_pass(&mut self, op_id: &str, input_label: &str);
fn on_fail(&mut self, failure: &ParityFailure);
fn on_op_done(&mut self, op_id: &str, pass_count: usize, fail_count: usize);
fn summary(&self) -> String;
}
#[derive(Default)]
pub struct ConsoleReporter {
pass_count: usize,
fail_count: usize,
}
#[derive(Default)]
pub struct MetricsReporter {
per_op: BTreeMap<String, Counts>,
per_generator: BTreeMap<String, Counts>,
}
#[derive(Default, Clone, Copy)]
struct Counts {
pass: usize,
fail: usize,
}
impl Reporter for ConsoleReporter {
fn on_op_start(&mut self, op_id: &str, input_count: usize) {
eprintln!("vyre-conform: {op_id}: running {input_count} inputs");
}
fn on_pass(&mut self, _op_id: &str, _input_label: &str) {
self.pass_count += 1;
}
fn on_fail(&mut self, failure: &ParityFailure) {
self.fail_count += 1;
eprintln!(
"vyre-conform: FAIL {} v{} wg{} {}: {}",
failure.op_id,
failure.spec_version,
failure.workgroup_size,
failure.input_label,
failure.message
);
}
fn on_op_done(&mut self, op_id: &str, pass_count: usize, fail_count: usize) {
eprintln!("vyre-conform: {op_id}: {pass_count} passed, {fail_count} failed");
}
fn summary(&self) -> String {
format!("{} passed, {} failed", self.pass_count, self.fail_count)
}
}
impl Reporter for MetricsReporter {
fn on_op_start(&mut self, op_id: &str, _input_count: usize) {
self.per_op.entry(op_id.to_string()).or_default();
}
fn on_pass(&mut self, op_id: &str, input_label: &str) {
self.per_op.entry(op_id.to_string()).or_default().pass += 1;
self.per_generator
.entry(generator_name(input_label).to_string())
.or_default()
.pass += 1;
}
fn on_fail(&mut self, failure: &ParityFailure) {
self.per_op.entry(failure.op_id.clone()).or_default().fail += 1;
self.per_generator
.entry(failure.generator.clone())
.or_default()
.fail += 1;
}
fn on_op_done(&mut self, _op_id: &str, _pass_count: usize, _fail_count: usize) {}
fn summary(&self) -> String {
let mut lines = Vec::new();
for (op, counts) in &self.per_op {
lines.push(format!(
"{op}: {} passed, {} failed",
counts.pass, counts.fail
));
}
for (generator, counts) in &self.per_generator {
lines.push(format!(
"generator {generator}: {} passed, {} failed",
counts.pass, counts.fail
));
}
lines.join("\n")
}
}
fn generator_name(input_label: &str) -> &str {
match input_label.split_once('/') {
Some((generator, _)) => generator,
None => "unknown",
}
}
#[cfg(test)]
mod tests {
use super::*;
use super::{generator_name, ConsoleReporter, MetricsReporter};
use crate::spec::types::ParityFailure;
#[test]
fn generator_name_extracts_prefix() {
assert_eq!(generator_name("edge_cases/max_value"), "edge_cases");
}
#[test]
fn generator_name_no_slash_returns_unknown() {
assert_eq!(generator_name("bare_label"), "unknown");
}
#[test]
fn generator_name_multiple_slashes_takes_first() {
assert_eq!(generator_name("a/b/c"), "a");
}
#[test]
fn console_reporter_counts_passes() {
let mut reporter = ConsoleReporter::default();
reporter.on_op_start("test.op", 3);
reporter.on_pass("test.op", "case1");
reporter.on_pass("test.op", "case2");
reporter.on_op_done("test.op", 2, 0);
let summary = reporter.summary();
assert!(summary.contains("2 passed"), "got: {summary}");
assert!(summary.contains("0 failed"), "got: {summary}");
}
#[test]
fn console_reporter_counts_failures() {
let mut reporter = ConsoleReporter::default();
reporter.on_op_start("test.op", 1);
let failure = ParityFailure {
op_id: "test.op".to_string(),
generator: "edge_cases".to_string(),
input_label: "max".to_string(),
input: vec![0xFF; 4],
gpu_output: vec![0x00; 4],
cpu_output: vec![0xFF; 4],
message: "mismatch".to_string(),
spec_version: 1,
workgroup_size: 1,
};
reporter.on_fail(&failure);
reporter.on_op_done("test.op", 0, 1);
let summary = reporter.summary();
assert!(summary.contains("0 passed"), "got: {summary}");
assert!(summary.contains("1 failed"), "got: {summary}");
}
#[test]
fn metrics_reporter_tracks_per_op() {
let mut reporter = MetricsReporter::default();
reporter.on_op_start("op.a", 2);
reporter.on_pass("op.a", "gen/case1");
reporter.on_pass("op.a", "gen/case2");
reporter.on_op_done("op.a", 2, 0);
reporter.on_op_start("op.b", 1);
let failure = ParityFailure {
op_id: "op.b".to_string(),
generator: "edge_cases".to_string(),
input_label: "x".to_string(),
input: vec![],
gpu_output: vec![],
cpu_output: vec![],
message: "fail".to_string(),
spec_version: 1,
workgroup_size: 1,
};
reporter.on_fail(&failure);
reporter.on_op_done("op.b", 0, 1);
let summary = reporter.summary();
assert!(
summary.contains("op.a: 2 passed, 0 failed"),
"got: {summary}"
);
assert!(
summary.contains("op.b: 0 passed, 1 failed"),
"got: {summary}"
);
}
#[test]
fn metrics_reporter_tracks_per_generator() {
let mut reporter = MetricsReporter::default();
reporter.on_pass("op.a", "random/case1");
reporter.on_pass("op.a", "random/case2");
reporter.on_pass("op.a", "edge_cases/max");
let summary = reporter.summary();
assert!(
summary.contains("generator random: 2 passed"),
"got: {summary}"
);
assert!(
summary.contains("generator edge_cases: 1 passed"),
"got: {summary}"
);
}
}