vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
use super::ops::is_unsupported_by_backend_message;
use super::ParitySummary;
use crate::generate::generators::default_generators;
use crate::pipeline::execution::{execute_op, regression_inputs, seed_from};
use crate::pipeline::streaming::StreamingRunner;
use crate::proof::comparator::{Comparator, ExactMatch};
use crate::spec::types::{OpSpec, ParityFailure};
use crate::verify::budget::{budget_for_op, with_exec_budget, Archetype, BudgetTracker};

pub(super) fn run_parity(backend: &dyn vyre::VyreBackend, spec: &OpSpec) -> ParitySummary {
    let inputs = parity_inputs(spec);
    let workgroup_sizes = workgroup_sizes(spec.workgroup_size);
    let mut cases_tested = 0_u64;
    let mut failures = Vec::new();
    let mut unsupported = false;

    run_parity_for_wgsl(
        backend,
        spec,
        &inputs,
        &workgroup_sizes,
        &mut cases_tested,
        &mut failures,
        &mut unsupported,
    );
    if unsupported {
        return ParitySummary {
            cases_tested,
            failures,
            unsupported,
        };
    }
    for (_label, alt_wgsl_fn) in &spec.alt_wgsl_fns {
        let alt_spec = OpSpec {
            wgsl_fn: *alt_wgsl_fn,
            alt_wgsl_fns: Vec::new(),
            version_history: spec.version_history.clone(),
            ..spec.clone()
        };
        run_parity_for_wgsl(
            backend,
            &alt_spec,
            &inputs,
            &workgroup_sizes,
            &mut cases_tested,
            &mut failures,
            &mut unsupported,
        );
        if unsupported {
            break;
        }
    }

    ParitySummary {
        cases_tested,
        failures,
        unsupported,
    }
}

pub(super) fn run_streaming_parity(
    backend: &dyn vyre::VyreBackend,
    spec: &OpSpec,
) -> ParitySummary {
    let (summary, failures) = StreamingRunner::new()
        .batch_size(1024)
        .run_with_summary(backend, std::slice::from_ref(spec));
    let unsupported = failures
        .iter()
        .any(|failure| is_unsupported_by_backend_message(&failure.message));
    ParitySummary {
        cases_tested: summary.tested,
        failures,
        unsupported,
    }
}

fn run_parity_for_wgsl(
    backend: &dyn vyre::VyreBackend,
    spec: &OpSpec,
    inputs: &[(String, String, Vec<u8>)],
    workgroup_sizes: &[u32],
    cases_tested: &mut u64,
    failures: &mut Vec<ParityFailure>,
    unsupported: &mut bool,
) {
    let archetype = infer_archetype(spec.id);
    let budget = budget_for_op(spec.id, &archetype);
    let tracker = BudgetTracker::new(budget, spec.id);

    let result = with_exec_budget(tracker, || {
        for (generator, label, bytes) in inputs {
            for workgroup_size in workgroup_sizes {
                *cases_tested = cases_tested.saturating_add(2);
                match execute_deterministic_case(backend, spec, bytes, *workgroup_size) {
                    Ok(()) => {}
                    Err(error) => {
                        if is_unsupported_by_backend_message(&error.message) {
                            *unsupported = true;
                            return;
                        }
                        failures.push(parity_failure(
                            spec,
                            generator,
                            label,
                            bytes,
                            error.gpu_output,
                            error.cpu_output,
                            error.message,
                            *workgroup_size,
                        ));
                    }
                }
            }
        }
    });

    if let Err(bomb) = result {
        failures.push(parity_failure(
            spec,
            "budget",
            "reference_bomb",
            &[],
            Vec::new(),
            Vec::new(),
            bomb.to_string(),
            0,
        ));
    }
}

struct ParityCaseError {
    message: String,
    gpu_output: Vec<u8>,
    cpu_output: Vec<u8>,
}

fn execute_deterministic_case(
    backend: &dyn vyre::VyreBackend,
    spec: &OpSpec,
    bytes: &[u8],
    workgroup_size: u32,
) -> Result<(), ParityCaseError> {
    let (first_gpu, first_cpu) =
        execute_op(backend, spec, bytes, workgroup_size).map_err(|message| ParityCaseError {
            message,
            gpu_output: Vec::new(),
            cpu_output: Vec::new(),
        })?;
    let (second_gpu, second_cpu) =
        execute_op(backend, spec, bytes, workgroup_size).map_err(|message| ParityCaseError {
            message,
            gpu_output: first_gpu.clone(),
            cpu_output: first_cpu.clone(),
        })?;
    if let Err(message) = ExactMatch.compare(&first_gpu, &second_gpu) {
        return Err(ParityCaseError {
            message: format!(
                "nondeterministic GPU output between identical dispatches: {message} \
                 Fix: remove data races, uninitialized reads, or workgroup-order dependence."
            ),
            gpu_output: second_gpu,
            cpu_output: first_gpu,
        });
    }
    if let Err(message) = spec.comparator.compare(&first_gpu, &first_cpu) {
        return Err(ParityCaseError {
            message,
            gpu_output: first_gpu,
            cpu_output: first_cpu,
        });
    }
    if let Err(message) = spec.comparator.compare(&second_gpu, &second_cpu) {
        return Err(ParityCaseError {
            message,
            gpu_output: second_gpu,
            cpu_output: second_cpu,
        });
    }
    Ok(())
}

fn parity_inputs(spec: &OpSpec) -> Vec<(String, String, Vec<u8>)> {
    let mut inputs: Vec<_> = regression_inputs(spec.id)
        .into_iter()
        .map(|case| (case.generator, case.label, case.bytes))
        .collect();
    let seed = seed_from(spec.id);
    for generator in default_generators() {
        if generator.handles(&spec.signature) {
            for (label, bytes) in generator.generate_for_op(spec.id, &spec.signature, seed) {
                inputs.push((generator.name().to_string(), label, bytes));
            }
        }
    }
    inputs
}

fn parity_failure(
    spec: &OpSpec,
    generator: &str,
    label: &str,
    input: &[u8],
    gpu_output: Vec<u8>,
    cpu_output: Vec<u8>,
    message: String,
    workgroup_size: u32,
) -> ParityFailure {
    ParityFailure {
        op_id: spec.id.to_string(),
        generator: generator.to_string(),
        input_label: label.to_string(),
        input: input.to_vec(),
        gpu_output,
        cpu_output,
        message,
        spec_version: spec.version,
        workgroup_size,
    }
}

fn workgroup_sizes(preferred: Option<u32>) -> Vec<u32> {
    let mut sizes = vec![1, 64];
    if let Some(size) = preferred {
        if size != 0 {
            sizes.push(size);
        }
    }
    sizes.sort_unstable();
    sizes.dedup();
    sizes
}

fn infer_archetype(op_id: &str) -> Archetype {
    if op_id.contains("hash") && op_id.contains("u32") {
        Archetype("hash-bytes-to-u32")
    } else if op_id.contains("hash") && op_id.contains("u64") {
        Archetype("hash-bytes-to-u64")
    } else if op_id.contains("decode") {
        Archetype("decode-bytes-to-bytes")
    } else if op_id.contains("compress") {
        Archetype("compression-bytes-to-bytes")
    } else {
        Archetype("unknown")
    }
}