use crate::spec::types::ParityFailure;
use std::collections::BTreeMap;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn generator_name_extracts_prefix() {
assert_eq!(generator_name("edge_cases/max_value"), "edge_cases");
}
#[test]
fn generator_name_no_slash_returns_unknown() {
assert_eq!(generator_name("bare_label"), "unknown");
}
#[test]
fn generator_name_multiple_slashes_takes_first() {
assert_eq!(generator_name("a/b/c"), "a");
}
#[test]
fn console_reporter_counts_passes() {
let mut reporter = ConsoleReporter::default();
reporter.on_op_start("test.op", 3);
reporter.on_pass("test.op", "case1");
reporter.on_pass("test.op", "case2");
reporter.on_op_done("test.op", 2, 0);
let summary = reporter.summary();
assert!(summary.contains("2 passed"), "got: {summary}");
assert!(summary.contains("0 failed"), "got: {summary}");
}
#[test]
fn console_reporter_counts_failures() {
let mut reporter = ConsoleReporter::default();
reporter.on_op_start("test.op", 1);
let failure = ParityFailure {
op_id: "test.op".to_string(),
generator: "edge_cases".to_string(),
input_label: "max".to_string(),
input: vec![0xFF; 4],
gpu_output: vec![0x00; 4],
cpu_output: vec![0xFF; 4],
message: "mismatch".to_string(),
spec_version: 1,
workgroup_size: 1,
};
reporter.on_fail(&failure);
reporter.on_op_done("test.op", 0, 1);
let summary = reporter.summary();
assert!(summary.contains("0 passed"), "got: {summary}");
assert!(summary.contains("1 failed"), "got: {summary}");
}
#[test]
fn metrics_reporter_tracks_per_op() {
let mut reporter = MetricsReporter::default();
reporter.on_op_start("op.a", 2);
reporter.on_pass("op.a", "gen/case1");
reporter.on_pass("op.a", "gen/case2");
reporter.on_op_done("op.a", 2, 0);
reporter.on_op_start("op.b", 1);
let failure = ParityFailure {
op_id: "op.b".to_string(),
generator: "edge_cases".to_string(),
input_label: "x".to_string(),
input: vec![],
gpu_output: vec![],
cpu_output: vec![],
message: "fail".to_string(),
spec_version: 1,
workgroup_size: 1,
};
reporter.on_fail(&failure);
reporter.on_op_done("op.b", 0, 1);
let summary = reporter.summary();
assert!(
summary.contains("op.a: 2 passed, 0 failed"),
"got: {summary}"
);
assert!(
summary.contains("op.b: 0 passed, 1 failed"),
"got: {summary}"
);
}
#[test]
fn metrics_reporter_tracks_per_generator() {
let mut reporter = MetricsReporter::default();
reporter.on_pass("op.a", "random/case1");
reporter.on_pass("op.a", "random/case2");
reporter.on_pass("op.a", "edge_cases/max");
let summary = reporter.summary();
assert!(
summary.contains("generator random: 2 passed"),
"got: {summary}"
);
assert!(
summary.contains("generator edge_cases: 1 passed"),
"got: {summary}"
);
}
}