vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
//! GPU parity tests for all primitive operations.
//!
//! These tests dispatch each primitive's WGSL shader on a real GPU and compare
//! the output byte-for-byte against the CPU reference function. If ANY byte
//! differs, the test fails. This is the test that makes vyre's promise real.
//!
//! Without these tests, vyre-conform is a CPU-only self-consistency check
//! that proves nothing about GPU behavior.

#[cfg(test)]
mod tests {
    use super::*;
    use crate::pipeline::backend::{require_gpu, ConformDispatchConfig, WgslBackend};
    use crate::wgpu_backend::WgpuBackend;

    const WORKGROUP_SWEEP: &[u32] = &[1, 8, 64, 256];

    fn run_at_all_workgroups(
        backend: &WgpuBackend,
        op_wgsl: &str,
        input: &[u8],
        output_size: usize,
        base_config: &ConformDispatchConfig,
        cpu_result: &[u8],
        context: &str,
    ) {
        for &workgroup_size in WORKGROUP_SWEEP {
            if let Some(max) = backend.max_workgroup_invocations() {
                assert!(
                    workgroup_size <= max,
                    "{context}: GPU max workgroup invocations {max} cannot cover required sweep size {workgroup_size}. \
                     Fix: run on a GPU that supports the vyre conformance workgroup sweep."
                );
            }
            let mut config = base_config.clone();
            config.workgroup_size = workgroup_size;
            let shader = crate::pipeline::backend::wrap_shader(op_wgsl, &config);
            let gpu_result = backend
                .dispatch(&shader, input, output_size, config)
                .unwrap_or_else(|e| {
                    panic!("{context}: GPU dispatch failed at workgroup_size={workgroup_size}: {e}")
                });

            assert_eq!(
                gpu_result, cpu_result,
                "{context}: GPU/CPU parity failure at workgroup_size={workgroup_size}: gpu={:?} cpu={:?}",
                gpu_result, cpu_result
            );
        }
    }

    /// Run a binary u32 op through GPU and compare against CPU reference.
    fn verify_binary_parity(
        backend: &WgpuBackend,
        op_wgsl: &str,
        cpu_fn: fn(&[u8]) -> Vec<u8>,
        test_pairs: &[(u32, u32)],
        op_name: &str,
    ) {
        assert!(
            !test_pairs.is_empty(),
            "Fix: {op_name} must define at least one boundary pair"
        );
        crate::pipeline::backend::run_with_all_init_policies(|init| {
            let mut config = ConformDispatchConfig::default();
            config.buffer_init = init;
            for &(a, b) in test_pairs {
                let mut input = Vec::with_capacity(8);
                input.extend_from_slice(&a.to_le_bytes());
                input.extend_from_slice(&b.to_le_bytes());

                let cpu_result = cpu_fn(&input);
                run_at_all_workgroups(
                    backend,
                    op_wgsl,
                    &input,
                    4,
                    &config,
                    &cpu_result,
                    &format!("{op_name} ({a:#010X}, {b:#010X}) with {init:?}"),
                );
            }
        });
    }

    /// Run a unary u32 op through GPU and compare against CPU reference.
    fn verify_unary_parity(
        backend: &WgpuBackend,
        op_wgsl: &str,
        cpu_fn: fn(&[u8]) -> Vec<u8>,
        test_values: &[u32],
        op_name: &str,
    ) {
        assert!(
            !test_values.is_empty(),
            "Fix: {op_name} must define at least one boundary value"
        );
        crate::pipeline::backend::run_with_all_init_policies(|init| {
            let mut config = ConformDispatchConfig::default();
            config.buffer_init = init;
            for &a in test_values {
                let input = a.to_le_bytes().to_vec();
                let cpu_result = cpu_fn(&input);
                run_at_all_workgroups(
                    backend,
                    op_wgsl,
                    &input,
                    4,
                    &config,
                    &cpu_result,
                    &format!("{op_name} {a:#010X} with {init:?}"),
                );
            }
        });
    }

    /// Standard boundary pairs for binary ops.
    const BINARY_BOUNDARIES: &[(u32, u32)] = &[
        (0, 0),
        (0, 1),
        (1, 0),
        (1, 1),
        (0, u32::MAX),
        (u32::MAX, 0),
        (u32::MAX, u32::MAX),
        (u32::MAX, 1),
        (1, u32::MAX),
        (0x80000000, 0),
        (0, 0x80000000),
        (0x80000000, 0x80000000),
        (0xDEADBEEF, 0xCAFEBABE),
        (0x55555555, 0xAAAAAAAA),
        (0xF0F0F0F0, 0x0F0F0F0F),
        (0x12345678, 0x9ABCDEF0),
        (31, 1),
        (32, 1),
        (33, 1),
        (1, 31),
        (1, 32),
        (1, 33),
    ];

    /// Standard boundary values for unary ops.
    const UNARY_BOUNDARIES: &[u32] = &[
        0,
        1,
        2,
        31,
        32,
        0xFF,
        0x100,
        0x8000,
        0x80000000,
        0x7FFFFFFF,
        u32::MAX,
        0xDEADBEEF,
        0x55555555,
        0xAAAAAAAA,
        0xF0F0F0F0,
    ];

    // ── Exhaustive u8 domain verification ──────────────────────────────

    /// Run a binary op through ALL u8 x u8 pairs (65,536 total).
    /// This is the mathematical proof that the GPU matches the CPU reference
    /// for every possible input in the u8 domain.
    fn verify_binary_exhaustive_u8(
        backend: &WgpuBackend,
        op_wgsl: &str,
        cpu_fn: fn(&[u8]) -> Vec<u8>,
        op_name: &str,
    ) {
        // u8 exhaustive domain is inherently non-empty (0..=255).
        crate::pipeline::backend::run_with_all_init_policies(|init| {
            let mut config = ConformDispatchConfig::default();
            config.buffer_init = init;
            for a in 0u32..=255 {
                for b in 0u32..=255 {
                    let mut input = Vec::with_capacity(8);
                    input.extend_from_slice(&a.to_le_bytes());
                    input.extend_from_slice(&b.to_le_bytes());

                    let cpu_result = cpu_fn(&input);
                    run_at_all_workgroups(
                        backend,
                        op_wgsl,
                        &input,
                        4,
                        &config,
                        &cpu_result,
                        &format!("{op_name} exhaustive ({a}, {b}) with {init:?}"),
                    );
                }
            }
        });
    }

    /// Run a unary op through ALL u8 values (256 total).
    fn verify_unary_exhaustive_u8(
        backend: &WgpuBackend,
        op_wgsl: &str,
        cpu_fn: fn(&[u8]) -> Vec<u8>,
        op_name: &str,
    ) {
        // u8 exhaustive domain is inherently non-empty (0..=255).
        crate::pipeline::backend::run_with_all_init_policies(|init| {
            let mut config = ConformDispatchConfig::default();
            config.buffer_init = init;
            for a in 0u32..=255 {
                let input = a.to_le_bytes().to_vec();
                let cpu_result = cpu_fn(&input);
                run_at_all_workgroups(
                    backend,
                    op_wgsl,
                    &input,
                    4,
                    &config,
                    &cpu_result,
                    &format!("{op_name} exhaustive {a} with {init:?}"),
                );
            }
        });
    }

    // ── GPU parity tests: one per primitive ─────────────────────────────

    use crate::spec::primitive;

    macro_rules! binary_gpu_parity {
        ($name:ident, $mod:ident, $label:expr) => {
            #[test]
            fn $name() {
                let backend = require_gpu().expect("vyre-conform test needs a GPU adapter");
                let spec = primitive::$mod::spec();
                verify_binary_parity(
                    &backend,
                    &(spec.wgsl_fn)(),
                    spec.cpu_fn,
                    BINARY_BOUNDARIES,
                    $label,
                );
            }
        };
    }

    macro_rules! unary_gpu_parity {
        ($name:ident, $mod:ident, $label:expr) => {
            #[test]
            fn $name() {
                let backend = require_gpu().expect("vyre-conform test needs a GPU adapter");
                let spec = primitive::$mod::spec();
                verify_unary_parity(
                    &backend,
                    &(spec.wgsl_fn)(),
                    spec.cpu_fn,
                    UNARY_BOUNDARIES,
                    $label,
                );
            }
        };
    }

    macro_rules! binary_exhaustive_u8 {
        ($name:ident, $mod:ident, $label:expr) => {
            #[test]
            fn $name() {
                let backend = require_gpu().expect("vyre-conform test needs a GPU adapter");
                let spec = primitive::$mod::spec();
                verify_binary_exhaustive_u8(&backend, &(spec.wgsl_fn)(), spec.cpu_fn, $label);
            }
        };
    }

    macro_rules! unary_exhaustive_u8 {
        ($name:ident, $mod:ident, $label:expr) => {
            #[test]
            fn $name() {
                let backend = require_gpu().expect("vyre-conform test needs a GPU adapter");
                let spec = primitive::$mod::spec();
                verify_unary_exhaustive_u8(&backend, &(spec.wgsl_fn)(), spec.cpu_fn, $label);
            }
        };
    }

    // ── Boundary parity (fast, runs on every commit) ────────────────────

    // Bitwise
    binary_gpu_parity!(gpu_parity_xor, xor, "xor");
    binary_gpu_parity!(gpu_parity_and, and, "and");
    binary_gpu_parity!(gpu_parity_or, or, "or");
    unary_gpu_parity!(gpu_parity_not, not, "not");
    binary_gpu_parity!(gpu_parity_shl, shl, "shl");
    binary_gpu_parity!(gpu_parity_shr, shr, "shr");
    binary_gpu_parity!(gpu_parity_rotl, rotl, "rotl");
    binary_gpu_parity!(gpu_parity_rotr, rotr, "rotr");
    unary_gpu_parity!(gpu_parity_popcount, popcount, "popcount");
    unary_gpu_parity!(gpu_parity_clz, clz, "clz");
    unary_gpu_parity!(gpu_parity_ctz, ctz, "ctz");
    unary_gpu_parity!(gpu_parity_reverse_bits, reverse_bits, "reverse_bits");
    binary_gpu_parity!(gpu_parity_extract_bits, extract_bits, "extract_bits");
    binary_gpu_parity!(gpu_parity_insert_bits, insert_bits, "insert_bits");

    // Arithmetic
    binary_gpu_parity!(gpu_parity_add, add, "add");
    binary_gpu_parity!(gpu_parity_sub, sub, "sub");
    binary_gpu_parity!(gpu_parity_mul, mul, "mul");
    binary_gpu_parity!(gpu_parity_div, div, "div");
    binary_gpu_parity!(gpu_parity_mod, mod_op, "mod");
    binary_gpu_parity!(gpu_parity_min, min, "min");
    binary_gpu_parity!(gpu_parity_max, max, "max");
    binary_gpu_parity!(gpu_parity_clamp, clamp, "clamp");
    unary_gpu_parity!(gpu_parity_abs, abs, "abs");
    unary_gpu_parity!(gpu_parity_negate, negate, "negate");

    // Comparison
    binary_gpu_parity!(gpu_parity_eq, eq, "eq");
    binary_gpu_parity!(gpu_parity_ne, ne, "ne");
    binary_gpu_parity!(gpu_parity_lt, lt, "lt");
    binary_gpu_parity!(gpu_parity_gt, gt, "gt");
    binary_gpu_parity!(gpu_parity_le, le, "le");
    binary_gpu_parity!(gpu_parity_ge, ge, "ge");
    binary_gpu_parity!(gpu_parity_select, select, "select");
    unary_gpu_parity!(gpu_parity_logical_not, logical_not, "logical_not");

    // ── Exhaustive u8 (slower, runs nightly) ────────────────────────────

    binary_exhaustive_u8!(exhaustive_u8_xor, xor, "xor");
    binary_exhaustive_u8!(exhaustive_u8_and, and, "and");
    binary_exhaustive_u8!(exhaustive_u8_or, or, "or");
    unary_exhaustive_u8!(exhaustive_u8_not, not, "not");
    binary_exhaustive_u8!(exhaustive_u8_shl, shl, "shl");
    binary_exhaustive_u8!(exhaustive_u8_shr, shr, "shr");
    binary_exhaustive_u8!(exhaustive_u8_add, add, "add");
    binary_exhaustive_u8!(exhaustive_u8_sub, sub, "sub");
    binary_exhaustive_u8!(exhaustive_u8_mul, mul, "mul");
    binary_exhaustive_u8!(exhaustive_u8_div, div, "div");
    binary_exhaustive_u8!(exhaustive_u8_mod, mod_op, "mod");
    binary_exhaustive_u8!(exhaustive_u8_eq, eq, "eq");
    binary_exhaustive_u8!(exhaustive_u8_ne, ne, "ne");
    binary_exhaustive_u8!(exhaustive_u8_lt, lt, "lt");
    binary_exhaustive_u8!(exhaustive_u8_gt, gt, "gt");
    binary_exhaustive_u8!(exhaustive_u8_le, le, "le");
    binary_exhaustive_u8!(exhaustive_u8_ge, ge, "ge");
    unary_exhaustive_u8!(exhaustive_u8_popcount, popcount, "popcount");
    unary_exhaustive_u8!(exhaustive_u8_clz, clz, "clz");
    unary_exhaustive_u8!(exhaustive_u8_ctz, ctz, "ctz");
    unary_exhaustive_u8!(exhaustive_u8_reverse_bits, reverse_bits, "reverse_bits");
    unary_exhaustive_u8!(exhaustive_u8_negate, negate, "negate");
    unary_exhaustive_u8!(exhaustive_u8_abs, abs, "abs");
    unary_exhaustive_u8!(exhaustive_u8_logical_not, logical_not, "logical_not");
}