vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
//! Progress reporters for streaming conformance suite results.
//!
//! Reporters receive events as the suite runs and produce summaries
//! on completion.

use crate::spec::types::ParityFailure;
use std::collections::BTreeMap;

/// Receives streaming suite progress and produces summaries.
pub trait Reporter: Send + Sync {
    /// Called when an op begins testing with the expected input count.
    fn on_op_start(&mut self, op_id: &str, input_count: usize);
    /// Called when a single input passes.
    fn on_pass(&mut self, op_id: &str, input_label: &str);
    /// Called when a single input fails with full failure context.
    fn on_fail(&mut self, failure: &ParityFailure);
    /// Called when all inputs for an op are complete.
    fn on_op_done(&mut self, op_id: &str, pass_count: usize, fail_count: usize);
    /// Produce a human-readable summary of the entire run.
    fn summary(&self) -> String;
}

/// Simple console reporter that tracks cumulative pass/fail counts.
#[derive(Default)]
pub struct ConsoleReporter {
    pass_count: usize,
    fail_count: usize,
}

/// Detailed metrics reporter that tracks per-op and per-generator statistics.
#[derive(Default)]
pub struct MetricsReporter {
    per_op: BTreeMap<String, Counts>,
    per_generator: BTreeMap<String, Counts>,
}

#[derive(Default, Clone, Copy)]
struct Counts {
    pass: usize,
    fail: usize,
}

impl Reporter for ConsoleReporter {
    fn on_op_start(&mut self, op_id: &str, input_count: usize) {
        eprintln!("vyre-conform: {op_id}: running {input_count} inputs");
    }

    fn on_pass(&mut self, _op_id: &str, _input_label: &str) {
        self.pass_count += 1;
    }

    fn on_fail(&mut self, failure: &ParityFailure) {
        self.fail_count += 1;
        eprintln!(
            "vyre-conform: FAIL {} v{} wg{} {}: {}",
            failure.op_id,
            failure.spec_version,
            failure.workgroup_size,
            failure.input_label,
            failure.message
        );
    }

    fn on_op_done(&mut self, op_id: &str, pass_count: usize, fail_count: usize) {
        eprintln!("vyre-conform: {op_id}: {pass_count} passed, {fail_count} failed");
    }

    fn summary(&self) -> String {
        format!("{} passed, {} failed", self.pass_count, self.fail_count)
    }
}

impl Reporter for MetricsReporter {
    fn on_op_start(&mut self, op_id: &str, _input_count: usize) {
        self.per_op.entry(op_id.to_string()).or_default();
    }

    fn on_pass(&mut self, op_id: &str, input_label: &str) {
        self.per_op.entry(op_id.to_string()).or_default().pass += 1;
        self.per_generator
            .entry(generator_name(input_label).to_string())
            .or_default()
            .pass += 1;
    }

    fn on_fail(&mut self, failure: &ParityFailure) {
        self.per_op.entry(failure.op_id.clone()).or_default().fail += 1;
        self.per_generator
            .entry(failure.generator.clone())
            .or_default()
            .fail += 1;
    }

    fn on_op_done(&mut self, _op_id: &str, _pass_count: usize, _fail_count: usize) {}

    fn summary(&self) -> String {
        let mut lines = Vec::new();
        for (op, counts) in &self.per_op {
            lines.push(format!(
                "{op}: {} passed, {} failed",
                counts.pass, counts.fail
            ));
        }
        for (generator, counts) in &self.per_generator {
            lines.push(format!(
                "generator {generator}: {} passed, {} failed",
                counts.pass, counts.fail
            ));
        }
        lines.join("\n")
    }
}

/// Extract the generator name from an input label like `"random/case_42"`.
fn generator_name(input_label: &str) -> &str {
    match input_label.split_once('/') {
        Some((generator, _)) => generator,
        None => "unknown",
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use super::{generator_name, ConsoleReporter, MetricsReporter};
    use crate::spec::types::ParityFailure;

    #[test]
    fn generator_name_extracts_prefix() {
        assert_eq!(generator_name("edge_cases/max_value"), "edge_cases");
    }

    #[test]
    fn generator_name_no_slash_returns_unknown() {
        assert_eq!(generator_name("bare_label"), "unknown");
    }

    #[test]
    fn generator_name_multiple_slashes_takes_first() {
        assert_eq!(generator_name("a/b/c"), "a");
    }

    #[test]
    fn console_reporter_counts_passes() {
        let mut reporter = ConsoleReporter::default();
        reporter.on_op_start("test.op", 3);
        reporter.on_pass("test.op", "case1");
        reporter.on_pass("test.op", "case2");
        reporter.on_op_done("test.op", 2, 0);
        let summary = reporter.summary();
        assert!(summary.contains("2 passed"), "got: {summary}");
        assert!(summary.contains("0 failed"), "got: {summary}");
    }

    #[test]
    fn console_reporter_counts_failures() {
        let mut reporter = ConsoleReporter::default();
        reporter.on_op_start("test.op", 1);
        let failure = ParityFailure {
            op_id: "test.op".to_string(),
            generator: "edge_cases".to_string(),
            input_label: "max".to_string(),
            input: vec![0xFF; 4],
            gpu_output: vec![0x00; 4],
            cpu_output: vec![0xFF; 4],
            message: "mismatch".to_string(),
            spec_version: 1,
            workgroup_size: 1,
        };
        reporter.on_fail(&failure);
        reporter.on_op_done("test.op", 0, 1);
        let summary = reporter.summary();
        assert!(summary.contains("0 passed"), "got: {summary}");
        assert!(summary.contains("1 failed"), "got: {summary}");
    }

    #[test]
    fn metrics_reporter_tracks_per_op() {
        let mut reporter = MetricsReporter::default();
        reporter.on_op_start("op.a", 2);
        reporter.on_pass("op.a", "gen/case1");
        reporter.on_pass("op.a", "gen/case2");
        reporter.on_op_done("op.a", 2, 0);

        reporter.on_op_start("op.b", 1);
        let failure = ParityFailure {
            op_id: "op.b".to_string(),
            generator: "edge_cases".to_string(),
            input_label: "x".to_string(),
            input: vec![],
            gpu_output: vec![],
            cpu_output: vec![],
            message: "fail".to_string(),
            spec_version: 1,
            workgroup_size: 1,
        };
        reporter.on_fail(&failure);
        reporter.on_op_done("op.b", 0, 1);

        let summary = reporter.summary();
        assert!(
            summary.contains("op.a: 2 passed, 0 failed"),
            "got: {summary}"
        );
        assert!(
            summary.contains("op.b: 0 passed, 1 failed"),
            "got: {summary}"
        );
    }

    #[test]
    fn metrics_reporter_tracks_per_generator() {
        let mut reporter = MetricsReporter::default();
        reporter.on_pass("op.a", "random/case1");
        reporter.on_pass("op.a", "random/case2");
        reporter.on_pass("op.a", "edge_cases/max");

        let summary = reporter.summary();
        assert!(
            summary.contains("generator random: 2 passed"),
            "got: {summary}"
        );
        assert!(
            summary.contains("generator edge_cases: 1 passed"),
            "got: {summary}"
        );
    }
}